Skip to content
Snippets Groups Projects
Commit 1843ff3f authored by Janet Goldstein's avatar Janet Goldstein
Browse files

Core sampler: make the -e option accept an SDM name rather than an execution_block_id

parent 425baa31
No related branches found
No related tags found
1 merge request!666Core sampler: add ability to accept either a project code or an SDM name
Pipeline #3717 passed
......@@ -28,9 +28,6 @@ import argparse
import psycopg2 as pg
from psycopg2 import extras
# pylint: disable=C0103, E0402, R0201
from pycapo import CapoConfig
from .database import PGTable
......@@ -41,6 +38,8 @@ from .row_writer import (
UniquifyingRowWriter,
)
# pylint: disable=C0103, E0402, R0201, R0903
# stolen shamelessly from aat_wrest
class MDDBConnector:
"""Use this connection to interrogate this science product locator"""
......@@ -105,16 +104,18 @@ class CoreSampler:
self.save(requested)
self.writer.close()
def sample_eb(self, exec_block_id: int):
def sample_eb(self, sdm_name: str):
"""
Pull execution block metadata from the archive database.
:param exec_block_id: execution_block_id of interest
:param sdm_name: SDM of interest
:return:
"""
ebs = self.table("execution_blocks")
requested = ebs.fetch({"execution_block_id": exec_block_id})
finder = ExecBlockFinder(self.connection)
eb_id = finder.find_eb_id(sdm_name)
requested = ebs.fetch({"execution_block_id": eb_id})
self.save(requested)
self.writer.close()
......@@ -155,6 +156,23 @@ class CoreSampler:
rows.write_to(self.writer)
class ExecBlockFinder:
"""Looks up execution block ID for an SDM"""
def __init__(self, connection: MDDBConnector):
self.connection = connection
def find_eb_id(self, sdm_name: str) -> int:
cursor = self.connection.cursor()
sql = """
SELECT execution_block_id FROM execution_blocks
WHERE ngas_fileset_id=%(sdm_name)s
"""
cursor.execute(sql, {"sdm_name": sdm_name})
data = cursor.fetchone()
return data["execution_block_id"]
def main():
parser = argparse.ArgumentParser()
group = parser.add_mutually_exclusive_group(required=True)
......@@ -163,18 +181,18 @@ def main():
)
group.add_argument(
"-e",
"--exec_block_id",
type=int,
"--sdm_name",
type=str,
nargs=1,
help="execution_block_id of SDM for which to pull a core sample",
help="name of SDM for which to pull a core sample, e.g., 21A-409.sb39530397.eb39561636.59309.07888592593",
action="store",
)
ns = parser.parse_args()
if ns.project_code:
CoreSampler(MDDBConnector()).sample_project(ns.project_code[0])
elif ns.exec_block_id:
CoreSampler(MDDBConnector()).sample_eb(int(ns.exec_block_id[0]))
elif ns.sdm_name:
CoreSampler(MDDBConnector()).sample_eb(ns.sdm_name[0])
return 0
......
......@@ -16,31 +16,49 @@
# You should have received a copy of the GNU General Public License
# along with Workspaces. If not, see <https://www.gnu.org/licenses/>.
#
""" Tests for core sampler"""
""" Tests for core sampler """
import logging
import sys
from io import StringIO
from logging import getLogger
from unittest.mock import patch
import core_sampler
import pytest
from core_sampler.core_sampler import main
logger = logging.getLogger("core_sampler")
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler(sys.stdout))
import pytest
def test_gets_eb_as_expected():
"""
Do we get the output we expect for this 20A-465 execution block?
Do we get the output we expect for this 21A-409 execution block?
:return:
"""
sys.argv.append("-e")
sys.argv.append("148839")
sys.argv.append("21A-409.sb39530397.eb39561636.59309.07888592593")
with patch("sys.stdout", new=StringIO()) as fake_out:
main()
output = fake_out.getvalue()
rows = output.split("\n")
assert len(rows) == 86
found = False
for row in rows:
if "156759763" in row and "23142018" in row:
found = True
break
assert found
if len(rows) == 181:
# sometimes we get an extra blank line
assert len(rows[-1]) == 0
else:
assert len(rows) == 180
sys.argv.pop()
sys.argv.pop()
......@@ -59,6 +77,17 @@ def test_gets_project_as_expected():
main()
output = fake_out.getvalue()
rows = output.split("\n")
assert len(rows) == 704
if len(rows) == 705:
# sometimes we get an extra blank line
assert len(rows[-1]) == 0
else:
assert len(rows) == 704
found = False
for row in rows:
if "uid____evla_bdf_1589728351957.bdf" in row and "121847" in row:
found = True
break
assert found
sys.argv.pop()
sys.argv.pop()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment