Skip to content
Snippets Groups Projects
Commit ff80afec authored by Nathan Hertz's avatar Nathan Hertz
Browse files

Refactored threaded channel/connection to be more elegant

parent 7698be48
No related branches found
No related tags found
No related merge requests found
......@@ -76,10 +76,14 @@ ChannelDef = TypeVar("ChannelDef", bound=ChannelDefinition, covariant=True)
class Channel(Generic[ChannelDef]):
def __init__(self, definition: ChannelDefinition[T]):
def __init__(self, definition: ChannelDefinition[T], threaded: bool = False):
self.definition: ChannelDefinition[T] = definition
self.chan: BlockingChannel = None
self.config: CapoConfig = None
self.threaded = threaded
if self.threaded is True:
# Define connection instance variable
self.connection = None
def connect(self, **kwargs: Union[int, str]):
"""
......@@ -97,7 +101,7 @@ class Channel(Generic[ChannelDef]):
"""
global CONN
if not CONN:
if not CONN or self.threaded is True:
self.config = CapoConfig(profile=kwargs.get("profile", None)).settings(
"edu.nrao.archive.configuration.AmqpServer"
)
......@@ -113,45 +117,14 @@ class Channel(Generic[ChannelDef]):
password=kwargs.get("password", self.config.password),
),
)
CONN = pika.BlockingConnection(connection_parameters)
self.chan = CONN.channel()
connection = pika.BlockingConnection(connection_parameters)
self.chan = connection.channel()
self.definition.declarations(self.chan)
@classmethod
def threaded_connect(cls, definition: ChannelDefinition[T], **kwargs: Union[int, str]):
"""
Initialize connection to AMQP server given a CAPO profile
:param definition: Channel definition
Keyword arguments for the AMQP connection. These do not need to be specified:
:param: hostname: Hostname to connect to
:param: port: Port to connect to
:param: connection_attempts: Number of connection attempts to try
:param: socket_timeout: Time to wait for a socket to connect
:param: retry_delay: Time to wait between retrying the connection
:param: username: Username to connect to as
:param: password: Password to use when connecting
:param: exchange: Exchange to use when connection
"""
self = cls(definition)
self.config = CapoConfig(profile=kwargs.get("profile", None)).settings(
"edu.nrao.archive.configuration.AmqpServer"
)
connection_parameters = pika.ConnectionParameters(
host=kwargs.get("hostname", self.config.hostname),
port=int(kwargs.get("port", self.config.port)),
connection_attempts=int(kwargs.get("connection_attempts", 5)),
socket_timeout=int(kwargs.get("socket_timeout", 5000)),
retry_delay=int(kwargs.get("retry_delay", 500)),
credentials=pika.PlainCredentials(
username=kwargs.get("username", self.config.username),
password=kwargs.get("password", self.config.password),
),
)
connection = pika.BlockingConnection(connection_parameters)
self.chan = connection.channel()
self.definition.declarations(self.chan)
return self
if self.threaded is False:
CONN = connection
else:
self.connection = connection
def close(self):
"""
......@@ -176,7 +149,7 @@ class Channel(Generic[ChannelDef]):
)
def listen(
self, callback: Optional[Callable], pattern: str = "#", auto_ack: bool = False, threaded: bool = False
self, callback: Optional[Callable], pattern: str = "#", auto_ack: bool = False
):
"""
Establishes queue and binds it to a given channel and consumes messages matching the
......@@ -192,8 +165,7 @@ class Channel(Generic[ChannelDef]):
event = self.definition.schema().loads(message)
callback(event)
if not threaded:
self.connect()
self.connect()
queue = self.chan.queue_declare(queue="", exclusive=True).method.queue
self.chan.queue_bind(
queue=queue, exchange=self.definition.exchange(), routing_key=pattern
......
......@@ -213,8 +213,8 @@ class CapabilityService(CapabilityServiceIF):
to update capability executions
:return:
"""
thread_workflow_events = Channel.threaded_connect(WorkflowEventChannel())
thread_workflow_events.listen(callback=self.update_execution, threaded=True)
thread_workflow_events = Channel(WorkflowEventChannel(), threaded=True)
thread_workflow_events.listen(callback=self.update_execution)
class CapabilityInfo(CapabilityInfoIF):
......@@ -504,7 +504,7 @@ class WorkflowService(WorkflowServiceIF):
# Start listening for events from the wf_monitor stream
self.listener = threading.Thread(
target=lambda: Channel.threaded_connect(WorkflowEventChannel()).listen(self.on_workflow_event, threaded=True)
target=lambda: Channel(WorkflowEventChannel(), threaded=True).listen(self.on_workflow_event)
)
self.listener.start()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment