Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove health check and "bad" node concept from ClusterManager #132

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ packages = [

[tool.poetry.dependencies]
python = "^3.8"
temporalio = "^1.5.0"
temporalio = "^1.6.0"

[tool.poetry.dev-dependencies]
black = "^22.3.0"
Expand Down
12 changes: 5 additions & 7 deletions tests/updates_and_signals/safe_message_handlers/workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from updates_and_signals.safe_message_handlers.activities import (
assign_nodes_to_job,
find_bad_nodes,
unassign_nodes_for_job,
)
from updates_and_signals.safe_message_handlers.workflow import (
Expand All @@ -30,7 +29,7 @@ async def test_safe_message_handlers(client: Client, env: WorkflowEnvironment):
client,
task_queue=task_queue,
workflows=[ClusterManagerWorkflow],
activities=[assign_nodes_to_job, unassign_nodes_for_job, find_bad_nodes],
activities=[assign_nodes_to_job, unassign_nodes_for_job],
):
cluster_manager_handle = await client.start_workflow(
ClusterManagerWorkflow.run,
Expand Down Expand Up @@ -82,7 +81,7 @@ async def test_update_idempotency(client: Client, env: WorkflowEnvironment):
client,
task_queue=task_queue,
workflows=[ClusterManagerWorkflow],
activities=[assign_nodes_to_job, unassign_nodes_for_job, find_bad_nodes],
activities=[assign_nodes_to_job, unassign_nodes_for_job],
):
cluster_manager_handle = await client.start_workflow(
ClusterManagerWorkflow.run,
Expand All @@ -106,8 +105,7 @@ async def test_update_idempotency(client: Client, env: WorkflowEnvironment):
total_num_nodes=5, job_name="jobby-job"
),
)
# the second call should not assign more nodes (it may return fewer if the health check finds bad nodes
# in between the two signals.)
# the second call should not assign more nodes
assert result_1.nodes_assigned >= result_2.nodes_assigned


Expand All @@ -121,7 +119,7 @@ async def test_update_failure(client: Client, env: WorkflowEnvironment):
client,
task_queue=task_queue,
workflows=[ClusterManagerWorkflow],
activities=[assign_nodes_to_job, unassign_nodes_for_job, find_bad_nodes],
activities=[assign_nodes_to_job, unassign_nodes_for_job],
):
cluster_manager_handle = await client.start_workflow(
ClusterManagerWorkflow.run,
Expand Down Expand Up @@ -152,4 +150,4 @@ async def test_update_failure(client: Client, env: WorkflowEnvironment):
finally:
await cluster_manager_handle.signal(ClusterManagerWorkflow.shutdown_cluster)
result = await cluster_manager_handle.result()
assert result.num_currently_assigned_nodes + result.num_bad_nodes == 24
assert result.num_currently_assigned_nodes == 24
16 changes: 0 additions & 16 deletions updates_and_signals/safe_message_handlers/activities.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,3 @@ class UnassignNodesForJobInput:
async def unassign_nodes_for_job(input: UnassignNodesForJobInput) -> None:
print(f"Deallocating nodes {input.nodes} from job {input.job_name}")
await asyncio.sleep(0.1)


@dataclass
class FindBadNodesInput:
nodes_to_check: Set[str]


@activity.defn
async def find_bad_nodes(input: FindBadNodesInput) -> Set[str]:
await asyncio.sleep(0.1)
bad_nodes = set([n for n in input.nodes_to_check if int(n) % 5 == 0])
if bad_nodes:
print(f"Found bad nodes: {bad_nodes}")
else:
print("No new bad nodes found.")
return bad_nodes
3 changes: 1 addition & 2 deletions updates_and_signals/safe_message_handlers/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from updates_and_signals.safe_message_handlers.workflow import (
ClusterManagerWorkflow,
assign_nodes_to_job,
find_bad_nodes,
unassign_nodes_for_job,
)

Expand All @@ -22,7 +21,7 @@ async def main():
client,
task_queue="safe-message-handlers-task-queue",
workflows=[ClusterManagerWorkflow],
activities=[assign_nodes_to_job, unassign_nodes_for_job, find_bad_nodes],
activities=[assign_nodes_to_job, unassign_nodes_for_job],
):
# Wait until interrupted
logging.info("ClusterManagerWorkflow worker started, ctrl+c to exit")
Expand Down
61 changes: 8 additions & 53 deletions updates_and_signals/safe_message_handlers/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,12 @@
from typing import Dict, List, Optional, Set

from temporalio import workflow
from temporalio.common import RetryPolicy
from temporalio.exceptions import ApplicationError

from updates_and_signals.safe_message_handlers.activities import (
AssignNodesToJobInput,
FindBadNodesInput,
UnassignNodesForJobInput,
assign_nodes_to_job,
find_bad_nodes,
unassign_nodes_for_job,
)

Expand All @@ -37,7 +34,6 @@ class ClusterManagerInput:
@dataclass
class ClusterManagerResult:
num_currently_assigned_nodes: int
num_bad_nodes: int


# Be in the habit of storing message inputs and outputs in serializable structures.
Expand Down Expand Up @@ -70,7 +66,6 @@ def __init__(self) -> None:
# Protects workflow state from interleaved access
self.nodes_lock = asyncio.Lock()
self.max_history_length: Optional[int] = None
self.sleep_interval_seconds: int = 600

@workflow.signal
async def start_cluster(self) -> None:
Expand Down Expand Up @@ -116,7 +111,7 @@ async def assign_nodes_to_job(
)
nodes_to_assign = unassigned_nodes[: input.total_num_nodes]
# This await would be dangerous without nodes_lock because it yields control and allows interleaving
# with delete_job and perform_health_checks, which both touch self.state.nodes.
# with delete_job, which touches self.state.nodes.
await self._assign_nodes_to_job(nodes_to_assign, input.job_name)
return ClusterManagerAssignNodesToJobResult(
nodes_assigned=self.get_assigned_nodes(job_name=input.job_name)
Expand Down Expand Up @@ -150,7 +145,7 @@ async def delete_job(self, input: ClusterManagerDeleteJobInput) -> None:
k for k, v in self.state.nodes.items() if v == input.job_name
]
# This await would be dangerous without nodes_lock because it yields control and allows interleaving
# with assign_nodes_to_job and perform_health_checks, which all touch self.state.nodes.
# with assign_nodes_to_job, which touches self.state.nodes.
await self._unassign_nodes_for_job(nodes_to_unassign, input.job_name)

async def _unassign_nodes_for_job(
Expand All @@ -167,40 +162,11 @@ async def _unassign_nodes_for_job(
def get_unassigned_nodes(self) -> List[str]:
return [k for k, v in self.state.nodes.items() if v is None]

def get_bad_nodes(self) -> Set[str]:
return set([k for k, v in self.state.nodes.items() if v == "BAD!"])

def get_assigned_nodes(self, *, job_name: Optional[str] = None) -> Set[str]:
if job_name:
return set([k for k, v in self.state.nodes.items() if v == job_name])
else:
return set(
[
k
for k, v in self.state.nodes.items()
if v is not None and v != "BAD!"
]
)

async def perform_health_checks(self) -> None:
async with self.nodes_lock:
assigned_nodes = self.get_assigned_nodes()
try:
# This await would be dangerous without nodes_lock because it yields control and allows interleaving
# with assign_nodes_to_job and delete_job, which both touch self.state.nodes.
bad_nodes = await workflow.execute_activity(
find_bad_nodes,
FindBadNodesInput(nodes_to_check=assigned_nodes),
start_to_close_timeout=timedelta(seconds=10),
# This health check is optional, and our lock would block the whole workflow if we let it retry forever.
retry_policy=RetryPolicy(maximum_attempts=1),
)
for node in bad_nodes:
self.state.nodes[node] = "BAD!"
except Exception as e:
workflow.logger.warn(
f"Health check failed with error {type(e).__name__}:{e}"
)
return set([k for k, v in self.state.nodes.items() if v is not None])

# The cluster manager is a long-running "entity" workflow so we need to periodically checkpoint its state and
# continue-as-new.
Expand All @@ -209,7 +175,6 @@ def init(self, input: ClusterManagerInput) -> None:
self.state = input.state
if input.test_continue_as_new:
self.max_history_length = 120
self.sleep_interval_seconds = 1

def should_continue_as_new(self) -> bool:
# We don't want to continue-as-new if we're in the middle of an update
Expand All @@ -228,29 +193,19 @@ def should_continue_as_new(self) -> bool:
@workflow.run
async def run(self, input: ClusterManagerInput) -> ClusterManagerResult:
self.init(input)
await workflow.wait_condition(lambda: self.state.cluster_started)
# Perform health checks at intervals.
while True:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no value in this being a loop anymore and no value in the wait condition having a timeout (and therefore no value in a sleep interval setting)

Copy link
Contributor Author

@dandavison dandavison Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, very good point! (Fixed same bug in Typescript version). I think this actually improves the sample a lot, since it may not be obvious to users that they can implement the main "loop" so simply, and it makes continue-as-new usage clearer and less intimidating.

await self.perform_health_checks()
try:
await workflow.wait_condition(
lambda: self.state.cluster_shutdown
or self.should_continue_as_new(),
timeout=timedelta(seconds=self.sleep_interval_seconds),
)
except asyncio.TimeoutError:
pass
await workflow.wait_condition(
lambda: self.state.cluster_shutdown or self.should_continue_as_new()
)
if self.state.cluster_shutdown:
break
if self.should_continue_as_new():
await workflow.wait_condition(lambda: workflow.all_handlers_finished())
workflow.logger.info("Continuing as new")
workflow.continue_as_new(
ClusterManagerInput(
state=self.state,
test_continue_as_new=input.test_continue_as_new,
)
)
return ClusterManagerResult(
len(self.get_assigned_nodes()),
len(self.get_bad_nodes()),
)
return ClusterManagerResult(len(self.get_assigned_nodes()))
Loading