Skip to content

Commit

Permalink
KEP-2170: Implement ValidateObjects interface to the runtime framework
Browse files Browse the repository at this point in the history
Signed-off-by: Yuki Iwai <[email protected]>
  • Loading branch information
tenzen-y committed Oct 10, 2024
1 parent 22da8af commit 23d7b44
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 6 deletions.
15 changes: 15 additions & 0 deletions pkg/runtime.v2/core/clustertrainingruntime.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ import (
"fmt"

"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/validation/field"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"

kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1"
runtime "github.com/kubeflow/training-operator/pkg/runtime.v2"
Expand Down Expand Up @@ -60,3 +62,16 @@ func (r *ClusterTrainingRuntime) NewObjects(ctx context.Context, trainJob *kubef
func (r *ClusterTrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {
return nil
}

func (r *ClusterTrainingRuntime) ValidateObjects(ctx context.Context, old, new *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {
if err := r.client.Get(ctx, client.ObjectKey{
Namespace: old.Namespace,
Name: old.Spec.TrainingRuntimeRef.Name,
}, &kubeflowv2.ClusterTrainingRuntime{}); err != nil {
return nil, field.ErrorList{
field.Invalid(field.NewPath("spec", "trainingRuntimeRef"), old.Spec.TrainingRuntimeRef,
fmt.Sprintf("%v: specified clusterTrainingRuntime must be created before the TrainJob is created", err)),
}
}
return r.framework.RunCustomValidationPlugins(old, new)
}
15 changes: 15 additions & 0 deletions pkg/runtime.v2/core/trainingruntime.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
apiruntime "k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/validation/field"
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"

kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1"
Expand Down Expand Up @@ -119,3 +121,16 @@ func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {
}
return builders
}

func (r *TrainingRuntime) ValidateObjects(ctx context.Context, old, new *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {
if err := r.client.Get(ctx, client.ObjectKey{
Namespace: old.Namespace,
Name: old.Spec.TrainingRuntimeRef.Name,
}, &kubeflowv2.TrainingRuntime{}); err != nil {
return nil, field.ErrorList{
field.Invalid(field.NewPath("spec", "trainingRuntimeRef"), old.Spec.TrainingRuntimeRef,
fmt.Sprintf("%v: specified trainingRuntime must be created before the TrainJob is created", err)),
}
}
return r.framework.RunCustomValidationPlugins(old, new)
}
4 changes: 2 additions & 2 deletions pkg/runtime.v2/framework/core/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (f *Framework) RunEnforcePodGroupPolicyPlugins(trainJob *kubeflowv2.TrainJo
return nil
}

func (f *Framework) RunCustomValidationPlugins(oldObj, newObj client.Object) (admission.Warnings, error) {
func (f *Framework) RunCustomValidationPlugins(oldObj, newObj client.Object) (admission.Warnings, field.ErrorList) {
var aggregatedWarnings admission.Warnings
var aggregatedErrors field.ErrorList
for _, plugin := range f.customValidationPlugins {
Expand All @@ -104,7 +104,7 @@ func (f *Framework) RunCustomValidationPlugins(oldObj, newObj client.Object) (ad
if len(aggregatedErrors) == 0 {
return aggregatedWarnings, nil
}
return aggregatedWarnings, aggregatedErrors.ToAggregate()
return aggregatedWarnings, aggregatedErrors
}

func (f *Framework) RunComponentBuilderPlugins(ctx context.Context, info *runtime.Info, trainJob *kubeflowv2.TrainJob) ([]client.Object, error) {
Expand Down
9 changes: 5 additions & 4 deletions pkg/runtime.v2/framework/core/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
apiruntime "k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/util/validation/field"
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
Expand Down Expand Up @@ -291,7 +292,7 @@ func TestRunCustomValidationPlugins(t *testing.T) {
oldObj client.Object
newObj client.Object
wantWarnings admission.Warnings
wantError error
wantError field.ErrorList
}{
// Need to implement more detail testing after we implement custom validator in any plugins.
"there are not any custom validations": {
Expand All @@ -300,7 +301,7 @@ func TestRunCustomValidationPlugins(t *testing.T) {
oldObj: testingutil.MakeTrainingRuntimeWrapper(t, metav1.NamespaceDefault, "test").Obj(),
newObj: testingutil.MakeTrainingRuntimeWrapper(t, metav1.NamespaceDefault, "test").Obj(),
},
"an empty registry": {
"an empty registry": {
trainJob: &kubeflowv2.TrainJob{ObjectMeta: metav1.ObjectMeta{Name: "test-job", Namespace: metav1.NamespaceDefault}},
oldObj: testingutil.MakeTrainingRuntimeWrapper(t, metav1.NamespaceDefault, "test").Obj(),
newObj: testingutil.MakeTrainingRuntimeWrapper(t, metav1.NamespaceDefault, "test").Obj(),
Expand All @@ -316,11 +317,11 @@ func TestRunCustomValidationPlugins(t *testing.T) {
if err != nil {
t.Fatal(err)
}
warnings, err := fwk.RunCustomValidationPlugins(tc.oldObj, tc.newObj)
warnings, errs := fwk.RunCustomValidationPlugins(tc.oldObj, tc.newObj)
if diff := cmp.Diff(tc.wantWarnings, warnings, cmpopts.SortSlices(func(a, b string) bool { return a < b })); len(diff) != 0 {
t.Errorf("Unexpected warninigs (-want,+got):\n%s", diff)
}
if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 {
if diff := cmp.Diff(tc.wantError, errs, cmpopts.IgnoreFields(field.Error{}, "Detail", "BadValue")); len(diff) != 0 {
t.Errorf("Unexpected error (-want,+got):\n%s", diff)
}
})
Expand Down
3 changes: 3 additions & 0 deletions pkg/runtime.v2/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ package runtimev2
import (
"context"

"k8s.io/apimachinery/pkg/util/validation/field"
"sigs.k8s.io/controller-runtime/pkg/builder"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"

kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1"
)
Expand All @@ -30,4 +32,5 @@ type ReconcilerBuilder func(*builder.Builder, client.Client) *builder.Builder
type Runtime interface {
NewObjects(ctx context.Context, trainJob *kubeflowv2.TrainJob) ([]client.Object, error)
EventHandlerRegistrars() []ReconcilerBuilder
ValidateObjects(ctx context.Context, old, new *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList)
}

0 comments on commit 23d7b44

Please sign in to comment.