Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increase task_class input validation and log messages #191

Merged
merged 1 commit into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 34 additions & 8 deletions src/workflow_app/workflow/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
from .settings import POSTPROCESS_ERROR, CATALOG_DATA_READY
from .settings import REDUCTION_DATA_READY, REDUCTION_CATALOG_DATA_READY
from .database import transactions

import importlib
import inspect
import json
import logging
import re


class StateAction:
Expand Down Expand Up @@ -47,6 +51,30 @@ def _call_default_task(self, headers, message):
action_cls = globals()[destination]
action_cls(connection=self._send_connection)(headers, message)

def _get_class_from_path(self, class_path: str):
"""
Returns the class given by the class path
:param class_path: the class, e.g. "module_name.ClassName"
:return: class or None
"""
# check that the string is in the format "package_name.module_name.class_name"
pattern = r"^[a-zA-Z0-9_\.]+\.[a-zA-Z0-9_]+$"
if not re.match(pattern, class_path):
logging.error(f"task_class {class_path} does not match pattern module_name.ClassName")
return None
module_name, class_name = class_path.rsplit(".", 1)

# try importing the class
try:
module = importlib.import_module(module_name)
cls = getattr(module, class_name)
if not inspect.isclass(cls):
raise ValueError
return cls
except (ModuleNotFoundError, AttributeError, ValueError):
logging.error(f"task_class {class_path} cannot be imported")
return None

def _call_db_task(self, task_data, headers, message):
"""
:param task_data: JSON-encoded task definition
Expand All @@ -59,14 +87,12 @@ def _call_db_task(self, task_data, headers, message):
and (task_def["task_class"] is not None)
and len(task_def["task_class"].strip()) > 0
):
try:
toks = task_def["task_class"].strip().split(".")
module = ".".join(toks[: len(toks) - 1])
cls = toks[len(toks) - 1]
exec("from %s import %s as action_cls" % (module, cls))
action_cls(connection=self._send_connection)(headers, message) # noqa: F821
except: # noqa: E722
logging.exception("Task [%s] failed:", headers["destination"])
action_cls = self._get_class_from_path(task_def["task_class"])
if action_cls:
try:
action_cls(connection=self._send_connection)(headers, message) # noqa: F821
except: # noqa: E722
logging.exception("Task [%s] failed:", headers["destination"])
if "task_queues" in task_def:
for item in task_def["task_queues"]:
destination = "/queue/%s" % item
Expand Down
64 changes: 64 additions & 0 deletions src/workflow_app/workflow/tests/test_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,20 @@
_ = [workflow]


class FakeTestClass:
def __init__(self, connection):
pass

def __call__(self, headers, message):
raise ValueError


class StateActionTest(TestCase):

@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self.caplog = caplog

def test_call_default_task(self):
from workflow.states import StateAction

Expand Down Expand Up @@ -47,6 +60,57 @@ def test_call(self, mock_get_task):
sa(headers, message)
assert mock_connection.send.call_count - original_call_count == 2 # one per task queue

@mock.patch("workflow.states.transactions.get_task")
def test_task_class_path(self, mock_get_task):
from workflow.states import StateAction

mock_connection = mock.Mock()
sa = StateAction(connection=mock_connection, use_db_task=True)
headers = {"destination": "test", "message-id": "test-0"}
message = '{"facility": "SNS", "instrument": "arcs", "ipts": "IPTS-5", "run_number": 3, "data_file": "test"}'

# test with task class "-" (inserted by Django admin interface when left empty)
mock_get_task.return_value = '{"task_class": "-"}'
self.caplog.clear()
sa(headers, message)
assert "does not match pattern" in self.caplog.text

# test with task class that does not follow the pattern "module_name.ClassName"
mock_get_task.return_value = '{"task_class": "FakeClass"}'
self.caplog.clear()
sa(headers, message)
assert "does not match pattern" in self.caplog.text

# test with module that does not exist
mock_get_task.return_value = '{"task_class": "fake_module.FakeClass"}'
self.caplog.clear()
sa(headers, message)
assert "cannot be imported" in self.caplog.text

# test with module exists but class does not
mock_get_task.return_value = '{"task_class": "workflow.states.FakeClass"}'
self.caplog.clear()
sa(headers, message)
assert "cannot be imported" in self.caplog.text

# test with module attribute is not a class
mock_get_task.return_value = '{"task_class": "workflow.state_utilities.decode_message"}'
self.caplog.clear()
sa(headers, message)
assert "cannot be imported" in self.caplog.text

# test with calling class fails
mock_get_task.return_value = '{"task_class": "workflow.tests.test_states.FakeTestClass"}'
self.caplog.clear()
sa(headers, message)
assert "Task [test] failed" in self.caplog.text

# test with valid class
mock_get_task.return_value = '{"task_class": "workflow.states.Reduction_request"}'
self.caplog.clear()
sa(headers, message)
assert mock_connection.send.call_count == 1

@mock.patch("workflow.database.transactions.add_status_entry")
def test_send(self, mockAddStatusEntry):
from workflow.states import StateAction
Expand Down