From eb452fbafa3ba0a8a29e53f62155aaeccaa7e93c Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Tue, 26 Nov 2024 18:22:50 +0100 Subject: [PATCH] Use typed event handlers and predicates in job controllers Signed-off-by: Antonin Stefanutti --- pkg/common/util/reconciler.go | 114 +----------------- pkg/common/util/reconciler_generic.go | 26 ++-- pkg/common/util/reconciler_test.go | 59 --------- pkg/controller.v1/jax/jaxjob_controller.go | 30 ++--- pkg/controller.v1/mpi/mpijob_controller.go | 42 +++---- .../paddlepaddle/paddlepaddle_controller.go | 30 ++--- .../pytorch/pytorchjob_controller.go | 30 ++--- .../tensorflow/tfjob_controller.go | 30 ++--- .../xgboost/xgboostjob_controller.go | 30 ++--- 9 files changed, 99 insertions(+), 292 deletions(-) delete mode 100644 pkg/common/util/reconciler_test.go diff --git a/pkg/common/util/reconciler.go b/pkg/common/util/reconciler.go index f11d11eef2..73b223441d 100644 --- a/pkg/common/util/reconciler.go +++ b/pkg/common/util/reconciler.go @@ -15,17 +15,10 @@ package util import ( - "fmt" - "reflect" - - corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "sigs.k8s.io/controller-runtime/pkg/event" - kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" "github.com/kubeflow/training-operator/pkg/controller.v1/common" "github.com/kubeflow/training-operator/pkg/controller.v1/expectation" - commonutil "github.com/kubeflow/training-operator/pkg/util" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) // SatisfiedExpectations returns true if the required adds/dels for the given job have been observed. @@ -45,82 +38,6 @@ func SatisfiedExpectations(exp expectation.ControllerExpectationsInterface, jobK return satisfied } -// OnDependentCreateFunc modify expectations when dependent (pod/service) creation observed. -func OnDependentCreateFunc(exp expectation.ControllerExpectationsInterface) func(event.CreateEvent) bool { - return func(e event.CreateEvent) bool { - rtype := e.Object.GetLabels()[kubeflowv1.ReplicaTypeLabel] - if len(rtype) == 0 { - return false - } - - //logrus.Info("Update on create function ", ptjr.ControllerName(), " create object ", e.Object.GetName()) - if controllerRef := metav1.GetControllerOf(e.Object); controllerRef != nil { - jobKey := fmt.Sprintf("%s/%s", e.Object.GetNamespace(), controllerRef.Name) - var expectKey string - switch e.Object.(type) { - case *corev1.Pod: - expectKey = expectation.GenExpectationPodsKey(jobKey, rtype) - case *corev1.Service: - expectKey = expectation.GenExpectationServicesKey(jobKey, rtype) - default: - return false - } - exp.CreationObserved(expectKey) - return true - } - - return true - } -} - -// OnDependentUpdateFunc modify expectations when dependent (pod/service) update observed. -func OnDependentUpdateFunc(jc *common.JobController) func(updateEvent event.UpdateEvent) bool { - return func(e event.UpdateEvent) bool { - newObj := e.ObjectNew - oldObj := e.ObjectOld - if newObj.GetResourceVersion() == oldObj.GetResourceVersion() { - // Periodic resync will send update events for all known pods. - // Two different versions of the same pod will always have different RVs. - return false - } - - kind := jc.Controller.GetAPIGroupVersionKind().Kind - var logger = LoggerForGenericKind(newObj, kind) - - switch obj := newObj.(type) { - case *corev1.Pod: - logger = commonutil.LoggerForPod(obj, jc.Controller.GetAPIGroupVersionKind().Kind) - case *corev1.Service: - logger = commonutil.LoggerForService(newObj.(*corev1.Service), jc.Controller.GetAPIGroupVersionKind().Kind) - default: - return false - } - - newControllerRef := metav1.GetControllerOf(newObj) - oldControllerRef := metav1.GetControllerOf(oldObj) - controllerRefChanged := !reflect.DeepEqual(newControllerRef, oldControllerRef) - - if controllerRefChanged && oldControllerRef != nil { - // The ControllerRef was changed. Sync the old controller, if any. - if job := resolveControllerRef(jc, oldObj.GetNamespace(), oldControllerRef); job != nil { - logger.Infof("pod/service controller ref updated: %v, %v", newObj, oldObj) - return true - } - } - - // If it has a controller ref, that's all that matters. - if newControllerRef != nil { - job := resolveControllerRef(jc, newObj.GetNamespace(), newControllerRef) - if job == nil { - return false - } - logger.Debugf("pod/service has a controller ref: %v, %v", newObj, oldObj) - return true - } - return false - } -} - // resolveControllerRef returns the job referenced by a ControllerRef, // or nil if the ControllerRef could not be resolved to a matching job // of the correct Kind. @@ -141,32 +58,3 @@ func resolveControllerRef(jc *common.JobController, namespace string, controller } return job } - -// OnDependentDeleteFunc modify expectations when dependent (pod/service) deletion observed. -func OnDependentDeleteFunc(exp expectation.ControllerExpectationsInterface) func(event.DeleteEvent) bool { - return func(e event.DeleteEvent) bool { - - rtype := e.Object.GetLabels()[kubeflowv1.ReplicaTypeLabel] - if len(rtype) == 0 { - return false - } - - // logrus.Info("Update on deleting function ", xgbr.ControllerName(), " delete object ", e.Object.GetName()) - if controllerRef := metav1.GetControllerOf(e.Object); controllerRef != nil { - jobKey := fmt.Sprintf("%s/%s", e.Object.GetNamespace(), controllerRef.Name) - var expectKey string - switch e.Object.(type) { - case *corev1.Pod: - expectKey = expectation.GenExpectationPodsKey(jobKey, rtype) - case *corev1.Service: - expectKey = expectation.GenExpectationServicesKey(jobKey, rtype) - default: - return false - } - exp.DeletionObserved(expectKey) - return true - } - - return true - } -} diff --git a/pkg/common/util/reconciler_generic.go b/pkg/common/util/reconciler_generic.go index 4b3f737436..e955e16577 100644 --- a/pkg/common/util/reconciler_generic.go +++ b/pkg/common/util/reconciler_generic.go @@ -19,13 +19,14 @@ import ( "reflect" "strings" - log "github.com/sirupsen/logrus" - kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" "github.com/kubeflow/training-operator/pkg/controller.v1/common" "github.com/kubeflow/training-operator/pkg/controller.v1/expectation" + log "github.com/sirupsen/logrus" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/event" + "sigs.k8s.io/controller-runtime/pkg/predicate" ) // GenExpectationGenericKey generates an expectation key for {Kind} of a job @@ -50,9 +51,17 @@ func LoggerForGenericKind(obj metav1.Object, kind string) *log.Entry { }) } +func OnDependentFuncs[T client.Object](expectations expectation.ControllerExpectationsInterface, jobController *common.JobController) predicate.TypedFuncs[T] { + return predicate.TypedFuncs[T]{ + CreateFunc: OnDependentCreateFuncGeneric[T](expectations), + UpdateFunc: OnDependentUpdateFuncGeneric[T](jobController), + DeleteFunc: OnDependentDeleteFuncGeneric[T](expectations), + } +} + // OnDependentCreateFuncGeneric modify expectations when dependent (pod/service) creation observed. -func OnDependentCreateFuncGeneric(exp expectation.ControllerExpectationsInterface) func(event.CreateEvent) bool { - return func(e event.CreateEvent) bool { +func OnDependentCreateFuncGeneric[T client.Object](exp expectation.ControllerExpectationsInterface) func(createEvent event.TypedCreateEvent[T]) bool { + return func(e event.TypedCreateEvent[T]) bool { rtype := e.Object.GetLabels()[kubeflowv1.ReplicaTypeLabel] if len(rtype) == 0 { return false @@ -72,8 +81,8 @@ func OnDependentCreateFuncGeneric(exp expectation.ControllerExpectationsInterfac } // OnDependentUpdateFuncGeneric modify expectations when dependent (pod/service) update observed. -func OnDependentUpdateFuncGeneric(jc *common.JobController) func(updateEvent event.UpdateEvent) bool { - return func(e event.UpdateEvent) bool { +func OnDependentUpdateFuncGeneric[T client.Object](jc *common.JobController) func(updateEvent event.TypedUpdateEvent[T]) bool { + return func(e event.TypedUpdateEvent[T]) bool { newObj := e.ObjectNew oldObj := e.ObjectOld if newObj.GetResourceVersion() == oldObj.GetResourceVersion() { @@ -111,9 +120,8 @@ func OnDependentUpdateFuncGeneric(jc *common.JobController) func(updateEvent eve } // OnDependentDeleteFuncGeneric modify expectations when dependent (pod/service) deletion observed. -func OnDependentDeleteFuncGeneric(exp expectation.ControllerExpectationsInterface) func(event.DeleteEvent) bool { - return func(e event.DeleteEvent) bool { - +func OnDependentDeleteFuncGeneric[T client.Object](exp expectation.ControllerExpectationsInterface) func(event.TypedDeleteEvent[T]) bool { + return func(e event.TypedDeleteEvent[T]) bool { rtype := e.Object.GetLabels()[kubeflowv1.ReplicaTypeLabel] if len(rtype) == 0 { return false diff --git a/pkg/common/util/reconciler_test.go b/pkg/common/util/reconciler_test.go deleted file mode 100644 index 5442216889..0000000000 --- a/pkg/common/util/reconciler_test.go +++ /dev/null @@ -1,59 +0,0 @@ -package util - -import ( - "testing" - - corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "sigs.k8s.io/controller-runtime/pkg/client" - "sigs.k8s.io/controller-runtime/pkg/event" - - kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" - "github.com/kubeflow/training-operator/pkg/controller.v1/expectation" -) - -func TestOnDependentXXXFunc(t *testing.T) { - createfunc := OnDependentCreateFunc(expectation.NewControllerExpectations()) - deletefunc := OnDependentDeleteFunc(expectation.NewControllerExpectations()) - - for _, testCase := range []struct { - object client.Object - expect bool - }{ - { - // pod object with label is allowed - object: &corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - Labels: map[string]string{ - kubeflowv1.ReplicaTypeLabel: "Worker", - }, - }, - }, - expect: true, - }, - { - // service object without label is not allowed - object: &corev1.Service{}, - expect: false, - }, - { - // objects other than pod/service are not allowed - object: &corev1.ConfigMap{}, - expect: false, - }, - } { - ret := createfunc(event.CreateEvent{ - Object: testCase.object, - }) - if ret != testCase.expect { - t.Errorf("expect %t, but get %t", testCase.expect, ret) - } - ret = deletefunc(event.DeleteEvent{ - Object: testCase.object, - }) - if ret != testCase.expect { - t.Errorf("expect %t, but get %t", testCase.expect, ret) - } - - } -} diff --git a/pkg/controller.v1/jax/jaxjob_controller.go b/pkg/controller.v1/jax/jaxjob_controller.go index eb2016059e..22ab85b90b 100644 --- a/pkg/controller.v1/jax/jaxjob_controller.go +++ b/pkg/controller.v1/jax/jaxjob_controller.go @@ -182,33 +182,25 @@ func (r *JAXJobReconciler) SetupWithManager(mgr ctrl.Manager, controllerThreads ); err != nil { return err } - - // eventHandler for owned object - eventHandler := handler.EnqueueRequestForOwner(mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.JAXJob{}, handler.OnlyControllerOwner()) - predicates := predicate.Funcs{ - CreateFunc: util.OnDependentCreateFunc(r.Expectations), - UpdateFunc: util.OnDependentUpdateFunc(&r.JobController), - DeleteFunc: util.OnDependentDeleteFunc(r.Expectations), - } - // Create generic predicates - genericPredicates := predicate.Funcs{ - CreateFunc: util.OnDependentCreateFuncGeneric(r.Expectations), - UpdateFunc: util.OnDependentUpdateFuncGeneric(&r.JobController), - DeleteFunc: util.OnDependentDeleteFuncGeneric(r.Expectations), - } // inject watching for job related pod - if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &corev1.Pod{}, eventHandler, predicates)); err != nil { + if err = c.Watch(source.Kind[*corev1.Pod](mgr.GetCache(), &corev1.Pod{}, + handler.TypedEnqueueRequestForOwner[*corev1.Pod](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.JAXJob{}, handler.OnlyControllerOwner()), + util.OnDependentFuncs[*corev1.Pod](r.Expectations, &r.JobController))); err != nil { return err } // inject watching for job related service - if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &corev1.Service{}, eventHandler, predicates)); err != nil { + if err = c.Watch(source.Kind[*corev1.Service](mgr.GetCache(), &corev1.Service{}, + handler.TypedEnqueueRequestForOwner[*corev1.Service](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.JAXJob{}, handler.OnlyControllerOwner()), + util.OnDependentFuncs[*corev1.Service](r.Expectations, &r.JobController))); err != nil { return err } // skip watching volcano PodGroup if volcano PodGroup is not installed if _, err = mgr.GetRESTMapper().RESTMapping(schema.GroupKind{Group: v1beta1.GroupName, Kind: "PodGroup"}, v1beta1.SchemeGroupVersion.Version); err == nil { // inject watching for job related volcano PodGroup - if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &v1beta1.PodGroup{}, eventHandler, genericPredicates)); err != nil { + if err = c.Watch(source.Kind[*v1beta1.PodGroup](mgr.GetCache(), &v1beta1.PodGroup{}, + handler.TypedEnqueueRequestForOwner[*v1beta1.PodGroup](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.JAXJob{}, handler.OnlyControllerOwner()), + util.OnDependentFuncs[*v1beta1.PodGroup](r.Expectations, &r.JobController))); err != nil { return err } } @@ -216,7 +208,9 @@ func (r *JAXJobReconciler) SetupWithManager(mgr ctrl.Manager, controllerThreads if _, err = mgr.GetRESTMapper().RESTMapping(schema.GroupKind{Group: schedulerpluginsv1alpha1.SchemeGroupVersion.Group, Kind: "PodGroup"}, schedulerpluginsv1alpha1.SchemeGroupVersion.Version); err == nil { // inject watching for job related scheduler-plugins PodGroup - if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &schedulerpluginsv1alpha1.PodGroup{}, eventHandler, genericPredicates)); err != nil { + if err = c.Watch(source.Kind[*schedulerpluginsv1alpha1.PodGroup](mgr.GetCache(), &schedulerpluginsv1alpha1.PodGroup{}, + handler.TypedEnqueueRequestForOwner[*schedulerpluginsv1alpha1.PodGroup](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.JAXJob{}, handler.OnlyControllerOwner()), + util.OnDependentFuncs[*schedulerpluginsv1alpha1.PodGroup](r.Expectations, &r.JobController))); err != nil { return err } } diff --git a/pkg/controller.v1/mpi/mpijob_controller.go b/pkg/controller.v1/mpi/mpijob_controller.go index 6dcb2e68e3..abefdf290c 100644 --- a/pkg/controller.v1/mpi/mpijob_controller.go +++ b/pkg/controller.v1/mpi/mpijob_controller.go @@ -201,38 +201,34 @@ func (jc *MPIJobReconciler) SetupWithManager(mgr ctrl.Manager, controllerThreads ); err != nil { return err } - - // eventHandler for owned objects - eventHandler := handler.EnqueueRequestForOwner(mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.MPIJob{}, handler.OnlyControllerOwner()) - predicates := predicate.Funcs{ - CreateFunc: util.OnDependentCreateFunc(jc.Expectations), - UpdateFunc: util.OnDependentUpdateFunc(&jc.JobController), - DeleteFunc: util.OnDependentDeleteFunc(jc.Expectations), - } - // Create generic predicates - genericPredicates := predicate.Funcs{ - CreateFunc: util.OnDependentCreateFuncGeneric(jc.Expectations), - UpdateFunc: util.OnDependentUpdateFuncGeneric(&jc.JobController), - DeleteFunc: util.OnDependentDeleteFuncGeneric(jc.Expectations), - } // inject watching for job related pod - if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &corev1.Pod{}, eventHandler, predicates)); err != nil { + if err = c.Watch(source.Kind[*corev1.Pod](mgr.GetCache(), &corev1.Pod{}, + handler.TypedEnqueueRequestForOwner[*corev1.Pod](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.MPIJob{}, handler.OnlyControllerOwner()), + util.OnDependentFuncs[*corev1.Pod](jc.Expectations, &jc.JobController))); err != nil { return err } // inject watching for job related ConfigMap - if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &corev1.ConfigMap{}, eventHandler, genericPredicates)); err != nil { + if err = c.Watch(source.Kind[*corev1.ConfigMap](mgr.GetCache(), &corev1.ConfigMap{}, + handler.TypedEnqueueRequestForOwner[*corev1.ConfigMap](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.MPIJob{}, handler.OnlyControllerOwner()), + util.OnDependentFuncs[*corev1.ConfigMap](jc.Expectations, &jc.JobController))); err != nil { return err } // inject watching for job related Role - if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &rbacv1.Role{}, eventHandler, genericPredicates)); err != nil { + if err = c.Watch(source.Kind[*rbacv1.Role](mgr.GetCache(), &rbacv1.Role{}, + handler.TypedEnqueueRequestForOwner[*rbacv1.Role](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.MPIJob{}, handler.OnlyControllerOwner()), + util.OnDependentFuncs[*rbacv1.Role](jc.Expectations, &jc.JobController))); err != nil { return err } // inject watching for job related RoleBinding - if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &rbacv1.RoleBinding{}, eventHandler, genericPredicates)); err != nil { + if err = c.Watch(source.Kind[*rbacv1.RoleBinding](mgr.GetCache(), &rbacv1.RoleBinding{}, + handler.TypedEnqueueRequestForOwner[*rbacv1.RoleBinding](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.MPIJob{}, handler.OnlyControllerOwner()), + util.OnDependentFuncs[*rbacv1.RoleBinding](jc.Expectations, &jc.JobController))); err != nil { return err } // inject watching for job related ServiceAccount - if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &corev1.ServiceAccount{}, eventHandler, genericPredicates)); err != nil { + if err = c.Watch(source.Kind[*corev1.ServiceAccount](mgr.GetCache(), &corev1.ServiceAccount{}, + handler.TypedEnqueueRequestForOwner[*corev1.ServiceAccount](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.MPIJob{}, handler.OnlyControllerOwner()), + util.OnDependentFuncs[*corev1.ServiceAccount](jc.Expectations, &jc.JobController))); err != nil { return err } // skip watching volcano PodGroup if volcano PodGroup is not installed @@ -240,7 +236,9 @@ func (jc *MPIJobReconciler) SetupWithManager(mgr ctrl.Manager, controllerThreads v1beta1.SchemeGroupVersion.Version, ); err == nil { // inject watching for job related volcano PodGroup - if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &v1beta1.PodGroup{}, eventHandler, genericPredicates)); err != nil { + if err = c.Watch(source.Kind[*v1beta1.PodGroup](mgr.GetCache(), &v1beta1.PodGroup{}, + handler.TypedEnqueueRequestForOwner[*v1beta1.PodGroup](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.MPIJob{}, handler.OnlyControllerOwner()), + util.OnDependentFuncs[*v1beta1.PodGroup](jc.Expectations, &jc.JobController))); err != nil { return err } } @@ -250,7 +248,9 @@ func (jc *MPIJobReconciler) SetupWithManager(mgr ctrl.Manager, controllerThreads schedulerpluginsv1alpha1.SchemeGroupVersion.Version, ); err == nil { // inject watching for job related scheduler-plugins PodGroup - if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &schedulerpluginsv1alpha1.PodGroup{}, eventHandler, genericPredicates)); err != nil { + if err = c.Watch(source.Kind[*schedulerpluginsv1alpha1.PodGroup](mgr.GetCache(), &schedulerpluginsv1alpha1.PodGroup{}, + handler.TypedEnqueueRequestForOwner[*schedulerpluginsv1alpha1.PodGroup](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.MPIJob{}, handler.OnlyControllerOwner()), + util.OnDependentFuncs[*schedulerpluginsv1alpha1.PodGroup](jc.Expectations, &jc.JobController))); err != nil { return err } } diff --git a/pkg/controller.v1/paddlepaddle/paddlepaddle_controller.go b/pkg/controller.v1/paddlepaddle/paddlepaddle_controller.go index 778601e84d..e6c2a84370 100644 --- a/pkg/controller.v1/paddlepaddle/paddlepaddle_controller.go +++ b/pkg/controller.v1/paddlepaddle/paddlepaddle_controller.go @@ -190,26 +190,16 @@ func (r *PaddleJobReconciler) SetupWithManager(mgr ctrl.Manager, controllerThrea ); err != nil { return err } - - // eventHandler for owned objects - eventHandler := handler.EnqueueRequestForOwner(mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.PaddleJob{}, handler.OnlyControllerOwner()) - predicates := predicate.Funcs{ - CreateFunc: util.OnDependentCreateFunc(r.Expectations), - UpdateFunc: util.OnDependentUpdateFunc(&r.JobController), - DeleteFunc: util.OnDependentDeleteFunc(r.Expectations), - } - // Create generic predicates - genericPredicates := predicate.Funcs{ - CreateFunc: util.OnDependentCreateFuncGeneric(r.Expectations), - UpdateFunc: util.OnDependentUpdateFuncGeneric(&r.JobController), - DeleteFunc: util.OnDependentDeleteFuncGeneric(r.Expectations), - } // inject watching for job related pod - if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &corev1.Pod{}, eventHandler, predicates)); err != nil { + if err = c.Watch(source.Kind[*corev1.Pod](mgr.GetCache(), &corev1.Pod{}, + handler.TypedEnqueueRequestForOwner[*corev1.Pod](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.PaddleJob{}, handler.OnlyControllerOwner()), + util.OnDependentFuncs[*corev1.Pod](r.Expectations, &r.JobController))); err != nil { return err } // inject watching for job related service - if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &corev1.Service{}, eventHandler, predicates)); err != nil { + if err = c.Watch(source.Kind[*corev1.Service](mgr.GetCache(), &corev1.Service{}, + handler.TypedEnqueueRequestForOwner[*corev1.Service](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.PaddleJob{}, handler.OnlyControllerOwner()), + util.OnDependentFuncs[*corev1.Service](r.Expectations, &r.JobController))); err != nil { return err } // skip watching volcano PodGroup if volcano PodGroup is not installed @@ -217,7 +207,9 @@ func (r *PaddleJobReconciler) SetupWithManager(mgr ctrl.Manager, controllerThrea v1beta1.SchemeGroupVersion.Version, ); err == nil { // inject watching for job related volcano PodGroup - if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &v1beta1.PodGroup{}, eventHandler, genericPredicates)); err != nil { + if err = c.Watch(source.Kind[*v1beta1.PodGroup](mgr.GetCache(), &v1beta1.PodGroup{}, + handler.TypedEnqueueRequestForOwner[*v1beta1.PodGroup](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.PaddleJob{}, handler.OnlyControllerOwner()), + util.OnDependentFuncs[*v1beta1.PodGroup](r.Expectations, &r.JobController))); err != nil { return err } } @@ -227,7 +219,9 @@ func (r *PaddleJobReconciler) SetupWithManager(mgr ctrl.Manager, controllerThrea schedulerpluginsv1alpha1.SchemeGroupVersion.Version, ); err == nil { // inject watching for job related scheduler-plugins PodGroup - if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &schedulerpluginsv1alpha1.PodGroup{}, eventHandler, genericPredicates)); err != nil { + if err = c.Watch(source.Kind[*schedulerpluginsv1alpha1.PodGroup](mgr.GetCache(), &schedulerpluginsv1alpha1.PodGroup{}, + handler.TypedEnqueueRequestForOwner[*schedulerpluginsv1alpha1.PodGroup](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.PaddleJob{}, handler.OnlyControllerOwner()), + util.OnDependentFuncs[*schedulerpluginsv1alpha1.PodGroup](r.Expectations, &r.JobController))); err != nil { return err } } diff --git a/pkg/controller.v1/pytorch/pytorchjob_controller.go b/pkg/controller.v1/pytorch/pytorchjob_controller.go index eb872f7c28..aabc9d070a 100644 --- a/pkg/controller.v1/pytorch/pytorchjob_controller.go +++ b/pkg/controller.v1/pytorch/pytorchjob_controller.go @@ -194,33 +194,25 @@ func (r *PyTorchJobReconciler) SetupWithManager(mgr ctrl.Manager, controllerThre ); err != nil { return err } - - // eventHandler for owned object - eventHandler := handler.EnqueueRequestForOwner(mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.PyTorchJob{}, handler.OnlyControllerOwner()) - predicates := predicate.Funcs{ - CreateFunc: util.OnDependentCreateFunc(r.Expectations), - UpdateFunc: util.OnDependentUpdateFunc(&r.JobController), - DeleteFunc: util.OnDependentDeleteFunc(r.Expectations), - } - // Create generic predicates - genericPredicates := predicate.Funcs{ - CreateFunc: util.OnDependentCreateFuncGeneric(r.Expectations), - UpdateFunc: util.OnDependentUpdateFuncGeneric(&r.JobController), - DeleteFunc: util.OnDependentDeleteFuncGeneric(r.Expectations), - } // inject watching for job related pod - if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &corev1.Pod{}, eventHandler, predicates)); err != nil { + if err = c.Watch(source.Kind[*corev1.Pod](mgr.GetCache(), &corev1.Pod{}, + handler.TypedEnqueueRequestForOwner[*corev1.Pod](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.PyTorchJob{}, handler.OnlyControllerOwner()), + util.OnDependentFuncs[*corev1.Pod](r.Expectations, &r.JobController))); err != nil { return err } // inject watching for job related service - if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &corev1.Service{}, eventHandler, predicates)); err != nil { + if err = c.Watch(source.Kind[*corev1.Service](mgr.GetCache(), &corev1.Service{}, + handler.TypedEnqueueRequestForOwner[*corev1.Service](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.PyTorchJob{}, handler.OnlyControllerOwner()), + util.OnDependentFuncs[*corev1.Service](r.Expectations, &r.JobController))); err != nil { return err } // skip watching volcano PodGroup if volcano PodGroup is not installed if _, err = mgr.GetRESTMapper().RESTMapping(schema.GroupKind{Group: v1beta1.GroupName, Kind: "PodGroup"}, v1beta1.SchemeGroupVersion.Version); err == nil { // inject watching for job related volcano PodGroup - if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &v1beta1.PodGroup{}, eventHandler, genericPredicates)); err != nil { + if err = c.Watch(source.Kind[*v1beta1.PodGroup](mgr.GetCache(), &v1beta1.PodGroup{}, + handler.TypedEnqueueRequestForOwner[*v1beta1.PodGroup](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.PyTorchJob{}, handler.OnlyControllerOwner()), + util.OnDependentFuncs[*v1beta1.PodGroup](r.Expectations, &r.JobController))); err != nil { return err } } @@ -228,7 +220,9 @@ func (r *PyTorchJobReconciler) SetupWithManager(mgr ctrl.Manager, controllerThre if _, err = mgr.GetRESTMapper().RESTMapping(schema.GroupKind{Group: schedulerpluginsv1alpha1.SchemeGroupVersion.Group, Kind: "PodGroup"}, schedulerpluginsv1alpha1.SchemeGroupVersion.Version); err == nil { // inject watching for job related scheduler-plugins PodGroup - if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &schedulerpluginsv1alpha1.PodGroup{}, eventHandler, genericPredicates)); err != nil { + if err = c.Watch(source.Kind[*schedulerpluginsv1alpha1.PodGroup](mgr.GetCache(), &schedulerpluginsv1alpha1.PodGroup{}, + handler.TypedEnqueueRequestForOwner[*schedulerpluginsv1alpha1.PodGroup](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.PyTorchJob{}, handler.OnlyControllerOwner()), + util.OnDependentFuncs[*schedulerpluginsv1alpha1.PodGroup](r.Expectations, &r.JobController))); err != nil { return err } } diff --git a/pkg/controller.v1/tensorflow/tfjob_controller.go b/pkg/controller.v1/tensorflow/tfjob_controller.go index c13ac4a874..060e07b6e3 100644 --- a/pkg/controller.v1/tensorflow/tfjob_controller.go +++ b/pkg/controller.v1/tensorflow/tfjob_controller.go @@ -185,33 +185,25 @@ func (r *TFJobReconciler) SetupWithManager(mgr ctrl.Manager, controllerThreads i ); err != nil { return err } - - // eventHandler for owned objects - eventHandler := handler.EnqueueRequestForOwner(mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.TFJob{}, handler.OnlyControllerOwner()) - predicates := predicate.Funcs{ - CreateFunc: util.OnDependentCreateFunc(r.Expectations), - UpdateFunc: util.OnDependentUpdateFunc(&r.JobController), - DeleteFunc: util.OnDependentDeleteFunc(r.Expectations), - } - // Create generic predicates - genericPredicates := predicate.Funcs{ - CreateFunc: util.OnDependentCreateFuncGeneric(r.Expectations), - UpdateFunc: util.OnDependentUpdateFuncGeneric(&r.JobController), - DeleteFunc: util.OnDependentDeleteFuncGeneric(r.Expectations), - } // inject watching for job related pod - if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &corev1.Pod{}, eventHandler, predicates)); err != nil { + if err = c.Watch(source.Kind[*corev1.Pod](mgr.GetCache(), &corev1.Pod{}, + handler.TypedEnqueueRequestForOwner[*corev1.Pod](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.TFJob{}, handler.OnlyControllerOwner()), + util.OnDependentFuncs[*corev1.Pod](r.Expectations, &r.JobController))); err != nil { return err } // inject watching for job related service - if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &corev1.Service{}, eventHandler, predicates)); err != nil { + if err = c.Watch(source.Kind[*corev1.Service](mgr.GetCache(), &corev1.Service{}, + handler.TypedEnqueueRequestForOwner[*corev1.Service](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.TFJob{}, handler.OnlyControllerOwner()), + util.OnDependentFuncs[*corev1.Service](r.Expectations, &r.JobController))); err != nil { return err } // skip watching volcano PodGroup if volcano PodGroup is not installed if _, err = mgr.GetRESTMapper().RESTMapping(schema.GroupKind{Group: v1beta1.GroupName, Kind: "PodGroup"}, v1beta1.SchemeGroupVersion.Version); err == nil { // inject watching for job related volcano PodGroup - if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &v1beta1.PodGroup{}, eventHandler, genericPredicates)); err != nil { + if err = c.Watch(source.Kind[*v1beta1.PodGroup](mgr.GetCache(), &v1beta1.PodGroup{}, + handler.TypedEnqueueRequestForOwner[*v1beta1.PodGroup](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.TFJob{}, handler.OnlyControllerOwner()), + util.OnDependentFuncs[*v1beta1.PodGroup](r.Expectations, &r.JobController))); err != nil { return err } } @@ -219,7 +211,9 @@ func (r *TFJobReconciler) SetupWithManager(mgr ctrl.Manager, controllerThreads i if _, err = mgr.GetRESTMapper().RESTMapping(schema.GroupKind{Group: schedulerpluginsv1alpha1.SchemeGroupVersion.Group, Kind: "PodGroup"}, schedulerpluginsv1alpha1.SchemeGroupVersion.Version); err == nil { // inject watching for job related scheduler-plugins PodGroup - if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &schedulerpluginsv1alpha1.PodGroup{}, eventHandler, genericPredicates)); err != nil { + if err = c.Watch(source.Kind[*schedulerpluginsv1alpha1.PodGroup](mgr.GetCache(), &schedulerpluginsv1alpha1.PodGroup{}, + handler.TypedEnqueueRequestForOwner[*schedulerpluginsv1alpha1.PodGroup](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.TFJob{}, handler.OnlyControllerOwner()), + util.OnDependentFuncs[*schedulerpluginsv1alpha1.PodGroup](r.Expectations, &r.JobController))); err != nil { return err } } diff --git a/pkg/controller.v1/xgboost/xgboostjob_controller.go b/pkg/controller.v1/xgboost/xgboostjob_controller.go index 7489d75379..8b71cb70a3 100644 --- a/pkg/controller.v1/xgboost/xgboostjob_controller.go +++ b/pkg/controller.v1/xgboost/xgboostjob_controller.go @@ -188,33 +188,25 @@ func (r *XGBoostJobReconciler) SetupWithManager(mgr ctrl.Manager, controllerThre ); err != nil { return err } - - // eventHandler for owned objects - eventHandler := handler.EnqueueRequestForOwner(mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.XGBoostJob{}, handler.OnlyControllerOwner()) - predicates := predicate.Funcs{ - CreateFunc: util.OnDependentCreateFunc(r.Expectations), - UpdateFunc: util.OnDependentUpdateFunc(&r.JobController), - DeleteFunc: util.OnDependentDeleteFunc(r.Expectations), - } - // Create generic predicates - genericPredicates := predicate.Funcs{ - CreateFunc: util.OnDependentCreateFuncGeneric(r.Expectations), - UpdateFunc: util.OnDependentUpdateFuncGeneric(&r.JobController), - DeleteFunc: util.OnDependentDeleteFuncGeneric(r.Expectations), - } // inject watching for job related pod - if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &corev1.Pod{}, eventHandler, predicates)); err != nil { + if err = c.Watch(source.Kind[*corev1.Pod](mgr.GetCache(), &corev1.Pod{}, + handler.TypedEnqueueRequestForOwner[*corev1.Pod](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.XGBoostJob{}, handler.OnlyControllerOwner()), + util.OnDependentFuncs[*corev1.Pod](r.Expectations, &r.JobController))); err != nil { return err } // inject watching for job related service - if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &corev1.Service{}, eventHandler, predicates)); err != nil { + if err = c.Watch(source.Kind[*corev1.Service](mgr.GetCache(), &corev1.Service{}, + handler.TypedEnqueueRequestForOwner[*corev1.Service](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.XGBoostJob{}, handler.OnlyControllerOwner()), + util.OnDependentFuncs[*corev1.Service](r.Expectations, &r.JobController))); err != nil { return err } // skip watching volcano PodGroup if volcano PodGroup is not installed if _, err = mgr.GetRESTMapper().RESTMapping(schema.GroupKind{Group: v1beta1.GroupName, Kind: "PodGroup"}, v1beta1.SchemeGroupVersion.Version); err == nil { // inject watching for job related volcano PodGroup - if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &v1beta1.PodGroup{}, eventHandler, genericPredicates)); err != nil { + if err = c.Watch(source.Kind[*v1beta1.PodGroup](mgr.GetCache(), &v1beta1.PodGroup{}, + handler.TypedEnqueueRequestForOwner[*v1beta1.PodGroup](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.XGBoostJob{}, handler.OnlyControllerOwner()), + util.OnDependentFuncs[*v1beta1.PodGroup](r.Expectations, &r.JobController))); err != nil { return err } } @@ -222,7 +214,9 @@ func (r *XGBoostJobReconciler) SetupWithManager(mgr ctrl.Manager, controllerThre if _, err = mgr.GetRESTMapper().RESTMapping(schema.GroupKind{Group: schedulerpluginsv1alpha1.SchemeGroupVersion.Group, Kind: "PodGroup"}, schedulerpluginsv1alpha1.SchemeGroupVersion.Version); err == nil { // inject watching for job related scheduler-plugins PodGroup - if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &schedulerpluginsv1alpha1.PodGroup{}, eventHandler, genericPredicates)); err != nil { + if err = c.Watch(source.Kind[*schedulerpluginsv1alpha1.PodGroup](mgr.GetCache(), &schedulerpluginsv1alpha1.PodGroup{}, + handler.TypedEnqueueRequestForOwner[*schedulerpluginsv1alpha1.PodGroup](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.XGBoostJob{}, handler.OnlyControllerOwner()), + util.OnDependentFuncs[*schedulerpluginsv1alpha1.PodGroup](r.Expectations, &r.JobController))); err != nil { return err } }