Skip to content

Commit

Permalink
Parse runtime as object
Browse files Browse the repository at this point in the history
Signed-off-by: Andrey Velichkevich <[email protected]>
  • Loading branch information
andreyvelich committed Nov 29, 2024
1 parent e33549e commit 7f0df7b
Showing 1 changed file with 30 additions and 34 deletions.
64 changes: 30 additions & 34 deletions sdk_v2/kubeflow/training/api/training_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,61 +96,57 @@ def list_runtimes(self) -> List[types.Runtime]:
constants.CLUSTER_TRAINING_RUNTIME_PLURAL,
async_req=True,
)
# TODO (andreyvelich): We should de-serialize runtime into object.
# For that, we need to import the JobSet models.

response = thread.get(constants.DEFAULT_TIMEOUT)
for item in response["items"]:

runtime = self.api_client.deserialize(
utils.FakeResponse(item),
models.KubeflowOrgV2alpha1ClusterTrainingRuntime,
)
print(runtime)
ml_policy = runtime.spec.ml_policy # type: ignore
metadata = runtime.metadata # type: ignore

# TODO (andreyvelich): Currently, the labels must be presented.
if "labels" in item["metadata"]:
if metadata.labels:
# Get the Trainer container resources.
resources = None
for job in item["spec"]["template"]["spec"]["replicatedJobs"]:
if job["name"] == constants.JOB_TRAINER_NODE:
pod_spec = job["template"]["spec"]["template"]["spec"]
for container in pod_spec["containers"]:
if container["name"] == constants.CONTAINER_TRAINER:
if "resources" in container:
resources = client.V1ResourceRequirements(
**container["resources"]
)
for job in runtime.spec.template.spec.replicated_jobs: # type: ignore
if job.name == constants.JOB_TRAINER_NODE:
pod_spec = job.template.spec.template.spec
for container in pod_spec.containers:
if container.name == constants.CONTAINER_TRAINER:
resources = container.resources

# TODO (andreyvelich): Currently, only Torch is supported for NumProcPerNode.
num_procs = None
if "torch" in item["spec"]["mlPolicy"]:
num_procs = item["spec"]["mlPolicy"]["torch"]["numProcPerNode"]
num_procs = (
ml_policy.torch.num_proc_per_node if ml_policy.torch else None
)

# Get the device count per Trainer node.
# TODO (andreyvelich): Currently, we get the device type from
# the runtime labels.
_, device_count = utils.get_container_devices(resources, num_procs)
if device_count != constants.UNKNOWN:
device_count = str(
int(device_count)
* int(item["spec"]["mlPolicy"]["numNodes"])
device_count = str(int(device_count) * int(ml_policy.num_nodes))

result.append(
types.Runtime(
name=metadata.name,
phase=(
metadata.labels[constants.PHASE_KEY]
if constants.PHASE_KEY in metadata.labels
else constants.UNKNOWN
),
device=(
metadata.labels[constants.DEVICE_KEY]
if constants.DEVICE_KEY in metadata.labels
else constants.UNKNOWN
),
device_count=device_count,
)

runtime = types.Runtime(
name=item["metadata"]["name"],
phase=(
item["metadata"]["labels"][constants.PHASE_KEY]
if constants.PHASE_KEY in item["metadata"]["labels"]
else constants.UNKNOWN
),
device=(
item["metadata"]["labels"][constants.DEVICE_KEY]
if constants.DEVICE_KEY in item["metadata"]["labels"]
else constants.UNKNOWN
),
device_count=device_count,
)

result.append(runtime)
except multiprocessing.TimeoutError:
raise TimeoutError(
f"Timeout to list {constants.CLUSTER_TRAINING_RUNTIME_KIND}s "
Expand Down

0 comments on commit 7f0df7b

Please sign in to comment.