diff --git a/update/job_runner_I2.py b/update/job_runner_I2.py index 47f0a513..76dbde83 100644 --- a/update/job_runner_I2.py +++ b/update/job_runner_I2.py @@ -1,8 +1,10 @@ import asyncio +from collections import OrderedDict from dataclasses import dataclass from datetime import datetime, timedelta +from enum import Enum import logging -from typing import Optional +from typing import Awaitable, Callable, Optional from temporalio import common, workflow, activity from temporalio.client import Client, WorkflowHandle @@ -29,6 +31,23 @@ class JobOutput: stderr: str +class TaskStatus(Enum): + BLOCKED = 1 + UNBLOCKED = 2 + + +@dataclass +class Task: + input: Job + handler: Callable[["JobRunner", Job], Awaitable[JobOutput]] + status: TaskStatus = TaskStatus.BLOCKED + output: Optional[JobOutput] = None + + @property + def blocked(self) -> bool: + return self.status == TaskStatus.BLOCKED + + @workflow.defn class JobRunner: """ @@ -36,15 +55,73 @@ class JobRunner: not before `job.after_time`. """ + def __init__(self) -> None: + self.task_queue = OrderedDict[JobID, Task]() + self.completed_tasks = set[JobID]() + + def all_handlers_completed(self): + # We are considering adding an API like `all_handlers_completed` to SDKs. In this particular + # case, the user doesn't actually need the new API, since they are forced to track pending + # tasks in their queue implementation. + return not self.task_queue + + # Note some undesirable things: + # 1. The update handler functions have become generic enqueuers; the "real" handler functions + # are some other methods that don't have the @workflow.update decorator. + # 2. The update handler functions have to store a reference to the real handler in the queue. + # 3. The workflow `run` method is *much* more complicated and bug-prone here, compared to + # I1:WaitUntilReadyToExecuteHandler + @workflow.run async def run(self): - await workflow.wait_condition( - lambda: workflow.info().is_continue_as_new_suggested() - ) + """ + Process all tasks in the queue serially, in the main workflow coroutine. + """ + # Note: there are many mistakes a user will make while trying to implement this workflow. + while not ( + workflow.info().is_continue_as_new_suggested() + and self.all_handlers_completed() + ): + await workflow.wait_condition(lambda: bool(self.task_queue)) + for id, task in list(self.task_queue.items()): + if task.status == TaskStatus.UNBLOCKED: + await task.handler(self, task.input) + del self.task_queue[id] + self.completed_tasks.add(id) + for id, task in self.task_queue.items(): + if task.status == TaskStatus.BLOCKED and self.ready_to_execute( + task.input + ): + task.status = TaskStatus.UNBLOCKED workflow.continue_as_new() + def ready_to_execute(self, job: Job) -> bool: + if not set(job.depends_on) <= self.completed_tasks: + return False + if after_time := job.after_time: + if float(after_time) > workflow.now().timestamp(): + return False + return True + + async def _enqueue_job_and_wait_for_result( + self, job: Job, handler: Callable[["JobRunner", Job], Awaitable[JobOutput]] + ) -> JobOutput: + task = Task(job, handler) + self.task_queue[job.id] = task + await workflow.wait_condition(lambda: task.output is not None) + # Footgun: a user might well think that they can record task completion here, but in fact it + # deadlocks. + # self.completed_tasks.add(job.id) + assert task.output + return task.output + @workflow.update async def run_shell_script_job(self, job: Job) -> JobOutput: + return await self._enqueue_job_and_wait_for_result( + job, JobRunner._actually_run_shell_script_job + ) + + async def _actually_run_shell_script_job(self, job: Job) -> JobOutput: if security_errors := await workflow.execute_activity( run_shell_script_security_linter, args=[job.run], @@ -58,6 +135,11 @@ async def run_shell_script_job(self, job: Job) -> JobOutput: @workflow.update async def run_python_job(self, job: Job) -> JobOutput: + return await self._enqueue_job_and_wait_for_result( + job, JobRunner._actually_run_python_job + ) + + async def _actually_run_python_job(self, job: Job) -> JobOutput: if not await workflow.execute_activity( check_python_interpreter_version, args=[job.python_interpreter_version],