Skip to content

Commit

Permalink
Update execute_task signature (#1719)
Browse files Browse the repository at this point in the history
`run_dir` is not optional; in a previous commit, it was enforced by an inline
check.  Update the function signature to match that requirement.

While here, test the addition, and also bolster the module's tests in general.
  • Loading branch information
khk-globus committed Nov 14, 2024
1 parent 6f95cb2 commit 55e0759
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 54 deletions.
39 changes: 24 additions & 15 deletions compute_endpoint/globus_compute_endpoint/engines/helper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import logging
import os
import pathlib
import time
import typing as t
import uuid

from globus_compute_common import messagepack
Expand All @@ -21,14 +23,16 @@
log = logging.getLogger(__name__)

_serde = ComputeSerializer()
_RESULT_SIZE_LIMIT = 10 * 1024 * 1024 # 10 MiB


def execute_task(
task_id: uuid.UUID,
task_body: bytes,
endpoint_id: t.Optional[uuid.UUID],
result_size_limit: int = 10 * 1024 * 1024,
run_dir: t.Optional[t.Union[str, os.PathLike]] = None,
endpoint_id: uuid.UUID | None = None,
*,
run_dir: str | os.PathLike,
result_size_limit: int = _RESULT_SIZE_LIMIT,
run_in_sandbox: bool = False,
) -> bytes:
"""Execute task is designed to enable any executor to execute a Task payload
Expand All @@ -38,7 +42,7 @@ def execute_task(
----------
task_id: uuid string
task_body: packed message as bytes
endpoint_id: uuid string or None
endpoint_id: uuid.UUID or None
result_size_limit: result size in bytes
run_dir: directory to run function in
run_in_sandbox: if enabled run task under run_dir/<task_uuid>
Expand All @@ -56,21 +60,26 @@ def execute_task(
uuid.UUID | str | tuple[str, str] | list[TaskTransition] | dict[str, str],
]

task_id_str = str(task_id)
os.environ.pop("GC_TASK_SANDBOX_DIR", None)
os.environ["GC_TASK_UUID"] = str(task_id)
os.environ["GC_TASK_UUID"] = task_id_str

if not run_dir or not os.path.isabs(run_dir):
raise RuntimeError(
f"execute_task requires an absolute path for run_dir, got {run_dir=}"
if result_size_limit < 128:
raise ValueError(
f"Invalid result limit; must be at least 128 bytes ({result_size_limit=})"
)

os.makedirs(run_dir, exist_ok=True)
os.chdir(run_dir)
task_dir = pathlib.Path(run_dir)
if not task_dir.is_absolute():
raise ValueError(f"Absolute path required. Received: {run_dir=}")

task_dir = task_dir.resolve() # strict=False (default); path may not exist yet
if run_in_sandbox:
os.makedirs(str(task_id)) # task_id is expected to be unique
os.chdir(str(task_id))
task_dir = task_dir / task_id_str
# Set sandbox dir so that apps can use it
os.environ["GC_TASK_SANDBOX_DIR"] = os.getcwd()
os.environ["GC_TASK_SANDBOX_DIR"] = str(task_dir)
task_dir.mkdir(parents=True, exist_ok=True)
os.chdir(task_dir)

env_details = get_env_details()
try:
Expand Down Expand Up @@ -116,7 +125,7 @@ def execute_task(
return messagepack.pack(Result(**result_message))


def _unpack_messagebody(message: bytes) -> t.Tuple[Task, str]:
def _unpack_messagebody(message: bytes) -> tuple[Task, str]:
"""Unpack messagebody as a messagepack message with
some legacy handling
Parameters
Expand Down
7 changes: 7 additions & 0 deletions compute_endpoint/tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import inspect
import os
import pathlib
Expand All @@ -6,6 +7,7 @@
import uuid

import pytest
from globus_compute_endpoint.engines.helper import execute_task
from tests.conftest import randomstring_impl


Expand Down Expand Up @@ -85,3 +87,8 @@ def get_random_of_datatype_impl(cls):
@pytest.fixture
def get_random_of_datatype():
return get_random_of_datatype_impl


@pytest.fixture
def execute_task_runner(task_uuid, tmp_path):
return functools.partial(execute_task, task_uuid, run_dir=tmp_path)
97 changes: 88 additions & 9 deletions compute_endpoint/tests/unit/test_execute_task.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,128 @@
import logging
import os
import random
from unittest import mock

import pytest
from globus_compute_common import messagepack
from globus_compute_endpoint.engines.helper import execute_task
from globus_compute_endpoint.engines.helper import _RESULT_SIZE_LIMIT, execute_task
from globus_compute_sdk.errors import MaxResultSizeExceeded
from tests.utils import divide

logger = logging.getLogger(__name__)

_MOCK_BASE = "globus_compute_endpoint.engines.helper."


@pytest.mark.parametrize("run_dir", ("tmp", None, "$HOME"))
@pytest.mark.parametrize("run_dir", ("", ".", "./", "../", "tmp", "$HOME"))
def test_bad_run_dir(endpoint_uuid, task_uuid, run_dir):
with pytest.raises(RuntimeError):
with pytest.raises(ValueError): # not absolute
execute_task(task_uuid, b"", endpoint_uuid, run_dir=run_dir)

with pytest.raises(TypeError): # not anything, allow-any-type language
execute_task(task_uuid, b"", endpoint_uuid, run_dir=None)

def test_execute_task(endpoint_uuid, serde, task_uuid, ez_pack_task, tmp_path):

def test_happy_path(serde, task_uuid, ez_pack_task, execute_task_runner):
out = random.randint(1, 100_000)
divisor = random.randint(1, 100_000)

task_bytes = ez_pack_task(divide, divisor * out, divisor)

packed_result = execute_task(task_uuid, task_bytes, endpoint_uuid, run_dir=tmp_path)
packed_result = execute_task_runner(task_bytes)
assert isinstance(packed_result, bytes)

result = messagepack.unpack(packed_result)
assert isinstance(result, messagepack.message_types.Result)
assert result.data
assert result.task_id == task_uuid
assert "os" in result.details
assert "python_version" in result.details
assert "dill_version" in result.details
assert "endpoint_id" in result.details
assert serde.deserialize(result.data) == out


def test_execute_task_with_exception(endpoint_uuid, task_uuid, ez_pack_task, tmp_path):
def test_sandbox(ez_pack_task, execute_task_runner, task_uuid, tmp_path):
task_bytes = ez_pack_task(divide, 10, 2)
packed_result = execute_task_runner(task_bytes, run_in_sandbox=True)
result = messagepack.unpack(packed_result)
assert result.task_id == task_uuid
assert result.error_details is None, "Verify test setup: execution successful"

exp_dir = tmp_path / str(task_uuid)
assert os.environ.get("GC_TASK_SANDBOX_DIR") == str(exp_dir), "Share dir w/ func"
assert os.getcwd() == str(exp_dir), "Expect sandbox dir entered"


def test_nested_run_dir(ez_pack_task, task_uuid, tmp_path):
task_bytes = ez_pack_task(divide, 10, 2)
nested_root = tmp_path / "a/"
nested_path = nested_root / "b/c/d"
assert not nested_root.exists(), "Verify test setup"

packed_result = execute_task(task_uuid, task_bytes, run_dir=nested_path)

result = messagepack.unpack(packed_result)
assert result.error_details is None, "Verify test setup: execution successful"

assert nested_path.exists(), "Test namesake"


@pytest.mark.parametrize("size_limit", (128, 256, 1024, 4096, _RESULT_SIZE_LIMIT))
def test_result_size_limit(serde, ez_pack_task, execute_task_runner, size_limit):
task_bytes = ez_pack_task(divide, 10, 2)
exp_data = f"{MaxResultSizeExceeded.__name__}({size_limit + 1}, {size_limit})"
res_data_good = "a" * size_limit
res_data_bad = "a" * (size_limit + 1)

with mock.patch(f"{_MOCK_BASE}_call_user_function") as mock_callfn:
with mock.patch(f"{_MOCK_BASE}log.exception"): # silence tests
mock_callfn.return_value = res_data_good
res_bytes = execute_task_runner(task_bytes, result_size_limit=size_limit)
result = messagepack.unpack(res_bytes)
assert result.data == res_data_good

mock_callfn.return_value = res_data_bad
res_bytes = execute_task_runner(task_bytes, result_size_limit=size_limit)
result = messagepack.unpack(res_bytes)
assert exp_data == result.data
assert result.error_details.code == "MaxResultSizeExceeded"


def test_default_result_size_limit(ez_pack_task, execute_task_runner):
task_bytes = ez_pack_task(divide, 10, 2)
default = _RESULT_SIZE_LIMIT
exp_data = f"{MaxResultSizeExceeded.__name__}({default + 1}, {default})"
res_data_good = "a" * default
res_data_bad = "a" * (default + 1)

with mock.patch(f"{_MOCK_BASE}_call_user_function") as mock_callfn:
with mock.patch(f"{_MOCK_BASE}log.exception"): # silence tests
mock_callfn.return_value = res_data_good
res_bytes = execute_task_runner(task_bytes)
result = messagepack.unpack(res_bytes)
assert result.data == res_data_good

mock_callfn.return_value = res_data_bad
res_bytes = execute_task_runner(task_bytes)
result = messagepack.unpack(res_bytes)
assert exp_data == result.data
assert result.error_details.code == "MaxResultSizeExceeded"


@pytest.mark.parametrize("size_limit", (-5, 0, 1, 65, 127))
def test_invalid_result_size_limit(size_limit):
with pytest.raises(ValueError) as pyt_e:
execute_task("test_tid", b"", run_dir="/", result_size_limit=5)
assert "must be at least" in str(pyt_e.value)


def test_execute_task_with_exception(ez_pack_task, execute_task_runner):
task_bytes = ez_pack_task(divide, 10, 0)

with mock.patch(f"{_MOCK_BASE}log") as mock_log:
packed_result = execute_task(
task_uuid, task_bytes, endpoint_uuid, run_dir=tmp_path
)
packed_result = execute_task_runner(task_bytes)

assert mock_log.exception.called
a, _k = mock_log.exception.call_args
Expand Down
32 changes: 2 additions & 30 deletions compute_endpoint/tests/unit/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import pytest
from globus_compute_common import messagepack
from globus_compute_endpoint.engines.helper import execute_task
from globus_compute_endpoint.engines.high_throughput.messages import Task
from globus_compute_endpoint.engines.high_throughput.worker import Worker

Expand Down Expand Up @@ -124,39 +123,12 @@ def test_execute_failing_function(test_worker):
)


def test_execute_function_exceeding_result_size_limit(
test_worker, endpoint_uuid, task_uuid, ez_pack_task, tmp_path
):
return_size = 10

task_bytes = ez_pack_task(large_result, return_size)

with mock.patch("globus_compute_endpoint.engines.helper.log") as mock_log:
s_result = execute_task(
task_uuid,
task_bytes,
endpoint_uuid,
result_size_limit=return_size - 2,
run_dir=tmp_path,
)
result = messagepack.unpack(s_result)

assert isinstance(result, messagepack.message_types.Result)
assert result.error_details
assert result.task_id == task_uuid
assert result.error_details
assert result.error_details.code == "MaxResultSizeExceeded"
assert mock_log.exception.called


def test_app_timeout(test_worker, endpoint_uuid, task_uuid, ez_pack_task, tmp_path):
def test_app_timeout(test_worker, execute_task_runner, task_uuid, ez_pack_task):
task_bytes = ez_pack_task(sleeper, 1)

with mock.patch("globus_compute_endpoint.engines.helper.log") as mock_log:
with mock.patch.dict(os.environ, {"GC_TASK_TIMEOUT": "0.01"}):
packed_result = execute_task(
task_uuid, task_bytes, endpoint_uuid, run_dir=tmp_path
)
packed_result = execute_task_runner(task_bytes)

result = messagepack.unpack(packed_result)
assert isinstance(result, messagepack.message_types.Result)
Expand Down

0 comments on commit 55e0759

Please sign in to comment.