diff --git a/pkg/globalmanager/controllers/federatedlearning/federatedlearningjob.go b/pkg/globalmanager/controllers/federatedlearning/federatedlearningjob.go index ec3ddff17..c12e8a95c 100644 --- a/pkg/globalmanager/controllers/federatedlearning/federatedlearningjob.go +++ b/pkg/globalmanager/controllers/federatedlearning/federatedlearningjob.go @@ -20,11 +20,13 @@ import ( "context" "fmt" "strconv" + "sync" "time" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" utilrand "k8s.io/apimachinery/pkg/util/rand" utilruntime "k8s.io/apimachinery/pkg/util/runtime" "k8s.io/apimachinery/pkg/util/wait" @@ -88,6 +90,15 @@ type Controller struct { cfg *config.ControllerConfig sendToEdgeFunc runtime.DownstreamSendFunc + + // map to record the pods that are recreated + recreatedPods sync.Map + + flSelector labels.Selector + + aggServiceHost string + + preventRecreation bool } // Run starts the main goroutine responsible for watching and syncing jobs. @@ -190,6 +201,49 @@ func (c *Controller) deletePod(obj interface{}) { } } c.enqueueByPod(pod, true) + + // when the CRD is updated, do not recreate the pod + // if c.preventRecreation is true, do not recreate the pod + if c.preventRecreation { + return + } + // if pod is manually deleted, recreate it + // first check if the pod is owned by a FederatedLearningJob + controllerRef := metav1.GetControllerOf(pod) + if controllerRef == nil || controllerRef.Kind != Kind.Kind { + return + } + + // then check if the pod is already in the map + if _, exists := c.recreatedPods.Load(pod.Name); exists { + return + } + + // if not, recreate it + klog.Infof("Pod %s/%s deleted, recreating...", pod.Namespace, pod.Name) + // Create a deep copy of the old pod + newPod := pod.DeepCopy() + // Reset the resource version and UID as they are unique to each object + newPod.ResourceVersion = "" + newPod.UID = "" + // Clear the status + newPod.Status = v1.PodStatus{} + // Remove the deletion timestamp + newPod.DeletionTimestamp = nil + // Remove the deletion grace period seconds + newPod.DeletionGracePeriodSeconds = nil + _, err := c.kubeClient.CoreV1().Pods(pod.Namespace).Create(context.TODO(), newPod, metav1.CreateOptions{}) + if err != nil { + return + } + klog.Infof("Successfully recreated pod %s/%s", newPod.Namespace, newPod.Name) + // mark the pod as recreated + c.recreatedPods.Store(newPod.Name, true) + // set a timer to delete the record from the map after a while + go func() { + time.Sleep(5 * time.Second) + c.recreatedPods.Delete(pod.Name) + }() } // obj could be an *sednav1.FederatedLearningJob, or a DeletionFinalStateUnknown marker item, @@ -271,14 +325,16 @@ func (c *Controller) sync(key string) (bool, error) { return true, nil } - selector, _ := runtime.GenerateSelector(&job) - pods, err := c.podStore.Pods(job.Namespace).List(selector) + c.flSelector, _ = runtime.GenerateSelector(&job) + pods, err := c.podStore.Pods(job.Namespace).List(c.flSelector) if err != nil { return false, err } activePods := k8scontroller.FilterActivePods(pods) active := int32(len(activePods)) + var activeAgg int32 + var activeTrain int32 succeeded, failed := countPods(pods) conditions := len(job.Status.Conditions) @@ -289,6 +345,8 @@ func (c *Controller) sync(key string) (bool, error) { } var manageJobErr error + var manageAggErr error + var manageTrainErr error jobFailed := false var failureReason string var failureMessage string @@ -307,7 +365,13 @@ func (c *Controller) sync(key string) (bool, error) { } else { // in the First time, we create the pods if len(pods) == 0 { - active, manageJobErr = c.createPod(&job) + activeAgg, manageAggErr = c.createAggPod(&job) + createServiceErr := c.createService(&job) + if createServiceErr != nil { + return false, createServiceErr + } + activeTrain, manageTrainErr = c.createTrainPod(&job) + active = activeAgg + activeTrain } complete := false if succeeded > 0 && active == 0 { @@ -324,6 +388,10 @@ func (c *Controller) sync(key string) (bool, error) { } } + // Combine manageAggErr and manageTrainErr into a single error + if manageAggErr != nil || manageTrainErr != nil { + manageJobErr = fmt.Errorf("aggregator error: %v, training error: %v", manageAggErr, manageTrainErr) + } forget := false // Check if the number of jobs succeeded increased since the last check. If yes "forget" should be true // This logic is linked to the issue: https://github.com/kubernetes/kubernetes/issues/56853 that aims to @@ -499,8 +567,7 @@ func (c *Controller) addTransmitterToWorkerParam(param *runtime.WorkerParam, job return nil } - -func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32, err error) { +func (c *Controller) createAggPod(job *sednav1.FederatedLearningJob) (active int32, err error) { active = 0 ctx := context.Background() @@ -513,7 +580,7 @@ func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32, modelName := job.Spec.AggregationWorker.Model.Name model, modelSecret, err := c.getModelAndItsSecret(ctx, job.Namespace, modelName) if err != nil { - return active, err + return active, fmt.Errorf("failed to get aggregation model: %w", err) } participantsCount := strconv.Itoa(len(job.Spec.TrainingWorkers)) @@ -524,6 +591,7 @@ func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32, // Configure aggregation worker's mounts and envs var aggPort int32 = 7363 var aggWorkerParam runtime.WorkerParam + aggWorkerParam.Env = map[string]string{ "NAMESPACE": job.Namespace, "WORKER_NAME": "aggworker-" + utilrand.String(5), @@ -534,7 +602,7 @@ func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32, } if err := c.addTransmitterToWorkerParam(&aggWorkerParam, job); err != nil { - return active, err + return active, fmt.Errorf("failed to add transmitter to worker param: %w", err) } aggWorkerParam.WorkerType = jobStageAgg @@ -547,19 +615,36 @@ func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32, c.addWorkerMount(&aggWorkerParam, pretrainedModel.Spec.URL, "PRETRAINED_MODEL_URL", pretrainedModelSecret, true) } - + aggWorker.Template.Name = fmt.Sprintf("%s-aggworker", job.Name) // create aggpod based on configured parameters _, err = runtime.CreatePodWithTemplate(c.kubeClient, job, &aggWorker.Template, &aggWorkerParam) if err != nil { return active, fmt.Errorf("failed to create aggregation worker: %w", err) } + klog.Infof("create aggpod success") active++ + return +} + +func (c *Controller) createTrainPod(job *sednav1.FederatedLearningJob) (active int32, err error) { + active = 0 + ctx := context.Background() - aggServiceHost, err := runtime.CreateEdgeMeshService(c.kubeClient, job, jobStageAgg, aggPort) + pretrainedModelName := job.Spec.PretrainedModel.Name + pretrainedModel, pretrainedModelSecret, err := c.getModelAndItsSecret(ctx, job.Namespace, pretrainedModelName) if err != nil { - return active, err + return active, fmt.Errorf("failed to get pretrained model: %w", err) + } + + modelName := job.Spec.AggregationWorker.Model.Name + model, modelSecret, err := c.getModelAndItsSecret(ctx, job.Namespace, modelName) + if err != nil { + return active, fmt.Errorf("failed to get aggregation model: %w", err) } + var aggPort int32 = 7363 + participantsCount := strconv.Itoa(len(job.Spec.TrainingWorkers)) + // deliver pod for training worker for i, trainingWorker := range job.Spec.TrainingWorkers { // Configure training worker's mounts and envs @@ -583,7 +668,7 @@ func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32, workerParam.Env = map[string]string{ "AGG_PORT": strconv.Itoa(int(aggPort)), - "AGG_IP": aggServiceHost, + "AGG_IP": c.aggServiceHost, "WORKER_NAME": "trainworker-" + utilrand.String(5), "JOB_NAME": job.Name, @@ -593,14 +678,15 @@ func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32, "DATASET_NAME": datasetName, "LC_SERVER": c.cfg.LC.Server, } + workerParam.WorkerType = runtime.TrainPodType workerParam.HostNetwork = true workerParam.RestartPolicy = v1.RestartPolicyOnFailure if err := c.addTransmitterToWorkerParam(&workerParam, job); err != nil { - return active, err + return active, fmt.Errorf("failed to add transmitter to worker param: %w", err) } - + trainingWorker.Template.Name = fmt.Sprintf("%s-trainworker-%d", job.Name, i) // create training worker based on configured parameters _, err = runtime.CreatePodWithTemplate(c.kubeClient, job, &trainingWorker.Template, &workerParam) if err != nil { @@ -640,13 +726,8 @@ func New(cc *runtime.ControllerContext) (runtime.FeatureControllerI, error) { // send it to edge's LC. fc.syncToEdge(watch.Added, obj) }, - UpdateFunc: func(old, cur interface{}) { - fc.enqueueController(cur, true) + UpdateFunc: fc.updateJob, - // when a federated learning job is updated, - // send it to edge's LC as Added event. - fc.syncToEdge(watch.Added, cur) - }, DeleteFunc: func(obj interface{}) { fc.enqueueController(obj, true) @@ -669,3 +750,60 @@ func New(cc *runtime.ControllerContext) (runtime.FeatureControllerI, error) { return fc, nil } + +func (c *Controller) updateJob(old, cur interface{}) { + oldJob, ok := old.(*sednav1.FederatedLearningJob) + if !ok { + return + } + curJob, ok := cur.(*sednav1.FederatedLearningJob) + if !ok { + return + } + + if oldJob.ResourceVersion == curJob.ResourceVersion { + return + } + + if oldJob.Generation != curJob.Generation { + pods, err := c.podStore.Pods(curJob.Namespace).List(c.flSelector) + if err != nil { + klog.Errorf("Failed to list pods: %v", err) + } + c.preventRecreation = true + for _, pod := range pods { + // delete all pods + c.kubeClient.CoreV1().Pods(pod.Namespace).Delete(context.TODO(), pod.Name, metav1.DeleteOptions{}) + klog.Infof("CRD modified, so we deleted pod %s/%s", pod.Namespace, pod.Name) + } + klog.Infof("CRD modified, so we deleted all pods, and will create new pods") + curJob.SetGroupVersionKind(Kind) + _, err = c.createAggPod(curJob) + if err != nil { + klog.Errorf("Failed to create aggregation worker: %v", err) + } + _, err = c.createTrainPod(curJob) + if err != nil { + klog.Errorf("Failed to create training workers: %v", err) + } + // update the job status + c.client.FederatedLearningJobs(curJob.Namespace).Update(context.TODO(), curJob, metav1.UpdateOptions{}) + } + + c.preventRecreation = false + c.enqueueController(curJob, true) + + // when a federated learning job is updated, + // send it to edge's LC as Added event. + c.syncToEdge(watch.Added, curJob) +} + +// create edgemesh service for the job +func (c *Controller) createService(job *sednav1.FederatedLearningJob) (err error) { + var aggPort int32 = 7363 + c.aggServiceHost, err = runtime.CreateEdgeMeshService(c.kubeClient, job, jobStageAgg, aggPort) + if err != nil { + return err + } + return nil +} diff --git a/pkg/globalmanager/controllers/federatedlearning/federatedlearningjob_test.go b/pkg/globalmanager/controllers/federatedlearning/federatedlearningjob_test.go new file mode 100644 index 000000000..e1356bfd6 --- /dev/null +++ b/pkg/globalmanager/controllers/federatedlearning/federatedlearningjob_test.go @@ -0,0 +1,368 @@ +package federatedlearning + +import ( + "context" + "testing" + + corev1 "k8s.io/api/core/v1" + v1 "k8s.io/api/core/v1" + k8serrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/watch" + kubernetesfake "k8s.io/client-go/kubernetes/fake" + "k8s.io/client-go/kubernetes/scheme" + v1core "k8s.io/client-go/kubernetes/typed/core/v1" + corelisters "k8s.io/client-go/listers/core/v1" + "k8s.io/client-go/tools/record" + "k8s.io/client-go/util/workqueue" + + sednav1 "github.com/kubeedge/sedna/pkg/apis/sedna/v1alpha1" + fakeseednaclientset "github.com/kubeedge/sedna/pkg/client/clientset/versioned/fake" + "github.com/kubeedge/sedna/pkg/globalmanager/config" + "github.com/kubeedge/sedna/pkg/globalmanager/runtime" +) + +type mockPodLister struct { + pods []*v1.Pod +} + +func (m *mockPodLister) List(selector labels.Selector) (ret []*v1.Pod, err error) { + return m.pods, nil +} + +func (m *mockPodLister) Pods(namespace string) corelisters.PodNamespaceLister { + return mockPodNamespaceLister{pods: m.pods, namespace: namespace} +} + +type mockPodNamespaceLister struct { + pods []*v1.Pod + namespace string +} + +func (m mockPodNamespaceLister) List(selector labels.Selector) ([]*v1.Pod, error) { + var filteredPods []*v1.Pod + for _, pod := range m.pods { + if pod.Namespace == m.namespace { + filteredPods = append(filteredPods, pod) + } + } + return filteredPods, nil +} + +func (m mockPodNamespaceLister) Get(name string) (*v1.Pod, error) { + for _, pod := range m.pods { + if pod.Namespace == m.namespace && pod.Name == name { + return pod, nil + } + } + return nil, k8serrors.NewNotFound(corev1.Resource("pod"), name) +} + +// unit test for deletePod function +func Test_deletePod(t *testing.T) { + t.Run("delete existing pod successfully", func(t *testing.T) { + // Create a fake client + fakeClient := kubernetesfake.NewSimpleClientset() + + // Create a test pod + testPod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pod", + Namespace: "default", + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "test-container", + Image: "test-image", + }, + }, + }, + } + + // Create the pod using the fake client + _, err := fakeClient.CoreV1().Pods("default").Create(context.TODO(), testPod, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create test pod: %v", err) + } + + // Create a controller with the fake client + controller := &Controller{ + kubeClient: fakeClient, + } + + // Call deletePod function + controller.deletePod(testPod) + + // Verify that the pod was recreated + _, err = fakeClient.CoreV1().Pods("default").Get(context.TODO(), "test-pod", metav1.GetOptions{}) + if err != nil { + t.Fatalf("Pod was not recreated") + } + }) + + t.Run("delete non-existent pod", func(t *testing.T) { + // Create a fake client + fakeClient := kubernetesfake.NewSimpleClientset() + + // Create a controller with the fake client + controller := &Controller{ + kubeClient: fakeClient, + } + + // Call deletePod with a non-existent pod + nonExistentPod := corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "non-existent-container", + Image: "non-existent-image", + }, + }, + } + controller.deletePod(nonExistentPod) + // No error should occur, and the function should complete + // verify if the pod is deleted + _, err := fakeClient.CoreV1().Pods("default").Get(context.TODO(), "test-pod", metav1.GetOptions{}) + if err == nil { + t.Fatalf("Pod was not deleted") + } + }) +} + +// unit test for updateJob function +func Test_updateJob(t *testing.T) { + t.Run("update correct job parameter successfully", func(t *testing.T) { + // Create fake clients + fakeSednaClient := fakeseednaclientset.NewSimpleClientset() + + // Create a test job + oldJob := &sednav1.FederatedLearningJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-fl-job", + Namespace: "default", + }, + Spec: sednav1.FLJobSpec{ + AggregationWorker: sednav1.AggregationWorker{ + Model: sednav1.TrainModel{ + Name: "test-model", + }, + Template: v1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-fl-job-aggregation-worker", + }, + Spec: v1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "test-container", + Image: "test-image", + }, + }, + }, + }, + }, + TrainingWorkers: []sednav1.TrainingWorker{ + { + Dataset: sednav1.TrainDataset{ + Name: "test-dataset1", + }, + Template: v1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-fl-job-training-worker-0", + }, + Spec: v1.PodSpec{ + NodeName: "test-node1", + Containers: []corev1.Container{ + { + Name: "test-container", + Image: "test-image", + ImagePullPolicy: corev1.PullIfNotPresent, + Env: []corev1.EnvVar{ + { + Name: "batch_size", + Value: "32", + }, + { + Name: "learning_rate", + Value: "0.001", + }, + { + Name: "epochs", + Value: "2", + }, + }, + Resources: corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceMemory: resource.MustParse("2Gi"), + }, + }, + }, + }, + }, + }, + }, + { + Dataset: sednav1.TrainDataset{ + Name: "test-dataset2", + }, + Template: v1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-fl-job-training-worker-1", + }, + Spec: v1.PodSpec{ + NodeName: "test-node2", + Containers: []corev1.Container{ + { + Name: "test-container", + Image: "test-image", + ImagePullPolicy: corev1.PullIfNotPresent, + Env: []corev1.EnvVar{ + { + Name: "batch_size", + Value: "32", + }, + { + Name: "learning_rate", + Value: "0.001", + }, + { + Name: "epochs", + Value: "2", + }, + }, + Resources: corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceMemory: resource.MustParse("2Gi"), + }, + }, + }, + }, + }, + }, + }, + }, + PretrainedModel: sednav1.PretrainedModel{ + Name: "test-pretrained-model", + }, + }, + } + oldJob.ResourceVersion = "1" + // Create the job using the fake client + _, err := fakeSednaClient.SednaV1alpha1().FederatedLearningJobs("default").Create(context.TODO(), oldJob, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create test job: %v", err) + } + fakeKubeClient := kubernetesfake.NewSimpleClientset() + + // Create test pods + testPods := []*v1.Pod{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "test-fl-job-aggregation-worker", + Namespace: "default", + }, + Spec: oldJob.Spec.AggregationWorker.Template.Spec, + }, + { + ObjectMeta: metav1.ObjectMeta{ + Name: "test-fl-job-training-worker-0", + Namespace: "default", + }, + Spec: oldJob.Spec.TrainingWorkers[0].Template.Spec, + }, + { + ObjectMeta: metav1.ObjectMeta{ + Name: "test-fl-job-training-worker-1", + Namespace: "default", + }, + Spec: oldJob.Spec.TrainingWorkers[1].Template.Spec, + }, + } + + // create pretrained model resource + pretrainedModel := &sednav1.Model{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pretrained-model", + Namespace: "default", + }, + } + _, err = fakeSednaClient.SednaV1alpha1().Models("default").Create(context.TODO(), pretrainedModel, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create pretrained model: %v", err) + } + // create model resource + model := &sednav1.Model{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-model", + Namespace: "default", + }, + } + _, err = fakeSednaClient.SednaV1alpha1().Models("default").Create(context.TODO(), model, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create model: %v", err) + } + + // create dataset1 resource + dataset1 := &sednav1.Dataset{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dataset1", + Namespace: "default", + }, + } + _, err = fakeSednaClient.SednaV1alpha1().Datasets("default").Create(context.TODO(), dataset1, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create dataset: %v", err) + } + + // create dataset2 resource + dataset2 := &sednav1.Dataset{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dataset2", + Namespace: "default", + }, + } + _, err = fakeSednaClient.SednaV1alpha1().Datasets("default").Create(context.TODO(), dataset2, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create dataset: %v", err) + } + cfg := &config.ControllerConfig{ + LC: config.LCConfig{ + Server: "http://test-lc-server:8080", + }, + } + eventBroadcaster := record.NewBroadcaster() + eventBroadcaster.StartRecordingToSink(&v1core.EventSinkImpl{Interface: fakeKubeClient.CoreV1().Events("")}) + // Create a controller with the fake clients + c := &Controller{ + kubeClient: fakeKubeClient, + client: fakeSednaClient.SednaV1alpha1(), + podStore: &mockPodLister{pods: testPods}, + flSelector: labels.SelectorFromSet(labels.Set{"federatedlearningjob.sedna.io/job-name": "test-fl-job"}), + queue: workqueue.NewNamedRateLimitingQueue(workqueue.NewItemExponentialFailureRateLimiter(runtime.DefaultBackOff, runtime.MaxBackOff), "test-fl-job"), + recorder: eventBroadcaster.NewRecorder(scheme.Scheme, v1.EventSource{Component: "test-fl-job"}), + cfg: cfg, + sendToEdgeFunc: func(nodeName string, eventType watch.EventType, job interface{}) error { + return nil + }, + } + c.aggServiceHost = "test-fl-job-aggregation.default" + + // Update the job + newJob := oldJob.DeepCopy() + newJob.Spec.TrainingWorkers[0].Template.Spec.Containers[0].Env[0].Value = "16" + newJob.Generation = 2 + newJob.ResourceVersion = "2" + + c.updateJob(oldJob, newJob) + + // Verify that the job was updated + updatedJob, err := fakeSednaClient.SednaV1alpha1().FederatedLearningJobs("default").Get(context.TODO(), "test-fl-job", metav1.GetOptions{}) + if err != nil { + t.Fatalf("Failed to get updated job: %v", err) + } + if updatedJob.Spec.TrainingWorkers[0].Template.Spec.Containers[0].Env[0].Value != "16" { + t.Fatalf("Job was not updated correctly") + } + }) +}