diff --git a/manifests/base/webhook/manifests.yaml b/manifests/base/webhook/manifests.yaml index ea3fad7c4c..405ce8bbf4 100644 --- a/manifests/base/webhook/manifests.yaml +++ b/manifests/base/webhook/manifests.yaml @@ -24,3 +24,23 @@ webhooks: resources: - pytorchjobs sideEffects: None +- admissionReviewVersions: + - v1 + clientConfig: + service: + name: webhook-service + namespace: system + path: /validate-kubeflow-org-v1-tfjob + failurePolicy: Fail + name: validator.tfjob.training-operator.kubeflow.org + rules: + - apiGroups: + - kubeflow.org + apiVersions: + - v1 + operations: + - CREATE + - UPDATE + resources: + - tfjobs + sideEffects: None diff --git a/manifests/base/webhook/patch.yaml b/manifests/base/webhook/patch.yaml index 060dfc8a52..bfce37bef6 100644 --- a/manifests/base/webhook/patch.yaml +++ b/manifests/base/webhook/patch.yaml @@ -1,6 +1,9 @@ - op: replace path: /webhooks/0/clientConfig/service/name value: training-operator +- op: replace + path: /webhooks/1/clientConfig/service/name + value: training-operator - op: replace path: /metadata/name value: validator.training-operator.kubeflow.org diff --git a/pkg/apis/kubeflow.org/v1/tensorflow_types.go b/pkg/apis/kubeflow.org/v1/tensorflow_types.go index 64da305827..7ee0a81d3a 100644 --- a/pkg/apis/kubeflow.org/v1/tensorflow_types.go +++ b/pkg/apis/kubeflow.org/v1/tensorflow_types.go @@ -122,9 +122,14 @@ const ( TFJobReplicaTypeEval ReplicaType = "Evaluator" ) +// IsChiefOrMaster returns true if the type is Master or Chief. +func IsChiefOrMaster(typ ReplicaType) bool { + return typ == TFJobReplicaTypeChief || typ == TFJobReplicaTypeMaster +} + // +k8s:deepcopy-gen:interfaces=k8s.io/apimachinery/pkg/runtime.Object // +resource:path=tfjobs -//+kubebuilder:object:root=true +// +kubebuilder:object:root=true // TFJobList is a list of TFJobs. type TFJobList struct { diff --git a/pkg/apis/kubeflow.org/v1/tensorflow_types_test.go b/pkg/apis/kubeflow.org/v1/tensorflow_types_test.go new file mode 100644 index 0000000000..4f4e289843 --- /dev/null +++ b/pkg/apis/kubeflow.org/v1/tensorflow_types_test.go @@ -0,0 +1,45 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package v1 + +import "testing" + +func TestIsChiefOrMaster(t *testing.T) { + tc := []struct { + Type ReplicaType + Expected bool + }{ + { + Type: TFJobReplicaTypeChief, + Expected: true, + }, + { + Type: TFJobReplicaTypeMaster, + Expected: true, + }, + { + Type: TFJobReplicaTypeWorker, + Expected: false, + }, + } + for _, c := range tc { + actual := IsChiefOrMaster(c.Type) + if actual != c.Expected { + t.Errorf("Expected %v; Got %v", c.Expected, actual) + } + } +} diff --git a/pkg/apis/kubeflow.org/v1/tensorflow_validation.go b/pkg/apis/kubeflow.org/v1/tensorflow_validation.go deleted file mode 100644 index a73721ae69..0000000000 --- a/pkg/apis/kubeflow.org/v1/tensorflow_validation.go +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright 2018 The Kubeflow Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package v1 - -import ( - "fmt" - - log "github.com/sirupsen/logrus" - apimachineryvalidation "k8s.io/apimachinery/pkg/api/validation" -) - -func ValidateV1TFJob(tfjob *TFJob) error { - if errors := apimachineryvalidation.NameIsDNS1035Label(tfjob.ObjectMeta.Name, false); errors != nil { - return fmt.Errorf("TFJob name is invalid: %v", errors) - } - if err := validateV1TFReplicaSpecs(tfjob.Spec.TFReplicaSpecs); err != nil { - return err - } - return nil -} - -// IsChieforMaster returns true if the type is Master or Chief. -func IsChieforMaster(typ ReplicaType) bool { - return typ == TFJobReplicaTypeChief || typ == TFJobReplicaTypeMaster -} - -// IsWorker returns true if the type is Worker. -func IsWorker(typ ReplicaType) bool { - return typ == TFJobReplicaTypeWorker -} - -// IsEvaluator returns true if the type is Evaluator. -func IsEvaluator(typ ReplicaType) bool { - return typ == TFJobReplicaTypeEval -} - -func validateV1TFReplicaSpecs(specs map[ReplicaType]*ReplicaSpec) error { - if specs == nil { - return fmt.Errorf("TFJobSpec is not valid") - } - foundChief := 0 - for rType, value := range specs { - if value == nil || len(value.Template.Spec.Containers) == 0 { - return fmt.Errorf("TFJobSpec is not valid: containers definition expected in %v", rType) - } - if IsChieforMaster(rType) { - foundChief++ - } - // Make sure the image is defined in the container. - numNamedTensorflow := 0 - for _, container := range value.Template.Spec.Containers { - if container.Image == "" { - msg := fmt.Sprintf("TFJobSpec is not valid: Image is undefined in the container of %v", rType) - log.Error(msg) - return fmt.Errorf(msg) - } - if container.Name == TFJobDefaultContainerName { - numNamedTensorflow++ - } - } - // Make sure there has at least one container named "tensorflow". - if numNamedTensorflow == 0 { - msg := fmt.Sprintf("TFJobSpec is not valid: There is no container named %s in %v", TFJobDefaultContainerName, rType) - log.Error(msg) - return fmt.Errorf(msg) - } - } - if foundChief > 1 { - return fmt.Errorf("TFJobSpec is not valid: more than 1 chief/master found") - } - return nil -} diff --git a/pkg/apis/kubeflow.org/v1/tensorflow_validation_test.go b/pkg/apis/kubeflow.org/v1/tensorflow_validation_test.go deleted file mode 100644 index db54308547..0000000000 --- a/pkg/apis/kubeflow.org/v1/tensorflow_validation_test.go +++ /dev/null @@ -1,197 +0,0 @@ -// Copyright 2018 The Kubeflow Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package v1 - -import ( - "testing" - - corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/utils/ptr" -) - -func TestValidateV1TFJob(t *testing.T) { - validTFReplicaSpecs := map[ReplicaType]*ReplicaSpec{ - TFJobReplicaTypeWorker: { - Replicas: ptr.To[int32](2), - RestartPolicy: RestartPolicyOnFailure, - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{{ - Name: "tensorflow", - Image: "kubeflow/tf-mnist-with-summaries:latest", - Command: []string{ - "python", - "/var/tf_mnist/mnist_with_summaries.py", - }, - }}, - }, - }, - }, - } - - testCases := map[string]struct { - tfJob *TFJob - wantErr bool - }{ - "valid tfJob": { - tfJob: &TFJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - }, - Spec: TFJobSpec{ - TFReplicaSpecs: validTFReplicaSpecs, - }, - }, - wantErr: false, - }, - "TFJob name does not meet DNS1035": { - tfJob: &TFJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "00test", - }, - Spec: TFJobSpec{ - TFReplicaSpecs: validTFReplicaSpecs, - }, - }, - wantErr: true, - }, - "no containers": { - tfJob: &TFJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - }, - Spec: TFJobSpec{ - TFReplicaSpecs: map[ReplicaType]*ReplicaSpec{ - TFJobReplicaTypeWorker: { - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{}, - }, - }, - }, - }, - }, - }, - wantErr: true, - }, - "empty image": { - tfJob: &TFJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - }, - Spec: TFJobSpec{ - TFReplicaSpecs: map[ReplicaType]*ReplicaSpec{ - TFJobReplicaTypeWorker: { - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{{ - Name: "tensorflow", - Image: "", - }}, - }, - }, - }, - }, - }, - }, - wantErr: true, - }, - "tfJob default container name doesn't present": { - tfJob: &TFJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - }, - Spec: TFJobSpec{ - TFReplicaSpecs: map[ReplicaType]*ReplicaSpec{ - TFJobReplicaTypeWorker: { - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{{ - Name: "", - Image: "kubeflow/tf-dist-mnist-test:1.0", - }}, - }, - }, - }, - }, - }, - }, - wantErr: true, - }, - "there are more than 2 masterReplica's or ChiefReplica's": { - tfJob: &TFJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - }, - Spec: TFJobSpec{ - TFReplicaSpecs: map[ReplicaType]*ReplicaSpec{ - TFJobReplicaTypeChief: { - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{}, - }, - }, - }, - TFJobReplicaTypeMaster: { - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{}, - }, - }, - }, - }, - }, - }, - wantErr: true, - }, - } - - for name, tc := range testCases { - t.Run(name, func(t *testing.T) { - got := ValidateV1TFJob(tc.tfJob) - if (got != nil) != tc.wantErr { - t.Fatalf("ValidateV1TFJob() error = %v, wantErr %v", got, tc.wantErr) - } - }) - } -} - -func TestIsChieforMaster(t *testing.T) { - tc := []struct { - Type ReplicaType - Expected bool - }{ - { - Type: TFJobReplicaTypeChief, - Expected: true, - }, - { - Type: TFJobReplicaTypeMaster, - Expected: true, - }, - { - Type: TFJobReplicaTypeWorker, - Expected: false, - }, - } - - for _, c := range tc { - actual := IsChieforMaster(c.Type) - if actual != c.Expected { - t.Errorf("Expected %v; Got %v", c.Expected, actual) - } - } -} diff --git a/pkg/cert/cert.go b/pkg/cert/cert.go index 4d1593fab5..cb400c1863 100644 --- a/pkg/cert/cert.go +++ b/pkg/cert/cert.go @@ -18,9 +18,10 @@ package cert import ( "fmt" + "os" + cert "github.com/open-policy-agent/cert-controller/pkg/rotator" "k8s.io/apimachinery/pkg/types" - "os" ctrl "sigs.k8s.io/controller-runtime" ) diff --git a/pkg/controller.v1/tensorflow/suite_test.go b/pkg/controller.v1/tensorflow/suite_test.go index 87cfc4b00d..c1824fea5c 100644 --- a/pkg/controller.v1/tensorflow/suite_test.go +++ b/pkg/controller.v1/tensorflow/suite_test.go @@ -16,7 +16,9 @@ package tensorflow import ( "context" + "crypto/tls" "fmt" + "net" "path/filepath" "testing" "time" @@ -24,6 +26,7 @@ import ( kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" "github.com/kubeflow/training-operator/pkg/controller.v1/common" "github.com/kubeflow/training-operator/pkg/util/testutil" + tensorflowwebhook "github.com/kubeflow/training-operator/pkg/webhooks/tensorflow" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -35,6 +38,7 @@ import ( logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server" + "sigs.k8s.io/controller-runtime/pkg/webhook" "volcano.sh/apis/pkg/apis/scheduling/v1beta1" //+kubebuilder:scaffold:imports ) @@ -65,6 +69,9 @@ var _ = BeforeSuite(func() { testEnv = &envtest.Environment{ CRDDirectoryPaths: []string{filepath.Join("..", "..", "..", "manifests", "base", "crds")}, ErrorIfCRDPathMissing: true, + WebhookInstallOptions: envtest.WebhookInstallOptions{ + Paths: []string{filepath.Join("..", "..", "..", "manifests", "base", "webhook", "manifests.yaml")}, + }, } cfg, err := testEnv.Start() @@ -86,12 +93,19 @@ var _ = BeforeSuite(func() { Metrics: metricsserver.Options{ BindAddress: "0", }, + WebhookServer: webhook.NewServer( + webhook.Options{ + Host: testEnv.WebhookInstallOptions.LocalServingHost, + Port: testEnv.WebhookInstallOptions.LocalServingPort, + CertDir: testEnv.WebhookInstallOptions.LocalServingCertDir, + }), }) Expect(err).NotTo(HaveOccurred()) gangSchedulingSetupFunc := common.GenNonGangSchedulerSetupFunc() reconciler = NewReconciler(mgr, gangSchedulingSetupFunc) Expect(reconciler.SetupWithManager(mgr, 1)).NotTo(HaveOccurred()) + Expect(tensorflowwebhook.SetupWebhook(mgr)).NotTo(HaveOccurred()) go func() { defer GinkgoRecover() @@ -99,6 +113,14 @@ var _ = BeforeSuite(func() { Expect(err).ToNot(HaveOccurred(), "failed to run manager") }() + dialer := &net.Dialer{Timeout: time.Second} + addrPort := fmt.Sprintf("%s:%d", testEnv.WebhookInstallOptions.LocalServingHost, testEnv.WebhookInstallOptions.LocalServingPort) + Eventually(func(g Gomega) { + conn, err := tls.DialWithDialer(dialer, "tcp", addrPort, &tls.Config{InsecureSkipVerify: true}) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(conn.Close()).NotTo(HaveOccurred()) + }).Should(Succeed()) + // This step is introduced to make sure cache starts before running any tests Eventually(func() error { nsList := &corev1.NamespaceList{} diff --git a/pkg/controller.v1/tensorflow/tfjob_controller.go b/pkg/controller.v1/tensorflow/tfjob_controller.go index bc6fa78e3f..a60d65affe 100644 --- a/pkg/controller.v1/tensorflow/tfjob_controller.go +++ b/pkg/controller.v1/tensorflow/tfjob_controller.go @@ -127,13 +127,6 @@ func (r *TFJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl return ctrl.Result{}, client.IgnoreNotFound(err) } - if err = kubeflowv1.ValidateV1TFJob(tfjob); err != nil { - logger.Error(err, "TFJob failed validation") - r.Recorder.Eventf(tfjob, corev1.EventTypeWarning, commonutil.NewReason(kubeflowv1.TFJobKind, commonutil.JobFailedValidationReason), - "TFJob failed validation because %s", err) - return ctrl.Result{}, err - } - // Check if reconciliation is needed jobKey, err := common.KeyFunc(tfjob) if err != nil { @@ -450,7 +443,7 @@ func (r *TFJobReconciler) UpdateJobStatus(job interface{}, replicas map[kubeflow // If the TFJob contains Chief or Master spec, then we will update the status // according to the Chief/Master spec. if ContainsChiefOrMasterSpec(tfJob.Spec.TFReplicaSpecs) { - if kubeflowv1.IsChieforMaster(rtype) { + if kubeflowv1.IsChiefOrMaster(rtype) { if running > 0 { msg := fmt.Sprintf("TFJob %s/%s is running.", tfJob.Namespace, tfJob.Name) commonutil.UpdateJobConditions(jobStatus, kubeflowv1.JobRunning, corev1.ConditionTrue, commonutil.NewReason(kubeflowv1.TFJobKind, commonutil.JobRunningReason), msg) diff --git a/pkg/webhooks/tensorflow/tfjob_webhook.go b/pkg/webhooks/tensorflow/tfjob_webhook.go new file mode 100644 index 0000000000..9e9629a1b8 --- /dev/null +++ b/pkg/webhooks/tensorflow/tfjob_webhook.go @@ -0,0 +1,121 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tensorflow + +import ( + "context" + "fmt" + "strings" + + apimachineryvalidation "k8s.io/apimachinery/pkg/api/validation" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/validation/field" + "k8s.io/klog/v2" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/webhook" + "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + + trainingoperator "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" +) + +var ( + specPath = field.NewPath("spec") + tfReplicaSpecPath = specPath.Child("tfReplicaSpecs") +) + +type Webhook struct{} + +func SetupWebhook(mgr ctrl.Manager) error { + return ctrl.NewWebhookManagedBy(mgr). + For(&trainingoperator.TFJob{}). + WithValidator(&Webhook{}). + Complete() +} + +// +kubebuilder:webhook:path=/validate-kubeflow-org-v1-tfjob,mutating=false,failurePolicy=fail,sideEffects=None,groups=kubeflow.org,resources=tfjobs,verbs=create;update,versions=v1,name=validator.tfjob.training-operator.kubeflow.org,admissionReviewVersions=v1 + +var _ webhook.CustomValidator = &Webhook{} + +func (w *Webhook) ValidateCreate(ctx context.Context, obj runtime.Object) (admission.Warnings, error) { + job := obj.(*trainingoperator.TFJob) + log := ctrl.LoggerFrom(ctx).WithName("tfjob-webhook") + log.V(5).Info("Validating create", "TFJob", klog.KObj(job)) + return nil, validateTFJob(job).ToAggregate() +} + +func (w *Webhook) ValidateUpdate(ctx context.Context, _, newObj runtime.Object) (admission.Warnings, error) { + job := newObj.(*trainingoperator.TFJob) + log := ctrl.LoggerFrom(ctx).WithName("tfjob-webhook") + log.V(5).Info("Validating update", "NewTFJob", klog.KObj(job)) + return nil, validateTFJob(job).ToAggregate() +} + +func (w *Webhook) ValidateDelete(context.Context, runtime.Object) (admission.Warnings, error) { + return nil, nil +} + +func validateTFJob(job *trainingoperator.TFJob) field.ErrorList { + var allErrs field.ErrorList + if errors := apimachineryvalidation.NameIsDNS1035Label(job.Name, false); len(errors) != 0 { + allErrs = append(allErrs, field.Invalid(field.NewPath("metadata").Child("name"), job.Name, fmt.Sprintf("should match: %v", strings.Join(errors, ",")))) + } + allErrs = append(allErrs, validateSpec(job.Spec)...) + return allErrs +} + +func validateSpec(spec trainingoperator.TFJobSpec) field.ErrorList { + return validateTFReplicaSpecs(spec.TFReplicaSpecs) +} + +func validateTFReplicaSpecs(rSpecs map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec) field.ErrorList { + var allErrs field.ErrorList + + if rSpecs == nil { + allErrs = append(allErrs, field.Required(tfReplicaSpecPath, "must be required")) + } + + chiefOrMaster := 0 + for rType, rSpec := range rSpecs { + rolePath := tfReplicaSpecPath.Key(string(rType)) + containerPath := rolePath.Child("template").Child("spec").Child("containers") + + if rSpec == nil || len(rSpec.Template.Spec.Containers) == 0 { + allErrs = append(allErrs, field.Required(containerPath, "must be specified")) + } + if trainingoperator.IsChiefOrMaster(rType) { + chiefOrMaster++ + } + // Make sure the image is defined in the container. + defaultContainerPresent := false + for idx, container := range rSpec.Template.Spec.Containers { + if container.Image == "" { + allErrs = append(allErrs, field.Required(containerPath.Index(idx).Child("image"), "must be required")) + } + if container.Name == trainingoperator.TFJobDefaultContainerName { + defaultContainerPresent = true + } + } + // Make sure there has at least one container named "tensorflow". + if !defaultContainerPresent { + allErrs = append(allErrs, field.Required(containerPath, fmt.Sprintf("must have at least one container with name %s", trainingoperator.TFJobDefaultContainerName))) + } + } + if chiefOrMaster > 1 { + allErrs = append(allErrs, field.Forbidden(tfReplicaSpecPath, "must not have more than 1 Chief or Master role")) + } + return allErrs +} diff --git a/pkg/webhooks/tensorflow/tfjob_webhook_test.go b/pkg/webhooks/tensorflow/tfjob_webhook_test.go new file mode 100644 index 0000000000..236d613295 --- /dev/null +++ b/pkg/webhooks/tensorflow/tfjob_webhook_test.go @@ -0,0 +1,192 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tensorflow + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/validation/field" + "k8s.io/utils/ptr" + + trainingoperator "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" +) + +func TestValidateTFJob(t *testing.T) { + validTFReplicaSpecs := map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.TFJobReplicaTypeWorker: { + Replicas: ptr.To[int32](2), + RestartPolicy: trainingoperator.RestartPolicyOnFailure, + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Name: "tensorflow", + Image: "kubeflow/tf-mnist-with-summaries:latest", + Command: []string{ + "python", + "/var/tf_mnist/mnist_with_summaries.py", + }, + }}, + }, + }, + }, + } + + testCases := map[string]struct { + tfJob *trainingoperator.TFJob + wantErr field.ErrorList + }{ + "valid tfJob": { + tfJob: &trainingoperator.TFJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.TFJobSpec{ + TFReplicaSpecs: validTFReplicaSpecs, + }, + }, + }, + "TFJob name does not meet DNS1035": { + tfJob: &trainingoperator.TFJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "00test", + }, + Spec: trainingoperator.TFJobSpec{ + TFReplicaSpecs: validTFReplicaSpecs, + }, + }, + wantErr: field.ErrorList{ + field.Invalid(field.NewPath("metadata").Child("name"), "", ""), + }, + }, + "no containers": { + tfJob: &trainingoperator.TFJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.TFJobSpec{ + TFReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.TFJobReplicaTypeWorker: { + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{}, + }, + }, + }, + }, + }, + }, + wantErr: field.ErrorList{ + field.Required(tfReplicaSpecPath.Key(string(trainingoperator.TFJobReplicaTypeWorker)).Child("template").Child("spec").Child("containers"), ""), + field.Required(tfReplicaSpecPath.Key(string(trainingoperator.TFJobReplicaTypeWorker)).Child("template").Child("spec").Child("containers"), ""), + }, + }, + "empty image": { + tfJob: &trainingoperator.TFJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.TFJobSpec{ + TFReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.TFJobReplicaTypeWorker: { + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Name: "tensorflow", + Image: "", + }}, + }, + }, + }, + }, + }, + }, + wantErr: field.ErrorList{ + field.Required(tfReplicaSpecPath.Key(string(trainingoperator.TFJobReplicaTypeWorker)).Child("template").Child("spec").Child("containers").Index(0).Child("image"), ""), + }, + }, + "tfJob default container name doesn't present": { + tfJob: &trainingoperator.TFJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.TFJobSpec{ + TFReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.TFJobReplicaTypeWorker: { + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Name: "", + Image: "kubeflow/tf-dist-mnist-test:1.0", + }}, + }, + }, + }, + }, + }, + }, + wantErr: field.ErrorList{ + field.Required(tfReplicaSpecPath.Key(string(trainingoperator.TFJobReplicaTypeWorker)).Child("template").Child("spec").Child("containers"), ""), + }, + }, + "there are more than 2 masterReplica's or ChiefReplica's": { + tfJob: &trainingoperator.TFJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.TFJobSpec{ + TFReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.TFJobReplicaTypeChief: { + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Name: "tensorflow", + Image: "kubeflow/tf-dist-mnist-test:1.0", + }}, + }, + }, + }, + trainingoperator.TFJobReplicaTypeMaster: { + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Name: "tensorflow", + Image: "kubeflow/tf-dist-mnist-test:1.0", + }}, + }, + }, + }, + }, + }, + }, + wantErr: field.ErrorList{ + field.Forbidden(tfReplicaSpecPath, ""), + }, + }, + } + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + got := validateTFJob(tc.tfJob) + if diff := cmp.Diff(tc.wantErr, got, cmpopts.IgnoreFields(field.Error{}, "Detail", "BadValue")); len(diff) != 0 { + t.Errorf("Unexpected error (-want,+got):\n%s", diff) + } + }) + } +} diff --git a/pkg/webhooks/webhooks.go b/pkg/webhooks/webhooks.go index 5e97a3d3f3..09ba671617 100644 --- a/pkg/webhooks/webhooks.go +++ b/pkg/webhooks/webhooks.go @@ -21,6 +21,7 @@ import ( trainingoperator "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" "github.com/kubeflow/training-operator/pkg/webhooks/pytorch" + "github.com/kubeflow/training-operator/pkg/webhooks/tensorflow" ) type WebhookSetupFunc func(manager manager.Manager) error @@ -28,7 +29,7 @@ type WebhookSetupFunc func(manager manager.Manager) error var ( SupportedSchemeWebhook = map[string]WebhookSetupFunc{ trainingoperator.PyTorchJobKind: pytorch.SetupWebhook, - trainingoperator.TFJobKind: scaffold, + trainingoperator.TFJobKind: tensorflow.SetupWebhook, trainingoperator.MXJobKind: scaffold, trainingoperator.XGBoostJobKind: scaffold, trainingoperator.MPIJobKind: scaffold,