From 4dd6429f049fe781e0860f8c0d1e48cdf33d7f61 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sat, 7 Dec 2024 12:46:03 -0500 Subject: [PATCH 1/4] Use workflow.init, refactor --- .../safe_message_handlers/workflow.py | 47 ++++++++++--------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/message_passing/safe_message_handlers/workflow.py b/message_passing/safe_message_handlers/workflow.py index 338a3065..4f7e6134 100644 --- a/message_passing/safe_message_handlers/workflow.py +++ b/message_passing/safe_message_handlers/workflow.py @@ -65,8 +65,17 @@ class ClusterManagerAssignNodesToJobResult: # These updates must run atomically. @workflow.defn class ClusterManagerWorkflow: - def __init__(self) -> None: - self.state = ClusterManagerState() + @workflow.init + def __init__(self, input: ClusterManagerInput) -> None: + if input.state: + self.state = input.state + else: + self.state = ClusterManagerState() + + if input.test_continue_as_new: + self.max_history_length = 120 + self.sleep_interval_seconds = 1 + # Protects workflow state from interleaved access self.nodes_lock = asyncio.Lock() self.max_history_length: Optional[int] = None @@ -202,29 +211,8 @@ async def perform_health_checks(self) -> None: f"Health check failed with error {type(e).__name__}:{e}" ) - # The cluster manager is a long-running "entity" workflow so we need to periodically checkpoint its state and - # continue-as-new. - def init(self, input: ClusterManagerInput) -> None: - if input.state: - self.state = input.state - if input.test_continue_as_new: - self.max_history_length = 120 - self.sleep_interval_seconds = 1 - - def should_continue_as_new(self) -> bool: - if workflow.info().is_continue_as_new_suggested(): - return True - # This is just for ease-of-testing. In production, we trust temporal to tell us when to continue as new. - if ( - self.max_history_length - and workflow.info().get_current_history_length() > self.max_history_length - ): - return True - return False - @workflow.run async def run(self, input: ClusterManagerInput) -> ClusterManagerResult: - self.init(input) await workflow.wait_condition(lambda: self.state.cluster_started) # Perform health checks at intervals. while True: @@ -239,6 +227,8 @@ async def run(self, input: ClusterManagerInput) -> ClusterManagerResult: pass if self.state.cluster_shutdown: break + # The cluster manager is a long-running "entity" workflow so we need to periodically checkpoint its state and + # continue-as-new. if self.should_continue_as_new(): # We don't want to leave any job assignment or deletion handlers half-finished when we continue as new. await workflow.wait_condition(lambda: workflow.all_handlers_finished()) @@ -255,3 +245,14 @@ async def run(self, input: ClusterManagerInput) -> ClusterManagerResult: len(self.get_assigned_nodes()), len(self.get_bad_nodes()), ) + + def should_continue_as_new(self) -> bool: + if workflow.info().is_continue_as_new_suggested(): + return True + # This is just for ease-of-testing. In production, we trust temporal to tell us when to continue as new. + if ( + self.max_history_length + and workflow.info().get_current_history_length() > self.max_history_length + ): + return True + return False From b90b1effa607b95db6cc27703627e6762fc811d9 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sat, 7 Dec 2024 13:49:01 -0500 Subject: [PATCH 2/4] Start cluster automatically; use update to wait until started --- .../safe_message_handlers/README.md | 1 - .../safe_message_handlers/activities.py | 10 +++++++ .../safe_message_handlers/starter.py | 6 +++-- .../safe_message_handlers/worker.py | 8 +++++- .../safe_message_handlers/workflow.py | 20 +++++++++----- .../safe_message_handlers/workflow_test.py | 26 ++++++++++++++----- 6 files changed, 54 insertions(+), 17 deletions(-) diff --git a/message_passing/safe_message_handlers/README.md b/message_passing/safe_message_handlers/README.md index 7d727af3..274d9cdc 100644 --- a/message_passing/safe_message_handlers/README.md +++ b/message_passing/safe_message_handlers/README.md @@ -3,7 +3,6 @@ This sample shows off important techniques for handling signals and updates, aka messages. In particular, it illustrates how message handlers can interleave or not be completed before the workflow completes, and how you can manage that. * Here, using workflow.wait_condition, signal and update handlers will only operate when the workflow is within a certain state--between cluster_started and cluster_shutdown. -* You can run start_workflow with an initializer signal that you want to run before anything else other than the workflow's constructor. This pattern is known as "signal-with-start." * Message handlers can block and their actions can be interleaved with one another and with the main workflow. This can easily cause bugs, so you can use a lock to protect shared state from interleaved access. * An "Entity" workflow, i.e. a long-lived workflow, periodically "continues as new". It must do this to prevent its history from growing too large, and it passes its state to the next workflow. You can check `workflow.info().is_continue_as_new_suggested()` to see when it's time. * Most people want their message handlers to finish before the workflow run completes or continues as new. Use `await workflow.wait_condition(lambda: workflow.all_handlers_finished())` to achieve this. diff --git a/message_passing/safe_message_handlers/activities.py b/message_passing/safe_message_handlers/activities.py index 3a1c9cd2..f3e369ce 100644 --- a/message_passing/safe_message_handlers/activities.py +++ b/message_passing/safe_message_handlers/activities.py @@ -11,6 +11,16 @@ class AssignNodesToJobInput: job_name: str +@dataclass +class ClusterState: + node_ids: List[str] + + +@activity.defn +async def start_cluster() -> ClusterState: + return ClusterState(node_ids=[f"{i}" for i in range(25)]) + + @activity.defn async def assign_nodes_to_job(input: AssignNodesToJobInput) -> None: print(f"Assigning nodes {input.nodes} to job {input.job_name}") diff --git a/message_passing/safe_message_handlers/starter.py b/message_passing/safe_message_handlers/starter.py index cf712163..7ffe13d9 100644 --- a/message_passing/safe_message_handlers/starter.py +++ b/message_passing/safe_message_handlers/starter.py @@ -16,8 +16,10 @@ async def do_cluster_lifecycle(wf: WorkflowHandle, delay_seconds: Optional[int] = None): - - await wf.signal(ClusterManagerWorkflow.start_cluster) + cluster_status = await wf.execute_update( + ClusterManagerWorkflow.wait_until_cluster_started + ) + print(f"Cluster started with {len(cluster_status.nodes)} nodes") print("Assigning jobs to nodes...") allocation_updates = [] diff --git a/message_passing/safe_message_handlers/worker.py b/message_passing/safe_message_handlers/worker.py index a900c7bb..34e71290 100644 --- a/message_passing/safe_message_handlers/worker.py +++ b/message_passing/safe_message_handlers/worker.py @@ -8,6 +8,7 @@ ClusterManagerWorkflow, assign_nodes_to_job, find_bad_nodes, + start_cluster, unassign_nodes_for_job, ) @@ -21,7 +22,12 @@ async def main(): client, task_queue="safe-message-handlers-task-queue", workflows=[ClusterManagerWorkflow], - activities=[assign_nodes_to_job, unassign_nodes_for_job, find_bad_nodes], + activities=[ + assign_nodes_to_job, + unassign_nodes_for_job, + find_bad_nodes, + start_cluster, + ], ): logging.info("ClusterManagerWorkflow worker started, ctrl+c to exit") await interrupt_event.wait() diff --git a/message_passing/safe_message_handlers/workflow.py b/message_passing/safe_message_handlers/workflow.py index 4f7e6134..437cfb5f 100644 --- a/message_passing/safe_message_handlers/workflow.py +++ b/message_passing/safe_message_handlers/workflow.py @@ -14,6 +14,7 @@ UnassignNodesForJobInput, assign_nodes_to_job, find_bad_nodes, + start_cluster, unassign_nodes_for_job, ) @@ -81,11 +82,10 @@ def __init__(self, input: ClusterManagerInput) -> None: self.max_history_length: Optional[int] = None self.sleep_interval_seconds: int = 600 - @workflow.signal - async def start_cluster(self) -> None: - self.state.cluster_started = True - self.state.nodes = {str(k): None for k in range(25)} - workflow.logger.info("Cluster started") + @workflow.update + async def wait_until_cluster_started(self) -> ClusterManagerState: + await workflow.wait_condition(lambda: self.state.cluster_started) + return self.state @workflow.signal async def shutdown_cluster(self) -> None: @@ -144,7 +144,7 @@ async def _assign_nodes_to_job( self.state.jobs_assigned.add(job_name) # Even though it returns nothing, this is an update because the client may want to track it, for example - # to wait for nodes to be unassignd before reassigning them. + # to wait for nodes to be unassigned before reassigning them. @workflow.update async def delete_job(self, input: ClusterManagerDeleteJobInput) -> None: await workflow.wait_condition(lambda: self.state.cluster_started) @@ -213,7 +213,13 @@ async def perform_health_checks(self) -> None: @workflow.run async def run(self, input: ClusterManagerInput) -> ClusterManagerResult: - await workflow.wait_condition(lambda: self.state.cluster_started) + cluster_state = await workflow.execute_activity( + start_cluster, schedule_to_close_timeout=timedelta(seconds=10) + ) + self.state.nodes = {k: None for k in cluster_state.node_ids} + self.state.cluster_started = True + workflow.logger.info("Cluster started") + # Perform health checks at intervals. while True: await self.perform_health_checks() diff --git a/tests/message_passing/safe_message_handlers/workflow_test.py b/tests/message_passing/safe_message_handlers/workflow_test.py index 92345be7..1f0aa875 100644 --- a/tests/message_passing/safe_message_handlers/workflow_test.py +++ b/tests/message_passing/safe_message_handlers/workflow_test.py @@ -10,6 +10,7 @@ from message_passing.safe_message_handlers.activities import ( assign_nodes_to_job, find_bad_nodes, + start_cluster, unassign_nodes_for_job, ) from message_passing.safe_message_handlers.workflow import ( @@ -19,6 +20,13 @@ ClusterManagerWorkflow, ) +ACTIVITIES = [ + assign_nodes_to_job, + unassign_nodes_for_job, + find_bad_nodes, + start_cluster, +] + async def test_safe_message_handlers(client: Client, env: WorkflowEnvironment): if env.supports_time_skipping: @@ -30,7 +38,7 @@ async def test_safe_message_handlers(client: Client, env: WorkflowEnvironment): client, task_queue=task_queue, workflows=[ClusterManagerWorkflow], - activities=[assign_nodes_to_job, unassign_nodes_for_job, find_bad_nodes], + activities=ACTIVITIES, ): cluster_manager_handle = await client.start_workflow( ClusterManagerWorkflow.run, @@ -38,7 +46,9 @@ async def test_safe_message_handlers(client: Client, env: WorkflowEnvironment): id=f"ClusterManagerWorkflow-{uuid.uuid4()}", task_queue=task_queue, ) - await cluster_manager_handle.signal(ClusterManagerWorkflow.start_cluster) + await cluster_manager_handle.execute_update( + ClusterManagerWorkflow.wait_until_cluster_started + ) allocation_updates = [] for i in range(6): @@ -82,7 +92,7 @@ async def test_update_idempotency(client: Client, env: WorkflowEnvironment): client, task_queue=task_queue, workflows=[ClusterManagerWorkflow], - activities=[assign_nodes_to_job, unassign_nodes_for_job, find_bad_nodes], + activities=ACTIVITIES, ): cluster_manager_handle = await client.start_workflow( ClusterManagerWorkflow.run, @@ -91,7 +101,9 @@ async def test_update_idempotency(client: Client, env: WorkflowEnvironment): task_queue=task_queue, ) - await cluster_manager_handle.signal(ClusterManagerWorkflow.start_cluster) + await cluster_manager_handle.execute_update( + ClusterManagerWorkflow.wait_until_cluster_started + ) result_1 = await cluster_manager_handle.execute_update( ClusterManagerWorkflow.assign_nodes_to_job, @@ -121,7 +133,7 @@ async def test_update_failure(client: Client, env: WorkflowEnvironment): client, task_queue=task_queue, workflows=[ClusterManagerWorkflow], - activities=[assign_nodes_to_job, unassign_nodes_for_job, find_bad_nodes], + activities=ACTIVITIES, ): cluster_manager_handle = await client.start_workflow( ClusterManagerWorkflow.run, @@ -130,7 +142,9 @@ async def test_update_failure(client: Client, env: WorkflowEnvironment): task_queue=task_queue, ) - await cluster_manager_handle.signal(ClusterManagerWorkflow.start_cluster) + await cluster_manager_handle.execute_update( + ClusterManagerWorkflow.wait_until_cluster_started + ) await cluster_manager_handle.execute_update( ClusterManagerWorkflow.assign_nodes_to_job, From 295bfde321bc9e5351c1554b74ed425eb27b01b7 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Mon, 9 Dec 2024 11:26:24 -0500 Subject: [PATCH 3/4] Fixups / lint --- message_passing/safe_message_handlers/workflow.py | 7 ++++--- .../message_passing/safe_message_handlers/workflow_test.py | 3 ++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/message_passing/safe_message_handlers/workflow.py b/message_passing/safe_message_handlers/workflow.py index 437cfb5f..ca549a61 100644 --- a/message_passing/safe_message_handlers/workflow.py +++ b/message_passing/safe_message_handlers/workflow.py @@ -74,13 +74,14 @@ def __init__(self, input: ClusterManagerInput) -> None: self.state = ClusterManagerState() if input.test_continue_as_new: - self.max_history_length = 120 + self.max_history_length: Optional[int] = 120 self.sleep_interval_seconds = 1 + else: + self.max_history_length = None + self.sleep_interval_seconds = 600 # Protects workflow state from interleaved access self.nodes_lock = asyncio.Lock() - self.max_history_length: Optional[int] = None - self.sleep_interval_seconds: int = 600 @workflow.update async def wait_until_cluster_started(self) -> ClusterManagerState: diff --git a/tests/message_passing/safe_message_handlers/workflow_test.py b/tests/message_passing/safe_message_handlers/workflow_test.py index 1f0aa875..8cd303d5 100644 --- a/tests/message_passing/safe_message_handlers/workflow_test.py +++ b/tests/message_passing/safe_message_handlers/workflow_test.py @@ -1,5 +1,6 @@ import asyncio import uuid +from typing import Callable, Sequence import pytest from temporalio.client import Client, WorkflowUpdateFailedError @@ -20,7 +21,7 @@ ClusterManagerWorkflow, ) -ACTIVITIES = [ +ACTIVITIES: Sequence[Callable] = [ assign_nodes_to_job, unassign_nodes_for_job, find_bad_nodes, From 99f5300e49b91a904fdfc6697cb83750c048940f Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Mon, 9 Dec 2024 11:31:03 -0500 Subject: [PATCH 4/4] Don't require node IDs to parse as ints --- message_passing/safe_message_handlers/activities.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/message_passing/safe_message_handlers/activities.py b/message_passing/safe_message_handlers/activities.py index f3e369ce..da8a8be0 100644 --- a/message_passing/safe_message_handlers/activities.py +++ b/message_passing/safe_message_handlers/activities.py @@ -18,7 +18,7 @@ class ClusterState: @activity.defn async def start_cluster() -> ClusterState: - return ClusterState(node_ids=[f"{i}" for i in range(25)]) + return ClusterState(node_ids=[f"node-{i}" for i in range(25)]) @activity.defn @@ -47,7 +47,7 @@ class FindBadNodesInput: @activity.defn async def find_bad_nodes(input: FindBadNodesInput) -> Set[str]: await asyncio.sleep(0.1) - bad_nodes = set([n for n in input.nodes_to_check if int(n) % 5 == 0]) + bad_nodes = set([id for id in input.nodes_to_check if hash(id) % 5 == 0]) if bad_nodes: print(f"Found bad nodes: {bad_nodes}") else: