diff --git a/pkg/webhooks/pytorch/pytorchjob_webhook.go b/pkg/webhooks/pytorch/pytorchjob_webhook.go index 2459815935..a0e6b05409 100644 --- a/pkg/webhooks/pytorch/pytorchjob_webhook.go +++ b/pkg/webhooks/pytorch/pytorchjob_webhook.go @@ -93,13 +93,19 @@ func validatePyTorchJob(oldJob, newJob *trainingoperator.PyTorchJob) (admission. func validateSpec(spec trainingoperator.PyTorchJobSpec) (admission.Warnings, field.ErrorList) { var allErrs field.ErrorList var warnings admission.Warnings - - if spec.ElasticPolicy != nil && spec.ElasticPolicy.NProcPerNode != nil { - elasticNProcPerNodePath := specPath.Child("elasticPolicy").Child("nProcPerNode") - nprocPerNodePath := specPath.Child("nprocPerNode") - warnings = append(warnings, fmt.Sprintf("%s is deprecated, use %s instead", elasticNProcPerNodePath.String(), nprocPerNodePath.String())) - if spec.NprocPerNode != nil { - allErrs = append(allErrs, field.Forbidden(elasticNProcPerNodePath, fmt.Sprintf("must not be used with %s", nprocPerNodePath))) + if spec.ElasticPolicy != nil { + workerReplicaSpec, ok := spec.PyTorchReplicaSpecs[trainingoperator.PyTorchJobReplicaTypeWorker] + if !ok || (workerReplicaSpec.Replicas != nil && int(*workerReplicaSpec.Replicas) < 1) { + workerPath := specPath.Child("pytorchReplicaSpecs").Child("Worker") + allErrs = append(allErrs, field.Required(workerPath, "at least one worker must be configured if elastic policy is used")) + } + if spec.ElasticPolicy.NProcPerNode != nil { + elasticNProcPerNodePath := specPath.Child("elasticPolicy").Child("nProcPerNode") + nprocPerNodePath := specPath.Child("nprocPerNode") + warnings = append(warnings, fmt.Sprintf("%s is deprecated, use %s instead", elasticNProcPerNodePath.String(), nprocPerNodePath.String())) + if spec.NprocPerNode != nil { + allErrs = append(allErrs, field.Forbidden(elasticNProcPerNodePath, fmt.Sprintf("must not be used with %s", nprocPerNodePath))) + } } } allErrs = append(allErrs, validatePyTorchReplicaSpecs(spec.PyTorchReplicaSpecs)...) diff --git a/pkg/webhooks/pytorch/pytorchjob_webhook_test.go b/pkg/webhooks/pytorch/pytorchjob_webhook_test.go index 7757e36b3e..ac457c808b 100644 --- a/pkg/webhooks/pytorch/pytorchjob_webhook_test.go +++ b/pkg/webhooks/pytorch/pytorchjob_webhook_test.go @@ -88,6 +88,9 @@ func TestValidateV1PyTorchJob(t *testing.T) { RunPolicy: trainingoperator.RunPolicy{ ManagedBy: ptr.To(trainingoperator.KubeflowJobsController), }, + ElasticPolicy: &trainingoperator.ElasticPolicy{ + RDZVBackend: ptr.To(trainingoperator.BackendC10D), + }, PyTorchReplicaSpecs: validPyTorchReplicaSpecs, }, }, @@ -247,6 +250,19 @@ func TestValidateV1PyTorchJob(t *testing.T) { }, }, }, + trainingoperator.PyTorchJobReplicaTypeWorker: { + Replicas: ptr.To[int32](1), + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "pytorch", + Image: "gcr.io/kubeflow-ci/pytorch-dist-mnist_test:1.0", + }, + }, + }, + }, + }, }, }, }, @@ -279,6 +295,19 @@ func TestValidateV1PyTorchJob(t *testing.T) { }, }, }, + trainingoperator.PyTorchJobReplicaTypeWorker: { + Replicas: ptr.To[int32](1), + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "pytorch", + Image: "gcr.io/kubeflow-ci/pytorch-dist-mnist_test:1.0", + }, + }, + }, + }, + }, }, }, }, @@ -335,6 +364,62 @@ func TestValidateV1PyTorchJob(t *testing.T) { field.Invalid(field.NewPath("spec", "runPolicy", "managedBy"), trainingoperator.MultiKueueController, apivalidation.FieldImmutableErrorMsg), }, }, + "attempt to configure elasticPolicy when no worker is configured": { + pytorchJob: &trainingoperator.PyTorchJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.PyTorchJobSpec{ + ElasticPolicy: &trainingoperator.ElasticPolicy{}, + PyTorchReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.PyTorchJobReplicaTypeMaster: { + Replicas: ptr.To[int32](1), + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "pytorch", + Image: "gcr.io/kubeflow-ci/pytorch-dist-mnist_test:1.0", + }, + }, + }, + }, + }, + }, + }, + }, + wantErr: field.ErrorList{ + field.Required(field.NewPath("spec", "pytorchReplicaSpecs", "Worker"), "at least one worker must be configured if elastic policy is used"), + }, + }, + "attempt to configure elasticPolicy when worker replicas is 0": { + pytorchJob: &trainingoperator.PyTorchJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.PyTorchJobSpec{ + ElasticPolicy: &trainingoperator.ElasticPolicy{}, + PyTorchReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.PyTorchJobReplicaTypeWorker: { + Replicas: ptr.To[int32](0), + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "pytorch", + Image: "gcr.io/kubeflow-ci/pytorch-dist-mnist_test:1.0", + }, + }, + }, + }, + }, + }, + }, + }, + wantErr: field.ErrorList{ + field.Required(field.NewPath("spec", "pytorchReplicaSpecs", "Worker"), "at least one worker must be configured if elastic policy is used"), + }, + }, } for name, tc := range testCases {