Skip to content

Commit

Permalink
Use typed event handlers and predicates in job controllers
Browse files Browse the repository at this point in the history
Signed-off-by: Antonin Stefanutti <[email protected]>
  • Loading branch information
astefanutti committed Nov 26, 2024
1 parent aaa79c0 commit eb452fb
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 292 deletions.
114 changes: 1 addition & 113 deletions pkg/common/util/reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,10 @@
package util

import (
"fmt"
"reflect"

corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"sigs.k8s.io/controller-runtime/pkg/event"

kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
"github.com/kubeflow/training-operator/pkg/controller.v1/common"
"github.com/kubeflow/training-operator/pkg/controller.v1/expectation"
commonutil "github.com/kubeflow/training-operator/pkg/util"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

// SatisfiedExpectations returns true if the required adds/dels for the given job have been observed.
Expand All @@ -45,82 +38,6 @@ func SatisfiedExpectations(exp expectation.ControllerExpectationsInterface, jobK
return satisfied
}

// OnDependentCreateFunc modify expectations when dependent (pod/service) creation observed.
func OnDependentCreateFunc(exp expectation.ControllerExpectationsInterface) func(event.CreateEvent) bool {
return func(e event.CreateEvent) bool {
rtype := e.Object.GetLabels()[kubeflowv1.ReplicaTypeLabel]
if len(rtype) == 0 {
return false
}

//logrus.Info("Update on create function ", ptjr.ControllerName(), " create object ", e.Object.GetName())
if controllerRef := metav1.GetControllerOf(e.Object); controllerRef != nil {
jobKey := fmt.Sprintf("%s/%s", e.Object.GetNamespace(), controllerRef.Name)
var expectKey string
switch e.Object.(type) {
case *corev1.Pod:
expectKey = expectation.GenExpectationPodsKey(jobKey, rtype)
case *corev1.Service:
expectKey = expectation.GenExpectationServicesKey(jobKey, rtype)
default:
return false
}
exp.CreationObserved(expectKey)
return true
}

return true
}
}

// OnDependentUpdateFunc modify expectations when dependent (pod/service) update observed.
func OnDependentUpdateFunc(jc *common.JobController) func(updateEvent event.UpdateEvent) bool {
return func(e event.UpdateEvent) bool {
newObj := e.ObjectNew
oldObj := e.ObjectOld
if newObj.GetResourceVersion() == oldObj.GetResourceVersion() {
// Periodic resync will send update events for all known pods.
// Two different versions of the same pod will always have different RVs.
return false
}

kind := jc.Controller.GetAPIGroupVersionKind().Kind
var logger = LoggerForGenericKind(newObj, kind)

switch obj := newObj.(type) {
case *corev1.Pod:
logger = commonutil.LoggerForPod(obj, jc.Controller.GetAPIGroupVersionKind().Kind)
case *corev1.Service:
logger = commonutil.LoggerForService(newObj.(*corev1.Service), jc.Controller.GetAPIGroupVersionKind().Kind)
default:
return false
}

newControllerRef := metav1.GetControllerOf(newObj)
oldControllerRef := metav1.GetControllerOf(oldObj)
controllerRefChanged := !reflect.DeepEqual(newControllerRef, oldControllerRef)

if controllerRefChanged && oldControllerRef != nil {
// The ControllerRef was changed. Sync the old controller, if any.
if job := resolveControllerRef(jc, oldObj.GetNamespace(), oldControllerRef); job != nil {
logger.Infof("pod/service controller ref updated: %v, %v", newObj, oldObj)
return true
}
}

// If it has a controller ref, that's all that matters.
if newControllerRef != nil {
job := resolveControllerRef(jc, newObj.GetNamespace(), newControllerRef)
if job == nil {
return false
}
logger.Debugf("pod/service has a controller ref: %v, %v", newObj, oldObj)
return true
}
return false
}
}

// resolveControllerRef returns the job referenced by a ControllerRef,
// or nil if the ControllerRef could not be resolved to a matching job
// of the correct Kind.
Expand All @@ -141,32 +58,3 @@ func resolveControllerRef(jc *common.JobController, namespace string, controller
}
return job
}

// OnDependentDeleteFunc modify expectations when dependent (pod/service) deletion observed.
func OnDependentDeleteFunc(exp expectation.ControllerExpectationsInterface) func(event.DeleteEvent) bool {
return func(e event.DeleteEvent) bool {

rtype := e.Object.GetLabels()[kubeflowv1.ReplicaTypeLabel]
if len(rtype) == 0 {
return false
}

// logrus.Info("Update on deleting function ", xgbr.ControllerName(), " delete object ", e.Object.GetName())
if controllerRef := metav1.GetControllerOf(e.Object); controllerRef != nil {
jobKey := fmt.Sprintf("%s/%s", e.Object.GetNamespace(), controllerRef.Name)
var expectKey string
switch e.Object.(type) {
case *corev1.Pod:
expectKey = expectation.GenExpectationPodsKey(jobKey, rtype)
case *corev1.Service:
expectKey = expectation.GenExpectationServicesKey(jobKey, rtype)
default:
return false
}
exp.DeletionObserved(expectKey)
return true
}

return true
}
}
26 changes: 17 additions & 9 deletions pkg/common/util/reconciler_generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ import (
"reflect"
"strings"

log "github.com/sirupsen/logrus"

kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
"github.com/kubeflow/training-operator/pkg/controller.v1/common"
"github.com/kubeflow/training-operator/pkg/controller.v1/expectation"
log "github.com/sirupsen/logrus"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/event"
"sigs.k8s.io/controller-runtime/pkg/predicate"
)

// GenExpectationGenericKey generates an expectation key for {Kind} of a job
Expand All @@ -50,9 +51,17 @@ func LoggerForGenericKind(obj metav1.Object, kind string) *log.Entry {
})
}

func OnDependentFuncs[T client.Object](expectations expectation.ControllerExpectationsInterface, jobController *common.JobController) predicate.TypedFuncs[T] {
return predicate.TypedFuncs[T]{
CreateFunc: OnDependentCreateFuncGeneric[T](expectations),
UpdateFunc: OnDependentUpdateFuncGeneric[T](jobController),
DeleteFunc: OnDependentDeleteFuncGeneric[T](expectations),
}
}

// OnDependentCreateFuncGeneric modify expectations when dependent (pod/service) creation observed.
func OnDependentCreateFuncGeneric(exp expectation.ControllerExpectationsInterface) func(event.CreateEvent) bool {
return func(e event.CreateEvent) bool {
func OnDependentCreateFuncGeneric[T client.Object](exp expectation.ControllerExpectationsInterface) func(createEvent event.TypedCreateEvent[T]) bool {
return func(e event.TypedCreateEvent[T]) bool {
rtype := e.Object.GetLabels()[kubeflowv1.ReplicaTypeLabel]
if len(rtype) == 0 {
return false
Expand All @@ -72,8 +81,8 @@ func OnDependentCreateFuncGeneric(exp expectation.ControllerExpectationsInterfac
}

// OnDependentUpdateFuncGeneric modify expectations when dependent (pod/service) update observed.
func OnDependentUpdateFuncGeneric(jc *common.JobController) func(updateEvent event.UpdateEvent) bool {
return func(e event.UpdateEvent) bool {
func OnDependentUpdateFuncGeneric[T client.Object](jc *common.JobController) func(updateEvent event.TypedUpdateEvent[T]) bool {
return func(e event.TypedUpdateEvent[T]) bool {
newObj := e.ObjectNew
oldObj := e.ObjectOld
if newObj.GetResourceVersion() == oldObj.GetResourceVersion() {
Expand Down Expand Up @@ -111,9 +120,8 @@ func OnDependentUpdateFuncGeneric(jc *common.JobController) func(updateEvent eve
}

// OnDependentDeleteFuncGeneric modify expectations when dependent (pod/service) deletion observed.
func OnDependentDeleteFuncGeneric(exp expectation.ControllerExpectationsInterface) func(event.DeleteEvent) bool {
return func(e event.DeleteEvent) bool {

func OnDependentDeleteFuncGeneric[T client.Object](exp expectation.ControllerExpectationsInterface) func(event.TypedDeleteEvent[T]) bool {
return func(e event.TypedDeleteEvent[T]) bool {
rtype := e.Object.GetLabels()[kubeflowv1.ReplicaTypeLabel]
if len(rtype) == 0 {
return false
Expand Down
59 changes: 0 additions & 59 deletions pkg/common/util/reconciler_test.go

This file was deleted.

30 changes: 12 additions & 18 deletions pkg/controller.v1/jax/jaxjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,41 +182,35 @@ func (r *JAXJobReconciler) SetupWithManager(mgr ctrl.Manager, controllerThreads
); err != nil {
return err
}

// eventHandler for owned object
eventHandler := handler.EnqueueRequestForOwner(mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.JAXJob{}, handler.OnlyControllerOwner())
predicates := predicate.Funcs{
CreateFunc: util.OnDependentCreateFunc(r.Expectations),
UpdateFunc: util.OnDependentUpdateFunc(&r.JobController),
DeleteFunc: util.OnDependentDeleteFunc(r.Expectations),
}
// Create generic predicates
genericPredicates := predicate.Funcs{
CreateFunc: util.OnDependentCreateFuncGeneric(r.Expectations),
UpdateFunc: util.OnDependentUpdateFuncGeneric(&r.JobController),
DeleteFunc: util.OnDependentDeleteFuncGeneric(r.Expectations),
}
// inject watching for job related pod
if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &corev1.Pod{}, eventHandler, predicates)); err != nil {
if err = c.Watch(source.Kind[*corev1.Pod](mgr.GetCache(), &corev1.Pod{},
handler.TypedEnqueueRequestForOwner[*corev1.Pod](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.JAXJob{}, handler.OnlyControllerOwner()),
util.OnDependentFuncs[*corev1.Pod](r.Expectations, &r.JobController))); err != nil {
return err
}
// inject watching for job related service
if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &corev1.Service{}, eventHandler, predicates)); err != nil {
if err = c.Watch(source.Kind[*corev1.Service](mgr.GetCache(), &corev1.Service{},
handler.TypedEnqueueRequestForOwner[*corev1.Service](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.JAXJob{}, handler.OnlyControllerOwner()),
util.OnDependentFuncs[*corev1.Service](r.Expectations, &r.JobController))); err != nil {
return err
}
// skip watching volcano PodGroup if volcano PodGroup is not installed
if _, err = mgr.GetRESTMapper().RESTMapping(schema.GroupKind{Group: v1beta1.GroupName, Kind: "PodGroup"},
v1beta1.SchemeGroupVersion.Version); err == nil {
// inject watching for job related volcano PodGroup
if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &v1beta1.PodGroup{}, eventHandler, genericPredicates)); err != nil {
if err = c.Watch(source.Kind[*v1beta1.PodGroup](mgr.GetCache(), &v1beta1.PodGroup{},
handler.TypedEnqueueRequestForOwner[*v1beta1.PodGroup](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.JAXJob{}, handler.OnlyControllerOwner()),
util.OnDependentFuncs[*v1beta1.PodGroup](r.Expectations, &r.JobController))); err != nil {
return err
}
}
// skip watching scheduler-plugins PodGroup if scheduler-plugins PodGroup is not installed
if _, err = mgr.GetRESTMapper().RESTMapping(schema.GroupKind{Group: schedulerpluginsv1alpha1.SchemeGroupVersion.Group, Kind: "PodGroup"},
schedulerpluginsv1alpha1.SchemeGroupVersion.Version); err == nil {
// inject watching for job related scheduler-plugins PodGroup
if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &schedulerpluginsv1alpha1.PodGroup{}, eventHandler, genericPredicates)); err != nil {
if err = c.Watch(source.Kind[*schedulerpluginsv1alpha1.PodGroup](mgr.GetCache(), &schedulerpluginsv1alpha1.PodGroup{},
handler.TypedEnqueueRequestForOwner[*schedulerpluginsv1alpha1.PodGroup](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.JAXJob{}, handler.OnlyControllerOwner()),
util.OnDependentFuncs[*schedulerpluginsv1alpha1.PodGroup](r.Expectations, &r.JobController))); err != nil {
return err
}
}
Expand Down
42 changes: 21 additions & 21 deletions pkg/controller.v1/mpi/mpijob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,46 +201,44 @@ func (jc *MPIJobReconciler) SetupWithManager(mgr ctrl.Manager, controllerThreads
); err != nil {
return err
}

// eventHandler for owned objects
eventHandler := handler.EnqueueRequestForOwner(mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.MPIJob{}, handler.OnlyControllerOwner())
predicates := predicate.Funcs{
CreateFunc: util.OnDependentCreateFunc(jc.Expectations),
UpdateFunc: util.OnDependentUpdateFunc(&jc.JobController),
DeleteFunc: util.OnDependentDeleteFunc(jc.Expectations),
}
// Create generic predicates
genericPredicates := predicate.Funcs{
CreateFunc: util.OnDependentCreateFuncGeneric(jc.Expectations),
UpdateFunc: util.OnDependentUpdateFuncGeneric(&jc.JobController),
DeleteFunc: util.OnDependentDeleteFuncGeneric(jc.Expectations),
}
// inject watching for job related pod
if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &corev1.Pod{}, eventHandler, predicates)); err != nil {
if err = c.Watch(source.Kind[*corev1.Pod](mgr.GetCache(), &corev1.Pod{},
handler.TypedEnqueueRequestForOwner[*corev1.Pod](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.MPIJob{}, handler.OnlyControllerOwner()),
util.OnDependentFuncs[*corev1.Pod](jc.Expectations, &jc.JobController))); err != nil {
return err
}
// inject watching for job related ConfigMap
if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &corev1.ConfigMap{}, eventHandler, genericPredicates)); err != nil {
if err = c.Watch(source.Kind[*corev1.ConfigMap](mgr.GetCache(), &corev1.ConfigMap{},
handler.TypedEnqueueRequestForOwner[*corev1.ConfigMap](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.MPIJob{}, handler.OnlyControllerOwner()),
util.OnDependentFuncs[*corev1.ConfigMap](jc.Expectations, &jc.JobController))); err != nil {
return err
}
// inject watching for job related Role
if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &rbacv1.Role{}, eventHandler, genericPredicates)); err != nil {
if err = c.Watch(source.Kind[*rbacv1.Role](mgr.GetCache(), &rbacv1.Role{},
handler.TypedEnqueueRequestForOwner[*rbacv1.Role](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.MPIJob{}, handler.OnlyControllerOwner()),
util.OnDependentFuncs[*rbacv1.Role](jc.Expectations, &jc.JobController))); err != nil {
return err
}
// inject watching for job related RoleBinding
if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &rbacv1.RoleBinding{}, eventHandler, genericPredicates)); err != nil {
if err = c.Watch(source.Kind[*rbacv1.RoleBinding](mgr.GetCache(), &rbacv1.RoleBinding{},
handler.TypedEnqueueRequestForOwner[*rbacv1.RoleBinding](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.MPIJob{}, handler.OnlyControllerOwner()),
util.OnDependentFuncs[*rbacv1.RoleBinding](jc.Expectations, &jc.JobController))); err != nil {
return err
}
// inject watching for job related ServiceAccount
if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &corev1.ServiceAccount{}, eventHandler, genericPredicates)); err != nil {
if err = c.Watch(source.Kind[*corev1.ServiceAccount](mgr.GetCache(), &corev1.ServiceAccount{},
handler.TypedEnqueueRequestForOwner[*corev1.ServiceAccount](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.MPIJob{}, handler.OnlyControllerOwner()),
util.OnDependentFuncs[*corev1.ServiceAccount](jc.Expectations, &jc.JobController))); err != nil {
return err
}
// skip watching volcano PodGroup if volcano PodGroup is not installed
if _, err = mgr.GetRESTMapper().RESTMapping(schema.GroupKind{Group: v1beta1.GroupName, Kind: "PodGroup"},
v1beta1.SchemeGroupVersion.Version,
); err == nil {
// inject watching for job related volcano PodGroup
if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &v1beta1.PodGroup{}, eventHandler, genericPredicates)); err != nil {
if err = c.Watch(source.Kind[*v1beta1.PodGroup](mgr.GetCache(), &v1beta1.PodGroup{},
handler.TypedEnqueueRequestForOwner[*v1beta1.PodGroup](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.MPIJob{}, handler.OnlyControllerOwner()),
util.OnDependentFuncs[*v1beta1.PodGroup](jc.Expectations, &jc.JobController))); err != nil {
return err
}
}
Expand All @@ -250,7 +248,9 @@ func (jc *MPIJobReconciler) SetupWithManager(mgr ctrl.Manager, controllerThreads
schedulerpluginsv1alpha1.SchemeGroupVersion.Version,
); err == nil {
// inject watching for job related scheduler-plugins PodGroup
if err = c.Watch(source.Kind[client.Object](mgr.GetCache(), &schedulerpluginsv1alpha1.PodGroup{}, eventHandler, genericPredicates)); err != nil {
if err = c.Watch(source.Kind[*schedulerpluginsv1alpha1.PodGroup](mgr.GetCache(), &schedulerpluginsv1alpha1.PodGroup{},
handler.TypedEnqueueRequestForOwner[*schedulerpluginsv1alpha1.PodGroup](mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.MPIJob{}, handler.OnlyControllerOwner()),
util.OnDependentFuncs[*schedulerpluginsv1alpha1.PodGroup](jc.Expectations, &jc.JobController))); err != nil {
return err
}
}
Expand Down
Loading

0 comments on commit eb452fb

Please sign in to comment.