Skip to content

Commit

Permalink
Rename RuntimeRef in runtime framework
Browse files Browse the repository at this point in the history
Signed-off-by: Andrey Velichkevich <[email protected]>
  • Loading branch information
andreyvelich committed Oct 17, 2024
1 parent b0719cb commit 9cebed1
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 24 deletions.
6 changes: 3 additions & 3 deletions pkg/runtime.v2/core/clustertrainingruntime.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func NewClusterTrainingRuntime(context.Context, client.Client, client.FieldIndex

func (r *ClusterTrainingRuntime) NewObjects(ctx context.Context, trainJob *kubeflowv2.TrainJob) ([]client.Object, error) {
var clTrainingRuntime kubeflowv2.ClusterTrainingRuntime
if err := r.client.Get(ctx, client.ObjectKey{Name: trainJob.Spec.TrainingRuntimeRef.Name}, &clTrainingRuntime); err != nil {
if err := r.client.Get(ctx, client.ObjectKey{Name: trainJob.Spec.RuntimeRef.Name}, &clTrainingRuntime); err != nil {
return nil, fmt.Errorf("%w: %w", errorNotFoundSpecifiedClusterTrainingRuntime, err)
}
return r.buildObjects(ctx, trainJob, clTrainingRuntime.Spec.Template, clTrainingRuntime.Spec.MLPolicy, clTrainingRuntime.Spec.PodGroupPolicy)
Expand All @@ -66,10 +66,10 @@ func (r *ClusterTrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBu
func (r *ClusterTrainingRuntime) ValidateObjects(ctx context.Context, old, new *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {
if err := r.client.Get(ctx, client.ObjectKey{
Namespace: old.Namespace,
Name: old.Spec.TrainingRuntimeRef.Name,
Name: old.Spec.RuntimeRef.Name,
}, &kubeflowv2.ClusterTrainingRuntime{}); err != nil {
return nil, field.ErrorList{
field.Invalid(field.NewPath("spec", "trainingRuntimeRef"), old.Spec.TrainingRuntimeRef,
field.Invalid(field.NewPath("spec", "RuntimeRef"), old.Spec.RuntimeRef,
fmt.Sprintf("%v: specified clusterTrainingRuntime must be created before the TrainJob is created", err)),
}
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/runtime.v2/core/clustertrainingruntime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
"succeeded to build JobSet and PodGroup": {
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
UID("uid").
TrainingRuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.ClusterTrainingRuntimeKind), "test-runtime").
RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.ClusterTrainingRuntimeKind), "test-runtime").
Trainer(
testingutil.MakeTrainJobTrainerWrapper().
ContainerImage("test:trainjob").
Expand Down Expand Up @@ -93,7 +93,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
"missing trainingRuntime resource": {
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
UID("uid").
TrainingRuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.ClusterTrainingRuntimeKind), "test-runtime").
RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.ClusterTrainingRuntimeKind), "test-runtime").
Trainer(
testingutil.MakeTrainJobTrainerWrapper().
ContainerImage("test:trainjob").
Expand Down
10 changes: 5 additions & 5 deletions pkg/runtime.v2/core/trainingruntime.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ var _ runtime.Runtime = (*TrainingRuntime)(nil)
var trainingRuntimeFactory *TrainingRuntime

func NewTrainingRuntime(ctx context.Context, c client.Client, indexer client.FieldIndexer) (runtime.Runtime, error) {
if err := indexer.IndexField(ctx, &kubeflowv2.TrainJob{}, idxer.TrainJobTrainingRuntimeRefKey, idxer.IndexTrainJobTrainingRuntime); err != nil {
if err := indexer.IndexField(ctx, &kubeflowv2.TrainJob{}, idxer.TrainJobRuntimeRefKey, idxer.IndexTrainJobTrainingRuntime); err != nil {
return nil, fmt.Errorf("setting index on TrainingRuntime for TrainJob: %w", err)
}
if err := indexer.IndexField(ctx, &kubeflowv2.TrainJob{}, idxer.TrainJobClusterTrainingRuntimeRefKey, idxer.IndexTrainJobClusterTrainingRuntime); err != nil {
if err := indexer.IndexField(ctx, &kubeflowv2.TrainJob{}, idxer.TrainJobClusterRuntimeRefKey, idxer.IndexTrainJobClusterTrainingRuntime); err != nil {
return nil, fmt.Errorf("setting index on ClusterTrainingRuntime for TrainJob: %w", err)
}
fwk, err := fwkcore.New(ctx, c, fwkplugins.NewRegistry(), indexer)
Expand All @@ -74,7 +74,7 @@ func NewTrainingRuntime(ctx context.Context, c client.Client, indexer client.Fie

func (r *TrainingRuntime) NewObjects(ctx context.Context, trainJob *kubeflowv2.TrainJob) ([]client.Object, error) {
var trainingRuntime kubeflowv2.TrainingRuntime
err := r.client.Get(ctx, client.ObjectKey{Namespace: trainJob.Namespace, Name: trainJob.Spec.TrainingRuntimeRef.Name}, &trainingRuntime)
err := r.client.Get(ctx, client.ObjectKey{Namespace: trainJob.Namespace, Name: trainJob.Spec.RuntimeRef.Name}, &trainingRuntime)
if err != nil {
return nil, fmt.Errorf("%w: %w", errorNotFoundSpecifiedTrainingRuntime, err)
}
Expand Down Expand Up @@ -139,10 +139,10 @@ func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {
func (r *TrainingRuntime) ValidateObjects(ctx context.Context, old, new *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {
if err := r.client.Get(ctx, client.ObjectKey{
Namespace: old.Namespace,
Name: old.Spec.TrainingRuntimeRef.Name,
Name: old.Spec.RuntimeRef.Name,
}, &kubeflowv2.TrainingRuntime{}); err != nil {
return nil, field.ErrorList{
field.Invalid(field.NewPath("spec", "trainingRuntimeRef"), old.Spec.TrainingRuntimeRef,
field.Invalid(field.NewPath("spec", "runtimeRef"), old.Spec.RuntimeRef,
fmt.Sprintf("%v: specified trainingRuntime must be created before the TrainJob is created", err)),
}
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/runtime.v2/core/trainingruntime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
"succeeded to build JobSet and PodGroup": {
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
UID("uid").
TrainingRuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), "test-runtime").
RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), "test-runtime").
SpecLabel("conflictLabel", "override").
SpecAnnotation("conflictAnnotation", "override").
Trainer(
Expand Down Expand Up @@ -100,7 +100,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
"missing trainingRuntime resource": {
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
UID("uid").
TrainingRuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), "test-runtime").
RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), "test-runtime").
Trainer(
testingutil.MakeTrainJobTrainerWrapper().
ContainerImage("test:trainjob").
Expand Down
4 changes: 2 additions & 2 deletions pkg/runtime.v2/framework/plugins/coscheduling/coscheduling.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,15 +207,15 @@ func (h *PodGroupRuntimeClassHandler) queueSuspendedTrainJobs(ctx context.Contex
var trainJobs []kubeflowv2.TrainJob
for _, trainingRuntime := range trainingRuntimes.Items {
var trainJobsWithTrainingRuntime kubeflowv2.TrainJobList
err := h.client.List(ctx, &trainJobsWithTrainingRuntime, client.MatchingFields{runtimeindexer.TrainJobTrainingRuntimeRefKey: trainingRuntime.Name})
err := h.client.List(ctx, &trainJobsWithTrainingRuntime, client.MatchingFields{runtimeindexer.TrainJobRuntimeRefKey: trainingRuntime.Name})
if err != nil {
return err
}
trainJobs = append(trainJobs, trainJobsWithTrainingRuntime.Items...)
}
for _, clusterTrainingRuntime := range clusterTrainingRuntimes.Items {
var trainJobsWithClTrainingRuntime kubeflowv2.TrainJobList
err := h.client.List(ctx, &trainJobsWithClTrainingRuntime, client.MatchingFields{runtimeindexer.TrainJobClusterTrainingRuntimeRefKey: clusterTrainingRuntime.Name})
err := h.client.List(ctx, &trainJobsWithClTrainingRuntime, client.MatchingFields{runtimeindexer.TrainJobClusterRuntimeRefKey: clusterTrainingRuntime.Name})
if err != nil {
return err
}
Expand Down
16 changes: 8 additions & 8 deletions pkg/runtime.v2/indexer/indexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,18 @@ import (
)

const (
TrainJobTrainingRuntimeRefKey = ".spec.trainingRuntimeRef.kind=trainingRuntime"
TrainJobClusterTrainingRuntimeRefKey = ".spec.trainingRuntimeRef.kind=clusterTrainingRuntime"
TrainJobRuntimeRefKey = ".spec.runtimeRef.kind=trainingRuntime"
TrainJobClusterRuntimeRefKey = ".spec.runtimeRef.kind=clusterTrainingRuntime"
)

func IndexTrainJobTrainingRuntime(obj client.Object) []string {
trainJob, ok := obj.(*kubeflowv2.TrainJob)
if !ok {
return nil
}
if ptr.Deref(trainJob.Spec.TrainingRuntimeRef.APIGroup, "") == kubeflowv2.GroupVersion.Group &&
ptr.Deref(trainJob.Spec.TrainingRuntimeRef.Kind, "") == kubeflowv2.TrainingRuntimeKind {
return []string{trainJob.Spec.TrainingRuntimeRef.Name}
if ptr.Deref(trainJob.Spec.RuntimeRef.APIGroup, "") == kubeflowv2.GroupVersion.Group &&
ptr.Deref(trainJob.Spec.RuntimeRef.Kind, "") == kubeflowv2.TrainingRuntimeKind {
return []string{trainJob.Spec.RuntimeRef.Name}
}
return nil
}
Expand All @@ -45,9 +45,9 @@ func IndexTrainJobClusterTrainingRuntime(obj client.Object) []string {
if !ok {
return nil
}
if ptr.Deref(trainJob.Spec.TrainingRuntimeRef.APIGroup, "") == kubeflowv2.GroupVersion.Group &&
ptr.Deref(trainJob.Spec.TrainingRuntimeRef.Kind, "") == kubeflowv2.ClusterTrainingRuntimeKind {
return []string{trainJob.Spec.TrainingRuntimeRef.Name}
if ptr.Deref(trainJob.Spec.RuntimeRef.APIGroup, "") == kubeflowv2.GroupVersion.Group &&
ptr.Deref(trainJob.Spec.RuntimeRef.Kind, "") == kubeflowv2.ClusterTrainingRuntimeKind {
return []string{trainJob.Spec.RuntimeRef.Name}
}
return nil
}
4 changes: 2 additions & 2 deletions pkg/util.v2/testing/wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,8 @@ func (t *TrainJobWrapper) Trainer(trainer *kubeflowv2.Trainer) *TrainJobWrapper
return t
}

func (t *TrainJobWrapper) TrainingRuntimeRef(gvk schema.GroupVersionKind, name string) *TrainJobWrapper {
t.Spec.TrainingRuntimeRef = kubeflowv2.TrainingRuntimeRef{
func (t *TrainJobWrapper) RuntimeRef(gvk schema.GroupVersionKind, name string) *TrainJobWrapper {
t.Spec.RuntimeRef = kubeflowv2.RuntimeRef{
APIGroup: &gvk.Group,
Kind: &gvk.Kind,
Name: name,
Expand Down

0 comments on commit 9cebed1

Please sign in to comment.