Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cluster manager: use update to wait for cluster to start #153

Merged
merged 4 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion message_passing/safe_message_handlers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
This sample shows off important techniques for handling signals and updates, aka messages. In particular, it illustrates how message handlers can interleave or not be completed before the workflow completes, and how you can manage that.

* Here, using workflow.wait_condition, signal and update handlers will only operate when the workflow is within a certain state--between cluster_started and cluster_shutdown.
* You can run start_workflow with an initializer signal that you want to run before anything else other than the workflow's constructor. This pattern is known as "signal-with-start."
* Message handlers can block and their actions can be interleaved with one another and with the main workflow. This can easily cause bugs, so you can use a lock to protect shared state from interleaved access.
* An "Entity" workflow, i.e. a long-lived workflow, periodically "continues as new". It must do this to prevent its history from growing too large, and it passes its state to the next workflow. You can check `workflow.info().is_continue_as_new_suggested()` to see when it's time.
* Most people want their message handlers to finish before the workflow run completes or continues as new. Use `await workflow.wait_condition(lambda: workflow.all_handlers_finished())` to achieve this.
Expand Down
12 changes: 11 additions & 1 deletion message_passing/safe_message_handlers/activities.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@ class AssignNodesToJobInput:
job_name: str


@dataclass
class ClusterState:
node_ids: List[str]


@activity.defn
async def start_cluster() -> ClusterState:
return ClusterState(node_ids=[f"node-{i}" for i in range(25)])


@activity.defn
async def assign_nodes_to_job(input: AssignNodesToJobInput) -> None:
print(f"Assigning nodes {input.nodes} to job {input.job_name}")
Expand All @@ -37,7 +47,7 @@ class FindBadNodesInput:
@activity.defn
async def find_bad_nodes(input: FindBadNodesInput) -> Set[str]:
await asyncio.sleep(0.1)
bad_nodes = set([n for n in input.nodes_to_check if int(n) % 5 == 0])
bad_nodes = set([id for id in input.nodes_to_check if hash(id) % 5 == 0])
if bad_nodes:
print(f"Found bad nodes: {bad_nodes}")
else:
Expand Down
6 changes: 4 additions & 2 deletions message_passing/safe_message_handlers/starter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@


async def do_cluster_lifecycle(wf: WorkflowHandle, delay_seconds: Optional[int] = None):

await wf.signal(ClusterManagerWorkflow.start_cluster)
cluster_status = await wf.execute_update(
ClusterManagerWorkflow.wait_until_cluster_started
)
print(f"Cluster started with {len(cluster_status.nodes)} nodes")

print("Assigning jobs to nodes...")
allocation_updates = []
Expand Down
8 changes: 7 additions & 1 deletion message_passing/safe_message_handlers/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
ClusterManagerWorkflow,
assign_nodes_to_job,
find_bad_nodes,
start_cluster,
unassign_nodes_for_job,
)

Expand All @@ -21,7 +22,12 @@ async def main():
client,
task_queue="safe-message-handlers-task-queue",
workflows=[ClusterManagerWorkflow],
activities=[assign_nodes_to_job, unassign_nodes_for_job, find_bad_nodes],
activities=[
assign_nodes_to_job,
unassign_nodes_for_job,
find_bad_nodes,
start_cluster,
],
):
logging.info("ClusterManagerWorkflow worker started, ctrl+c to exit")
await interrupt_event.wait()
Expand Down
72 changes: 40 additions & 32 deletions message_passing/safe_message_handlers/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
UnassignNodesForJobInput,
assign_nodes_to_job,
find_bad_nodes,
start_cluster,
unassign_nodes_for_job,
)

Expand Down Expand Up @@ -65,18 +66,27 @@ class ClusterManagerAssignNodesToJobResult:
# These updates must run atomically.
@workflow.defn
class ClusterManagerWorkflow:
def __init__(self) -> None:
self.state = ClusterManagerState()
@workflow.init
def __init__(self, input: ClusterManagerInput) -> None:
if input.state:
self.state = input.state
else:
self.state = ClusterManagerState()

if input.test_continue_as_new:
self.max_history_length: Optional[int] = 120
self.sleep_interval_seconds = 1
else:
self.max_history_length = None
self.sleep_interval_seconds = 600

# Protects workflow state from interleaved access
self.nodes_lock = asyncio.Lock()
self.max_history_length: Optional[int] = None
self.sleep_interval_seconds: int = 600

@workflow.signal
async def start_cluster(self) -> None:
self.state.cluster_started = True
self.state.nodes = {str(k): None for k in range(25)}
workflow.logger.info("Cluster started")
@workflow.update
async def wait_until_cluster_started(self) -> ClusterManagerState:
dandavison marked this conversation as resolved.
Show resolved Hide resolved
await workflow.wait_condition(lambda: self.state.cluster_started)
return self.state

@workflow.signal
async def shutdown_cluster(self) -> None:
Expand Down Expand Up @@ -135,7 +145,7 @@ async def _assign_nodes_to_job(
self.state.jobs_assigned.add(job_name)

# Even though it returns nothing, this is an update because the client may want to track it, for example
# to wait for nodes to be unassignd before reassigning them.
# to wait for nodes to be unassigned before reassigning them.
@workflow.update
async def delete_job(self, input: ClusterManagerDeleteJobInput) -> None:
await workflow.wait_condition(lambda: self.state.cluster_started)
Expand Down Expand Up @@ -202,30 +212,15 @@ async def perform_health_checks(self) -> None:
f"Health check failed with error {type(e).__name__}:{e}"
)

# The cluster manager is a long-running "entity" workflow so we need to periodically checkpoint its state and
# continue-as-new.
def init(self, input: ClusterManagerInput) -> None:
if input.state:
self.state = input.state
if input.test_continue_as_new:
self.max_history_length = 120
self.sleep_interval_seconds = 1

def should_continue_as_new(self) -> bool:
if workflow.info().is_continue_as_new_suggested():
return True
# This is just for ease-of-testing. In production, we trust temporal to tell us when to continue as new.
if (
self.max_history_length
and workflow.info().get_current_history_length() > self.max_history_length
):
return True
return False

@workflow.run
async def run(self, input: ClusterManagerInput) -> ClusterManagerResult:
self.init(input)
await workflow.wait_condition(lambda: self.state.cluster_started)
cluster_state = await workflow.execute_activity(
start_cluster, schedule_to_close_timeout=timedelta(seconds=10)
)
self.state.nodes = {k: None for k in cluster_state.node_ids}
self.state.cluster_started = True
workflow.logger.info("Cluster started")

# Perform health checks at intervals.
while True:
await self.perform_health_checks()
Expand All @@ -239,6 +234,8 @@ async def run(self, input: ClusterManagerInput) -> ClusterManagerResult:
pass
if self.state.cluster_shutdown:
break
# The cluster manager is a long-running "entity" workflow so we need to periodically checkpoint its state and
# continue-as-new.
if self.should_continue_as_new():
# We don't want to leave any job assignment or deletion handlers half-finished when we continue as new.
await workflow.wait_condition(lambda: workflow.all_handlers_finished())
Expand All @@ -255,3 +252,14 @@ async def run(self, input: ClusterManagerInput) -> ClusterManagerResult:
len(self.get_assigned_nodes()),
len(self.get_bad_nodes()),
)

def should_continue_as_new(self) -> bool:
if workflow.info().is_continue_as_new_suggested():
return True
# This is just for ease-of-testing. In production, we trust temporal to tell us when to continue as new.
if (
self.max_history_length
and workflow.info().get_current_history_length() > self.max_history_length
):
return True
return False
27 changes: 21 additions & 6 deletions tests/message_passing/safe_message_handlers/workflow_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import uuid
from typing import Callable, Sequence

import pytest
from temporalio.client import Client, WorkflowUpdateFailedError
Expand All @@ -10,6 +11,7 @@
from message_passing.safe_message_handlers.activities import (
assign_nodes_to_job,
find_bad_nodes,
start_cluster,
unassign_nodes_for_job,
)
from message_passing.safe_message_handlers.workflow import (
Expand All @@ -19,6 +21,13 @@
ClusterManagerWorkflow,
)

ACTIVITIES: Sequence[Callable] = [
assign_nodes_to_job,
unassign_nodes_for_job,
find_bad_nodes,
start_cluster,
]


async def test_safe_message_handlers(client: Client, env: WorkflowEnvironment):
if env.supports_time_skipping:
Expand All @@ -30,15 +39,17 @@ async def test_safe_message_handlers(client: Client, env: WorkflowEnvironment):
client,
task_queue=task_queue,
workflows=[ClusterManagerWorkflow],
activities=[assign_nodes_to_job, unassign_nodes_for_job, find_bad_nodes],
activities=ACTIVITIES,
):
cluster_manager_handle = await client.start_workflow(
ClusterManagerWorkflow.run,
ClusterManagerInput(),
id=f"ClusterManagerWorkflow-{uuid.uuid4()}",
task_queue=task_queue,
)
await cluster_manager_handle.signal(ClusterManagerWorkflow.start_cluster)
await cluster_manager_handle.execute_update(
ClusterManagerWorkflow.wait_until_cluster_started
)

allocation_updates = []
for i in range(6):
Expand Down Expand Up @@ -82,7 +93,7 @@ async def test_update_idempotency(client: Client, env: WorkflowEnvironment):
client,
task_queue=task_queue,
workflows=[ClusterManagerWorkflow],
activities=[assign_nodes_to_job, unassign_nodes_for_job, find_bad_nodes],
activities=ACTIVITIES,
):
cluster_manager_handle = await client.start_workflow(
ClusterManagerWorkflow.run,
Expand All @@ -91,7 +102,9 @@ async def test_update_idempotency(client: Client, env: WorkflowEnvironment):
task_queue=task_queue,
)

await cluster_manager_handle.signal(ClusterManagerWorkflow.start_cluster)
await cluster_manager_handle.execute_update(
ClusterManagerWorkflow.wait_until_cluster_started
)

result_1 = await cluster_manager_handle.execute_update(
ClusterManagerWorkflow.assign_nodes_to_job,
Expand Down Expand Up @@ -121,7 +134,7 @@ async def test_update_failure(client: Client, env: WorkflowEnvironment):
client,
task_queue=task_queue,
workflows=[ClusterManagerWorkflow],
activities=[assign_nodes_to_job, unassign_nodes_for_job, find_bad_nodes],
activities=ACTIVITIES,
):
cluster_manager_handle = await client.start_workflow(
ClusterManagerWorkflow.run,
Expand All @@ -130,7 +143,9 @@ async def test_update_failure(client: Client, env: WorkflowEnvironment):
task_queue=task_queue,
)

await cluster_manager_handle.signal(ClusterManagerWorkflow.start_cluster)
await cluster_manager_handle.execute_update(
ClusterManagerWorkflow.wait_until_cluster_started
)

await cluster_manager_handle.execute_update(
ClusterManagerWorkflow.assign_nodes_to_job,
Expand Down
Loading