Skip to content

Commit

Permalink
Adding cel validations on trainjob crd
Browse files Browse the repository at this point in the history
Signed-off-by: Akshay Chitneni <[email protected]>
  • Loading branch information
Akshay Chitneni committed Oct 16, 2024
1 parent 126110f commit 9d739ea
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 8 deletions.
22 changes: 14 additions & 8 deletions manifests/v2/base/crds/kubeflow.org_trainjobs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -197,15 +197,14 @@ spec:
They will be merged with the TrainingRuntime values.
type: object
managedBy:
description: |-
ManagedBy is used to indicate the controller or entity that manages a TrainJob.
The value must be either an empty, `kubeflow.org/trainjob-controller` or
`kueue.x-k8s.io/multikueue`. The built-in TrainJob controller reconciles TrainJob which
don't have this field at all or the field value is the reserved string
`kubeflow.org/trainjob-controller`, but delegates reconciling TrainJobs
with a 'kueue.x-k8s.io/multikueue' to the Kueue. The field is immutable.
Defaults to `kubeflow.org/trainjob-controller`
default: kubeflow.org/trainjob-controller
type: string
x-kubernetes-validations:
- message: ManagedBy must be kubeflow.org/trainjob-controller or kueue.x-k8s.io/multikueue
if set
rule: self in ['kubeflow.org/trainjob-controller', 'kueue.x-k8s.io/multikueue']
- message: ManagedBy value is immutable
rule: self == oldSelf
modelConfig:
description: Configuration of the pre-trained and trained model.
properties:
Expand Down Expand Up @@ -2733,6 +2732,7 @@ spec:
type: object
type: array
suspend:
default: false
description: |-
Whether the controller should suspend the running TrainJob.
Defaults to false.
Expand Down Expand Up @@ -2941,16 +2941,22 @@ spec:
description: Reference to the training runtime.
properties:
apiGroup:
default: kubeflow.org
description: |-
APIGroup of the runtime being referenced.
Defaults to `kubeflow.org`.
type: string
kind:
default: ClusterTrainingRuntime
description: |-
Kind of the runtime being referenced.
It must be one of TrainingRuntime or ClusterTrainingRuntime.
Defaults to ClusterTrainingRuntime.
type: string
x-kubernetes-validations:
- message: Kind must be ClusterTrainingRuntime or TrainingRuntime
if set
rule: self in ['ClusterTrainingRuntime', 'TrainingRuntime']
name:
description: |-
Name of the runtime being referenced.
Expand Down
6 changes: 6 additions & 0 deletions pkg/apis/kubeflow.org/v2alpha1/trainingruntime_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package v2alpha1
import (
autoscalingv2 "k8s.io/api/autoscaling/v2"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"reflect"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"
)

Expand Down Expand Up @@ -210,6 +211,11 @@ const (
MPIImplementationMPICH MPIImplementation = "MPICH"
)

var (
TrainingRuntimeKind = reflect.TypeOf(TrainingRuntime{}).Name()
ClusterTrainingRuntimeKind = reflect.TypeOf(ClusterTrainingRuntime{}).Name()
)

func init() {
SchemeBuilder.Register(&ClusterTrainingRuntime{}, &ClusterTrainingRuntimeList{}, &TrainingRuntime{}, &TrainingRuntimeList{})
}
13 changes: 13 additions & 0 deletions pkg/apis/kubeflow.org/v2alpha1/trainjob_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package v2alpha1
import (
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"reflect"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"
)

Expand Down Expand Up @@ -82,6 +83,7 @@ type TrainJobSpec struct {

// Whether the controller should suspend the running TrainJob.
// Defaults to false.
// +kubebuilder:default=false
Suspend *bool `json:"suspend,omitempty"`

// ManagedBy is used to indicate the controller or entity that manages a TrainJob.
Expand All @@ -91,6 +93,10 @@ type TrainJobSpec struct {
// `kubeflow.org/trainjob-controller`, but delegates reconciling TrainJobs
// with a 'kueue.x-k8s.io/multikueue' to the Kueue. The field is immutable.
// Defaults to `kubeflow.org/trainjob-controller`

// +kubebuilder:default="kubeflow.org/trainjob-controller"
// +kubebuilder:validation:XValidation:rule="self in ['kubeflow.org/trainjob-controller', 'kueue.x-k8s.io/multikueue']", message="ManagedBy must be kubeflow.org/trainjob-controller or kueue.x-k8s.io/multikueue if set"
// +kubebuilder:validation:XValidation:rule="self == oldSelf", message="ManagedBy value is immutable"
ManagedBy *string `json:"managedBy,omitempty"`
}

Expand All @@ -103,11 +109,14 @@ type TrainingRuntimeRef struct {

// APIGroup of the runtime being referenced.
// Defaults to `kubeflow.org`.
// +kubebuilder:default="kubeflow.org"
APIGroup *string `json:"apiGroup,omitempty"`

// Kind of the runtime being referenced.
// It must be one of TrainingRuntime or ClusterTrainingRuntime.
// Defaults to ClusterTrainingRuntime.
// +kubebuilder:default="ClusterTrainingRuntime"
// +kubebuilder:validation:XValidation:rule="self in ['ClusterTrainingRuntime', 'TrainingRuntime']", message="Kind must be ClusterTrainingRuntime or TrainingRuntime if set"
Kind *string `json:"kind,omitempty"`
}

Expand Down Expand Up @@ -251,6 +260,10 @@ type TrainJobStatus struct {
ReplicatedJobsStatus []jobsetv1alpha2.ReplicatedJobStatus `json:"replicatedJobsStatus,omitempty"`
}

var (
TrainJobKind = reflect.TypeOf(TrainJob{}).Name()
)

func init() {
SchemeBuilder.Register(&TrainJob{}, &TrainJobList{})
}
117 changes: 117 additions & 0 deletions test/integration/controller.v2/trainjob_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (

var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() {
var ns *corev1.Namespace
apiGroup := kubeflowv2.GroupVersion.Group

ginkgo.BeforeAll(func() {
fwk = &framework.Framework{}
Expand Down Expand Up @@ -71,4 +72,120 @@ var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() {
gomega.Expect(k8sClient.Create(ctx, trainJob)).Should(gomega.Succeed())
})
})

ginkgo.When("TrainJob CR Validation", func() {
ginkgo.AfterEach(func() {
gomega.Expect(k8sClient.DeleteAllOf(ctx, &kubeflowv2.TrainJob{}, client.InNamespace(ns.Name))).Should(
gomega.Succeed())
})

ginkgo.It("Should succeed in creating TrainJob", func() {

managedBy := "kubeflow.org/trainjob-controller"

trainingRuntimeRef := kubeflowv2.TrainingRuntimeRef{
Name: "InvalidRuntimeRef",
APIGroup: &apiGroup,
Kind: &kubeflowv2.TrainingRuntimeKind,
}
jobSpec := kubeflowv2.TrainJobSpec{
TrainingRuntimeRef: trainingRuntimeRef,
ManagedBy: &managedBy,
}
trainJob := &kubeflowv2.TrainJob{
TypeMeta: metav1.TypeMeta{
APIVersion: kubeflowv2.SchemeGroupVersion.String(),
Kind: kubeflowv2.TrainJobKind,
},
ObjectMeta: metav1.ObjectMeta{
GenerateName: "valid-trainjob-",
Namespace: ns.Name,
},
Spec: jobSpec,
}

err := k8sClient.Create(ctx, trainJob)
gomega.Expect(err).Should(gomega.Succeed())
})

ginkgo.It("Should fail in creating TrainJob with invalid spec.trainingRuntimeRef", func() {

kind := "InvalidRuntime"

trainingRuntimeRef := kubeflowv2.TrainingRuntimeRef{
Name: "InvalidRuntimeRef",
APIGroup: &apiGroup,
Kind: &kind,
}
jobSpec := kubeflowv2.TrainJobSpec{
TrainingRuntimeRef: trainingRuntimeRef,
}
trainJob := &kubeflowv2.TrainJob{
TypeMeta: metav1.TypeMeta{
APIVersion: kubeflowv2.SchemeGroupVersion.String(),
Kind: kubeflowv2.TrainJobKind,
},
ObjectMeta: metav1.ObjectMeta{
GenerateName: "invalid-trainjob-",
Namespace: ns.Name,
},
Spec: jobSpec,
}
gomega.Expect(k8sClient.Create(ctx, trainJob)).To(gomega.MatchError(
gomega.ContainSubstring("spec.trainingRuntimeRef.kind: Invalid value")))
})

ginkgo.It("Should fail in creating TrainJob with invalid spec.managedBy", func() {
managedBy := "invalidManagedBy"
jobSpec := kubeflowv2.TrainJobSpec{
ManagedBy: &managedBy,
}
trainJob := &kubeflowv2.TrainJob{
TypeMeta: metav1.TypeMeta{
APIVersion: kubeflowv2.SchemeGroupVersion.String(),
Kind: kubeflowv2.TrainJobKind,
},
ObjectMeta: metav1.ObjectMeta{
Name: "invalid-trainjob",
Namespace: ns.Name,
},
Spec: jobSpec,
}
gomega.Expect(k8sClient.Create(ctx, trainJob)).To(gomega.MatchError(
gomega.ContainSubstring("spec.managedBy: Invalid value")))
})

ginkgo.It("Should fail in updating spec.managedBy", func() {

managedBy := "kubeflow.org/trainjob-controller"

trainingRuntimeRef := kubeflowv2.TrainingRuntimeRef{
Name: "InvalidRuntimeRef",
APIGroup: &apiGroup,
Kind: &kubeflowv2.TrainingRuntimeKind,
}
jobSpec := kubeflowv2.TrainJobSpec{
TrainingRuntimeRef: trainingRuntimeRef,
ManagedBy: &managedBy,
}
trainJob := &kubeflowv2.TrainJob{
TypeMeta: metav1.TypeMeta{
APIVersion: kubeflowv2.SchemeGroupVersion.String(),
Kind: kubeflowv2.TrainJobKind,
},
ObjectMeta: metav1.ObjectMeta{
Name: "job-with-failed-update",
Namespace: ns.Name,
},
Spec: jobSpec,
}

gomega.Expect(k8sClient.Create(ctx, trainJob)).Should(gomega.Succeed())
updatedManagedBy := "kueue.x-k8s.io/multikueue"
jobSpec.ManagedBy = &updatedManagedBy
trainJob.Spec = jobSpec
gomega.Expect(k8sClient.Update(ctx, trainJob)).To(gomega.MatchError(
gomega.ContainSubstring("ManagedBy value is immutable")))
})
})
})

0 comments on commit 9d739ea

Please sign in to comment.