diff --git a/src/workflow_app/workflow/states.py b/src/workflow_app/workflow/states.py index b90104c2..44a4ceef 100644 --- a/src/workflow_app/workflow/states.py +++ b/src/workflow_app/workflow/states.py @@ -9,8 +9,12 @@ from .settings import POSTPROCESS_ERROR, CATALOG_DATA_READY from .settings import REDUCTION_DATA_READY, REDUCTION_CATALOG_DATA_READY from .database import transactions + +import importlib +import inspect import json import logging +import re class StateAction: @@ -47,6 +51,30 @@ def _call_default_task(self, headers, message): action_cls = globals()[destination] action_cls(connection=self._send_connection)(headers, message) + def _get_class_from_path(self, class_path: str): + """ + Returns the class given by the class path + :param class_path: the class, e.g. "module_name.ClassName" + :return: class or None + """ + # check that the string is in the format "package_name.module_name.class_name" + pattern = r"^[a-zA-Z0-9_\.]+\.[a-zA-Z0-9_]+$" + if not re.match(pattern, class_path): + logging.error(f"task_class {class_path} does not match pattern module_name.ClassName") + return None + module_name, class_name = class_path.rsplit(".", 1) + + # try importing the class + try: + module = importlib.import_module(module_name) + cls = getattr(module, class_name) + if not inspect.isclass(cls): + raise ValueError + return cls + except (ModuleNotFoundError, AttributeError, ValueError): + logging.error(f"task_class {class_path} cannot be imported") + return None + def _call_db_task(self, task_data, headers, message): """ :param task_data: JSON-encoded task definition @@ -59,14 +87,12 @@ def _call_db_task(self, task_data, headers, message): and (task_def["task_class"] is not None) and len(task_def["task_class"].strip()) > 0 ): - try: - toks = task_def["task_class"].strip().split(".") - module = ".".join(toks[: len(toks) - 1]) - cls = toks[len(toks) - 1] - exec("from %s import %s as action_cls" % (module, cls)) - action_cls(connection=self._send_connection)(headers, message) # noqa: F821 - except: # noqa: E722 - logging.exception("Task [%s] failed:", headers["destination"]) + action_cls = self._get_class_from_path(task_def["task_class"]) + if action_cls: + try: + action_cls(connection=self._send_connection)(headers, message) # noqa: F821 + except: # noqa: E722 + logging.exception("Task [%s] failed:", headers["destination"]) if "task_queues" in task_def: for item in task_def["task_queues"]: destination = "/queue/%s" % item diff --git a/src/workflow_app/workflow/tests/test_states.py b/src/workflow_app/workflow/tests/test_states.py index 899fda85..8f275a0e 100644 --- a/src/workflow_app/workflow/tests/test_states.py +++ b/src/workflow_app/workflow/tests/test_states.py @@ -6,7 +6,20 @@ _ = [workflow] +class FakeTestClass: + def __init__(self, connection): + pass + + def __call__(self, headers, message): + raise ValueError + + class StateActionTest(TestCase): + + @pytest.fixture(autouse=True) + def inject_fixtures(self, caplog): + self.caplog = caplog + def test_call_default_task(self): from workflow.states import StateAction @@ -47,6 +60,57 @@ def test_call(self, mock_get_task): sa(headers, message) assert mock_connection.send.call_count - original_call_count == 2 # one per task queue + @mock.patch("workflow.states.transactions.get_task") + def test_task_class_path(self, mock_get_task): + from workflow.states import StateAction + + mock_connection = mock.Mock() + sa = StateAction(connection=mock_connection, use_db_task=True) + headers = {"destination": "test", "message-id": "test-0"} + message = '{"facility": "SNS", "instrument": "arcs", "ipts": "IPTS-5", "run_number": 3, "data_file": "test"}' + + # test with task class "-" (inserted by Django admin interface when left empty) + mock_get_task.return_value = '{"task_class": "-"}' + self.caplog.clear() + sa(headers, message) + assert "does not match pattern" in self.caplog.text + + # test with task class that does not follow the pattern "module_name.ClassName" + mock_get_task.return_value = '{"task_class": "FakeClass"}' + self.caplog.clear() + sa(headers, message) + assert "does not match pattern" in self.caplog.text + + # test with module that does not exist + mock_get_task.return_value = '{"task_class": "fake_module.FakeClass"}' + self.caplog.clear() + sa(headers, message) + assert "cannot be imported" in self.caplog.text + + # test with module exists but class does not + mock_get_task.return_value = '{"task_class": "workflow.states.FakeClass"}' + self.caplog.clear() + sa(headers, message) + assert "cannot be imported" in self.caplog.text + + # test with module attribute is not a class + mock_get_task.return_value = '{"task_class": "workflow.state_utilities.decode_message"}' + self.caplog.clear() + sa(headers, message) + assert "cannot be imported" in self.caplog.text + + # test with calling class fails + mock_get_task.return_value = '{"task_class": "workflow.tests.test_states.FakeTestClass"}' + self.caplog.clear() + sa(headers, message) + assert "Task [test] failed" in self.caplog.text + + # test with valid class + mock_get_task.return_value = '{"task_class": "workflow.states.Reduction_request"}' + self.caplog.clear() + sa(headers, message) + assert mock_connection.send.call_count == 1 + @mock.patch("workflow.database.transactions.add_status_entry") def test_send(self, mockAddStatusEntry): from workflow.states import StateAction