From d229706659bb79103c85baaab34e8bae109c1856 Mon Sep 17 00:00:00 2001
From: nhertz <nhertz@nrao.edu>
Date: Fri, 11 Sep 2020 10:16:23 -0600
Subject: [PATCH] Small changes to amqp_helpers.py

---
 shared/channels/src/channels/amqp_helpers.py | 72 ++++++++++----------
 1 file changed, 35 insertions(+), 37 deletions(-)

diff --git a/shared/channels/src/channels/amqp_helpers.py b/shared/channels/src/channels/amqp_helpers.py
index 302750ca4..82a554bba 100644
--- a/shared/channels/src/channels/amqp_helpers.py
+++ b/shared/channels/src/channels/amqp_helpers.py
@@ -8,12 +8,6 @@ from pycapo import CapoConfig
 from workspaces.json import WorkflowEventSchema
 from workspaces.schema import WorkflowEvent
 
-
-CONNECTION = None
-CHANNEL = None
-CONFIG = None
-
-
 T = TypeVar('T')
 
 
@@ -56,69 +50,73 @@ class WorkflowEventChannel(ChannelDefinition[WorkflowEvent]):
 
 
 class Channel(Protocol[T]):
+    CONNECTION = None
+    CHANNEL = None
+    CONFIG = None
+
     def __init__(self, definition: ChannelDefinition[T]):
         self.definition = definition
 
-    def connect(self, **kwargs: Union[int, str]) -> pika.BlockingConnection:
+    def connect(self, **kwargs: Union[int, str]):
         """
         Initialize connection to AMQP server given a CAPO profile
-        :return: Established connection
-        """
-        global CONNECTION
-        global CHANNEL
-        global CONFIG
 
-        if not CONNECTION:
-            CONFIG = CapoConfig(
+        Keyword arguments for the AMQP connection. These do not need to be specified:
+        :param: hostname: Hostname to connect to
+        :param: port: Port to connect to
+        :param: connection_attempts: Number of connection attempts to try
+        :param: socket_timeout: Time to wait for a socket to connect
+        :param: retry_delay: Time to wait between retrying the connection
+        :param: username: Username to connect to as
+        :param: password: Password to use when connecting
+        :param: exchange: Exchange to use when connection
+        """
+        if not self.CONNECTION:
+            self.CONFIG = CapoConfig(
                 profile=kwargs.get('profile', None)
             ).settings('edu.nrao.archive.configuration.AmqpServer')
 
             connection_parameters = pika.ConnectionParameters(
-                host=kwargs.get('hostname', CONFIG.hostname),
-                port=int(kwargs.get('port', CONFIG.port)),
-                # FIXME: Copied from events. Do these args need to be cast to int?
-                connection_attempts=kwargs.get('connection_attempts', 5),
-                socket_timeout=kwargs.get('socket_timeout', 5000),
-                retry_delay=kwargs.get('retry_delay', 500),
-                #
+                host=kwargs.get('hostname', self.CONFIG.hostname),
+                port=int(kwargs.get('port', self.CONFIG.port)),
+                connection_attempts=int(kwargs.get('connection_attempts', 5)),
+                socket_timeout=int(kwargs.get('socket_timeout', 5000)),
+                retry_delay=int(kwargs.get('retry_delay', 500)),
                 credentials=pika.PlainCredentials(
-                    username=kwargs.get('username', CONFIG.username),
-                    password=kwargs.get('password', CONFIG.password)
+                    username=kwargs.get('username', self.CONFIG.username),
+                    password=kwargs.get('password', self.CONFIG.password)
                 )
             )
-            CONNECTION = pika.BlockingConnection(connection_parameters)
-            CHANNEL = CONNECTION.channel()
-            self.definition.declarations(CHANNEL)
+            self.CONNECTION = pika.BlockingConnection(connection_parameters)
+            self.CHANNEL = self.CONNECTION.channel()
+            self.definition.declarations(self.CHANNEL)
 
     def close(self):
-        if CONNECTION:
-            CONNECTION.close()
+        if self.CONNECTION:
+            self.CONNECTION.close()
 
     def send(self, event: WorkflowEventSchema):
         rendered = self.definition.schema.dump(event)
         routing_key = self.definition.routing_key_for(event)
-        CHANNEL.basic_publish(routing_key=routing_key, body=rendered)
+        self.CHANNEL.basic_publish(routing_key=routing_key, body=rendered)
 
-    def listen(self, callback: Optional[Callable], pattern: str = '#', **kwargs: Union[str, bool]):
+    def listen(self, callback: Optional[Callable], pattern: str = '#', auto_ack: bool = False):
         """
         Establishes queue and binds it to a given channel and consumes messages matching the
         routing key from given exchange
         :param callback: Callback function for when a message is consumed
         :param pattern: Pattern to be used as routing key
         Optional keyword arguments
-        :param exchange: AMQP exchange name
         :param auto_ack: If true, consumer automatically acknowledges when it consumes a message
         """
-        auto_ack = kwargs.get('auto_ack', False)
-
         def unwrapping_callback(message):
             event = self.definition.schema.load(message)
             callback(event)
 
-        queue = CHANNEL.queue_declare(queue='', exclusive=True).method.queue
-        CHANNEL.queue_bind(queue=queue, exchange=self.definition.exchange(), routing_key=pattern)
-        CHANNEL.basic_consume(queue=queue, on_message_callback=callback, auto_ack=auto_ack)
-        CHANNEL.start_consuming()
+        queue = self.CHANNEL.queue_declare(queue='', exclusive=True).method.queue
+        self.CHANNEL.queue_bind(queue=queue, exchange=self.definition.exchange(), routing_key=pattern)
+        self.CHANNEL.basic_consume(queue=queue, on_message_callback=callback, auto_ack=auto_ack)
+        self.CHANNEL.start_consuming()
 
     def __enter__(self, **kwargs: Union[int, str]):
         """
-- 
GitLab