Skip to content
Snippets Groups Projects
Commit 374edb24 authored by Daniel Lyons's avatar Daniel Lyons Committed by Daniel Lyons
Browse files

Can recur down into the filegroup structure but there things get iffy

parent d80eb420
No related branches found
No related tags found
1 merge request!343Core sampler
......@@ -8,11 +8,11 @@ The core sampler outputs an SQL file you can use to load the core sample into a
"""
import argparse
from typing import Union, Any, Iterable
from pycapo import CapoConfig
from typing import Any, Iterable
import psycopg2 as pg
import psycopg2.extras as extras
from pycapo import CapoConfig
# stolen shamelessly from aat_wrest
......@@ -30,10 +30,13 @@ class MDDBConnector:
"""
settings = CapoConfig().settings("metadataDatabase")
host, dbname = settings["jdbcUrl"].split(":")[2][2:].split("/")
host_slash_db = settings["jdbcUrl"][len("jdbc:postgresql://") :]
host, dbname = host_slash_db.split("/")
host, port = host.split(":")
try:
conn = pg.connect(
host=host,
port=port,
database=dbname,
user=settings.jdbcUsername,
password=settings.jdbcPassword,
......@@ -44,13 +47,18 @@ class MDDBConnector:
raise exc
def cursor(self):
return self.connection.cursor()
return self.connection.cursor(cursor_factory=extras.RealDictCursor)
def close(self):
self.connection.close()
class CoreSampler:
class RowWriter:
def write_row(self, table: str, row: dict):
raise NotImplementedError
class CoreSampler(RowWriter):
"""
A device for retrieving core samples from a database.
"""
......@@ -64,7 +72,7 @@ class CoreSampler:
"""
# the first time through, we select from the projects table and get that row
projects = self.table("projects")
requested = projects.fetch(project_code)
requested = projects.fetch({"project_code": project_code})
self.save(requested)
def save(self, rows: "RowSet"):
......@@ -76,7 +84,10 @@ class CoreSampler:
self.write(rows)
for relation in rows.relations():
more = relation.fetch_related_to(rows)
self.save(more)
# recur, if we have some rows
if len(more) > 0:
self.save(more)
def table(self, name: str) -> "Table":
"""
......@@ -85,7 +96,7 @@ class CoreSampler:
:param name: name of the table we're talking about
:return: a Table with this name
"""
raise NotImplementedError
return Table(self.connection.cursor(), name)
def write(self, rows: "RowSet"):
"""
......@@ -93,7 +104,10 @@ class CoreSampler:
:param rows: the rows to record
"""
raise NotImplementedError
rows.write_to(self)
def write_row(self, table: str, row: dict):
pass # print((table, row))
class Table:
......@@ -101,7 +115,11 @@ class Table:
I'm a table in a relational database.
"""
def fetch(self, primary_key: Any) -> "RowSet":
def __init__(self, cursor, name: str):
self.cursor = cursor
self.name = name
def fetch(self, primary_key: dict[str, Any]) -> "RowSet":
"""
Fetch rows with the associated primary key. The result will be a single row,
but contained in a RowSet object.
......@@ -109,7 +127,96 @@ class Table:
:param primary_key: the key to look up by
:return: a RowSet with the row in it
"""
raise NotImplementedError
# 1. Determine the primary key columns
self.cursor.execute(
"""SELECT c.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_name)
JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND
tc.table_name = c.table_name AND
ccu.column_name = c.column_name
WHERE constraint_type = 'PRIMARY KEY' AND
tc.table_name = %(tablename)s""",
dict(tablename=self.name),
)
primary_key_columns = [r["column_name"] for r in self.cursor.fetchall()]
# 2. Generate the WHERE clause
pkey = {column_name: primary_key[column_name] for column_name in primary_key_columns}
whereclause = " AND ".join(f"{name} = %({name})s" for name in pkey.keys())
# 3. Run the query
self.cursor.execute(f"SELECT * FROM {self.name} WHERE {whereclause}", pkey)
# 4. Manufacture the result
return RowSet(self.cursor, self, self.cursor.fetchall())
def fetch_according_to(self, rows: "RowSet", columns: list[str]):
"""
Fetch rows according to the supplied rows and columns
:param rows: rows to mine for the WHERE clause
:param columns: columns to consider in the generated WHERE clause
:return: RowSet for the rows in this table
"""
print(f"fetching from {self.name}")
# 1. generate the where clause
escaped = [self.escape(row[column]) for column in columns for row in rows if row[column] is not None]
if len(escaped) == 0:
return RowSet(self.cursor, self, [])
whereclause = f"({','.join(columns)}) IN ({','.join(escaped)})"
# 2. send the query off and make the result
self.cursor.execute(f"SELECT * FROM {self.name} WHERE {whereclause}")
return RowSet(self.cursor, self, self.cursor.fetchall())
def escape(self, value: Any):
if isinstance(value, str):
return f"E{repr(value)}"
elif value is None:
return None
else:
return str(value)
def relations(self) -> Iterable["Relationship"]:
"""
Return the relationships that might pertain to the rows we have here.
:return:
"""
self.cursor.execute(
"""select kcu.table_name as foreign_table,
string_agg(kcu.column_name, ', ') as fk_columns
from information_schema.table_constraints tco
join information_schema.key_column_usage kcu
on tco.constraint_schema = kcu.constraint_schema
and tco.constraint_name = kcu.constraint_name
join information_schema.referential_constraints rco
on tco.constraint_schema = rco.constraint_schema
and tco.constraint_name = rco.constraint_name
join information_schema.table_constraints rel_tco
on rco.unique_constraint_schema = rel_tco.constraint_schema
and rco.unique_constraint_name = rel_tco.constraint_name
where tco.constraint_type = 'FOREIGN KEY'
and rel_tco.table_name = %(tablename)s
group by kcu.table_name,
rel_tco.table_name,
rel_tco.table_schema,
kcu.constraint_name
order by kcu.table_name""",
dict(tablename=self.name),
)
result = []
for row in self.cursor.fetchall():
other_table = Table(self.cursor, row["foreign_table"])
result.append(Relationship(self, other_table, row["fk_columns"].split(", ")))
return result
class RowSet:
......@@ -117,16 +224,33 @@ class RowSet:
I'm a set of rows from a SELECT * or TABLE query, run against a relational database.
"""
def __init__(self, cursor, table: Table, rows: list[dict]):
self.cursor, self.table, self.rows = cursor, table, rows
def relations(self) -> Iterable["Relationship"]:
"""
Return the relationships that might pertain to the rows we have here.
:return:
"""
raise NotImplementedError
return self.table.relations()
def write_to(self, writer: RowWriter):
for row in self.rows:
writer.write_row(self.table.name, row)
def __iter__(self):
return iter(self.rows)
def __len__(self):
return len(self.rows)
class Relationship:
def __init__(self, source_table: Table, destination_table: Table, foreign_columns: list[str]):
self.source_table, self.destination_table = source_table, destination_table
self.foreign_columns = foreign_columns
def fetch_related_to(self, rows: RowSet) -> RowSet:
"""
Return rows related to the supplied rows based on this relationship.
......@@ -137,7 +261,7 @@ class Relationship:
:param rows: source rows
:return: rows from the destination
"""
raise NotImplementedError
return self.destination_table.fetch_according_to(rows, self.foreign_columns)
def main():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment