From 765cba4cb7449255d4b20bfa4b84758b91f0e6d4 Mon Sep 17 00:00:00 2001
From: Daniel K Lyons <dlyons@nrao.edu>
Date: Wed, 21 Jul 2021 15:53:25 -0600
Subject: [PATCH] Cleaning up the state machine example a bit

---
 .../workspaces/capability/statemachine.py     | 58 ++++++++++++++-----
 1 file changed, 42 insertions(+), 16 deletions(-)

diff --git a/shared/workspaces/workspaces/capability/statemachine.py b/shared/workspaces/workspaces/capability/statemachine.py
index cbc5f0e82..0fb063aca 100644
--- a/shared/workspaces/workspaces/capability/statemachine.py
+++ b/shared/workspaces/workspaces/capability/statemachine.py
@@ -6,21 +6,35 @@ reacts to certain events by triggering actions and going into another state.
 """
 import abc
 import json
+from typing import Optional
 
 
 class State(abc.ABC):
     """
     A state that a machine could reside in.
+
+    A state has a suite of transitions to other states. When an event comes in, we match against it; if we find a
+    matching pattern, we perform that transition to another state.
     """
-    @abc.abstractmethod
-    def matches(self, other: "State") -> bool:
-        """
-        This is most likely implemented by doing a string-equality test.
 
-        :param other:  the other state to compare to
-        :return:  true if we and the other state match
-        """
-        pass
+    def __init__(self, transitions: list["TransitionIF"]):
+        # We have a bit of a chicken-and-egg problem here, in that the State needs Transitions to be initialized but
+        # the Transition needs States to be initialized. Going from prototype to production here will mean breaking
+        # this cycle, possibly by introducing a builder of some kind, but for now we can just pretend that they are
+        # built successfully somehow.
+        self.transitions = transitions
+
+    def on_event(self, event: dict) -> Optional["State"]:
+        # Locate the first matching transition
+        matching_transition = None
+        for transition in self.transitions:
+            if transition.matches(event):
+                matching_transition = transition
+                break
+
+        # take this transition
+        if matching_transition is not None:
+            return matching_transition.take()
 
 
 class Action(abc.ABC):
@@ -34,6 +48,7 @@ class Action(abc.ABC):
     - StartWorkflow(workflow_name, additional_args) that starts a workflow with the
       provided name, the event and additional arguments
     """
+
     @abc.abstractmethod
     def execute(self):
         pass
@@ -56,20 +71,21 @@ class TransitionIF(abc.ABC):
     """
     A transition between states
     """
+
     def __init__(self, from_state: State, to_state: State, pattern: Pattern, action: Action):
         self.from_state, self.to_state = from_state, to_state
         self.pattern = pattern
         self.action = action
 
     @abc.abstractmethod
-    def matches(self, state: State, event: dict) -> bool:
+    def matches(self, event: dict) -> bool:
         """
         True if this transition is applicable in the supplied state and matches the supplied event.
         :param state: state to check against
         :param event: event to match against
         :return: true if everything matches
         """
-        return self.from_state.matches(state) and self.pattern.matches(event)
+        return self.pattern.matches(event)
 
     @abc.abstractmethod
     def take(self) -> State:
@@ -88,6 +104,7 @@ class MealyMachine:
     I am a state machine for a given capability. I am responsible for handling events
     and transitioning to other states.
     """
+
     def __init__(self):
         self.transitions = []
         self.current_state: State = None
@@ -122,6 +139,7 @@ class CapabilityInfoForMachines:
     This is a demonstration of the sort of query I expect we'll use to locate executions
     that are active and need to be acted on in response to an event of some kind.
     """
+
     def find_requests_matching_transition(self, event: dict) -> list["CapabilityExecution"]:
         """
         The concept here is to let the database do the heavy lifting and actually tell us
@@ -131,7 +149,8 @@ class CapabilityInfoForMachines:
         :param event:  the event to check
         :return:       a list of matching capability executions
         """
-        return self.session.query("""
+        return self.session.query(
+            """
         SELECT * 
         FROM transitions t
         JOIN machines m ON t.machine_id = m.id
@@ -139,30 +158,36 @@ class CapabilityInfoForMachines:
         JOIN capability_requests cr on cr.capability_name = c.name
         JOIN capability_executions ce on cr.capability_request_id = ce.capability_request_id
         WHERE %(event)s @? t.pattern AND ce.state = t.from_state
-        """, {"event": json.dumps(event)})
+        """,
+            {"event": json.dumps(event)},
+        )
 
     def build_tables(self):
         """
         This is just a demonstration method to hold some SQL to demo the tables I have
         in mind for this system.
         """
-        self.session.execute("""
+        self.session.execute(
+            """
         CREATE TABLE machines(id serial primary key);
         CREATE TABLE actions(id serial primary key, action_type varchar, action_arguments json);
         
         CREATE TABLE transitions (
             id serial primary key, 
-            machine_id integer references(machines),
+            machine_id integer references machines(id),
             from_state varchar, 
             to_state varchar, 
             pattern jsonpath, 
-            action_id integer references(actions)
+            action_id integer references actions(id)
         );
          
-        """)
+        """
+        )
 
 
 class CapabilityExecution:
+    machine: MealyMachine = None
+
     def process(self, event):
         self.machine.on_event(event)
 
@@ -173,6 +198,7 @@ class CapabilityServiceMachineMananger:
     machines alive during the execution of the program. The idea here is to be more
     efficient and more event-driven.
     """
+
     def __init__(self):
         self.info = CapabilityInfoForMachines()
 
-- 
GitLab