Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate pytorchjob workers are configured when elasticpolicy is configured #2320

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions pkg/webhooks/pytorch/pytorchjob_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,22 @@ func validatePyTorchJob(oldJob, newJob *trainingoperator.PyTorchJob) (admission.
func validateSpec(spec trainingoperator.PyTorchJobSpec) (admission.Warnings, field.ErrorList) {
var allErrs field.ErrorList
var warnings admission.Warnings

if spec.ElasticPolicy != nil && spec.ElasticPolicy.NProcPerNode != nil {
elasticNProcPerNodePath := specPath.Child("elasticPolicy").Child("nProcPerNode")
nprocPerNodePath := specPath.Child("nprocPerNode")
warnings = append(warnings, fmt.Sprintf("%s is deprecated, use %s instead", elasticNProcPerNodePath.String(), nprocPerNodePath.String()))
if spec.NprocPerNode != nil {
allErrs = append(allErrs, field.Forbidden(elasticNProcPerNodePath, fmt.Sprintf("must not be used with %s", nprocPerNodePath)))
if spec.ElasticPolicy != nil {
workerReplicaSpec, ok := spec.PyTorchReplicaSpecs[trainingoperator.PyTorchJobReplicaTypeWorker]
workerPath := specPath.Child("pytorchReplicaSpecs").Child("Worker")
if !ok {
allErrs = append(allErrs, field.Required(workerPath, "must be configured if elastic policy is used"))
} else if workerReplicaSpec.Replicas != nil && int(*workerReplicaSpec.Replicas) < 1 {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need this check ? Isn't the default value for Replicas is 1 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We added this validation to cover the case where replicas is explicitly set to 0 since the default is only set when it is nil https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/defaulting_utils.go#L43

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should have a separate validation that disallow users to set Replicas < 1 since it is invalid option ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that makes sense- we will update it to a separate validation that won't depend on whether elastic policy is set. Is it okay to have both validations in the same PR or would you prefer we separate that one into its own PR?

Copy link
Member

@andreyvelich andreyvelich Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, it is fine to add this simple check in this PR if you can @tarat44!

Copy link
Contributor Author

@tarat44 tarat44 Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andreyvelich, we separated the validation for worker replicas being at least 1. I also realized that we had worker as an attribute, instead of a key in the pytorchReplicaSpecs map, in the error field path so we fixed it to use the map representation. It should be good to re-review now

workerReplicasPath := workerPath.Child("replicas")
allErrs = append(allErrs, field.Forbidden(workerReplicasPath, "must be at least 1 if elastic policy is used"))
tarat44 marked this conversation as resolved.
Show resolved Hide resolved
}
if spec.ElasticPolicy.NProcPerNode != nil {
elasticNProcPerNodePath := specPath.Child("elasticPolicy").Child("nProcPerNode")
nprocPerNodePath := specPath.Child("nprocPerNode")
warnings = append(warnings, fmt.Sprintf("%s is deprecated, use %s instead", elasticNProcPerNodePath.String(), nprocPerNodePath.String()))
if spec.NprocPerNode != nil {
allErrs = append(allErrs, field.Forbidden(elasticNProcPerNodePath, fmt.Sprintf("must not be used with %s", nprocPerNodePath)))
}
}
}
allErrs = append(allErrs, validatePyTorchReplicaSpecs(spec.PyTorchReplicaSpecs)...)
Expand Down
85 changes: 85 additions & 0 deletions pkg/webhooks/pytorch/pytorchjob_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ func TestValidateV1PyTorchJob(t *testing.T) {
RunPolicy: trainingoperator.RunPolicy{
ManagedBy: ptr.To(trainingoperator.KubeflowJobsController),
},
ElasticPolicy: &trainingoperator.ElasticPolicy{
RDZVBackend: ptr.To(trainingoperator.BackendC10D),
},
PyTorchReplicaSpecs: validPyTorchReplicaSpecs,
},
},
Expand Down Expand Up @@ -247,6 +250,19 @@ func TestValidateV1PyTorchJob(t *testing.T) {
},
},
},
trainingoperator.PyTorchJobReplicaTypeWorker: {
Replicas: ptr.To[int32](1),
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{
Name: "pytorch",
Image: "gcr.io/kubeflow-ci/pytorch-dist-mnist_test:1.0",
},
},
},
},
},
},
},
},
Expand Down Expand Up @@ -279,6 +295,19 @@ func TestValidateV1PyTorchJob(t *testing.T) {
},
},
},
trainingoperator.PyTorchJobReplicaTypeWorker: {
Replicas: ptr.To[int32](1),
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{
Name: "pytorch",
Image: "gcr.io/kubeflow-ci/pytorch-dist-mnist_test:1.0",
},
},
},
},
},
},
},
},
Expand Down Expand Up @@ -335,6 +364,62 @@ func TestValidateV1PyTorchJob(t *testing.T) {
field.Invalid(field.NewPath("spec", "runPolicy", "managedBy"), trainingoperator.MultiKueueController, apivalidation.FieldImmutableErrorMsg),
},
},
"attempt to configure elasticPolicy when no worker is configured": {
pytorchJob: &trainingoperator.PyTorchJob{
ObjectMeta: metav1.ObjectMeta{
Name: "test",
},
Spec: trainingoperator.PyTorchJobSpec{
ElasticPolicy: &trainingoperator.ElasticPolicy{},
PyTorchReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{
trainingoperator.PyTorchJobReplicaTypeMaster: {
Replicas: ptr.To[int32](1),
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{
Name: "pytorch",
Image: "gcr.io/kubeflow-ci/pytorch-dist-mnist_test:1.0",
},
},
},
},
},
},
},
},
wantErr: field.ErrorList{
field.Required(field.NewPath("spec", "pytorchReplicaSpecs", "Worker"), "must be configured if elastic policy is used"),
},
},
"attempt to configure elasticPolicy when worker replicas is 0": {
pytorchJob: &trainingoperator.PyTorchJob{
ObjectMeta: metav1.ObjectMeta{
Name: "test",
},
Spec: trainingoperator.PyTorchJobSpec{
ElasticPolicy: &trainingoperator.ElasticPolicy{},
PyTorchReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{
trainingoperator.PyTorchJobReplicaTypeWorker: {
Replicas: ptr.To[int32](0),
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{
Name: "pytorch",
Image: "gcr.io/kubeflow-ci/pytorch-dist-mnist_test:1.0",
},
},
},
},
},
},
},
},
wantErr: field.ErrorList{
field.Forbidden(field.NewPath("spec", "pytorchReplicaSpecs", "Worker", "replicas"), "must be at least 1 if elastic policy is used"),
},
},
}

for name, tc := range testCases {
Expand Down
Loading