From 0dc165caaa12f1dc02b035c7ccec70c74f529c8e Mon Sep 17 00:00:00 2001
From: nhertz <nhertz@nrao.edu>
Date: Mon, 15 Nov 2021 17:02:05 -0700
Subject: [PATCH] Added new method `determine_state` to the `CapabilityRequest`
 class that will determine the state of the request based on the states of its
 versions; replaced instances of setting request state with new method

---
 .../test/test_capability_request.py           | 230 ++++++++++++++++++
 .../workspaces/capability/schema.py           |  24 ++
 .../capability/schema_interfaces.py           |   7 +-
 .../capability/services/capability_info.py    |   5 +-
 .../capability/services/capability_service.py |   6 +-
 5 files changed, 265 insertions(+), 7 deletions(-)
 create mode 100644 shared/workspaces/test/test_capability_request.py

diff --git a/shared/workspaces/test/test_capability_request.py b/shared/workspaces/test/test_capability_request.py
new file mode 100644
index 000000000..191e7312b
--- /dev/null
+++ b/shared/workspaces/test/test_capability_request.py
@@ -0,0 +1,230 @@
+#
+# Copyright (C) 2021 Associated Universities, Inc. Washington DC, USA.
+#
+# This file is part of NRAO Workspaces.
+#
+# Workspaces is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Workspaces is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Workspaces.  If not, see <https://www.gnu.org/licenses/>.
+
+import random
+from typing import Dict
+
+from hypothesis import given
+from hypothesis import strategies as st
+
+from workspaces.capability.enums import CapabilityRequestState, CapabilityVersionState
+from workspaces.capability.schema import CapabilityRequest, CapabilityVersion
+from workspaces.capability.services.capability_info import CapabilityInfo
+
+from .conftest import (
+    SAMPLE_CAPABILITY_NAMES,
+    SQLITE_MAX_INT,
+    SQLITE_MIN_INT,
+    clear_test_database,
+)
+
+pytest_plugins = ["testing.utils.conftest"]
+
+
+# Register CapabilityRequest JSON blueprint as a hypothesis type strategy
+# To use:
+# >>> @given(st.from_type(CapabilityRequest))
+# >>> def test(generated_request_json: Dict):
+# >>>   request = CapabilityRequest.from_json(generated_request_json)
+st.register_type_strategy(
+    CapabilityRequest,
+    st.fixed_dictionaries(
+        {
+            "type": st.just("CapabilityRequest"),
+            "id": st.integers(min_value=SQLITE_MIN_INT, max_value=SQLITE_MAX_INT),
+            "capability_name": st.sampled_from(SAMPLE_CAPABILITY_NAMES),
+            "state": st.sampled_from([name for name, _ in CapabilityRequestState.__members__.items()]),
+            "parameters": st.one_of(st.lists(st.text()), st.none()),
+            "ingested": st.booleans(),
+            "created_at": st.datetimes().map(
+                lambda time: time.isoformat(),
+            ),
+            "updated_at": st.datetimes().map(
+                lambda time: time.isoformat(),
+            ),
+            "current_execution": st.none(),
+        }
+    ),
+)
+
+# Register CapabilityVersion JSON blueprint as a hypothesis type strategy
+# To use:
+# >>> @given(st.from_type(CapabilityVersion))
+# >>> def test(generated_version_json: Dict):
+# >>>   version = CapabilityVersion.from_json(generated_version_json)
+st.register_type_strategy(
+    CapabilityVersion,
+    st.fixed_dictionaries(
+        {
+            "type": st.just("CapabilityVersion"),
+            "capability_request_id": st.integers(min_value=SQLITE_MIN_INT, max_value=SQLITE_MAX_INT),
+            "version_number": st.integers(min_value=0, max_value=50),
+            "parameters": st.one_of(st.lists(st.text()), st.none()),
+            "workflow_metadata": st.one_of(
+                st.none(), st.dictionaries(st.text(min_size=1), st.text(min_size=1), min_size=1)
+            ),
+            "files": st.none(),
+        }
+    ),
+)
+
+
+@given(
+    st.from_type(CapabilityRequest),
+    st.lists(
+        st.from_type(CapabilityVersion), min_size=1, max_size=5, unique_by=lambda version: version["version_number"]
+    ),
+)
+def test_determine_state_complete(mock_capability_info: CapabilityInfo, request_json: Dict, list_of_version_json: Dict):
+    """
+    Given: A capability request with multiple versions
+    When: QA is passed for the current version
+        And: The current version status is set to Complete
+        And: All other versions' statuses are set to Failed
+    Then: The request state should be set to Complete
+    """
+    request = CapabilityRequest.from_json(request_json)
+    versions = [CapabilityVersion.from_json(blob) for blob in list_of_version_json]
+    version_to_pass = random.choice(versions)
+
+    with clear_test_database(mock_capability_info):
+        mock_capability_info.save_entity(request)
+
+        for version in versions:
+            # Associate version with request
+            version.capability_request = request
+            version.capability_request_id = request.id
+
+            # Set one version to Complete, the rest to Failed or Executing
+            if version.version_number == version_to_pass.version_number:
+                print("Setting passing version to Complete")
+                version.state = CapabilityVersionState.Complete.name
+            else:
+                print("Setting other version")
+                version.state = (
+                    CapabilityVersionState.Failed.name
+                    if version.state == CapabilityVersionState.Complete.name or not version.state
+                    else version.state
+                )
+
+            mock_capability_info.save_entity(version)
+
+        request.determine_state()
+        assert request.state == CapabilityRequestState.Complete.name
+
+
+@given(
+    st.from_type(CapabilityRequest),
+    st.lists(
+        st.from_type(CapabilityVersion), min_size=1, max_size=5, unique_by=lambda version: version["version_number"]
+    ),
+)
+def test_determine_state_submitted(
+    mock_capability_info: CapabilityInfo, request_json: Dict, list_of_version_json: Dict
+):
+    """
+    Given: A capability request with multiple versions
+    When: There are no versions in the Complete state
+        And: There is a mix of versions in either the Created, Submitted, or Failed states
+    Then: The request state is set to Executing
+    """
+    request = CapabilityRequest.from_json(request_json)
+    versions = [CapabilityVersion.from_json(blob) for blob in list_of_version_json]
+
+    with clear_test_database(mock_capability_info):
+        mock_capability_info.save_entity(request)
+
+        for version in versions:
+            # Associate version with request
+            version.capability_request = request
+            version.capability_request_id = request.id
+
+            # Get rid of Complete versions (set their state to Executing instead)
+            version.state = (
+                CapabilityVersionState.Running.name
+                if version.state == CapabilityVersionState.Complete.name or not version.state
+                else version.state
+            )
+
+            mock_capability_info.save_entity(version)
+
+        request.determine_state()
+        assert request.state == CapabilityRequestState.Submitted.name
+
+
+@given(
+    st.from_type(CapabilityRequest),
+    st.lists(
+        st.from_type(CapabilityVersion), min_size=1, max_size=5, unique_by=lambda version: version["version_number"]
+    ),
+)
+def test_determine_state_failed(mock_capability_info: CapabilityInfo, request_json: Dict, list_of_version_json: Dict):
+    """
+    Given: A capability request with multiple versions
+    When: All versions for the request are in the Failed state
+    Then: The request state is set to Failed
+    """
+    request = CapabilityRequest.from_json(request_json)
+    versions = [CapabilityVersion.from_json(blob) for blob in list_of_version_json]
+
+    with clear_test_database(mock_capability_info):
+        mock_capability_info.save_entity(request)
+
+        for version in versions:
+            # Associate version with request
+            version.capability_request = request
+            version.capability_request_id = request.id
+
+            # Set all versions to Failed state
+            version.state = CapabilityVersionState.Failed.name
+
+            mock_capability_info.save_entity(version)
+
+        request.determine_state()
+        assert request.state == CapabilityRequestState.Failed.name
+
+
+@given(
+    st.from_type(CapabilityRequest),
+    st.lists(
+        st.from_type(CapabilityVersion), min_size=1, max_size=5, unique_by=lambda version: version["version_number"]
+    ),
+)
+def test_determine_state_created(mock_capability_info: CapabilityInfo, request_json: Dict, list_of_version_json: Dict):
+    """
+    Given: A capability request with multiple versions
+    When: All versions for the request are in the Created state
+    Then: The request state is set to Created
+    """
+    request = CapabilityRequest.from_json(request_json)
+    versions = [CapabilityVersion.from_json(blob) for blob in list_of_version_json]
+
+    with clear_test_database(mock_capability_info):
+        mock_capability_info.save_entity(request)
+        for version in versions:
+            # Associate version with request
+            version.capability_request = request
+            version.capability_request_id = request.id
+
+            # Set all versions to Failed state
+            version.state = CapabilityVersionState.Created.name
+
+            mock_capability_info.save_entity(version)
+
+        request.determine_state()
+        assert request.state == CapabilityRequestState.Created.name
diff --git a/shared/workspaces/workspaces/capability/schema.py b/shared/workspaces/workspaces/capability/schema.py
index 90f9f9af2..7d9032dce 100644
--- a/shared/workspaces/workspaces/capability/schema.py
+++ b/shared/workspaces/workspaces/capability/schema.py
@@ -440,6 +440,30 @@ class CapabilityRequest(Base, CapabilityRequestIF):
     def update_state(self, state: CapabilityRequestState):
         self.state = state.name
 
+    def determine_state(self):
+        """
+        Determine state of request based on the state of its versions and set it accordingly
+
+        RULES:
+        - If there is a complete version, the request is complete
+        - If all versions are failed, the request is failed
+        - If all versions are created, the request is created
+        - Otherwise, it is submitted
+        """
+        version_states = [version.state for version in self.versions]
+
+        if any(state == CapabilityVersionState.Complete.name for state in version_states):
+            # Request has a complete version, so it is complete
+            self.state = CapabilityRequestState.Complete.name
+        elif all(state == CapabilityRequestState.Failed.name for state in version_states):
+            # Request has all failed versions, so it is failed
+            self.state = CapabilityRequestState.Failed.name
+        elif all(state == CapabilityRequestState.Created.name for state in version_states):
+            # Request has no submitted versions, so it is still in the created state
+            self.state = CapabilityRequestState.Created.name
+        else:
+            self.state = CapabilityRequestState.Submitted.name
+
     def __str__(self):
         return f"CapabilityRequest object: {self.__dict__}"
 
diff --git a/shared/workspaces/workspaces/capability/schema_interfaces.py b/shared/workspaces/workspaces/capability/schema_interfaces.py
index 787f037fd..f595174d8 100644
--- a/shared/workspaces/workspaces/capability/schema_interfaces.py
+++ b/shared/workspaces/workspaces/capability/schema_interfaces.py
@@ -20,12 +20,12 @@ from __future__ import annotations
 import pathlib
 from typing import Dict, List
 
-# pylint: disable=C0114, C0115, C0116, R0903
-
 from workspaces.capability.helpers_interfaces import ParameterIF
 from workspaces.products.schema_interfaces import FutureProductIF
 from workspaces.system.schema import JSONSerializable
 
+# pylint: disable=C0114, C0115, C0116, R0903
+
 
 class CapabilityIF(JSONSerializable):
     id: int
@@ -48,6 +48,9 @@ class CapabilityRequestIF(JSONSerializable):
     def current_version(self):
         raise NotImplementedError
 
+    def determine_state(self):
+        pass
+
 
 class CapabilityVersionIF:
     capability_request: CapabilityRequestIF
diff --git a/shared/workspaces/workspaces/capability/services/capability_info.py b/shared/workspaces/workspaces/capability/services/capability_info.py
index e73b061e5..2aea8a830 100644
--- a/shared/workspaces/workspaces/capability/services/capability_info.py
+++ b/shared/workspaces/workspaces/capability/services/capability_info.py
@@ -178,8 +178,9 @@ class CapabilityInfo(CapabilityInfoIF):
         :return: new CapabilityVersion
         """
         request = self.lookup_capability_request(capability_request_id)
-        # Reset request state to Created
-        request.state = CapabilityRequestState.Created.name
+        # Reset request state accordingly
+
+        request.determine_state()
         self.save_entity(request)
         logger.info(f"Parent Request: {request.__json__()}")
 
diff --git a/shared/workspaces/workspaces/capability/services/capability_service.py b/shared/workspaces/workspaces/capability/services/capability_service.py
index 5b24a69a0..00216122b 100644
--- a/shared/workspaces/workspaces/capability/services/capability_service.py
+++ b/shared/workspaces/workspaces/capability/services/capability_service.py
@@ -22,7 +22,7 @@ import transaction
 from messaging.messenger import MessageSender
 from messaging.router import Router, on_message
 
-from workspaces.capability.enums import CapabilityRequestState, CapabilityVersionState
+from workspaces.capability.enums import CapabilityVersionState
 from workspaces.capability.helpers import Parameter
 from workspaces.capability.message_architect import CapabilityMessageArchitect
 from workspaces.capability.schema import CapabilityRequest
@@ -73,7 +73,7 @@ class CapabilityService(CapabilityServiceIF):
         # Set request state to Complete
         execution = message["subject"]
         capability_request = self.capability_info.lookup_capability_request(execution["capability_request_id"])
-        capability_request.state = CapabilityRequestState.Complete.name
+        capability_request.determine_state()
         self.capability_info.save_entity(capability_request)
 
         # Set version state to Complete
@@ -95,7 +95,7 @@ class CapabilityService(CapabilityServiceIF):
         # Set request state to Failed
         # TODO(nhertz): Dynamically calculate request state based on the states of its versions
         capability_request = self.capability_info.lookup_capability_request(execution["capability_request_id"])
-        capability_request.state = CapabilityRequestState.Failed.name
+        capability_request.determine_state()
         self.capability_info.save_entity(capability_request)
 
         # Set version state to Failed
-- 
GitLab