Skip to content

Commit

Permalink
Adding v2 trainjob validation webhook
Browse files Browse the repository at this point in the history
fixing runtime

Signed-off-by: Akshay Chitneni <[email protected]>
  • Loading branch information
Akshay Chitneni committed Nov 11, 2024
1 parent 9e46f9d commit 241c4f1
Show file tree
Hide file tree
Showing 19 changed files with 435 additions and 80 deletions.
6 changes: 6 additions & 0 deletions pkg/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ const (
// JobInitializer is the Job name for the initializer.
JobInitializer string = "initializer"

// JobExporter is the Job name for the exporter.
JobExporter string = "exporter"

// ContainerModelInitializer is the container name for the model initializer.
ContainerModelInitializer string = "model-initializer"

Expand All @@ -51,6 +54,9 @@ const (

// TorchEnvMasterPort is the env name for the master node port.
TorchEnvMasterPort string = "PET_MASTER_PORT"

// TorchEnvNamePrefix is the env name prefix for the distributed envs for torchrun.
TorchEnvNamePrefix = "PET_"
)

var (
Expand Down
16 changes: 3 additions & 13 deletions pkg/controller.v2/trainjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,21 @@ package controllerv2

import (
"context"
"errors"
"fmt"

"github.com/go-logr/logr"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/client-go/tools/record"
"k8s.io/klog/v2"
"k8s.io/utils/ptr"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/client/apiutil"

kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1"
jobruntimes "github.com/kubeflow/training-operator/pkg/runtime.v2"
runtimeutil "github.com/kubeflow/training-operator/pkg/util.v2/runtime"
)

var errorUnsupportedRuntime = errors.New("the specified runtime is not supported")

type TrainJobReconciler struct {
log logr.Logger
client client.Client
Expand Down Expand Up @@ -73,10 +70,10 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
func (r *TrainJobReconciler) createOrUpdateObjs(ctx context.Context, trainJob *kubeflowv2.TrainJob) error {
log := ctrl.LoggerFrom(ctx)

runtimeRefGK := runtimeRefToGroupKind(trainJob.Spec.RuntimeRef).String()
runtimeRefGK := runtimeutil.RuntimeRefToGroupKind(trainJob.Spec.RuntimeRef).String()
runtime, ok := r.runtimes[runtimeRefGK]
if !ok {
return fmt.Errorf("%w: %s", errorUnsupportedRuntime, runtimeRefGK)
return fmt.Errorf("%w: %s", runtimeutil.ErrorUnsupportedRuntime, runtimeRefGK)
}
objs, err := runtime.NewObjects(ctx, trainJob)
if err != nil {
Expand Down Expand Up @@ -117,13 +114,6 @@ func (r *TrainJobReconciler) createOrUpdateObjs(ctx context.Context, trainJob *k
return nil
}

func runtimeRefToGroupKind(runtimeRef kubeflowv2.RuntimeRef) schema.GroupKind {
return schema.GroupKind{
Group: ptr.Deref(runtimeRef.APIGroup, ""),
Kind: ptr.Deref(runtimeRef.Kind, ""),
}
}

func (r *TrainJobReconciler) SetupWithManager(mgr ctrl.Manager) error {
b := ctrl.NewControllerManagedBy(mgr).
For(&kubeflowv2.TrainJob{})
Expand Down
16 changes: 11 additions & 5 deletions pkg/runtime.v2/core/clustertrainingruntime.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"k8s.io/apimachinery/pkg/util/validation/field"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"

kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1"
runtime "github.com/kubeflow/training-operator/pkg/runtime.v2"
Expand Down Expand Up @@ -64,14 +65,19 @@ func (r *ClusterTrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBu
}

func (r *ClusterTrainingRuntime) ValidateObjects(ctx context.Context, old, new *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {
clusterTrainingRuntime := &kubeflowv2.ClusterTrainingRuntime{}
if err := r.client.Get(ctx, client.ObjectKey{
Namespace: old.Namespace,
Name: old.Spec.RuntimeRef.Name,
}, &kubeflowv2.ClusterTrainingRuntime{}); err != nil {
Namespace: new.Namespace,
Name: new.Spec.RuntimeRef.Name,
}, clusterTrainingRuntime); err != nil {
return nil, field.ErrorList{
field.Invalid(field.NewPath("spec", "RuntimeRef"), old.Spec.RuntimeRef,
field.Invalid(field.NewPath("spec", "RuntimeRef"), new.Spec.RuntimeRef,
fmt.Sprintf("%v: specified clusterTrainingRuntime must be created before the TrainJob is created", err)),
}
}
return r.framework.RunCustomValidationPlugins(old, new)
info := r.runtimeInfo(ctx, new, clusterTrainingRuntime.Spec.Template, clusterTrainingRuntime.Spec.MLPolicy, clusterTrainingRuntime.Spec.PodGroupPolicy)
jobSetTemplate := jobsetv1alpha2.JobSet{
Spec: clusterTrainingRuntime.Spec.Template.Spec,
}
return r.framework.RunCustomValidationPlugins(jobSetTemplate.DeepCopy(), info, old, new)
}
49 changes: 31 additions & 18 deletions pkg/runtime.v2/core/trainingruntime.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,26 @@ func (r *TrainingRuntime) NewObjects(ctx context.Context, trainJob *kubeflowv2.T
func (r *TrainingRuntime) buildObjects(
ctx context.Context, trainJob *kubeflowv2.TrainJob, jobSetTemplateSpec kubeflowv2.JobSetTemplateSpec, mlPolicy *kubeflowv2.MLPolicy, podGroupPolicy *kubeflowv2.PodGroupPolicy,
) ([]client.Object, error) {

info := r.runtimeInfo(ctx, trainJob, jobSetTemplateSpec, mlPolicy, podGroupPolicy)
if err := r.framework.RunEnforceMLPolicyPlugins(info, trainJob); err != nil {
return nil, err
}

if err := r.framework.RunEnforcePodGroupPolicyPlugins(info, trainJob); err != nil {
return nil, err
}

jobSetTemplate := jobsetv1alpha2.JobSet{
Spec: jobSetTemplateSpec.Spec,
}

return r.framework.RunComponentBuilderPlugins(ctx, jobSetTemplate.DeepCopy(), info, trainJob)
}

func (r *TrainingRuntime) runtimeInfo(
ctx context.Context, trainJob *kubeflowv2.TrainJob, jobSetTemplateSpec kubeflowv2.JobSetTemplateSpec, mlPolicy *kubeflowv2.MLPolicy, podGroupPolicy *kubeflowv2.PodGroupPolicy) *runtime.Info {

propagationLabels := jobSetTemplateSpec.Labels
if propagationLabels == nil && trainJob.Spec.Labels != nil {
propagationLabels = make(map[string]string, len(trainJob.Spec.Labels))
Expand Down Expand Up @@ -112,19 +132,7 @@ func (r *TrainingRuntime) buildObjects(

info := runtime.NewInfo(opts...)

if err := r.framework.RunEnforceMLPolicyPlugins(info, trainJob); err != nil {
return nil, err
}

if err := r.framework.RunEnforcePodGroupPolicyPlugins(info, trainJob); err != nil {
return nil, err
}

jobSetTemplate := jobsetv1alpha2.JobSet{
Spec: jobSetTemplateSpec.Spec,
}

return r.framework.RunComponentBuilderPlugins(ctx, jobSetTemplate.DeepCopy(), info, trainJob)
return info
}

func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {
Expand All @@ -136,14 +144,19 @@ func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {
}

func (r *TrainingRuntime) ValidateObjects(ctx context.Context, old, new *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {
trainingRuntime := &kubeflowv2.TrainingRuntime{}
if err := r.client.Get(ctx, client.ObjectKey{
Namespace: old.Namespace,
Name: old.Spec.RuntimeRef.Name,
}, &kubeflowv2.TrainingRuntime{}); err != nil {
Namespace: new.Namespace,
Name: new.Spec.RuntimeRef.Name,
}, trainingRuntime); err != nil {
return nil, field.ErrorList{
field.Invalid(field.NewPath("spec", "runtimeRef"), old.Spec.RuntimeRef,
field.Invalid(field.NewPath("spec", "runtimeRef"), new.Spec.RuntimeRef,
fmt.Sprintf("%v: specified trainingRuntime must be created before the TrainJob is created", err)),
}
}
return r.framework.RunCustomValidationPlugins(old, new)
info := r.runtimeInfo(ctx, new, trainingRuntime.Spec.Template, trainingRuntime.Spec.MLPolicy, trainingRuntime.Spec.PodGroupPolicy)
jobSetTemplate := jobsetv1alpha2.JobSet{
Spec: trainingRuntime.Spec.Template.Spec,
}
return r.framework.RunCustomValidationPlugins(jobSetTemplate.DeepCopy(), info, old, new)
}
4 changes: 2 additions & 2 deletions pkg/runtime.v2/framework/core/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,11 @@ func (f *Framework) RunEnforcePodGroupPolicyPlugins(info *runtime.Info, trainJob
return nil
}

func (f *Framework) RunCustomValidationPlugins(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {
func (f *Framework) RunCustomValidationPlugins(runtimeJobTemplate client.Object, info *runtime.Info, oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {
var aggregatedWarnings admission.Warnings
var aggregatedErrors field.ErrorList
for _, plugin := range f.customValidationPlugins {
warnings, errs := plugin.Validate(oldObj, newObj)
warnings, errs := plugin.Validate(runtimeJobTemplate, info, oldObj, newObj)
if len(warnings) != 0 {
aggregatedWarnings = append(aggregatedWarnings, warnings...)
}
Expand Down
5 changes: 4 additions & 1 deletion pkg/runtime.v2/framework/core/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ func TestNew(t *testing.T) {
customValidationPlugins: []framework.CustomValidationPlugin{
&mpi.MPI{},
&torch.Torch{},
&jobset.JobSet{},
},
watchExtensionPlugins: []framework.WatchExtensionPlugin{
&coscheduling.CoScheduling{},
Expand Down Expand Up @@ -364,7 +365,9 @@ func TestRunCustomValidationPlugins(t *testing.T) {
if err != nil {
t.Fatal(err)
}
warnings, errs := fwk.RunCustomValidationPlugins(tc.oldObj, tc.newObj)
runtimeInfo := runtime.NewInfo()
jobSetTemplate := testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test")
warnings, errs := fwk.RunCustomValidationPlugins(jobSetTemplate, runtimeInfo, tc.oldObj, tc.newObj)
if diff := cmp.Diff(tc.wantWarnings, warnings, cmpopts.SortSlices(func(a, b string) bool { return a < b })); len(diff) != 0 {
t.Errorf("Unexpected warninigs (-want,+got):\n%s", diff)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/runtime.v2/framework/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ type EnforceMLPolicyPlugin interface {

type CustomValidationPlugin interface {
Plugin
Validate(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList)
Validate(runtimeJobTemplate client.Object, runtimeInfo *runtime.Info, oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList)
}

type ComponentBuilderPlugin interface {
Expand Down
100 changes: 100 additions & 0 deletions pkg/runtime.v2/framework/plugins/jobset/jobset.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,23 @@ import (
"context"
"fmt"
"maps"
"slices"

"github.com/go-logr/logr"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/equality"
apierrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/api/meta"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
apiruntime "k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/validation/field"
"k8s.io/utils/ptr"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/builder"
"sigs.k8s.io/controller-runtime/pkg/client"
ctrlutil "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"

kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1"
Expand All @@ -50,6 +54,7 @@ type JobSet struct {

var _ framework.WatchExtensionPlugin = (*JobSet)(nil)
var _ framework.ComponentBuilderPlugin = (*JobSet)(nil)
var _ framework.CustomValidationPlugin = (*JobSet)(nil)

const Name = constants.JobSetKind

Expand Down Expand Up @@ -140,3 +145,98 @@ func (j *JobSet) ReconcilerBuilders() []runtime.ReconcilerBuilder {
},
}
}

func (j *JobSet) Validate(runtimeJobTemplate client.Object, runtimeInfo *runtime.Info, oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {

var allErrs field.ErrorList
specPath := field.NewPath("spec")
runtimeRefPath := specPath.Child("runtimeRef")

jobSet, ok := runtimeJobTemplate.(*jobsetv1alpha2.JobSet)
if !ok {
return nil, nil
}

if newObj.Spec.ModelConfig != nil && newObj.Spec.ModelConfig.Input != nil {
if !slices.ContainsFunc(jobSet.Spec.ReplicatedJobs, func(x jobsetv1alpha2.ReplicatedJob) bool {
return x.Name == constants.JobInitializer
}) {
allErrs = append(allErrs, field.Invalid(runtimeRefPath, newObj.Spec.RuntimeRef, fmt.Sprintf("trainingRuntime should have %s job when trainJob is configured with input modelConfig", constants.JobInitializer)))
} else {
for _, job := range jobSet.Spec.ReplicatedJobs {
if job.Name == constants.JobInitializer {
if !slices.ContainsFunc(job.Template.Spec.Template.Spec.InitContainers, func(x corev1.Container) bool {
return x.Name == constants.ContainerModelInitializer
}) {
allErrs = append(allErrs, field.Invalid(runtimeRefPath, newObj.Spec.RuntimeRef, fmt.Sprintf("trainingRuntime should have container with name - %s in the %s job", constants.ContainerModelInitializer, constants.JobInitializer)))
}
}
}
}
}

if newObj.Spec.DatasetConfig != nil {
if !slices.ContainsFunc(jobSet.Spec.ReplicatedJobs, func(x jobsetv1alpha2.ReplicatedJob) bool {
return x.Name == constants.JobInitializer
}) {
allErrs = append(allErrs, field.Invalid(runtimeRefPath, newObj.Spec.RuntimeRef, fmt.Sprintf("trainingRuntime should have %s job when trainJob is configured with input datasetConfig", constants.JobInitializer)))
} else {
for _, job := range jobSet.Spec.ReplicatedJobs {
if job.Name == constants.JobInitializer {
if !slices.ContainsFunc(job.Template.Spec.Template.Spec.InitContainers, func(x corev1.Container) bool {
return x.Name == constants.ContainerDatasetInitializer
}) {
allErrs = append(allErrs, field.Invalid(runtimeRefPath, newObj.Spec.RuntimeRef, fmt.Sprintf("trainingRuntime should have container with name - %s in the %s job", constants.ContainerDatasetInitializer, constants.JobInitializer)))
}
}
}
}
}

if len(newObj.Spec.PodSpecOverrides) != 0 {
podSpecOverridesPath := specPath.Child("podSpecOverrides")
jobsMap := map[string]bool{}
for _, job := range jobSet.Spec.ReplicatedJobs {
jobsMap[job.Name] = true
}
// validate if jobOverrides are valid
for idx, override := range newObj.Spec.PodSpecOverrides {
for _, job := range override.TargetJobs {
if _, found := jobsMap[job.Name]; !found {
allErrs = append(allErrs, field.Invalid(podSpecOverridesPath, newObj.Spec.PodSpecOverrides, fmt.Sprintf("job: %s, configured in the podOverride should be present in the referenced training runtime", job)))
}
}
if len(override.Containers) != 0 {
// validate if containerOverrides are valid
containerMap := map[string]bool{}
for _, job := range jobSet.Spec.ReplicatedJobs {
for _, container := range job.Template.Spec.Template.Spec.Containers {
containerMap[container.Name] = true
}
}
containerOverridePath := podSpecOverridesPath.Index(idx)
for _, container := range override.Containers {
if _, found := containerMap[container.Name]; !found {
allErrs = append(allErrs, field.Invalid(containerOverridePath, override.Containers, fmt.Sprintf("container: %s, configured in the containerOverride should be present in the referenced training runtime", container.Name)))
}
}
}
if len(override.InitContainers) != 0 {
// validate if initContainerOverrides are valid
initContainerMap := map[string]bool{}
for _, job := range jobSet.Spec.ReplicatedJobs {
for _, initContainer := range job.Template.Spec.Template.Spec.InitContainers {
initContainerMap[initContainer.Name] = true
}
}
initContainerOverridePath := podSpecOverridesPath.Index(idx)
for _, container := range override.Containers {
if _, found := initContainerMap[container.Name]; !found {
allErrs = append(allErrs, field.Invalid(initContainerOverridePath, override.InitContainers, fmt.Sprintf("initContainer: %s, configured in the initContainerOverride should be present in the referenced training runtime", container.Name)))
}
}
}
}
}
return nil, allErrs
}
16 changes: 13 additions & 3 deletions pkg/runtime.v2/framework/plugins/mpi/mpi.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package mpi

import (
"context"
"strconv"

"k8s.io/apimachinery/pkg/util/validation/field"
"sigs.k8s.io/controller-runtime/pkg/client"
Expand Down Expand Up @@ -55,7 +56,16 @@ func (m *MPI) EnforceMLPolicy(info *runtime.Info, trainJob *kubeflowv2.TrainJob)
return nil
}

// TODO: Need to implement validations for MPIJob.
func (m *MPI) Validate(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {
return nil, nil
func (m *MPI) Validate(runtimeJobTemplate client.Object, runtimeInfo *runtime.Info, oldJobObj, newJobObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {
var allErrs field.ErrorList
specPath := field.NewPath("spec")
if newJobObj.Spec.Trainer != nil {
numProcPerNodePath := specPath.Child("trainer").Child("numProcPerNode")
if runtimeInfo.RuntimePolicy.MLPolicy != nil && runtimeInfo.RuntimePolicy.MLPolicy.MPI != nil {
if _, err := strconv.Atoi(*newJobObj.Spec.Trainer.NumProcPerNode); err != nil {
allErrs = append(allErrs, field.Invalid(numProcPerNodePath, newJobObj.Spec.Trainer.NumProcPerNode, "should have an int value"))
}
}
}
return nil, allErrs
}
Loading

0 comments on commit 241c4f1

Please sign in to comment.