From 425baa3143f219baed840ebfd6926931d44a2d8e Mon Sep 17 00:00:00 2001 From: "Janet L. Goldstein" <jgoldste@nrao.edu> Date: Fri, 19 Nov 2021 16:59:05 -0700 Subject: [PATCH] Core sampler: modified to accept either a project code or an execution_block_id --- .../core_sampler/core_sampler/core_sampler.py | 48 ++++++++++++-- .../core_sampler/test/test_core_sampler.py | 64 +++++++++++++++++++ 2 files changed, 105 insertions(+), 7 deletions(-) create mode 100644 apps/cli/utilities/core_sampler/test/test_core_sampler.py diff --git a/apps/cli/utilities/core_sampler/core_sampler/core_sampler.py b/apps/cli/utilities/core_sampler/core_sampler/core_sampler.py index ca780b549..bd4c5deb9 100644 --- a/apps/cli/utilities/core_sampler/core_sampler/core_sampler.py +++ b/apps/cli/utilities/core_sampler/core_sampler/core_sampler.py @@ -27,7 +27,10 @@ The core sampler outputs an SQL file you can use to load the core sample into a import argparse import psycopg2 as pg -import psycopg2.extras as extras +from psycopg2 import extras + +# pylint: disable=C0103, E0402, R0201 + from pycapo import CapoConfig from .database import PGTable @@ -38,7 +41,6 @@ from .row_writer import ( UniquifyingRowWriter, ) - # stolen shamelessly from aat_wrest class MDDBConnector: """Use this connection to interrogate this science product locator""" @@ -89,16 +91,33 @@ class CoreSampler: self.writer = TopologicallySortingRowWriter(UniquifyingRowWriter(PostgresCopyRowWriter())) self.visited = set() - def sample(self, project_code: str): + def sample_project(self, project_code: str): """ - Sample the database. + Get project metadata from the archive database. + + :param project_code: project code of interest + :return: """ + # the first time through, we select from the projects table and get that row projects = self.table("projects") requested = projects.fetch({"project_code": project_code}) self.save(requested) self.writer.close() + def sample_eb(self, exec_block_id: int): + """ + Pull execution block metadata from the archive database. + + :param exec_block_id: execution_block_id of interest + :return: + """ + ebs = self.table("execution_blocks") + + requested = ebs.fetch({"execution_block_id": exec_block_id}) + self.save(requested) + self.writer.close() + def save(self, rows: "RowSet"): """ Save some rows, and then go and fetch their related rows and save them too, recursively. @@ -138,10 +157,25 @@ class CoreSampler: def main(): parser = argparse.ArgumentParser() - parser.add_argument("project_code", type=str, help="Project code to start core sampling from") - + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument( + "-p", "--project_code", type=str, nargs=1, help="Project code from which to start core sampling", action="store" + ) + group.add_argument( + "-e", + "--exec_block_id", + type=int, + nargs=1, + help="execution_block_id of SDM for which to pull a core sample", + action="store", + ) ns = parser.parse_args() - CoreSampler(MDDBConnector()).sample(ns.project_code) + + if ns.project_code: + CoreSampler(MDDBConnector()).sample_project(ns.project_code[0]) + elif ns.exec_block_id: + CoreSampler(MDDBConnector()).sample_eb(int(ns.exec_block_id[0])) + return 0 diff --git a/apps/cli/utilities/core_sampler/test/test_core_sampler.py b/apps/cli/utilities/core_sampler/test/test_core_sampler.py new file mode 100644 index 000000000..e1623fcc2 --- /dev/null +++ b/apps/cli/utilities/core_sampler/test/test_core_sampler.py @@ -0,0 +1,64 @@ +# +# Copyright (C) 2021 Associated Universities, Inc. Washington DC, USA. +# +# This file is part of NRAO Workspaces. +# +# Workspaces is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Workspaces is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Workspaces. If not, see <https://www.gnu.org/licenses/>. +# +""" Tests for core sampler""" + +import sys +from io import StringIO +from unittest.mock import patch + +import core_sampler +import pytest +from core_sampler.core_sampler import main + + +def test_gets_eb_as_expected(): + """ + Do we get the output we expect for this 20A-465 execution block? + + :return: + """ + sys.argv.append("-e") + sys.argv.append("148839") + + with patch("sys.stdout", new=StringIO()) as fake_out: + main() + output = fake_out.getvalue() + rows = output.split("\n") + assert len(rows) == 86 + + sys.argv.pop() + sys.argv.pop() + + +def test_gets_project_as_expected(): + """ + Do we get the output we expect for 20A-465? + + :return: + """ + sys.argv.append("-p") + sys.argv.append("20A-465") + + with patch("sys.stdout", new=StringIO()) as fake_out: + main() + output = fake_out.getvalue() + rows = output.split("\n") + assert len(rows) == 704 + sys.argv.pop() + sys.argv.pop() -- GitLab