Skip to content

Commit

Permalink
enhance federated learning controller. \n 1.improve the cascade delet…
Browse files Browse the repository at this point in the history
…ion of federated learning cotroller. \n 2.enable federated learning pod to rebuild itself if it is manually or wrongly deleted. \n 3.enable self updating pod config when modifying CRD of federated learning. \n 4.add test file to ensure the correctness of the solution.

Signed-off-by: SherlockShemol <[email protected]>
  • Loading branch information
SherlockShemol committed Oct 24, 2024
1 parent 3342955 commit b37522e
Show file tree
Hide file tree
Showing 208 changed files with 21,015 additions and 19 deletions.
176 changes: 157 additions & 19 deletions pkg/globalmanager/controllers/federatedlearning/federatedlearningjob.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ import (
"context"
"fmt"
"strconv"
"sync"
"time"

v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/labels"
utilrand "k8s.io/apimachinery/pkg/util/rand"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/apimachinery/pkg/util/wait"
Expand Down Expand Up @@ -88,6 +90,15 @@ type Controller struct {
cfg *config.ControllerConfig

sendToEdgeFunc runtime.DownstreamSendFunc

// map to record the pods that are recreated
recreatedPods sync.Map

flSelector labels.Selector

aggServiceHost string

preventRecreation bool
}

// Run starts the main goroutine responsible for watching and syncing jobs.
Expand Down Expand Up @@ -190,6 +201,49 @@ func (c *Controller) deletePod(obj interface{}) {
}
}
c.enqueueByPod(pod, true)

// when the CRD is updated, do not recreate the pod
// if c.preventRecreation is true, do not recreate the pod
if c.preventRecreation {
return
}
// if pod is manually deleted, recreate it
// first check if the pod is owned by a FederatedLearningJob
controllerRef := metav1.GetControllerOf(pod)
if controllerRef == nil || controllerRef.Kind != Kind.Kind {
return
}

// then check if the pod is already in the map
if _, exists := c.recreatedPods.Load(pod.Name); exists {
return
}

// if not, recreate it
klog.Infof("Pod %s/%s deleted, recreating...", pod.Namespace, pod.Name)
// Create a deep copy of the old pod
newPod := pod.DeepCopy()
// Reset the resource version and UID as they are unique to each object
newPod.ResourceVersion = ""
newPod.UID = ""
// Clear the status
newPod.Status = v1.PodStatus{}
// Remove the deletion timestamp
newPod.DeletionTimestamp = nil
// Remove the deletion grace period seconds
newPod.DeletionGracePeriodSeconds = nil
_, err := c.kubeClient.CoreV1().Pods(pod.Namespace).Create(context.TODO(), newPod, metav1.CreateOptions{})
if err != nil {
return
}
klog.Infof("Successfully recreated pod %s/%s", newPod.Namespace, newPod.Name)
// mark the pod as recreated
c.recreatedPods.Store(newPod.Name, true)
// set a timer to delete the record from the map after a while
go func() {
time.Sleep(5 * time.Second)
c.recreatedPods.Delete(pod.Name)
}()
}

// obj could be an *sednav1.FederatedLearningJob, or a DeletionFinalStateUnknown marker item,
Expand Down Expand Up @@ -271,14 +325,16 @@ func (c *Controller) sync(key string) (bool, error) {
return true, nil
}

selector, _ := runtime.GenerateSelector(&job)
pods, err := c.podStore.Pods(job.Namespace).List(selector)
c.flSelector, _ = runtime.GenerateSelector(&job)
pods, err := c.podStore.Pods(job.Namespace).List(c.flSelector)
if err != nil {
return false, err
}

activePods := k8scontroller.FilterActivePods(pods)
active := int32(len(activePods))
var activeAgg int32
var activeTrain int32
succeeded, failed := countPods(pods)
conditions := len(job.Status.Conditions)

Expand All @@ -289,6 +345,8 @@ func (c *Controller) sync(key string) (bool, error) {
}

var manageJobErr error
var manageAggErr error
var manageTrainErr error
jobFailed := false
var failureReason string
var failureMessage string
Expand All @@ -307,7 +365,13 @@ func (c *Controller) sync(key string) (bool, error) {
} else {
// in the First time, we create the pods
if len(pods) == 0 {
active, manageJobErr = c.createPod(&job)
activeAgg, manageAggErr = c.createAggPod(&job)
createServiceErr := c.createService(&job)
if createServiceErr != nil {
return false, createServiceErr
}
activeTrain, manageTrainErr = c.createTrainPod(&job)
active = activeAgg + activeTrain
}
complete := false
if succeeded > 0 && active == 0 {
Expand All @@ -324,6 +388,10 @@ func (c *Controller) sync(key string) (bool, error) {
}
}

// Combine manageAggErr and manageTrainErr into a single error
if manageAggErr != nil || manageTrainErr != nil {
manageJobErr = fmt.Errorf("aggregator error: %v, training error: %v", manageAggErr, manageTrainErr)
}
forget := false
// Check if the number of jobs succeeded increased since the last check. If yes "forget" should be true
// This logic is linked to the issue: https://github.com/kubernetes/kubernetes/issues/56853 that aims to
Expand Down Expand Up @@ -499,8 +567,7 @@ func (c *Controller) addTransmitterToWorkerParam(param *runtime.WorkerParam, job

return nil
}

func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32, err error) {
func (c *Controller) createAggPod(job *sednav1.FederatedLearningJob) (active int32, err error) {
active = 0
ctx := context.Background()

Expand All @@ -513,7 +580,7 @@ func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32,
modelName := job.Spec.AggregationWorker.Model.Name
model, modelSecret, err := c.getModelAndItsSecret(ctx, job.Namespace, modelName)
if err != nil {
return active, err
return active, fmt.Errorf("failed to get aggregation model: %w", err)
}

participantsCount := strconv.Itoa(len(job.Spec.TrainingWorkers))
Expand All @@ -524,6 +591,7 @@ func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32,
// Configure aggregation worker's mounts and envs
var aggPort int32 = 7363
var aggWorkerParam runtime.WorkerParam

aggWorkerParam.Env = map[string]string{
"NAMESPACE": job.Namespace,
"WORKER_NAME": "aggworker-" + utilrand.String(5),
Expand All @@ -534,7 +602,7 @@ func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32,
}

if err := c.addTransmitterToWorkerParam(&aggWorkerParam, job); err != nil {
return active, err
return active, fmt.Errorf("failed to add transmitter to worker param: %w", err)
}

aggWorkerParam.WorkerType = jobStageAgg
Expand All @@ -547,19 +615,36 @@ func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32,
c.addWorkerMount(&aggWorkerParam, pretrainedModel.Spec.URL, "PRETRAINED_MODEL_URL",
pretrainedModelSecret, true)
}

aggWorker.Template.Name = fmt.Sprintf("%s-aggworker", job.Name)
// create aggpod based on configured parameters
_, err = runtime.CreatePodWithTemplate(c.kubeClient, job, &aggWorker.Template, &aggWorkerParam)
if err != nil {
return active, fmt.Errorf("failed to create aggregation worker: %w", err)
}
klog.Infof("create aggpod success")
active++
return
}

func (c *Controller) createTrainPod(job *sednav1.FederatedLearningJob) (active int32, err error) {
active = 0
ctx := context.Background()

aggServiceHost, err := runtime.CreateEdgeMeshService(c.kubeClient, job, jobStageAgg, aggPort)
pretrainedModelName := job.Spec.PretrainedModel.Name
pretrainedModel, pretrainedModelSecret, err := c.getModelAndItsSecret(ctx, job.Namespace, pretrainedModelName)
if err != nil {
return active, err
return active, fmt.Errorf("failed to get pretrained model: %w", err)
}

modelName := job.Spec.AggregationWorker.Model.Name
model, modelSecret, err := c.getModelAndItsSecret(ctx, job.Namespace, modelName)
if err != nil {
return active, fmt.Errorf("failed to get aggregation model: %w", err)
}

var aggPort int32 = 7363
participantsCount := strconv.Itoa(len(job.Spec.TrainingWorkers))

// deliver pod for training worker
for i, trainingWorker := range job.Spec.TrainingWorkers {
// Configure training worker's mounts and envs
Expand All @@ -583,7 +668,7 @@ func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32,

workerParam.Env = map[string]string{
"AGG_PORT": strconv.Itoa(int(aggPort)),
"AGG_IP": aggServiceHost,
"AGG_IP": c.aggServiceHost,

"WORKER_NAME": "trainworker-" + utilrand.String(5),
"JOB_NAME": job.Name,
Expand All @@ -593,14 +678,15 @@ func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32,
"DATASET_NAME": datasetName,
"LC_SERVER": c.cfg.LC.Server,
}

workerParam.WorkerType = runtime.TrainPodType
workerParam.HostNetwork = true
workerParam.RestartPolicy = v1.RestartPolicyOnFailure

if err := c.addTransmitterToWorkerParam(&workerParam, job); err != nil {
return active, err
return active, fmt.Errorf("failed to add transmitter to worker param: %w", err)
}

trainingWorker.Template.Name = fmt.Sprintf("%s-trainworker-%d", job.Name, i)
// create training worker based on configured parameters
_, err = runtime.CreatePodWithTemplate(c.kubeClient, job, &trainingWorker.Template, &workerParam)
if err != nil {
Expand Down Expand Up @@ -640,13 +726,8 @@ func New(cc *runtime.ControllerContext) (runtime.FeatureControllerI, error) {
// send it to edge's LC.
fc.syncToEdge(watch.Added, obj)
},
UpdateFunc: func(old, cur interface{}) {
fc.enqueueController(cur, true)
UpdateFunc: fc.updateJob,

// when a federated learning job is updated,
// send it to edge's LC as Added event.
fc.syncToEdge(watch.Added, cur)
},
DeleteFunc: func(obj interface{}) {
fc.enqueueController(obj, true)

Expand All @@ -669,3 +750,60 @@ func New(cc *runtime.ControllerContext) (runtime.FeatureControllerI, error) {

return fc, nil
}

func (c *Controller) updateJob(old, cur interface{}) {
oldJob, ok := old.(*sednav1.FederatedLearningJob)
if !ok {
return
}
curJob, ok := cur.(*sednav1.FederatedLearningJob)
if !ok {
return
}

if oldJob.ResourceVersion == curJob.ResourceVersion {
return
}

if oldJob.Generation != curJob.Generation {
pods, err := c.podStore.Pods(curJob.Namespace).List(c.flSelector)
if err != nil {
klog.Errorf("Failed to list pods: %v", err)
}
c.preventRecreation = true
for _, pod := range pods {
// delete all pods
c.kubeClient.CoreV1().Pods(pod.Namespace).Delete(context.TODO(), pod.Name, metav1.DeleteOptions{})
klog.Infof("CRD modified, so we deleted pod %s/%s", pod.Namespace, pod.Name)
}
klog.Infof("CRD modified, so we deleted all pods, and will create new pods")
curJob.SetGroupVersionKind(Kind)
_, err = c.createAggPod(curJob)
if err != nil {
klog.Errorf("Failed to create aggregation worker: %v", err)
}
_, err = c.createTrainPod(curJob)
if err != nil {
klog.Errorf("Failed to create training workers: %v", err)
}
// update the job status
c.client.FederatedLearningJobs(curJob.Namespace).Update(context.TODO(), curJob, metav1.UpdateOptions{})
}

c.preventRecreation = false
c.enqueueController(curJob, true)

// when a federated learning job is updated,
// send it to edge's LC as Added event.
c.syncToEdge(watch.Added, curJob)
}

// create edgemesh service for the job
func (c *Controller) createService(job *sednav1.FederatedLearningJob) (err error) {
var aggPort int32 = 7363
c.aggServiceHost, err = runtime.CreateEdgeMeshService(c.kubeClient, job, jobStageAgg, aggPort)
if err != nil {
return err
}
return nil
}
Loading

0 comments on commit b37522e

Please sign in to comment.