Skip to content

Commit

Permalink
Update pytorchjob webhook to validate elastic policy without worker
Browse files Browse the repository at this point in the history
Co-authored-by: ricardov1 <[email protected]>
Co-authored-by: alenawang <[email protected]>
  • Loading branch information
3 people committed Oct 31, 2024
1 parent 7c5ea70 commit 603bf22
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 7 deletions.
20 changes: 13 additions & 7 deletions pkg/webhooks/pytorch/pytorchjob_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,19 @@ func validatePyTorchJob(oldJob, newJob *trainingoperator.PyTorchJob) (admission.
func validateSpec(spec trainingoperator.PyTorchJobSpec) (admission.Warnings, field.ErrorList) {
var allErrs field.ErrorList
var warnings admission.Warnings

if spec.ElasticPolicy != nil && spec.ElasticPolicy.NProcPerNode != nil {
elasticNProcPerNodePath := specPath.Child("elasticPolicy").Child("nProcPerNode")
nprocPerNodePath := specPath.Child("nprocPerNode")
warnings = append(warnings, fmt.Sprintf("%s is deprecated, use %s instead", elasticNProcPerNodePath.String(), nprocPerNodePath.String()))
if spec.NprocPerNode != nil {
allErrs = append(allErrs, field.Forbidden(elasticNProcPerNodePath, fmt.Sprintf("must not be used with %s", nprocPerNodePath)))
if spec.ElasticPolicy != nil {
workerReplicaSpec, ok := spec.PyTorchReplicaSpecs[trainingoperator.PyTorchJobReplicaTypeWorker]
if !ok || (workerReplicaSpec.Replicas != nil && int(*workerReplicaSpec.Replicas) < 1) {
workerPath := specPath.Child("pytorchReplicaSpecs").Child("Worker")
allErrs = append(allErrs, field.Required(workerPath, "at least one worker must be configured if elastic policy is used"))
}
if spec.ElasticPolicy.NProcPerNode != nil {
elasticNProcPerNodePath := specPath.Child("elasticPolicy").Child("nProcPerNode")
nprocPerNodePath := specPath.Child("nprocPerNode")
warnings = append(warnings, fmt.Sprintf("%s is deprecated, use %s instead", elasticNProcPerNodePath.String(), nprocPerNodePath.String()))
if spec.NprocPerNode != nil {
allErrs = append(allErrs, field.Forbidden(elasticNProcPerNodePath, fmt.Sprintf("must not be used with %s", nprocPerNodePath)))
}
}
}
allErrs = append(allErrs, validatePyTorchReplicaSpecs(spec.PyTorchReplicaSpecs)...)
Expand Down
85 changes: 85 additions & 0 deletions pkg/webhooks/pytorch/pytorchjob_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ func TestValidateV1PyTorchJob(t *testing.T) {
RunPolicy: trainingoperator.RunPolicy{
ManagedBy: ptr.To(trainingoperator.KubeflowJobsController),
},
ElasticPolicy: &trainingoperator.ElasticPolicy{
RDZVBackend: ptr.To(trainingoperator.BackendC10D),
},
PyTorchReplicaSpecs: validPyTorchReplicaSpecs,
},
},
Expand Down Expand Up @@ -247,6 +250,19 @@ func TestValidateV1PyTorchJob(t *testing.T) {
},
},
},
trainingoperator.PyTorchJobReplicaTypeWorker: {
Replicas: ptr.To[int32](1),
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{
Name: "pytorch",
Image: "gcr.io/kubeflow-ci/pytorch-dist-mnist_test:1.0",
},
},
},
},
},
},
},
},
Expand Down Expand Up @@ -279,6 +295,19 @@ func TestValidateV1PyTorchJob(t *testing.T) {
},
},
},
trainingoperator.PyTorchJobReplicaTypeWorker: {
Replicas: ptr.To[int32](1),
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{
Name: "pytorch",
Image: "gcr.io/kubeflow-ci/pytorch-dist-mnist_test:1.0",
},
},
},
},
},
},
},
},
Expand Down Expand Up @@ -335,6 +364,62 @@ func TestValidateV1PyTorchJob(t *testing.T) {
field.Invalid(field.NewPath("spec", "runPolicy", "managedBy"), trainingoperator.MultiKueueController, apivalidation.FieldImmutableErrorMsg),
},
},
"attempt to configure elasticPolicy when no worker is configured": {
pytorchJob: &trainingoperator.PyTorchJob{
ObjectMeta: metav1.ObjectMeta{
Name: "test",
},
Spec: trainingoperator.PyTorchJobSpec{
ElasticPolicy: &trainingoperator.ElasticPolicy{},
PyTorchReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{
trainingoperator.PyTorchJobReplicaTypeMaster: {
Replicas: ptr.To[int32](1),
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{
Name: "pytorch",
Image: "gcr.io/kubeflow-ci/pytorch-dist-mnist_test:1.0",
},
},
},
},
},
},
},
},
wantErr: field.ErrorList{
field.Required(field.NewPath("spec", "pytorchReplicaSpecs", "Worker"), "at least one worker must be configured if elastic policy is used"),
},
},
"attempt to configure elasticPolicy when worker replicas is 0": {
pytorchJob: &trainingoperator.PyTorchJob{
ObjectMeta: metav1.ObjectMeta{
Name: "test",
},
Spec: trainingoperator.PyTorchJobSpec{
ElasticPolicy: &trainingoperator.ElasticPolicy{},
PyTorchReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{
trainingoperator.PyTorchJobReplicaTypeWorker: {
Replicas: ptr.To[int32](0),
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{
Name: "pytorch",
Image: "gcr.io/kubeflow-ci/pytorch-dist-mnist_test:1.0",
},
},
},
},
},
},
},
},
wantErr: field.ErrorList{
field.Required(field.NewPath("spec", "pytorchReplicaSpecs", "Worker"), "at least one worker must be configured if elastic policy is used"),
},
},
}

for name, tc := range testCases {
Expand Down

0 comments on commit 603bf22

Please sign in to comment.