From 1843ff3fd0c27fcf1611ae7b3d75a7011335e753 Mon Sep 17 00:00:00 2001 From: "Janet L. Goldstein" <jgoldste@nrao.edu> Date: Mon, 22 Nov 2021 09:39:08 -0700 Subject: [PATCH] Core sampler: make the -e option accept an SDM name rather than an execution_block_id --- .../core_sampler/core_sampler/core_sampler.py | 40 ++++++++++++----- .../core_sampler/test/test_core_sampler.py | 43 ++++++++++++++++--- 2 files changed, 65 insertions(+), 18 deletions(-) diff --git a/apps/cli/utilities/core_sampler/core_sampler/core_sampler.py b/apps/cli/utilities/core_sampler/core_sampler/core_sampler.py index bd4c5deb9..c756f0f8f 100644 --- a/apps/cli/utilities/core_sampler/core_sampler/core_sampler.py +++ b/apps/cli/utilities/core_sampler/core_sampler/core_sampler.py @@ -28,9 +28,6 @@ import argparse import psycopg2 as pg from psycopg2 import extras - -# pylint: disable=C0103, E0402, R0201 - from pycapo import CapoConfig from .database import PGTable @@ -41,6 +38,8 @@ from .row_writer import ( UniquifyingRowWriter, ) +# pylint: disable=C0103, E0402, R0201, R0903 + # stolen shamelessly from aat_wrest class MDDBConnector: """Use this connection to interrogate this science product locator""" @@ -105,16 +104,18 @@ class CoreSampler: self.save(requested) self.writer.close() - def sample_eb(self, exec_block_id: int): + def sample_eb(self, sdm_name: str): """ Pull execution block metadata from the archive database. - :param exec_block_id: execution_block_id of interest + :param sdm_name: SDM of interest :return: """ ebs = self.table("execution_blocks") - requested = ebs.fetch({"execution_block_id": exec_block_id}) + finder = ExecBlockFinder(self.connection) + eb_id = finder.find_eb_id(sdm_name) + requested = ebs.fetch({"execution_block_id": eb_id}) self.save(requested) self.writer.close() @@ -155,6 +156,23 @@ class CoreSampler: rows.write_to(self.writer) +class ExecBlockFinder: + """Looks up execution block ID for an SDM""" + + def __init__(self, connection: MDDBConnector): + self.connection = connection + + def find_eb_id(self, sdm_name: str) -> int: + cursor = self.connection.cursor() + sql = """ +SELECT execution_block_id FROM execution_blocks +WHERE ngas_fileset_id=%(sdm_name)s + """ + cursor.execute(sql, {"sdm_name": sdm_name}) + data = cursor.fetchone() + return data["execution_block_id"] + + def main(): parser = argparse.ArgumentParser() group = parser.add_mutually_exclusive_group(required=True) @@ -163,18 +181,18 @@ def main(): ) group.add_argument( "-e", - "--exec_block_id", - type=int, + "--sdm_name", + type=str, nargs=1, - help="execution_block_id of SDM for which to pull a core sample", + help="name of SDM for which to pull a core sample, e.g., 21A-409.sb39530397.eb39561636.59309.07888592593", action="store", ) ns = parser.parse_args() if ns.project_code: CoreSampler(MDDBConnector()).sample_project(ns.project_code[0]) - elif ns.exec_block_id: - CoreSampler(MDDBConnector()).sample_eb(int(ns.exec_block_id[0])) + elif ns.sdm_name: + CoreSampler(MDDBConnector()).sample_eb(ns.sdm_name[0]) return 0 diff --git a/apps/cli/utilities/core_sampler/test/test_core_sampler.py b/apps/cli/utilities/core_sampler/test/test_core_sampler.py index e1623fcc2..0651b0b89 100644 --- a/apps/cli/utilities/core_sampler/test/test_core_sampler.py +++ b/apps/cli/utilities/core_sampler/test/test_core_sampler.py @@ -16,31 +16,49 @@ # You should have received a copy of the GNU General Public License # along with Workspaces. If not, see <https://www.gnu.org/licenses/>. # -""" Tests for core sampler""" +""" Tests for core sampler """ +import logging import sys from io import StringIO +from logging import getLogger from unittest.mock import patch -import core_sampler -import pytest from core_sampler.core_sampler import main +logger = logging.getLogger("core_sampler") +logger.setLevel(logging.INFO) +logger.addHandler(logging.StreamHandler(sys.stdout)) + +import pytest + def test_gets_eb_as_expected(): """ - Do we get the output we expect for this 20A-465 execution block? + Do we get the output we expect for this 21A-409 execution block? :return: """ sys.argv.append("-e") - sys.argv.append("148839") + sys.argv.append("21A-409.sb39530397.eb39561636.59309.07888592593") with patch("sys.stdout", new=StringIO()) as fake_out: main() output = fake_out.getvalue() rows = output.split("\n") - assert len(rows) == 86 + + found = False + for row in rows: + if "156759763" in row and "23142018" in row: + found = True + break + assert found + + if len(rows) == 181: + # sometimes we get an extra blank line + assert len(rows[-1]) == 0 + else: + assert len(rows) == 180 sys.argv.pop() sys.argv.pop() @@ -59,6 +77,17 @@ def test_gets_project_as_expected(): main() output = fake_out.getvalue() rows = output.split("\n") - assert len(rows) == 704 + if len(rows) == 705: + # sometimes we get an extra blank line + assert len(rows[-1]) == 0 + else: + assert len(rows) == 704 + found = False + for row in rows: + if "uid____evla_bdf_1589728351957.bdf" in row and "121847" in row: + found = True + break + assert found + sys.argv.pop() sys.argv.pop() -- GitLab