From 8df5502f68b3d51bbe55cb6000eb90ef532195d1 Mon Sep 17 00:00:00 2001 From: Daniel Lyons <dlyons@nrao.edu> Date: Mon, 19 Jul 2021 16:12:24 +0000 Subject: [PATCH] Core sampler --- apps/cli/utilities/core_sampler/README.md | 4 + .../core_sampler/core_sampler/__init__.py | 0 .../core_sampler/core_sampler/_version.py | 2 + .../core_sampler/core_sampler/core_sampler.py | 128 ++++++++ .../core_sampler/core_sampler/database.py | 288 ++++++++++++++++++ .../core_sampler/core_sampler/interfaces.py | 145 +++++++++ .../core_sampler/core_sampler/row_writer.py | 139 +++++++++ apps/cli/utilities/core_sampler/setup.py | 27 ++ .../core_sampler/test/test_row_writer.py | 103 +++++++ 9 files changed, 836 insertions(+) create mode 100644 apps/cli/utilities/core_sampler/README.md create mode 100644 apps/cli/utilities/core_sampler/core_sampler/__init__.py create mode 100644 apps/cli/utilities/core_sampler/core_sampler/_version.py create mode 100644 apps/cli/utilities/core_sampler/core_sampler/core_sampler.py create mode 100644 apps/cli/utilities/core_sampler/core_sampler/database.py create mode 100644 apps/cli/utilities/core_sampler/core_sampler/interfaces.py create mode 100644 apps/cli/utilities/core_sampler/core_sampler/row_writer.py create mode 100644 apps/cli/utilities/core_sampler/setup.py create mode 100644 apps/cli/utilities/core_sampler/test/test_row_writer.py diff --git a/apps/cli/utilities/core_sampler/README.md b/apps/cli/utilities/core_sampler/README.md new file mode 100644 index 000000000..9da7cf803 --- /dev/null +++ b/apps/cli/utilities/core_sampler/README.md @@ -0,0 +1,4 @@ +# Core Sampler + +This program extracts a "core sample" from a database. You supply a database name, a table name, and a primary key. +The core sampler outputs an SQL file you can use to load the core sample into a copy of the database somewhere else. diff --git a/apps/cli/utilities/core_sampler/core_sampler/__init__.py b/apps/cli/utilities/core_sampler/core_sampler/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/apps/cli/utilities/core_sampler/core_sampler/_version.py b/apps/cli/utilities/core_sampler/core_sampler/_version.py new file mode 100644 index 000000000..f27d146a3 --- /dev/null +++ b/apps/cli/utilities/core_sampler/core_sampler/_version.py @@ -0,0 +1,2 @@ +""" Version information for this package, don't put anything else here. """ +___version___ = '4.0.0a1.dev1' diff --git a/apps/cli/utilities/core_sampler/core_sampler/core_sampler.py b/apps/cli/utilities/core_sampler/core_sampler/core_sampler.py new file mode 100644 index 000000000..a801ea686 --- /dev/null +++ b/apps/cli/utilities/core_sampler/core_sampler/core_sampler.py @@ -0,0 +1,128 @@ +""" + +Core Sampler + +This program extracts a "core sample" from a database. You supply a database name, a table name, and a primary key. +The core sampler outputs an SQL file you can use to load the core sample into a copy of the database somewhere else. + +""" + +import argparse + +import psycopg2 as pg +import psycopg2.extras as extras +from pycapo import CapoConfig + +from .database import PGTable +from .interfaces import Table, RowSet +from .row_writer import TopologicallySortingRowWriter, UniquifyingRowWriter, PostgresCopyRowWriter + + +# stolen shamelessly from aat_wrest +class MDDBConnector: + """Use this connection to interrogate this science product locator""" + + def __init__(self): + self.connection = self._connect_to_mddb() + + def _connect_to_mddb(self): + """ + Establish a DB connection + + :return: + """ + settings = CapoConfig().settings("metadataDatabase") + + host_slash_db = settings["jdbcUrl"][len("jdbc:postgresql://") :] + host, dbname = host_slash_db.split("/") + port = 5432 + if ":" in host: + host, port = host.split(":") + try: + conn = pg.connect( + host=host, + port=port, + database=dbname, + user=settings.jdbcUsername, + password=settings.jdbcPassword, + ) + return conn + except Exception as exc: + print(f"Unable to connect to database: {exc}") + raise exc + + def cursor(self): + return self.connection.cursor(cursor_factory=extras.RealDictCursor) + + def close(self): + self.connection.close() + + +class CoreSampler: + """ + A device for retrieving core samples from a database. + """ + + def __init__(self, connection): + self.connection = connection + self.writer = TopologicallySortingRowWriter(UniquifyingRowWriter(PostgresCopyRowWriter())) + self.visited = set() + + def sample(self, project_code: str): + """ + Sample the database. + """ + # the first time through, we select from the projects table and get that row + projects = self.table("projects") + requested = projects.fetch({"project_code": project_code}) + self.save(requested) + self.writer.close() + + def save(self, rows: "RowSet"): + """ + Save some rows, and then go and fetch their related rows and save them too, recursively. + + :param rows: the seed rows to start from + """ + # bail out if we've already seen this table + if rows.table.name in self.visited: + return + + self.visited.add(rows.table.name) + self.write(rows) + for relation in rows.relations(): + more = relation.fetch_related_to(rows) + + # recur, if we have some rows + if len(more) > 0: + self.save(more) + + def table(self, name: str) -> Table: + """ + Return a Table with the given name. + + :param name: name of the table we're talking about + :return: a Table with this name + """ + return PGTable(self.connection.cursor(), name) + + def write(self, rows: RowSet): + """ + Record the rows we've found. + + :param rows: the rows to record + """ + rows.write_to(self.writer) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("project_code", type=str, help="Project code to start core sampling from") + + ns = parser.parse_args() + CoreSampler(MDDBConnector()).sample(ns.project_code) + return 0 + + +if __name__ == "__main__": + main() diff --git a/apps/cli/utilities/core_sampler/core_sampler/database.py b/apps/cli/utilities/core_sampler/core_sampler/database.py new file mode 100644 index 000000000..22aeee2f6 --- /dev/null +++ b/apps/cli/utilities/core_sampler/core_sampler/database.py @@ -0,0 +1,288 @@ +from typing import Dict, Any, List, Iterable, Optional +from .interfaces import RowSet, Table, Relationship, RowWriter + + +class PGTable(Table): + """ + I'm a table in PostgreSQL. + """ + + @property + def name(self) -> str: + return self._name + + def __init__( + self, + cursor, + name: str, + ): + self.cursor = cursor + self._name = name + + def fetch(self, primary_key: Dict[str, Any]) -> RowSet: + # 1. Determine the primary key columns + primary_key_columns = self.primary_key_columns() + + # 2. Generate the WHERE clause + pkey = {column_name: primary_key[column_name] for column_name in primary_key_columns} + whereclause = " AND ".join(f"{name} = %({name})s" for name in pkey.keys()) + + # 3. Run the query + self.cursor.execute(f"SELECT * FROM {self.name} WHERE {whereclause}", pkey) + + # 4. Manufacture the result + return PGRowSet(self, self.cursor.fetchall()) + + def primary_key_columns(self): + self.cursor.execute( + """SELECT c.column_name + FROM information_schema.table_constraints tc + JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_name) + JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND + tc.table_name = c.table_name AND + ccu.column_name = c.column_name + WHERE constraint_type = 'PRIMARY KEY' AND + tc.table_name = %(tablename)s""", + dict(tablename=self.name), + ) + primary_key_columns = [r["column_name"] for r in self.cursor.fetchall()] + return primary_key_columns + + def fetch_according_to(self, rows: RowSet, columns: List[str]): + self_join_columns = self.self_join_columns() + if self_join_columns: + return self.recursive_fetch_according_to(rows, columns, self_join_columns) + else: + return self.basic_fetch_according_to(rows, columns) + + def basic_fetch_according_to(self, rows: RowSet, columns: List[str]): + """ + Fetch from this table all the rows related to the supplied rowset, using the + supplied columns from the originating table. + + :param rows: + :param columns: + :return: + """ + query = self.basic_fetch_according_to_query(rows, columns) + + # if the query doesn't exist, we don't have any escaped rows, + # so there's nothing to do + if query is None: + return PGRowSet(self, []) + + self.cursor.execute(query) + return PGRowSet(self, self.cursor.fetchall()) + + def basic_fetch_according_to_query(self, rows: RowSet, columns: List[str]) -> Optional[str]: + """ + Returns the query we'll use to obtain some rows from this table according to rows in + another RowSet that reference us somehow. + + :param rows: the rows with references to our table + :param columns: the columns in those rows that reference our primary key + :return: SQL query string + """ + # 1. Escape the WHERE clause entries. it's important to use the primary keys + # to retrieve the values from the previous resultset; the new columns will + # appear in the query below. + key_columns = rows.table.primary_key_columns() + escaped = [self.escape(row[column]) for column in key_columns for row in rows if row[column] is not None] + + # If we have no escaped entries, then there won't be anything in the resultset and we can return now + if len(escaped) == 0: + return None + + # 2. Build the WHERE clause + whereclause = f"({','.join(columns)}) IN ({','.join(escaped)})" + + # 3. Send the query off and make the result + return f"SELECT * FROM {self.name} WHERE {whereclause}" + + def recursive_fetch_according_to(self, rows: RowSet, columns: List[str], self_columns: List[str]): + """ + For tables that have self-joins, we have to make a slightly more complex query. In + fact we need a recursive query that starts from the rows referencing us but encompasses + all of the children of those rows. To do that, we'll need to know the self-join columns. + + :param rows: rows in the originating table + :param columns: the columns that reference our primary key in those rows + :param self_columns: the self-join columns we have + :return: + """ + basic_select = self.basic_fetch_according_to_query(rows, columns) + + if basic_select is None: + return PGRowSet(self, []) + + on_clause = " AND ".join( + f"{self.name}.{fk} = recursive_{self.name}.{pk}" for pk, fk in zip(self.primary_key_columns(), self_columns) + ) + + query = f"""with recursive recursive_{self.name} as ( + {basic_select} + union + select {self.name}.* from {self.name} + join recursive_{self.name} on {on_clause} + ) + select * from recursive_{self.name}""" + + self.cursor.execute(query) + return PGRowSet(self, self.cursor.fetchall()) + + @staticmethod + def escape(value: Any): + """ + Escape the supplied value so that Postgres can understand it + :param value: some value, a string or something + :return: an escaped value + """ + if isinstance(value, str): + return f"E{repr(value)}" + elif value is None: + return None + else: + return str(value) + + def relations(self) -> Iterable[Relationship]: + # To obtain the other relationships, we must first retrieve the foreign keys, and then + # convert them into objects in this schema. + self.cursor.execute( + self.foreign_keys_query(self_only=False), + dict(tablename=self.name), + ) + + # now that we have the rows, we can construct our return value + result = [] + for row in self.cursor.fetchall(): + other_table = PGTable(self.cursor, row["foreign_table"]) + result.append(PGRelationship(self, other_table, row["fk_columns"].split(","))) + + return result + + def self_join_columns(self) -> List[str]: + """ + Determine if there are any self-join columns on this table, and if so, return them. + :return: A list of self-join columns, possibly empty. + """ + self.cursor.execute( + self.foreign_keys_query(self_only=True), + dict(tablename=self.name), + ) + rows = self.cursor.fetchall() + if len(rows) > 0: + return rows[0]["fk_columns"].split(",") + else: + return [] + + @staticmethod + def foreign_keys_query(self_only=False) -> str: + """ + Return a query we can use to find the foreign keys related to this table. + + :param self_only: if True, only self-joins. if False, only other-joins. + :return: A string query we can use. + """ + + # PostgreSQL has extensive metadata about the database stored in the information_schema, which + # is an SQL standard suite of views. In Postgres, these views are based on lower-level Postgres + # metadata in the pg_catalog schema. + # + # This precise query gets you the foreign keys from this table to other tables. It explicitly + # omits self-joins, because they're handled in this class when you retrieve. + + return f"""select kcu.table_name as foreign_table, + string_agg(kcu.column_name, ',') as fk_columns + from information_schema.table_constraints tco + join information_schema.key_column_usage kcu + on tco.constraint_schema = kcu.constraint_schema + and tco.constraint_name = kcu.constraint_name + join information_schema.referential_constraints rco + on tco.constraint_schema = rco.constraint_schema + and tco.constraint_name = rco.constraint_name + join information_schema.table_constraints rel_tco + on rco.unique_constraint_schema = rel_tco.constraint_schema + and rco.unique_constraint_name = rel_tco.constraint_name + where tco.constraint_type = 'FOREIGN KEY' + and rel_tco.table_name = %(tablename)s + and kcu.table_name {'=' if self_only else '<>'} %(tablename)s + group by kcu.table_name, + rel_tco.table_name, + rel_tco.table_schema, + kcu.constraint_name + order by kcu.table_name""" + + def __eq__(self, other): + return self.name == other.name + + def __hash__(self): + return hash(self.name) + + def __repr__(self): + return f"<Table {self.name}>" + + +class PGRowSet(RowSet): + """ + I'm a RowSet from a Postgres database. + """ + + def __init__(self, table: Table, rows: List[Dict]): + self._table, self.rows = table, rows + + @property + def table(self) -> Table: + return self._table + + def relations(self) -> Iterable[Relationship]: + """ + Return the relationships that might pertain to the rows we have here. + + :return: + """ + return self.table.relations() + + def write_to(self, writer: RowWriter): + writer.write_rows(self.table, self.rows) + + def __iter__(self) -> Iterable[Dict]: + return iter(self.rows) + + def __len__(self) -> int: + return len(self.rows) + + def __repr__(self): + return f"<RowSet of {self.table} with {len(self.rows)} rows>" + + +class PGRelationship(Relationship): + """ + I'm a relationship between Postgres tables. + """ + + def __init__(self, source_table: Table, destination_table: Table, foreign_columns: List[str]): + self._source_table, self._destination_table = source_table, destination_table + self.foreign_columns = foreign_columns + + @property + def source_table(self) -> Table: + return self._source_table + + @property + def destination_table(self) -> Table: + return self._destination_table + + def fetch_related_to(self, rows: RowSet) -> RowSet: + """ + Return rows related to the supplied rows based on this relationship. + For instance, if this is a relationship from projects to science products, + you can supply a RowSet of projects to this relationship, and receive a + new RowSet of science products. + + :param rows: source rows + :return: rows from the destination + """ + return self.destination_table.fetch_according_to(rows, self.foreign_columns) + + def __repr__(self): + return f"<Relationship from {self.source_table} to {self.destination_table}>" diff --git a/apps/cli/utilities/core_sampler/core_sampler/interfaces.py b/apps/cli/utilities/core_sampler/core_sampler/interfaces.py new file mode 100644 index 000000000..bc8f3f32d --- /dev/null +++ b/apps/cli/utilities/core_sampler/core_sampler/interfaces.py @@ -0,0 +1,145 @@ +import abc +from abc import ABC +from typing import Dict, Any, List, Iterable + + +class Table(ABC): + """ + I'm a table in a relational database. + """ + + @property + @abc.abstractmethod + def name(self) -> str: + """ + The name of the table. + :return: the name of this table + """ + pass + + @abc.abstractmethod + def fetch(self, primary_key: Dict[str, Any]) -> "RowSet": + """ + Fetch rows with the associated primary key. The result will be a single row, + but contained in a RowSet object. + + :param primary_key: the key to look up by + :return: a RowSet with the row in it + """ + pass + + @abc.abstractmethod + def primary_key_columns(self) -> List[str]: + """ + Return the primary key columns for this table. + """ + pass + + @abc.abstractmethod + def fetch_according_to(self, rows: "RowSet", columns: List[str]) -> "RowSet": + """ + Fetch rows according to the supplied rows and columns + + :param rows: rows to mine for the WHERE clause + :param columns: columns to consider in the generated WHERE clause + :return: RowSet for the rows in this table + """ + pass + + @abc.abstractmethod + def relations(self) -> Iterable["Relationship"]: + """ + Return the relationships that might pertain to the rows we have here. + + :return: + """ + pass + + +class RowWriter: + def write_rows(self, table: Table, rows: List[Dict]): + raise NotImplementedError + + def close(self): + """ + By default there's nothing to do here, but subclasses can override this + if the need to catch a signal that we're done writing. + """ + pass + + +class RowSet(ABC): + """ + I'm a set of rows from a SELECT * or TABLE query, run against a relational database. + """ + + @property + @abc.abstractmethod + def table(self) -> Table: + pass + + @abc.abstractmethod + def relations(self) -> Iterable["Relationship"]: + """ + Return the relationships that might pertain to the rows we have here. + + :return: + """ + pass + + @abc.abstractmethod + def write_to(self, writer: RowWriter): + """ + Write this rowset to a particular row writer + :param writer: write to communicate with + """ + pass + + @abc.abstractmethod + def __iter__(self) -> Iterable[Dict]: + pass + + @abc.abstractmethod + def __len__(self) -> int: + pass + + +class Relationship(ABC): + """ + I'm a relationship between two tables, typically a foreign key relationship. + """ + + @property + @abc.abstractmethod + def source_table(self) -> Table: + """ + The table with the foreign key. Consider the relationship + between execution_blocks(project_code) -> projects(project_code). The + source table is execution_blocks. + :return: + """ + pass + + @property + @abc.abstractmethod + def destination_table(self) -> Table: + """ + The table with the foreign key. Consider the relationship + between execution_blocks(project_code) -> projects(project_code). The + destination table is projects. + :return: + """ + pass + + @abc.abstractmethod + def fetch_related_to(self, rows: RowSet) -> RowSet: + """ + Return rows related to the supplied rows based on this relationship. + For instance, if this is a relationship from projects to science products, + you can supply a RowSet of projects to this relationship, and receive a + new RowSet of science products. + + :param rows: source rows + :return: rows from the destination + """ + pass diff --git a/apps/cli/utilities/core_sampler/core_sampler/row_writer.py b/apps/cli/utilities/core_sampler/core_sampler/row_writer.py new file mode 100644 index 000000000..6c37a3e7f --- /dev/null +++ b/apps/cli/utilities/core_sampler/core_sampler/row_writer.py @@ -0,0 +1,139 @@ +import datetime +from itertools import chain +from typing import List, Dict + +from .interfaces import RowWriter + + +class PostgresCopyRowWriter(RowWriter): + """ + Write rows in the PostgreSQL COPY TABLE format. This is the most efficient + way of loading data into Postgres and it's the way used by the pg_dump + utility for creating backups. + """ + + def write_rows(self, table: "Table", rows: List[Dict]): + columns = rows[0].keys() + print(f"COPY {table.name} ({', '.join(columns)}) FROM stdin;") + for row in rows: + print("\t".join([self.copy_format(row[col]) for col in columns])) + print("\.") + + @staticmethod + def copy_format(value): + if value is None: + return "\\N" + elif isinstance(value, str): + return value + elif isinstance(value, int) or isinstance(value, float): + return str(value) + elif isinstance(value, datetime.date): + return value.isoformat() + else: + raise TypeError(f"Unable to figure out what to do with {value} of type {type(value)}") + + +class UniquifyingRowWriter(RowWriter): + """ + Ensure that only a single instance of each row gets dumped out. + """ + + def __init__(self, underlying: RowWriter): + self.underlying = underlying + self.seen = {} + + def write_rows(self, table: "Table", rows: List[Dict]): + seen = self.seen.setdefault(table.name, []) + new_rows = [row for row in rows if row not in seen] + if len(new_rows) > 0: + self.underlying.write_rows(table, new_rows) + self.seen[table.name] += new_rows + + def close(self): + self.underlying.close() + + +class TopologicalSortFailed(Exception): + pass + + +class TopologicallySortingRowWriter(RowWriter): + """ + Topologically sorts the tables before outputting their rows. + + Because you can have an arbitrary graph topology of foreign keys in your database, you have to worry + about what order you try to input rows when you read them. If you try to read a row before you have + loaded a row that it depends on, the load will fail. For instance, science product locators point to + projects, but execblocks point to both projects and science product locators. By starting from the + project, the program will want to output execblocks first due to alphabetical order. This would be + difficult to do in a query so instead we do it here because we have all the objects we need to figure + out what order things should go in. + + The algorithm that gives you this ordering across a digraph is called topological sort. This implementation + was created while reading notes in the slides at: + https://courses.cs.washington.edu/courses/cse326/03wi/lectures/RaoLect20.pdf + """ + + def __init__(self, underlying: RowWriter): + self.tables_rows: Dict["Table", List[Dict]] = {} + self.underlying = underlying + + def write_rows(self, table: "Table", rows: List[Dict]): + # We've been asked to output some rows for this table. + # What we must now do is keep track of this table and the in-bound relationships with it. + # Knowing which tables we need to output will give us the vertices of the graph we need to sort. + self.tables_rows[table] = rows + + def close(self): + # Now that we know all the tables we're going to be asked to output, we can compute their edges. + # We have an edge for each relationship from that table to another table we've been asked to output + vertices = self.tables_rows.keys() + + # There is an object called graphlib.TopologicalSorter in Python 3.9+ that can replace + # the code below. + + # Now we can get the exact set of relations we actually care about by getting all + # of the relations between all of the tables and filtering out relations that pertain + # to tables we do not have data for. We also ignore self-joins here because they increase + # the in-degree by one but cannot ever be decreased, so we avoid an infinite loop by + # ignoring them. We can assume that, to whatever extent this program handles self-joins, + # it will obtain upper level rows before their dependencies programmatically, so we don't + # have to handle it manually here. + all_relations = chain.from_iterable(v.relations() for v in vertices) + relations = [ + r + for r in all_relations + if r.source_table in vertices and r.destination_table in vertices and r.source_table != r.destination_table + ] + + # we can now compute the in-degree values for each vertex by computing how many relationships + # from other tables we care about point to this table + in_degrees = {} + for vertex_table in vertices: + in_degrees[vertex_table] = sum(1 for relation in relations if relation.destination_table == vertex_table) + + # now that we know the in_degrees we can proceed with the sort + remaining = list(vertices) + while len(remaining) > 0: + # find a vertex in remaining whose in-degrees is 0 + table = None + for table in remaining: + if in_degrees[table] == 0: + break + + # if we didn't find one, throw an exception + if table is None: + raise TopologicalSortFailed("Unable to locate a table with no references!") + + # pass along the rows we found + self.underlying.write_rows(table, self.tables_rows[table]) + + # now we can remove this vertex from the remaining vertices + # and reduce its neighbor's input degree by one + remaining.remove(table) + for relation in relations: + if relation.source_table == table and relation.destination_table in in_degrees: + in_degrees[relation.destination_table] -= 1 + + # if we made it here, we ran out of remaining tables and we can conclude the process + self.underlying.close() diff --git a/apps/cli/utilities/core_sampler/setup.py b/apps/cli/utilities/core_sampler/setup.py new file mode 100644 index 000000000..68c0dd2e4 --- /dev/null +++ b/apps/cli/utilities/core_sampler/setup.py @@ -0,0 +1,27 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +from pathlib import Path + +from setuptools import find_packages, setup + +VERSION = open("core_sampler/_version.py").readlines()[-1].split()[-1].strip("\"'") +README = Path("README.md").read_text() + +requires = ["pycapo", "psycopg2"] + +setup( + name="ssa-" + Path().absolute().name, + version=VERSION, + description="Workspaces database core sampler", + long_description=README, + author="NRAO SSA Team", + author_email="dms-ssa@nrao.edu", + url="TBD", + license="GPL", + install_requires=requires, + keywords=[], + packages=find_packages(), + classifiers=["Programming Language :: Python :: 3.8"], + entry_points={"console_scripts": ["core_sampler = core_sampler.core_sampler:main"]}, +) diff --git a/apps/cli/utilities/core_sampler/test/test_row_writer.py b/apps/cli/utilities/core_sampler/test/test_row_writer.py new file mode 100644 index 000000000..c5c035c98 --- /dev/null +++ b/apps/cli/utilities/core_sampler/test/test_row_writer.py @@ -0,0 +1,103 @@ +import itertools +from typing import List, Dict + +import pytest + +from core_sampler.database import PGRowSet +from core_sampler.interfaces import Table, RowSet, RowWriter, Relationship +from core_sampler.row_writer import PostgresCopyRowWriter, UniquifyingRowWriter, TopologicallySortingRowWriter +from unittest.mock import MagicMock, Mock +import datetime + + +@pytest.fixture +def project_rowset() -> RowSet: + projects = MagicMock(Table) + projects.name = "projects" + + row = dict( + id=0, + name="this", + timestamp=datetime.datetime.fromtimestamp(1626473810.01041), + existential_crisis=None, + ) + + return PGRowSet(projects, [row]) + + +def test_copy_output(capsys, project_rowset: RowSet): + """ + Are we formatting the output properly for PostgreSQL? + """ + rw = PostgresCopyRowWriter() + project_rowset.write_to(rw) + lines = capsys.readouterr().out.strip().split("\n") + assert len(lines) == 3 + assert lines[0] == "COPY projects (id, name, timestamp, existential_crisis) FROM stdin;" + assert lines[1] == "0 this 2021-07-16T16:16:50.010410 \\N" + assert lines[2] == "\\." + + +def test_uniquifier(capsys, project_rowset: RowSet): + fake_writer = MagicMock(RowWriter) + uniq = UniquifyingRowWriter(fake_writer) + + # write the rows twice + project_rowset.write_to(uniq) + project_rowset.write_to(uniq) + + # we should only get called once though + assert fake_writer.write_rows.call_count == 1 + + +def test_topological_sort(capsys): + # The problem here can probably arise with just two tables but I want to show it with three. + # The basic idea is that if A -> B and B -> C and A -> C, then the topological sort is A, B, C. + # But that order must be respected no matter what order we happen to find rows we want to output. + # So to prove it works, we're going to generate a single row in three tables, A, B, and C with + # the relationships as described above, and then try outputting in each permutation, and make + # sure that we get them in ABC order regardless of what order permutation we choose. + + a = MagicMock(Table) + a.name = "a" + a.rowset = PGRowSet(a, [dict(name="a")]) + + b = MagicMock(Table) + b.name = "b" + b.rowset = PGRowSet(b, [dict(name="b")]) + + c = MagicMock(Table) + c.name = "c" + c.rowset = PGRowSet(c, [dict(name="c")]) + + def make_relationship(source, dest): + r = MagicMock(Relationship) + r.source_table = source + r.destination_table = dest + return r + + a.relations = Mock("relationships") + b.relations = Mock("relationships") + c.relations = Mock("relationships") + + a.relations.return_value = [make_relationship(a, b), make_relationship(a, c)] + b.relations.return_value = [make_relationship(b, c)] + c.relations.return_value = [] + + class TableNameCapturer(RowWriter): + def __init__(self): + self.tables = [] + + def write_rows(self, table: Table, rows: List[Dict]): + # we aren't worried about the rows for the purposes of the test + self.tables.append(table.name) + + # for every permutation of these things, I want to see the same output order of a, b, c + for t1, t2, t3 in itertools.permutations([a, b, c]): + capturer = TableNameCapturer() + writer = TopologicallySortingRowWriter(capturer) + writer.write_rows(t1, t1.rowset) + writer.write_rows(t2, t2.rowset) + writer.write_rows(t3, t3.rowset) + writer.close() + assert capturer.tables == ["a", "b", "c"] -- GitLab