Skip to content
Snippets Groups Projects
Commit 425baa31 authored by Janet Goldstein's avatar Janet Goldstein
Browse files

Core sampler: modified to accept either a project code or an execution_block_id

parent e85ddc9b
No related branches found
No related tags found
1 merge request!666Core sampler: add ability to accept either a project code or an SDM name
Pipeline #3716 passed
This commit is part of merge request !666. Comments created here will be created in the context of that merge request.
...@@ -27,7 +27,10 @@ The core sampler outputs an SQL file you can use to load the core sample into a ...@@ -27,7 +27,10 @@ The core sampler outputs an SQL file you can use to load the core sample into a
import argparse import argparse
import psycopg2 as pg import psycopg2 as pg
import psycopg2.extras as extras from psycopg2 import extras
# pylint: disable=C0103, E0402, R0201
from pycapo import CapoConfig from pycapo import CapoConfig
from .database import PGTable from .database import PGTable
...@@ -38,7 +41,6 @@ from .row_writer import ( ...@@ -38,7 +41,6 @@ from .row_writer import (
UniquifyingRowWriter, UniquifyingRowWriter,
) )
# stolen shamelessly from aat_wrest # stolen shamelessly from aat_wrest
class MDDBConnector: class MDDBConnector:
"""Use this connection to interrogate this science product locator""" """Use this connection to interrogate this science product locator"""
...@@ -89,16 +91,33 @@ class CoreSampler: ...@@ -89,16 +91,33 @@ class CoreSampler:
self.writer = TopologicallySortingRowWriter(UniquifyingRowWriter(PostgresCopyRowWriter())) self.writer = TopologicallySortingRowWriter(UniquifyingRowWriter(PostgresCopyRowWriter()))
self.visited = set() self.visited = set()
def sample(self, project_code: str): def sample_project(self, project_code: str):
""" """
Sample the database. Get project metadata from the archive database.
:param project_code: project code of interest
:return:
""" """
# the first time through, we select from the projects table and get that row # the first time through, we select from the projects table and get that row
projects = self.table("projects") projects = self.table("projects")
requested = projects.fetch({"project_code": project_code}) requested = projects.fetch({"project_code": project_code})
self.save(requested) self.save(requested)
self.writer.close() self.writer.close()
def sample_eb(self, exec_block_id: int):
"""
Pull execution block metadata from the archive database.
:param exec_block_id: execution_block_id of interest
:return:
"""
ebs = self.table("execution_blocks")
requested = ebs.fetch({"execution_block_id": exec_block_id})
self.save(requested)
self.writer.close()
def save(self, rows: "RowSet"): def save(self, rows: "RowSet"):
""" """
Save some rows, and then go and fetch their related rows and save them too, recursively. Save some rows, and then go and fetch their related rows and save them too, recursively.
...@@ -138,10 +157,25 @@ class CoreSampler: ...@@ -138,10 +157,25 @@ class CoreSampler:
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("project_code", type=str, help="Project code to start core sampling from") group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"-p", "--project_code", type=str, nargs=1, help="Project code from which to start core sampling", action="store"
)
group.add_argument(
"-e",
"--exec_block_id",
type=int,
nargs=1,
help="execution_block_id of SDM for which to pull a core sample",
action="store",
)
ns = parser.parse_args() ns = parser.parse_args()
CoreSampler(MDDBConnector()).sample(ns.project_code)
if ns.project_code:
CoreSampler(MDDBConnector()).sample_project(ns.project_code[0])
elif ns.exec_block_id:
CoreSampler(MDDBConnector()).sample_eb(int(ns.exec_block_id[0]))
return 0 return 0
......
#
# Copyright (C) 2021 Associated Universities, Inc. Washington DC, USA.
#
# This file is part of NRAO Workspaces.
#
# Workspaces is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Workspaces is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Workspaces. If not, see <https://www.gnu.org/licenses/>.
#
""" Tests for core sampler"""
import sys
from io import StringIO
from unittest.mock import patch
import core_sampler
import pytest
from core_sampler.core_sampler import main
def test_gets_eb_as_expected():
"""
Do we get the output we expect for this 20A-465 execution block?
:return:
"""
sys.argv.append("-e")
sys.argv.append("148839")
with patch("sys.stdout", new=StringIO()) as fake_out:
main()
output = fake_out.getvalue()
rows = output.split("\n")
assert len(rows) == 86
sys.argv.pop()
sys.argv.pop()
def test_gets_project_as_expected():
"""
Do we get the output we expect for 20A-465?
:return:
"""
sys.argv.append("-p")
sys.argv.append("20A-465")
with patch("sys.stdout", new=StringIO()) as fake_out:
main()
output = fake_out.getvalue()
rows = output.split("\n")
assert len(rows) == 704
sys.argv.pop()
sys.argv.pop()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment