Skip to content

Commit

Permalink
Simplifies ScheduleResults.job_verdicts (apple#872)
Browse files Browse the repository at this point in the history
* Simplifies job_verdicts so that its entries are ordered by the global priorities.

* Fixes unittest.

* Fixes unused import.

* Black

* Fixes bastion.py and cleaner.py.

* Adds a comment about the order of entries in ScheduleResults.job_verdicts.

* Adds a comment about the ordering of unscheduled jobs.

* Addresses Floris's comments.
  • Loading branch information
ruomingp authored Dec 9, 2024
1 parent 53f2cbb commit e4ed744
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 101 deletions.
103 changes: 53 additions & 50 deletions axlearn/cloud/common/bastion.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,11 @@ def _append_to_project_history(
):
now = datetime.now(timezone.utc)
for project_id, limits in schedule_results.project_limits.items():
job_verdicts = schedule_results.job_verdicts.get(project_id, {})
job_verdicts = {
job_id: verdict
for job_id, verdict in schedule_results.job_verdicts.items()
if jobs[job_id].project_id == project_id
}
verdicts = []
for job_id, verdict in job_verdicts.items():
verdicts.append((job_id, verdict.should_run(), verdict.metadata))
Expand Down Expand Up @@ -1076,56 +1080,55 @@ def _update_jobs(self):
verbosity=schedule_options["verbosity"],
)
self._append_to_project_history(schedulable_jobs, schedule_results)
for verdicts in schedule_results.job_verdicts.values():
for job_name, verdict in verdicts.items():
job = self._active_jobs[job_name]
assert job.state.status in {JobStatus.PENDING, JobStatus.ACTIVE}

if verdict:
old_tier = job.state.metadata.get("tier")
new_tier = verdict.metadata.get("tier")
changed_tiers = old_tier != new_tier

jobspec_changed = job.state.metadata.get("updated")

# Jobspec changed, trigger a restart of the runner.
if jobspec_changed:
self._append_to_job_history(
job,
msg="UPDATING: Detected updated jobspec. Will restart the runner "
"by sending to PENDING state",
state=JobLifecycleState.UPDATING,
)
job.state.status = JobStatus.PENDING
elif job.state.status == JobStatus.PENDING or not changed_tiers:
# Resume if not running, or keep running if scheduling tier did not change.
job.state.status = JobStatus.ACTIVE
else:
# Job changed scheduling tiers, and must be restarted on the new tier.
# NOTE: this can possibly lead to thrashing of jobs that frequently switch
# tiers. One option is track per-job tier changes and hold off on promoting
# low priority to high priority if it was demoted recently.
# TODO(markblee): Add instrumentation to track frequency of tier changes to
# see whether this is necessary.
assert job.state.status == JobStatus.ACTIVE and changed_tiers
self._append_to_job_history(
job,
msg=f"Rescheduling at a different tier from {old_tier} to {new_tier}",
state=JobLifecycleState.RESCHEDULING,
)
job.state.status = JobStatus.PENDING
for job_name, verdict in schedule_results.job_verdicts.items():
job = self._active_jobs[job_name]
assert job.state.status in {JobStatus.PENDING, JobStatus.ACTIVE}

if verdict:
old_tier = job.state.metadata.get("tier")
new_tier = verdict.metadata.get("tier")
changed_tiers = old_tier != new_tier

jobspec_changed = job.state.metadata.get("updated")

# Jobspec changed, trigger a restart of the runner.
if jobspec_changed:
self._append_to_job_history(
job,
msg="UPDATING: Detected updated jobspec. Will restart the runner "
"by sending to PENDING state",
state=JobLifecycleState.UPDATING,
)
job.state.status = JobStatus.PENDING
elif job.state.status == JobStatus.PENDING or not changed_tiers:
# Resume if not running, or keep running if scheduling tier did not change.
job.state.status = JobStatus.ACTIVE
else:
# Pre-empt/stay queued.
if job.command_proc is not None and _is_proc_complete(job.command_proc):
# As a slight optimization, we avoid pre-empting ACTIVE jobs that are
# complete, since we can directly transition to CLEANING.
job.state.status = JobStatus.ACTIVE
else:
job.state.status = JobStatus.PENDING
# Pending jobs which are not rescheduled should have no tier information.
verdict.metadata.pop("tier", None)

job.state.metadata = verdict.metadata
# Job changed scheduling tiers, and must be restarted on the new tier.
# NOTE: this can possibly lead to thrashing of jobs that frequently switch
# tiers. One option is track per-job tier changes and hold off on promoting
# low priority to high priority if it was demoted recently.
# TODO(markblee): Add instrumentation to track frequency of tier changes to
# see whether this is necessary.
assert job.state.status == JobStatus.ACTIVE and changed_tiers
self._append_to_job_history(
job,
msg=f"Rescheduling at a different tier from {old_tier} to {new_tier}",
state=JobLifecycleState.RESCHEDULING,
)
job.state.status = JobStatus.PENDING
else:
# Pre-empt/stay queued.
if job.command_proc is not None and _is_proc_complete(job.command_proc):
# As a slight optimization, we avoid pre-empting ACTIVE jobs that are
# complete, since we can directly transition to CLEANING.
job.state.status = JobStatus.ACTIVE
else:
job.state.status = JobStatus.PENDING
# Pending jobs which are not rescheduled should have no tier information.
verdict.metadata.pop("tier", None)

job.state.metadata = verdict.metadata

# TODO(markblee): Parallelize this.
for job_name, job in self._active_jobs.items():
Expand Down
2 changes: 1 addition & 1 deletion axlearn/cloud/common/cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,6 @@ def sweep(self, jobs: dict[str, JobSpec]) -> Sequence[str]:
schedule_result = scheduler.schedule(
dict(my_job=job_spec.metadata),
)
if schedule_result.job_verdicts[job_spec.metadata.project_id]["my_job"].over_limits:
if schedule_result.job_verdicts["my_job"].over_limits:
result.append(job_name)
return result
30 changes: 12 additions & 18 deletions axlearn/cloud/common/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,15 @@ class ScheduleResults:
Attributes:
project_limits: The effective resource limits.
project_usages: The resource usages.
job_verdicts: A mapping of project_id -> (job_id -> run_or_not).
job_verdicts: A mapping of job_id -> run_or_not.
The entries will be ordered by descending scheduling priorities (not necessarily
JobMetadata.priority), where the higher priority jobs will be scheduled before
lower priority ones. The jobs not getting scheduled will also be ordered.
"""

project_limits: ProjectResourceMap[int]
project_usages: ProjectResourceMap[int]
job_verdicts: dict[str, dict[str, JobVerdict]]
job_verdicts: dict[str, JobVerdict]

def schedule(
self,
Expand Down Expand Up @@ -345,8 +348,8 @@ def traverse_tiers(
break
return tier_usages

job_verdicts = collections.defaultdict(dict)
while not project_queue.empty() and remaining_limits:
job_verdicts = {}
while not project_queue.empty():
project_usage_ratio, _, project_id = project_queue.get()
job_id, job_metadata = project_jobs[project_id].popleft()

Expand Down Expand Up @@ -381,17 +384,10 @@ def traverse_tiers(
"Schedule %s(%s)/%s: %s", project_id, project_usage_ratio, job_id, verdict
)

job_verdicts[project_id][job_id] = verdict
job_verdicts[job_id] = verdict
if project_jobs[project_id]:
project_queue.put(project_queue_item(project_id))

# Remaining jobs are rejected.
for project_id, job_queue in project_jobs.items():
for job_id, job_metadata in job_queue:
job_verdicts[project_id][job_id] = JobVerdict(
over_limits=set(job_metadata.resources.keys())
)

return BaseScheduler.ScheduleResults(
# Treat the usages as the limits.
project_limits=_recursively_to_dict(project_usages),
Expand Down Expand Up @@ -472,14 +468,15 @@ def schedule(
logging.info("")
logging.info("==Begin scheduling report")
logging.info("Total resource limits: %s", resource_limits)
for project_id, project_verdicts in schedule_results.job_verdicts.items():
for project_id, project_job_queue in project_jobs.items():
logging.info(
"Verdicts for Project [%s] Quota [%s] Effective limits [%s]:",
project_id,
project_quotas.get(project_id, {}),
schedule_results.project_limits.get(project_id, {}),
)
for job_name, job_verdict in project_verdicts.items():
for job_name, job_metadata in project_job_queue:
job_verdict = schedule_results.job_verdicts[job_name]
logging.info(
"Job %s: Resources [%s] Over limits [%s] Should Run? [%s] Metadata [%s]",
job_name,
Expand All @@ -500,9 +497,6 @@ def schedule(
schedule_results = BaseScheduler.ScheduleResults(
project_limits=schedule_results.project_limits,
project_usages=project_usages,
job_verdicts={
project_id: {job_name: JobVerdict() for job_name in project_verdicts}
for project_id, project_verdicts in schedule_results.job_verdicts.items()
},
job_verdicts={job_name: JobVerdict() for job_name in schedule_results.job_verdicts},
)
return schedule_results
Loading

0 comments on commit e4ed744

Please sign in to comment.