Skip to content
Snippets Groups Projects

Core sampler

Merged Daniel Lyons requested to merge core-sampler into main
1 unresolved thread
1 file
+ 22
15
Compare changes
  • Side-by-side
  • Inline
"""
Core Sampler
This program extracts a "core sample" from a database. You supply a database name, a table name, and a primary key.
The core sampler outputs an SQL file you can use to load the core sample into a copy of the database somewhere else.
"""
import argparse
import psycopg2 as pg
import psycopg2.extras as extras
from pycapo import CapoConfig
from .database import PGTable
from .interfaces import Table, RowSet
from .row_writer import TopologicallySortingRowWriter, UniquifyingRowWriter, PostgresCopyRowWriter
# stolen shamelessly from aat_wrest
class MDDBConnector:
"""Use this connection to interrogate this science product locator"""
def __init__(self):
self.connection = self._connect_to_mddb()
def _connect_to_mddb(self):
"""
Establish a DB connection
:return:
"""
settings = CapoConfig().settings("metadataDatabase")
host_slash_db = settings["jdbcUrl"][len("jdbc:postgresql://") :]
host, dbname = host_slash_db.split("/")
port = 5432
if ":" in host:
host, port = host.split(":")
try:
conn = pg.connect(
host=host,
port=port,
database=dbname,
user=settings.jdbcUsername,
password=settings.jdbcPassword,
)
return conn
except Exception as exc:
print(f"Unable to connect to database: {exc}")
raise exc
def cursor(self):
return self.connection.cursor(cursor_factory=extras.RealDictCursor)
def close(self):
self.connection.close()
class CoreSampler:
"""
A device for retrieving core samples from a database.
"""
def __init__(self, connection):
self.connection = connection
self.writer = TopologicallySortingRowWriter(UniquifyingRowWriter(PostgresCopyRowWriter()))
self.visited = set()
def sample(self, project_code: str):
"""
Sample the database.
"""
# the first time through, we select from the projects table and get that row
projects = self.table("projects")
requested = projects.fetch({"project_code": project_code})
self.save(requested)
self.writer.close()
def save(self, rows: "RowSet"):
"""
Save some rows, and then go and fetch their related rows and save them too, recursively.
:param rows: the seed rows to start from
"""
# bail out if we've already seen this table
if rows.table.name in self.visited:
return
self.visited.add(rows.table.name)
self.write(rows)
for relation in rows.relations():
more = relation.fetch_related_to(rows)
# recur, if we have some rows
if len(more) > 0:
self.save(more)
def table(self, name: str) -> Table:
"""
Return a Table with the given name.
:param name: name of the table we're talking about
:return: a Table with this name
"""
return PGTable(self.connection.cursor(), name)
def write(self, rows: RowSet):
"""
Record the rows we've found.
:param rows: the rows to record
"""
rows.write_to(self.writer)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("project_code", type=str, help="Project code to start core sampling from")
ns = parser.parse_args()
CoreSampler(MDDBConnector()).sample(ns.project_code)
return 0
if __name__ == "__main__":
main()
Loading