Skip to content
Snippets Groups Projects

Core sampler

Merged Daniel Lyons requested to merge core-sampler into main
1 unresolved thread
4 files
+ 475
328
Compare changes
  • Side-by-side
  • Inline
Files
4
@@ -8,14 +8,15 @@ The core sampler outputs an SQL file you can use to load the core sample into a
"""
import argparse
from itertools import chain
from typing import Any, Iterable, Dict, List
import datetime
import psycopg2 as pg
import psycopg2.extras as extras
from pycapo import CapoConfig
from .database import PGTable
from .interfaces import Table, RowSet
from .row_writer import TopologicallySortingRowWriter, UniquifyingRowWriter, PostgresCopyRowWriter
# stolen shamelessly from aat_wrest
class MDDBConnector:
@@ -55,149 +56,6 @@ class MDDBConnector:
self.connection.close()
class RowWriter:
def write_rows(self, table: "Table", rows: List[Dict]):
raise NotImplementedError
def close(self):
"""
By default there's nothing to do here, but subclasses can override this
if the need to catch a signal that we're done writing.
"""
pass
class PostgresCopyRowWriter(RowWriter):
"""
Write rows in the PostgreSQL COPY TABLE format. This is the most efficient
way of loading data into Postgres and it's the way used by the pg_dump
utility for creating backups.
"""
def write_rows(self, table: "Table", rows: List[Dict]):
columns = rows[0].keys()
print(f"COPY {table.name} ({', '.join(columns)}) FROM stdin;")
for row in rows:
print("\t".join([self.copy_format(row[col]) for col in columns]))
print("\.")
@staticmethod
def copy_format(value):
if value is None:
return "\\N"
elif isinstance(value, str):
return value
elif isinstance(value, int) or isinstance(value, float):
return str(value)
elif isinstance(value, datetime.date):
return value.isoformat()
else:
raise TypeError(f"Unable to figure out what to do with {value} of type {type(value)}")
class UniquifyingRowWriter(RowWriter):
"""
Ensure that only a single instance of each row gets dumped out.
"""
def __init__(self, underlying: RowWriter):
self.underlying = underlying
self.seen = {}
def write_rows(self, table: "Table", rows: List[Dict]):
seen = self.seen.setdefault(table.name, [])
new_rows = [row for row in rows if row not in seen]
if len(new_rows) > 0:
self.underlying.write_rows(table, new_rows)
self.seen[table.name] += new_rows
def close(self):
self.underlying.close()
class TopologicalSortFailed(Exception):
pass
class TopologicallySortingRowWriter(RowWriter):
"""
Topologically sorts the tables before outputting their rows.
Because you can have an arbitrary graph topology of foreign keys in your database, you have to worry
about what order you try to input rows when you read them. If you try to read a row before you have
loaded a row that it depends on, the load will fail. For instance, science product locators point to
projects, but execblocks point to both projects and science product locators. By starting from the
project, the program will want to output execblocks first due to alphabetical order. This would be
difficult to do in a query so instead we do it here because we have all the objects we need to figure
out what order things should go in.
The algorithm that gives you this ordering across a digraph is called topological sort. This implementation
was created while reading notes in the slides at:
https://courses.cs.washington.edu/courses/cse326/03wi/lectures/RaoLect20.pdf
"""
def __init__(self, underlying: RowWriter):
self.tables_rows: Dict["Table", List[Dict]] = {}
self.underlying = underlying
def write_rows(self, table: "Table", rows: List[Dict]):
# We've been asked to output some rows for this table.
# What we must now do is keep track of this table and the in-bound relationships with it.
# Knowing which tables we need to output will give us the vertices of the graph we need to sort.
self.tables_rows[table] = rows
def close(self):
# Now that we know all the tables we're going to be asked to output, we can compute their edges.
# We have an edge for each relationship from that table to another table we've been asked to output
vertices = self.tables_rows.keys()
# Now we can get the exact set of relations we actually care about by getting all
# of the relations between all of the tables and filtering out relations that pertain
# to tables we do not have data for. We also ignore self-joins here because they increase
# the in-degree by one but cannot ever be decreased, so we avoid an infinite loop by
# ignoring them. We can assume that, to whatever extent this program handles self-joins,
# it will obtain upper level rows before their dependencies programmatically, so we don't
# have to handle it manually here.
all_relations = chain.from_iterable(v.relations() for v in vertices)
relations = [
r
for r in all_relations
if r.source_table in vertices and r.destination_table in vertices and r.source_table != r.destination_table
]
# we can now compute the in-degree values for each vertex by computing how many relationships
# from other tables we care about point to this table
in_degrees = {}
for vertex_table in vertices:
in_degrees[vertex_table] = sum(1 for relation in relations if relation.destination_table == vertex_table)
# now that we know the in_degrees we can proceed with the sort
remaining = list(vertices)
while len(remaining) > 0:
# find a vertex in remaining whose in-degrees is 0
table = None
for table in remaining:
if in_degrees[table] == 0:
break
# if we didn't find one, throw an exception
if table is None:
raise TopologicalSortFailed("Unable to locate a table with no references!")
# pass along the rows we found
self.underlying.write_rows(table, self.tables_rows[table])
# now we can remove this vertex from the remaining vertices
# and reduce its neighbor's input degree by one
remaining.remove(table)
for relation in relations:
if relation.source_table == table and relation.destination_table in in_degrees:
in_degrees[relation.destination_table] -= 1
# if we made it here, we ran out of remaining tables and we can conclude the process
self.underlying.close()
class CoreSampler:
"""
A device for retrieving core samples from a database.
@@ -237,16 +95,16 @@ class CoreSampler:
if len(more) > 0:
self.save(more)
def table(self, name: str) -> "Table":
def table(self, name: str) -> Table:
"""
Return a Table with the given name.
:param name: name of the table we're talking about
:return: a Table with this name
"""
return Table(self.connection.cursor(), name)
return PGTable(self.connection.cursor(), name)
def write(self, rows: "RowSet"):
def write(self, rows: RowSet):
"""
Record the rows we've found.
@@ -255,185 +113,6 @@ class CoreSampler:
rows.write_to(self.writer)
class Table:
"""
I'm a table in a relational database.
"""
def __init__(
self,
cursor,
name: str,
):
self.cursor = cursor
self.name = name
def fetch(self, primary_key: Dict[str, Any]) -> "RowSet":
"""
Fetch rows with the associated primary key. The result will be a single row,
but contained in a RowSet object.
:param primary_key: the key to look up by
:return: a RowSet with the row in it
"""
# 1. Determine the primary key columns
primary_key_columns = self.primary_key_columns()
# 2. Generate the WHERE clause
pkey = {column_name: primary_key[column_name] for column_name in primary_key_columns}
whereclause = " AND ".join(f"{name} = %({name})s" for name in pkey.keys())
# 3. Run the query
self.cursor.execute(f"SELECT * FROM {self.name} WHERE {whereclause}", pkey)
# 4. Manufacture the result
return RowSet(self.cursor, self, self.cursor.fetchall())
def primary_key_columns(self):
self.cursor.execute(
"""SELECT c.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_name)
JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND
tc.table_name = c.table_name AND
ccu.column_name = c.column_name
WHERE constraint_type = 'PRIMARY KEY' AND
tc.table_name = %(tablename)s""",
dict(tablename=self.name),
)
primary_key_columns = [r["column_name"] for r in self.cursor.fetchall()]
return primary_key_columns
def fetch_according_to(self, rows: "RowSet", columns: List[str]):
"""
Fetch rows according to the supplied rows and columns
:param rows: rows to mine for the WHERE clause
:param columns: columns to consider in the generated WHERE clause
:return: RowSet for the rows in this table
"""
# print(f"Fetching from {self.name} according to {','.join(columns)}")
# 1. Escape the WHERE clause entries. it's important to use the primary keys
# to retrieve the values from the previous resultset; the new columns will
# appear in the query below.
key_columns = rows.table.primary_key_columns()
escaped = [self.escape(row[column]) for column in key_columns for row in rows if row[column] is not None]
# If we have no escaped entries, then there won't be anything in the resultset and we can return now
if len(escaped) == 0:
return RowSet(self.cursor, self, [])
# 2. Build the WHERE clause
whereclause = f"({','.join(columns)}) IN ({','.join(escaped)})"
# 3. Send the query off and make the result
self.cursor.execute(f"SELECT * FROM {self.name} WHERE {whereclause}")
return RowSet(self.cursor, self, self.cursor.fetchall())
def escape(self, value: Any):
if isinstance(value, str):
return f"E{repr(value)}"
elif value is None:
return None
else:
return str(value)
def relations(self) -> Iterable["Relationship"]:
"""
Return the relationships that might pertain to the rows we have here.
:return:
"""
self.cursor.execute(
"""select kcu.table_name as foreign_table,
string_agg(kcu.column_name, ', ') as fk_columns
from information_schema.table_constraints tco
join information_schema.key_column_usage kcu
on tco.constraint_schema = kcu.constraint_schema
and tco.constraint_name = kcu.constraint_name
join information_schema.referential_constraints rco
on tco.constraint_schema = rco.constraint_schema
and tco.constraint_name = rco.constraint_name
join information_schema.table_constraints rel_tco
on rco.unique_constraint_schema = rel_tco.constraint_schema
and rco.unique_constraint_name = rel_tco.constraint_name
where tco.constraint_type = 'FOREIGN KEY'
and rel_tco.table_name = %(tablename)s
group by kcu.table_name,
rel_tco.table_name,
rel_tco.table_schema,
kcu.constraint_name
order by kcu.table_name""",
dict(tablename=self.name),
)
result = []
for row in self.cursor.fetchall():
other_table = Table(self.cursor, row["foreign_table"])
result.append(Relationship(self, other_table, row["fk_columns"].split(", ")))
return result
def __eq__(self, other):
return self.name == other.name
def __hash__(self):
return hash(self.name)
def __repr__(self):
return f"<Table {self.name}>"
class RowSet:
"""
I'm a set of rows from a SELECT * or TABLE query, run against a relational database.
"""
def __init__(self, cursor, table: Table, rows: List[Dict]):
self.cursor, self.table, self.rows = cursor, table, rows
def relations(self) -> Iterable["Relationship"]:
"""
Return the relationships that might pertain to the rows we have here.
:return:
"""
return self.table.relations()
def write_to(self, writer: RowWriter):
writer.write_rows(self.table, self.rows)
def __iter__(self):
return iter(self.rows)
def __len__(self):
return len(self.rows)
def __repr__(self):
return f"<RowSet of {self.table} with {len(self.rows)} rows>"
class Relationship:
def __init__(self, source_table: Table, destination_table: Table, foreign_columns: List[str]):
self.source_table, self.destination_table = source_table, destination_table
self.foreign_columns = foreign_columns
def fetch_related_to(self, rows: RowSet) -> RowSet:
"""
Return rows related to the supplied rows based on this relationship.
For instance, if this is a relationship from projects to science products,
you can supply a RowSet of projects to this relationship, and receive a
new RowSet of science products.
:param rows: source rows
:return: rows from the destination
"""
return self.destination_table.fetch_according_to(rows, self.foreign_columns)
def __repr__(self):
return f"<Relationship from {self.source_table} to {self.destination_table}>"
def main():
parser = argparse.ArgumentParser()
parser.add_argument("project_code", type=str, help="Project code to start core sampling from")
Loading