From f4c9f537d53a2730284ac58faf1eb3c506666962 Mon Sep 17 00:00:00 2001
From: Daniel K Lyons <dlyons@nrao.edu>
Date: Wed, 14 Jul 2021 15:24:21 -0600
Subject: [PATCH] Now writing output that can be loaded. Just two problems
 remain.

1. Output is in the wrong order. May need to topologically sort the dependencies or something.
2. Same data gets rendered more than once.
---
 .../core_sampler/core_sampler/core_sampler.py | 28 +++++++++++++++----
 1 file changed, 22 insertions(+), 6 deletions(-)

diff --git a/apps/cli/utilities/core_sampler/core_sampler/core_sampler.py b/apps/cli/utilities/core_sampler/core_sampler/core_sampler.py
index f35137ed2..9ee27d54a 100644
--- a/apps/cli/utilities/core_sampler/core_sampler/core_sampler.py
+++ b/apps/cli/utilities/core_sampler/core_sampler/core_sampler.py
@@ -9,6 +9,7 @@ The core sampler outputs an SQL file you can use to load the core sample into a
 
 import argparse
 from typing import Any, Iterable
+import datetime
 
 import psycopg2 as pg
 import psycopg2.extras as extras
@@ -54,7 +55,7 @@ class MDDBConnector:
 
 
 class RowWriter:
-    def write_row(self, table: str, row: dict):
+    def write_rows(self, table: str, row: dict):
         raise NotImplementedError
 
 
@@ -106,8 +107,24 @@ class CoreSampler(RowWriter):
         """
         rows.write_to(self)
 
-    def write_row(self, table: str, row: dict):
-        pass  # print((table, row))
+    def write_rows(self, table: str, rows: list[dict]):
+        columns = rows[0].keys()
+        print(f"COPY {table} ({', '.join(columns)}) FROM stdin;")
+        for row in rows:
+            print("\t".join([self.copy_format(row[col]) for col in columns]))
+        print("\.")
+
+    def copy_format(self, value):
+        if value == None:
+            return "\\N"
+        elif isinstance(value, str):
+            return value
+        elif isinstance(value, int) or isinstance(value, float):
+            return str(value)
+        elif isinstance(value, datetime.date):
+            return value.isoformat()
+        else:
+            raise TypeError(f"Unable to figure out what to do with {value} of type {type(value)}")
 
 
 class Table:
@@ -163,7 +180,7 @@ class Table:
         :param columns: columns to consider in the generated WHERE clause
         :return: RowSet for the rows in this table
         """
-
+        # print(f"Fetching from {self.name} according to {','.join(columns)}")
         # 1. Escape the WHERE clause entries. it's important to use the primary keys
         #    to retrieve the values from the previous resultset; the new columns will
         #    appear in the query below.
@@ -243,8 +260,7 @@ class RowSet:
         return self.table.relations()
 
     def write_to(self, writer: RowWriter):
-        for row in self.rows:
-            writer.write_row(self.table.name, row)
+        writer.write_rows(self.table.name, self.rows)
 
     def __iter__(self):
         return iter(self.rows)
-- 
GitLab