Skip to content

Commit

Permalink
Merge pull request #446 from SherlockShemol/fl-controller-enhancement
Browse files Browse the repository at this point in the history
Sedna FederatedLearning controller enhancement
  • Loading branch information
kubeedge-bot authored Oct 30, 2024
2 parents cabab5f + b37522e commit 712b62b
Show file tree
Hide file tree
Showing 2 changed files with 525 additions and 19 deletions.
176 changes: 157 additions & 19 deletions pkg/globalmanager/controllers/federatedlearning/federatedlearningjob.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ import (
"context"
"fmt"
"strconv"
"sync"
"time"

v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/labels"
utilrand "k8s.io/apimachinery/pkg/util/rand"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/apimachinery/pkg/util/wait"
Expand Down Expand Up @@ -88,6 +90,15 @@ type Controller struct {
cfg *config.ControllerConfig

sendToEdgeFunc runtime.DownstreamSendFunc

// map to record the pods that are recreated
recreatedPods sync.Map

flSelector labels.Selector

aggServiceHost string

preventRecreation bool
}

// Run starts the main goroutine responsible for watching and syncing jobs.
Expand Down Expand Up @@ -190,6 +201,49 @@ func (c *Controller) deletePod(obj interface{}) {
}
}
c.enqueueByPod(pod, true)

// when the CRD is updated, do not recreate the pod
// if c.preventRecreation is true, do not recreate the pod
if c.preventRecreation {
return
}
// if pod is manually deleted, recreate it
// first check if the pod is owned by a FederatedLearningJob
controllerRef := metav1.GetControllerOf(pod)
if controllerRef == nil || controllerRef.Kind != Kind.Kind {
return
}

// then check if the pod is already in the map
if _, exists := c.recreatedPods.Load(pod.Name); exists {
return
}

// if not, recreate it
klog.Infof("Pod %s/%s deleted, recreating...", pod.Namespace, pod.Name)
// Create a deep copy of the old pod
newPod := pod.DeepCopy()
// Reset the resource version and UID as they are unique to each object
newPod.ResourceVersion = ""
newPod.UID = ""
// Clear the status
newPod.Status = v1.PodStatus{}
// Remove the deletion timestamp
newPod.DeletionTimestamp = nil
// Remove the deletion grace period seconds
newPod.DeletionGracePeriodSeconds = nil
_, err := c.kubeClient.CoreV1().Pods(pod.Namespace).Create(context.TODO(), newPod, metav1.CreateOptions{})
if err != nil {
return
}
klog.Infof("Successfully recreated pod %s/%s", newPod.Namespace, newPod.Name)
// mark the pod as recreated
c.recreatedPods.Store(newPod.Name, true)
// set a timer to delete the record from the map after a while
go func() {
time.Sleep(5 * time.Second)
c.recreatedPods.Delete(pod.Name)
}()
}

// obj could be an *sednav1.FederatedLearningJob, or a DeletionFinalStateUnknown marker item,
Expand Down Expand Up @@ -271,14 +325,16 @@ func (c *Controller) sync(key string) (bool, error) {
return true, nil
}

selector, _ := runtime.GenerateSelector(&job)
pods, err := c.podStore.Pods(job.Namespace).List(selector)
c.flSelector, _ = runtime.GenerateSelector(&job)
pods, err := c.podStore.Pods(job.Namespace).List(c.flSelector)
if err != nil {
return false, err
}

activePods := k8scontroller.FilterActivePods(pods)
active := int32(len(activePods))
var activeAgg int32
var activeTrain int32
succeeded, failed := countPods(pods)
conditions := len(job.Status.Conditions)

Expand All @@ -289,6 +345,8 @@ func (c *Controller) sync(key string) (bool, error) {
}

var manageJobErr error
var manageAggErr error
var manageTrainErr error
jobFailed := false
var failureReason string
var failureMessage string
Expand All @@ -307,7 +365,13 @@ func (c *Controller) sync(key string) (bool, error) {
} else {
// in the First time, we create the pods
if len(pods) == 0 {
active, manageJobErr = c.createPod(&job)
activeAgg, manageAggErr = c.createAggPod(&job)
createServiceErr := c.createService(&job)
if createServiceErr != nil {
return false, createServiceErr
}
activeTrain, manageTrainErr = c.createTrainPod(&job)
active = activeAgg + activeTrain
}
complete := false
if succeeded > 0 && active == 0 {
Expand All @@ -324,6 +388,10 @@ func (c *Controller) sync(key string) (bool, error) {
}
}

// Combine manageAggErr and manageTrainErr into a single error
if manageAggErr != nil || manageTrainErr != nil {
manageJobErr = fmt.Errorf("aggregator error: %v, training error: %v", manageAggErr, manageTrainErr)
}
forget := false
// Check if the number of jobs succeeded increased since the last check. If yes "forget" should be true
// This logic is linked to the issue: https://github.com/kubernetes/kubernetes/issues/56853 that aims to
Expand Down Expand Up @@ -499,8 +567,7 @@ func (c *Controller) addTransmitterToWorkerParam(param *runtime.WorkerParam, job

return nil
}

func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32, err error) {
func (c *Controller) createAggPod(job *sednav1.FederatedLearningJob) (active int32, err error) {
active = 0
ctx := context.Background()

Expand All @@ -513,7 +580,7 @@ func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32,
modelName := job.Spec.AggregationWorker.Model.Name
model, modelSecret, err := c.getModelAndItsSecret(ctx, job.Namespace, modelName)
if err != nil {
return active, err
return active, fmt.Errorf("failed to get aggregation model: %w", err)
}

participantsCount := strconv.Itoa(len(job.Spec.TrainingWorkers))
Expand All @@ -524,6 +591,7 @@ func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32,
// Configure aggregation worker's mounts and envs
var aggPort int32 = 7363
var aggWorkerParam runtime.WorkerParam

aggWorkerParam.Env = map[string]string{
"NAMESPACE": job.Namespace,
"WORKER_NAME": "aggworker-" + utilrand.String(5),
Expand All @@ -534,7 +602,7 @@ func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32,
}

if err := c.addTransmitterToWorkerParam(&aggWorkerParam, job); err != nil {
return active, err
return active, fmt.Errorf("failed to add transmitter to worker param: %w", err)
}

aggWorkerParam.WorkerType = jobStageAgg
Expand All @@ -547,19 +615,36 @@ func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32,
c.addWorkerMount(&aggWorkerParam, pretrainedModel.Spec.URL, "PRETRAINED_MODEL_URL",
pretrainedModelSecret, true)
}

aggWorker.Template.Name = fmt.Sprintf("%s-aggworker", job.Name)
// create aggpod based on configured parameters
_, err = runtime.CreatePodWithTemplate(c.kubeClient, job, &aggWorker.Template, &aggWorkerParam)
if err != nil {
return active, fmt.Errorf("failed to create aggregation worker: %w", err)
}
klog.Infof("create aggpod success")
active++
return
}

func (c *Controller) createTrainPod(job *sednav1.FederatedLearningJob) (active int32, err error) {
active = 0
ctx := context.Background()

aggServiceHost, err := runtime.CreateEdgeMeshService(c.kubeClient, job, jobStageAgg, aggPort)
pretrainedModelName := job.Spec.PretrainedModel.Name
pretrainedModel, pretrainedModelSecret, err := c.getModelAndItsSecret(ctx, job.Namespace, pretrainedModelName)
if err != nil {
return active, err
return active, fmt.Errorf("failed to get pretrained model: %w", err)
}

modelName := job.Spec.AggregationWorker.Model.Name
model, modelSecret, err := c.getModelAndItsSecret(ctx, job.Namespace, modelName)
if err != nil {
return active, fmt.Errorf("failed to get aggregation model: %w", err)
}

var aggPort int32 = 7363
participantsCount := strconv.Itoa(len(job.Spec.TrainingWorkers))

// deliver pod for training worker
for i, trainingWorker := range job.Spec.TrainingWorkers {
// Configure training worker's mounts and envs
Expand All @@ -583,7 +668,7 @@ func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32,

workerParam.Env = map[string]string{
"AGG_PORT": strconv.Itoa(int(aggPort)),
"AGG_IP": aggServiceHost,
"AGG_IP": c.aggServiceHost,

"WORKER_NAME": "trainworker-" + utilrand.String(5),
"JOB_NAME": job.Name,
Expand All @@ -593,14 +678,15 @@ func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32,
"DATASET_NAME": datasetName,
"LC_SERVER": c.cfg.LC.Server,
}

workerParam.WorkerType = runtime.TrainPodType
workerParam.HostNetwork = true
workerParam.RestartPolicy = v1.RestartPolicyOnFailure

if err := c.addTransmitterToWorkerParam(&workerParam, job); err != nil {
return active, err
return active, fmt.Errorf("failed to add transmitter to worker param: %w", err)
}

trainingWorker.Template.Name = fmt.Sprintf("%s-trainworker-%d", job.Name, i)
// create training worker based on configured parameters
_, err = runtime.CreatePodWithTemplate(c.kubeClient, job, &trainingWorker.Template, &workerParam)
if err != nil {
Expand Down Expand Up @@ -640,13 +726,8 @@ func New(cc *runtime.ControllerContext) (runtime.FeatureControllerI, error) {
// send it to edge's LC.
fc.syncToEdge(watch.Added, obj)
},
UpdateFunc: func(old, cur interface{}) {
fc.enqueueController(cur, true)
UpdateFunc: fc.updateJob,

// when a federated learning job is updated,
// send it to edge's LC as Added event.
fc.syncToEdge(watch.Added, cur)
},
DeleteFunc: func(obj interface{}) {
fc.enqueueController(obj, true)

Expand All @@ -669,3 +750,60 @@ func New(cc *runtime.ControllerContext) (runtime.FeatureControllerI, error) {

return fc, nil
}

func (c *Controller) updateJob(old, cur interface{}) {
oldJob, ok := old.(*sednav1.FederatedLearningJob)
if !ok {
return
}
curJob, ok := cur.(*sednav1.FederatedLearningJob)
if !ok {
return
}

if oldJob.ResourceVersion == curJob.ResourceVersion {
return
}

if oldJob.Generation != curJob.Generation {
pods, err := c.podStore.Pods(curJob.Namespace).List(c.flSelector)
if err != nil {
klog.Errorf("Failed to list pods: %v", err)
}
c.preventRecreation = true
for _, pod := range pods {
// delete all pods
c.kubeClient.CoreV1().Pods(pod.Namespace).Delete(context.TODO(), pod.Name, metav1.DeleteOptions{})
klog.Infof("CRD modified, so we deleted pod %s/%s", pod.Namespace, pod.Name)
}
klog.Infof("CRD modified, so we deleted all pods, and will create new pods")
curJob.SetGroupVersionKind(Kind)
_, err = c.createAggPod(curJob)
if err != nil {
klog.Errorf("Failed to create aggregation worker: %v", err)
}
_, err = c.createTrainPod(curJob)
if err != nil {
klog.Errorf("Failed to create training workers: %v", err)
}
// update the job status
c.client.FederatedLearningJobs(curJob.Namespace).Update(context.TODO(), curJob, metav1.UpdateOptions{})
}

c.preventRecreation = false
c.enqueueController(curJob, true)

// when a federated learning job is updated,
// send it to edge's LC as Added event.
c.syncToEdge(watch.Added, curJob)
}

// create edgemesh service for the job
func (c *Controller) createService(job *sednav1.FederatedLearningJob) (err error) {
var aggPort int32 = 7363
c.aggServiceHost, err = runtime.CreateEdgeMeshService(c.kubeClient, job, jobStageAgg, aggPort)
if err != nil {
return err
}
return nil
}
Loading

0 comments on commit 712b62b

Please sign in to comment.