diff --git a/shared/channels/src/channels/amqp_helpers.py b/shared/channels/src/channels/amqp_helpers.py index 82a554bbaff80dc96be0192492df2cec8e921d48..b075ac372cf7c30bc0ab23d76efbe43bad654665 100644 --- a/shared/channels/src/channels/amqp_helpers.py +++ b/shared/channels/src/channels/amqp_helpers.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Callable, Any, Optional, Union, TypeVar, Protocol +from typing import Callable, Any, Optional, Union, TypeVar, Protocol, Generic import pika from marshmallow import Schema @@ -8,10 +8,10 @@ from pycapo import CapoConfig from workspaces.json import WorkflowEventSchema from workspaces.schema import WorkflowEvent -T = TypeVar('T') +T = TypeVar('T', contravariant=True) -class ChannelDefinition(ABC, Protocol[T]): +class ChannelDefinition(Protocol[T]): @abstractmethod def routing_key_for(self, message: T) -> str: pass @@ -49,10 +49,13 @@ class WorkflowEventChannel(ChannelDefinition[WorkflowEvent]): return self.EXCHANGE -class Channel(Protocol[T]): - CONNECTION = None - CHANNEL = None - CONFIG = None +ChannelDef = TypeVar('ChannelDef', bound=ChannelDefinition, covariant=True) + + +class Channel(Generic[ChannelDef]): + CONNECTION : pika.BlockingConnection = None + CHANNEL : BlockingChannel = None + CONFIG : CapoConfig = None def __init__(self, definition: ChannelDefinition[T]): self.definition = definition @@ -137,4 +140,4 @@ class Channel(Protocol[T]): # Predefined event channels for ease of use -workflow_events = Channel(WorkflowEventChannel()) +workflow_events : Channel[WorkflowEvent] = Channel(WorkflowEventChannel())