Skip to content
Snippets Groups Projects
Commit a3ce86a5 authored by Janet Goldstein's avatar Janet Goldstein
Browse files

Core sampler: add ability to accept either a project code or an SDM name

parent e85ddc9b
No related branches found
No related tags found
1 merge request!666Core sampler: add ability to accept either a project code or an SDM name
Pipeline #3718 passed
......@@ -27,7 +27,7 @@ The core sampler outputs an SQL file you can use to load the core sample into a
import argparse
import psycopg2 as pg
import psycopg2.extras as extras
from psycopg2 import extras
from pycapo import CapoConfig
from .database import PGTable
......@@ -38,6 +38,7 @@ from .row_writer import (
UniquifyingRowWriter,
)
# pylint: disable=C0103, E0402, R0201, R0903
# stolen shamelessly from aat_wrest
class MDDBConnector:
......@@ -89,16 +90,35 @@ class CoreSampler:
self.writer = TopologicallySortingRowWriter(UniquifyingRowWriter(PostgresCopyRowWriter()))
self.visited = set()
def sample(self, project_code: str):
def sample_project(self, project_code: str):
"""
Sample the database.
Get project metadata from the archive database.
:param project_code: project code of interest
:return:
"""
# the first time through, we select from the projects table and get that row
projects = self.table("projects")
requested = projects.fetch({"project_code": project_code})
self.save(requested)
self.writer.close()
def sample_eb(self, sdm_name: str):
"""
Pull execution block metadata from the archive database.
:param sdm_name: SDM of interest
:return:
"""
ebs = self.table("execution_blocks")
finder = ExecBlockFinder(self.connection)
eb_id = finder.find_eb_id(sdm_name)
requested = ebs.fetch({"execution_block_id": eb_id})
self.save(requested)
self.writer.close()
def save(self, rows: "RowSet"):
"""
Save some rows, and then go and fetch their related rows and save them too, recursively.
......@@ -136,12 +156,44 @@ class CoreSampler:
rows.write_to(self.writer)
class ExecBlockFinder:
"""Looks up execution block ID for an SDM"""
def __init__(self, connection: MDDBConnector):
self.connection = connection
def find_eb_id(self, sdm_name: str) -> int:
cursor = self.connection.cursor()
sql = """
SELECT execution_block_id FROM execution_blocks
WHERE ngas_fileset_id=%(sdm_name)s
"""
cursor.execute(sql, {"sdm_name": sdm_name})
data = cursor.fetchone()
return data["execution_block_id"]
def main():
parser = argparse.ArgumentParser()
parser.add_argument("project_code", type=str, help="Project code to start core sampling from")
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"-p", "--project_code", type=str, nargs=1, help="Project code from which to start core sampling", action="store"
)
group.add_argument(
"-e",
"--sdm_name",
type=str,
nargs=1,
help="name of SDM for which to pull a core sample, e.g., 21A-409.sb39530397.eb39561636.59309.07888592593",
action="store",
)
ns = parser.parse_args()
CoreSampler(MDDBConnector()).sample(ns.project_code)
if ns.project_code:
CoreSampler(MDDBConnector()).sample_project(ns.project_code[0])
elif ns.sdm_name:
CoreSampler(MDDBConnector()).sample_eb(ns.sdm_name[0])
return 0
......
#
# Copyright (C) 2021 Associated Universities, Inc. Washington DC, USA.
#
# This file is part of NRAO Workspaces.
#
# Workspaces is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Workspaces is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Workspaces. If not, see <https://www.gnu.org/licenses/>.
#
""" Tests for core sampler """
import logging
import sys
from io import StringIO
from logging import getLogger
from unittest.mock import patch
from core_sampler.core_sampler import main
logger = logging.getLogger("core_sampler")
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler(sys.stdout))
import pytest
def test_gets_eb_as_expected():
"""
Do we get the output we expect for this 21A-409 execution block?
:return:
"""
sys.argv.append("-e")
sys.argv.append("21A-409.sb39530397.eb39561636.59309.07888592593")
with patch("sys.stdout", new=StringIO()) as fake_out:
main()
output = fake_out.getvalue()
rows = output.split("\n")
found = False
for row in rows:
if "156759763" in row and "23142018" in row:
found = True
break
assert found
if len(rows) == 181:
# sometimes we get an extra blank line
assert len(rows[-1]) == 0
else:
assert len(rows) == 180
sys.argv.pop()
sys.argv.pop()
def test_gets_project_as_expected():
"""
Do we get the output we expect for 20A-465?
:return:
"""
sys.argv.append("-p")
sys.argv.append("20A-465")
with patch("sys.stdout", new=StringIO()) as fake_out:
main()
output = fake_out.getvalue()
rows = output.split("\n")
if len(rows) == 705:
# sometimes we get an extra blank line
assert len(rows[-1]) == 0
else:
assert len(rows) == 704
found = False
for row in rows:
if "uid____evla_bdf_1589728351957.bdf" in row and "121847" in row:
found = True
break
assert found
sys.argv.pop()
sys.argv.pop()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment