Skip to content
Snippets Groups Projects
Commit 1843ff3f authored by Janet Goldstein's avatar Janet Goldstein
Browse files

Core sampler: make the -e option accept an SDM name rather than an execution_block_id

parent 425baa31
No related branches found
No related tags found
1 merge request!666Core sampler: add ability to accept either a project code or an SDM name
Pipeline #3717 passed
...@@ -28,9 +28,6 @@ import argparse ...@@ -28,9 +28,6 @@ import argparse
import psycopg2 as pg import psycopg2 as pg
from psycopg2 import extras from psycopg2 import extras
# pylint: disable=C0103, E0402, R0201
from pycapo import CapoConfig from pycapo import CapoConfig
from .database import PGTable from .database import PGTable
...@@ -41,6 +38,8 @@ from .row_writer import ( ...@@ -41,6 +38,8 @@ from .row_writer import (
UniquifyingRowWriter, UniquifyingRowWriter,
) )
# pylint: disable=C0103, E0402, R0201, R0903
# stolen shamelessly from aat_wrest # stolen shamelessly from aat_wrest
class MDDBConnector: class MDDBConnector:
"""Use this connection to interrogate this science product locator""" """Use this connection to interrogate this science product locator"""
...@@ -105,16 +104,18 @@ class CoreSampler: ...@@ -105,16 +104,18 @@ class CoreSampler:
self.save(requested) self.save(requested)
self.writer.close() self.writer.close()
def sample_eb(self, exec_block_id: int): def sample_eb(self, sdm_name: str):
""" """
Pull execution block metadata from the archive database. Pull execution block metadata from the archive database.
:param exec_block_id: execution_block_id of interest :param sdm_name: SDM of interest
:return: :return:
""" """
ebs = self.table("execution_blocks") ebs = self.table("execution_blocks")
requested = ebs.fetch({"execution_block_id": exec_block_id}) finder = ExecBlockFinder(self.connection)
eb_id = finder.find_eb_id(sdm_name)
requested = ebs.fetch({"execution_block_id": eb_id})
self.save(requested) self.save(requested)
self.writer.close() self.writer.close()
...@@ -155,6 +156,23 @@ class CoreSampler: ...@@ -155,6 +156,23 @@ class CoreSampler:
rows.write_to(self.writer) rows.write_to(self.writer)
class ExecBlockFinder:
"""Looks up execution block ID for an SDM"""
def __init__(self, connection: MDDBConnector):
self.connection = connection
def find_eb_id(self, sdm_name: str) -> int:
cursor = self.connection.cursor()
sql = """
SELECT execution_block_id FROM execution_blocks
WHERE ngas_fileset_id=%(sdm_name)s
"""
cursor.execute(sql, {"sdm_name": sdm_name})
data = cursor.fetchone()
return data["execution_block_id"]
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
group = parser.add_mutually_exclusive_group(required=True) group = parser.add_mutually_exclusive_group(required=True)
...@@ -163,18 +181,18 @@ def main(): ...@@ -163,18 +181,18 @@ def main():
) )
group.add_argument( group.add_argument(
"-e", "-e",
"--exec_block_id", "--sdm_name",
type=int, type=str,
nargs=1, nargs=1,
help="execution_block_id of SDM for which to pull a core sample", help="name of SDM for which to pull a core sample, e.g., 21A-409.sb39530397.eb39561636.59309.07888592593",
action="store", action="store",
) )
ns = parser.parse_args() ns = parser.parse_args()
if ns.project_code: if ns.project_code:
CoreSampler(MDDBConnector()).sample_project(ns.project_code[0]) CoreSampler(MDDBConnector()).sample_project(ns.project_code[0])
elif ns.exec_block_id: elif ns.sdm_name:
CoreSampler(MDDBConnector()).sample_eb(int(ns.exec_block_id[0])) CoreSampler(MDDBConnector()).sample_eb(ns.sdm_name[0])
return 0 return 0
......
...@@ -16,31 +16,49 @@ ...@@ -16,31 +16,49 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with Workspaces. If not, see <https://www.gnu.org/licenses/>. # along with Workspaces. If not, see <https://www.gnu.org/licenses/>.
# #
""" Tests for core sampler""" """ Tests for core sampler """
import logging
import sys import sys
from io import StringIO from io import StringIO
from logging import getLogger
from unittest.mock import patch from unittest.mock import patch
import core_sampler
import pytest
from core_sampler.core_sampler import main from core_sampler.core_sampler import main
logger = logging.getLogger("core_sampler")
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler(sys.stdout))
import pytest
def test_gets_eb_as_expected(): def test_gets_eb_as_expected():
""" """
Do we get the output we expect for this 20A-465 execution block? Do we get the output we expect for this 21A-409 execution block?
:return: :return:
""" """
sys.argv.append("-e") sys.argv.append("-e")
sys.argv.append("148839") sys.argv.append("21A-409.sb39530397.eb39561636.59309.07888592593")
with patch("sys.stdout", new=StringIO()) as fake_out: with patch("sys.stdout", new=StringIO()) as fake_out:
main() main()
output = fake_out.getvalue() output = fake_out.getvalue()
rows = output.split("\n") rows = output.split("\n")
assert len(rows) == 86
found = False
for row in rows:
if "156759763" in row and "23142018" in row:
found = True
break
assert found
if len(rows) == 181:
# sometimes we get an extra blank line
assert len(rows[-1]) == 0
else:
assert len(rows) == 180
sys.argv.pop() sys.argv.pop()
sys.argv.pop() sys.argv.pop()
...@@ -59,6 +77,17 @@ def test_gets_project_as_expected(): ...@@ -59,6 +77,17 @@ def test_gets_project_as_expected():
main() main()
output = fake_out.getvalue() output = fake_out.getvalue()
rows = output.split("\n") rows = output.split("\n")
assert len(rows) == 704 if len(rows) == 705:
# sometimes we get an extra blank line
assert len(rows[-1]) == 0
else:
assert len(rows) == 704
found = False
for row in rows:
if "uid____evla_bdf_1589728351957.bdf" in row and "121847" in row:
found = True
break
assert found
sys.argv.pop() sys.argv.pop()
sys.argv.pop() sys.argv.pop()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment