From 425baa3143f219baed840ebfd6926931d44a2d8e Mon Sep 17 00:00:00 2001
From: "Janet L. Goldstein" <jgoldste@nrao.edu>
Date: Fri, 19 Nov 2021 16:59:05 -0700
Subject: [PATCH] Core sampler: modified to accept either a project code or an
 execution_block_id

---
 .../core_sampler/core_sampler/core_sampler.py | 48 ++++++++++++--
 .../core_sampler/test/test_core_sampler.py    | 64 +++++++++++++++++++
 2 files changed, 105 insertions(+), 7 deletions(-)
 create mode 100644 apps/cli/utilities/core_sampler/test/test_core_sampler.py

diff --git a/apps/cli/utilities/core_sampler/core_sampler/core_sampler.py b/apps/cli/utilities/core_sampler/core_sampler/core_sampler.py
index ca780b549..bd4c5deb9 100644
--- a/apps/cli/utilities/core_sampler/core_sampler/core_sampler.py
+++ b/apps/cli/utilities/core_sampler/core_sampler/core_sampler.py
@@ -27,7 +27,10 @@ The core sampler outputs an SQL file you can use to load the core sample into a
 import argparse
 
 import psycopg2 as pg
-import psycopg2.extras as extras
+from psycopg2 import extras
+
+# pylint: disable=C0103, E0402, R0201
+
 from pycapo import CapoConfig
 
 from .database import PGTable
@@ -38,7 +41,6 @@ from .row_writer import (
     UniquifyingRowWriter,
 )
 
-
 # stolen shamelessly from aat_wrest
 class MDDBConnector:
     """Use this connection to interrogate this science product locator"""
@@ -89,16 +91,33 @@ class CoreSampler:
         self.writer = TopologicallySortingRowWriter(UniquifyingRowWriter(PostgresCopyRowWriter()))
         self.visited = set()
 
-    def sample(self, project_code: str):
+    def sample_project(self, project_code: str):
         """
-        Sample the database.
+        Get project metadata from the archive database.
+
+        :param project_code: project code of interest
+        :return:
         """
+
         # the first time through, we select from the projects table and get that row
         projects = self.table("projects")
         requested = projects.fetch({"project_code": project_code})
         self.save(requested)
         self.writer.close()
 
+    def sample_eb(self, exec_block_id: int):
+        """
+        Pull execution block metadata from the archive database.
+
+        :param exec_block_id: execution_block_id of interest
+        :return:
+        """
+        ebs = self.table("execution_blocks")
+
+        requested = ebs.fetch({"execution_block_id": exec_block_id})
+        self.save(requested)
+        self.writer.close()
+
     def save(self, rows: "RowSet"):
         """
         Save some rows, and then go and fetch their related rows and save them too, recursively.
@@ -138,10 +157,25 @@ class CoreSampler:
 
 def main():
     parser = argparse.ArgumentParser()
-    parser.add_argument("project_code", type=str, help="Project code to start core sampling from")
-
+    group = parser.add_mutually_exclusive_group(required=True)
+    group.add_argument(
+        "-p", "--project_code", type=str, nargs=1, help="Project code from which to start core sampling", action="store"
+    )
+    group.add_argument(
+        "-e",
+        "--exec_block_id",
+        type=int,
+        nargs=1,
+        help="execution_block_id of SDM for which to pull a core sample",
+        action="store",
+    )
     ns = parser.parse_args()
-    CoreSampler(MDDBConnector()).sample(ns.project_code)
+
+    if ns.project_code:
+        CoreSampler(MDDBConnector()).sample_project(ns.project_code[0])
+    elif ns.exec_block_id:
+        CoreSampler(MDDBConnector()).sample_eb(int(ns.exec_block_id[0]))
+
     return 0
 
 
diff --git a/apps/cli/utilities/core_sampler/test/test_core_sampler.py b/apps/cli/utilities/core_sampler/test/test_core_sampler.py
new file mode 100644
index 000000000..e1623fcc2
--- /dev/null
+++ b/apps/cli/utilities/core_sampler/test/test_core_sampler.py
@@ -0,0 +1,64 @@
+#
+# Copyright (C) 2021 Associated Universities, Inc. Washington DC, USA.
+#
+# This file is part of NRAO Workspaces.
+#
+# Workspaces is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Workspaces is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Workspaces.  If not, see <https://www.gnu.org/licenses/>.
+#
+""" Tests for core sampler"""
+
+import sys
+from io import StringIO
+from unittest.mock import patch
+
+import core_sampler
+import pytest
+from core_sampler.core_sampler import main
+
+
+def test_gets_eb_as_expected():
+    """
+    Do we get the output we expect for this 20A-465 execution block?
+
+    :return:
+    """
+    sys.argv.append("-e")
+    sys.argv.append("148839")
+
+    with patch("sys.stdout", new=StringIO()) as fake_out:
+        main()
+        output = fake_out.getvalue()
+        rows = output.split("\n")
+        assert len(rows) == 86
+
+    sys.argv.pop()
+    sys.argv.pop()
+
+
+def test_gets_project_as_expected():
+    """
+    Do we get the output we expect for 20A-465?
+
+    :return:
+    """
+    sys.argv.append("-p")
+    sys.argv.append("20A-465")
+
+    with patch("sys.stdout", new=StringIO()) as fake_out:
+        main()
+        output = fake_out.getvalue()
+        rows = output.split("\n")
+        assert len(rows) == 704
+    sys.argv.pop()
+    sys.argv.pop()
-- 
GitLab