diff --git a/pkg/globalmanager/controllers/federatedlearning/federatedlearningjob.go b/pkg/globalmanager/controllers/federatedlearning/federatedlearningjob.go index 39e7befd5..c12e8a95c 100644 --- a/pkg/globalmanager/controllers/federatedlearning/federatedlearningjob.go +++ b/pkg/globalmanager/controllers/federatedlearning/federatedlearningjob.go @@ -94,7 +94,7 @@ type Controller struct { // map to record the pods that are recreated recreatedPods sync.Map - fl_selector labels.Selector + flSelector labels.Selector aggServiceHost string @@ -325,8 +325,8 @@ func (c *Controller) sync(key string) (bool, error) { return true, nil } - c.fl_selector, _ = runtime.GenerateSelector(&job) - pods, err := c.podStore.Pods(job.Namespace).List(c.fl_selector) + c.flSelector, _ = runtime.GenerateSelector(&job) + pods, err := c.podStore.Pods(job.Namespace).List(c.flSelector) if err != nil { return false, err } @@ -766,7 +766,7 @@ func (c *Controller) updateJob(old, cur interface{}) { } if oldJob.Generation != curJob.Generation { - pods, err := c.podStore.Pods(curJob.Namespace).List(c.fl_selector) + pods, err := c.podStore.Pods(curJob.Namespace).List(c.flSelector) if err != nil { klog.Errorf("Failed to list pods: %v", err) } @@ -806,4 +806,4 @@ func (c *Controller) createService(job *sednav1.FederatedLearningJob) (err error return err } return nil -} \ No newline at end of file +} diff --git a/pkg/globalmanager/controllers/federatedlearning/federatedlearningjob_test.go b/pkg/globalmanager/controllers/federatedlearning/federatedlearningjob_test.go index d8a6a02f9..e1356bfd6 100644 --- a/pkg/globalmanager/controllers/federatedlearning/federatedlearningjob_test.go +++ b/pkg/globalmanager/controllers/federatedlearning/federatedlearningjob_test.go @@ -1,368 +1,368 @@ -package federatedlearning - -import ( - "context" - "testing" - - corev1 "k8s.io/api/core/v1" - v1 "k8s.io/api/core/v1" - k8serrors "k8s.io/apimachinery/pkg/api/errors" - "k8s.io/apimachinery/pkg/api/resource" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/labels" - "k8s.io/apimachinery/pkg/watch" - kubernetesfake "k8s.io/client-go/kubernetes/fake" - "k8s.io/client-go/kubernetes/scheme" - v1core "k8s.io/client-go/kubernetes/typed/core/v1" - corelisters "k8s.io/client-go/listers/core/v1" - "k8s.io/client-go/tools/record" - "k8s.io/client-go/util/workqueue" - - sednav1 "github.com/kubeedge/sedna/pkg/apis/sedna/v1alpha1" - fakeseednaclientset "github.com/kubeedge/sedna/pkg/client/clientset/versioned/fake" - "github.com/kubeedge/sedna/pkg/globalmanager/config" - "github.com/kubeedge/sedna/pkg/globalmanager/runtime" -) - -type mockPodLister struct { - pods []*v1.Pod -} - -func (m *mockPodLister) List(selector labels.Selector) (ret []*v1.Pod, err error) { - return m.pods, nil -} - -func (m *mockPodLister) Pods(namespace string) corelisters.PodNamespaceLister { - return mockPodNamespaceLister{pods: m.pods, namespace: namespace} -} - -type mockPodNamespaceLister struct { - pods []*v1.Pod - namespace string -} - -func (m mockPodNamespaceLister) List(selector labels.Selector) ([]*v1.Pod, error) { - var filteredPods []*v1.Pod - for _, pod := range m.pods { - if pod.Namespace == m.namespace { - filteredPods = append(filteredPods, pod) - } - } - return filteredPods, nil -} - -func (m mockPodNamespaceLister) Get(name string) (*v1.Pod, error) { - for _, pod := range m.pods { - if pod.Namespace == m.namespace && pod.Name == name { - return pod, nil - } - } - return nil, k8serrors.NewNotFound(corev1.Resource("pod"), name) -} - -// unit test for deletePod function -func Test_deletePod(t *testing.T) { - t.Run("delete existing pod successfully", func(t *testing.T) { - // Create a fake client - fakeClient := kubernetesfake.NewSimpleClientset() - - // Create a test pod - testPod := &corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-pod", - Namespace: "default", - }, - Spec: corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "test-container", - Image: "test-image", - }, - }, - }, - } - - // Create the pod using the fake client - _, err := fakeClient.CoreV1().Pods("default").Create(context.TODO(), testPod, metav1.CreateOptions{}) - if err != nil { - t.Fatalf("Failed to create test pod: %v", err) - } - - // Create a controller with the fake client - controller := &Controller{ - kubeClient: fakeClient, - } - - // Call deletePod function - controller.deletePod(testPod) - - // Verify that the pod was recreated - _, err = fakeClient.CoreV1().Pods("default").Get(context.TODO(), "test-pod", metav1.GetOptions{}) - if err != nil { - t.Fatalf("Pod was not recreated") - } - }) - - t.Run("delete non-existent pod", func(t *testing.T) { - // Create a fake client - fakeClient := kubernetesfake.NewSimpleClientset() - - // Create a controller with the fake client - controller := &Controller{ - kubeClient: fakeClient, - } - - // Call deletePod with a non-existent pod - nonExistentPod := corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "non-existent-container", - Image: "non-existent-image", - }, - }, - } - controller.deletePod(nonExistentPod) - // No error should occur, and the function should complete - // verify if the pod is deleted - _, err := fakeClient.CoreV1().Pods("default").Get(context.TODO(), "test-pod", metav1.GetOptions{}) - if err == nil { - t.Fatalf("Pod was not deleted") - } - }) -} - -// unit test for updateJob function -func Test_updateJob(t *testing.T) { - t.Run("update correct job parameter successfully", func(t *testing.T) { - // Create fake clients - fakeSednaClient := fakeseednaclientset.NewSimpleClientset() - - // Create a test job - oldJob := &sednav1.FederatedLearningJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-fl-job", - Namespace: "default", - }, - Spec: sednav1.FLJobSpec{ - AggregationWorker: sednav1.AggregationWorker{ - Model: sednav1.TrainModel{ - Name: "test-model", - }, - Template: v1.PodTemplateSpec{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-fl-job-aggregation-worker", - }, - Spec: v1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "test-container", - Image: "test-image", - }, - }, - }, - }, - }, - TrainingWorkers: []sednav1.TrainingWorker{ - { - Dataset: sednav1.TrainDataset{ - Name: "test-dataset1", - }, - Template: v1.PodTemplateSpec{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-fl-job-training-worker-0", - }, - Spec: v1.PodSpec{ - NodeName: "test-node1", - Containers: []corev1.Container{ - { - Name: "test-container", - Image: "test-image", - ImagePullPolicy: corev1.PullIfNotPresent, - Env: []corev1.EnvVar{ - { - Name: "batch_size", - Value: "32", - }, - { - Name: "learning_rate", - Value: "0.001", - }, - { - Name: "epochs", - Value: "2", - }, - }, - Resources: corev1.ResourceRequirements{ - Requests: corev1.ResourceList{ - corev1.ResourceMemory: resource.MustParse("2Gi"), - }, - }, - }, - }, - }, - }, - }, - { - Dataset: sednav1.TrainDataset{ - Name: "test-dataset2", - }, - Template: v1.PodTemplateSpec{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-fl-job-training-worker-1", - }, - Spec: v1.PodSpec{ - NodeName: "test-node2", - Containers: []corev1.Container{ - { - Name: "test-container", - Image: "test-image", - ImagePullPolicy: corev1.PullIfNotPresent, - Env: []corev1.EnvVar{ - { - Name: "batch_size", - Value: "32", - }, - { - Name: "learning_rate", - Value: "0.001", - }, - { - Name: "epochs", - Value: "2", - }, - }, - Resources: corev1.ResourceRequirements{ - Requests: corev1.ResourceList{ - corev1.ResourceMemory: resource.MustParse("2Gi"), - }, - }, - }, - }, - }, - }, - }, - }, - PretrainedModel: sednav1.PretrainedModel{ - Name: "test-pretrained-model", - }, - }, - } - oldJob.ResourceVersion = "1" - // Create the job using the fake client - _, err := fakeSednaClient.SednaV1alpha1().FederatedLearningJobs("default").Create(context.TODO(), oldJob, metav1.CreateOptions{}) - if err != nil { - t.Fatalf("Failed to create test job: %v", err) - } - fakeKubeClient := kubernetesfake.NewSimpleClientset() - - // Create test pods - testPods := []*v1.Pod{ - { - ObjectMeta: metav1.ObjectMeta{ - Name: "test-fl-job-aggregation-worker", - Namespace: "default", - }, - Spec: oldJob.Spec.AggregationWorker.Template.Spec, - }, - { - ObjectMeta: metav1.ObjectMeta{ - Name: "test-fl-job-training-worker-0", - Namespace: "default", - }, - Spec: oldJob.Spec.TrainingWorkers[0].Template.Spec, - }, - { - ObjectMeta: metav1.ObjectMeta{ - Name: "test-fl-job-training-worker-1", - Namespace: "default", - }, - Spec: oldJob.Spec.TrainingWorkers[1].Template.Spec, - }, - } - - // create pretrained model resource - pretrainedModel := &sednav1.Model{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-pretrained-model", - Namespace: "default", - }, - } - _, err = fakeSednaClient.SednaV1alpha1().Models("default").Create(context.TODO(), pretrainedModel, metav1.CreateOptions{}) - if err != nil { - t.Fatalf("Failed to create pretrained model: %v", err) - } - // create model resource - model := &sednav1.Model{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-model", - Namespace: "default", - }, - } - _, err = fakeSednaClient.SednaV1alpha1().Models("default").Create(context.TODO(), model, metav1.CreateOptions{}) - if err != nil { - t.Fatalf("Failed to create model: %v", err) - } - - // create dataset1 resource - dataset1 := &sednav1.Dataset{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-dataset1", - Namespace: "default", - }, - } - _, err = fakeSednaClient.SednaV1alpha1().Datasets("default").Create(context.TODO(), dataset1, metav1.CreateOptions{}) - if err != nil { - t.Fatalf("Failed to create dataset: %v", err) - } - - // create dataset2 resource - dataset2 := &sednav1.Dataset{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-dataset2", - Namespace: "default", - }, - } - _, err = fakeSednaClient.SednaV1alpha1().Datasets("default").Create(context.TODO(), dataset2, metav1.CreateOptions{}) - if err != nil { - t.Fatalf("Failed to create dataset: %v", err) - } - cfg := &config.ControllerConfig{ - LC: config.LCConfig{ - Server: "http://test-lc-server:8080", - }, - } - eventBroadcaster := record.NewBroadcaster() - eventBroadcaster.StartRecordingToSink(&v1core.EventSinkImpl{Interface: fakeKubeClient.CoreV1().Events("")}) - // Create a controller with the fake clients - c := &Controller{ - kubeClient: fakeKubeClient, - client: fakeSednaClient.SednaV1alpha1(), - podStore: &mockPodLister{pods: testPods}, - fl_selector: labels.SelectorFromSet(labels.Set{"federatedlearningjob.sedna.io/job-name": "test-fl-job"}), - queue: workqueue.NewNamedRateLimitingQueue(workqueue.NewItemExponentialFailureRateLimiter(runtime.DefaultBackOff, runtime.MaxBackOff), "test-fl-job"), - recorder: eventBroadcaster.NewRecorder(scheme.Scheme, v1.EventSource{Component: "test-fl-job"}), - cfg: cfg, - sendToEdgeFunc: func(nodeName string, eventType watch.EventType, job interface{}) error { - return nil - }, - } - c.aggServiceHost = "test-fl-job-aggregation.default" - - // Update the job - newJob := oldJob.DeepCopy() - newJob.Spec.TrainingWorkers[0].Template.Spec.Containers[0].Env[0].Value = "16" - newJob.Generation = 2 - newJob.ResourceVersion = "2" - - c.updateJob(oldJob, newJob) - - // Verify that the job was updated - updatedJob, err := fakeSednaClient.SednaV1alpha1().FederatedLearningJobs("default").Get(context.TODO(), "test-fl-job", metav1.GetOptions{}) - if err != nil { - t.Fatalf("Failed to get updated job: %v", err) - } - if updatedJob.Spec.TrainingWorkers[0].Template.Spec.Containers[0].Env[0].Value != "16" { - t.Fatalf("Job was not updated correctly") - } - }) -} +package federatedlearning + +import ( + "context" + "testing" + + corev1 "k8s.io/api/core/v1" + v1 "k8s.io/api/core/v1" + k8serrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/watch" + kubernetesfake "k8s.io/client-go/kubernetes/fake" + "k8s.io/client-go/kubernetes/scheme" + v1core "k8s.io/client-go/kubernetes/typed/core/v1" + corelisters "k8s.io/client-go/listers/core/v1" + "k8s.io/client-go/tools/record" + "k8s.io/client-go/util/workqueue" + + sednav1 "github.com/kubeedge/sedna/pkg/apis/sedna/v1alpha1" + fakeseednaclientset "github.com/kubeedge/sedna/pkg/client/clientset/versioned/fake" + "github.com/kubeedge/sedna/pkg/globalmanager/config" + "github.com/kubeedge/sedna/pkg/globalmanager/runtime" +) + +type mockPodLister struct { + pods []*v1.Pod +} + +func (m *mockPodLister) List(selector labels.Selector) (ret []*v1.Pod, err error) { + return m.pods, nil +} + +func (m *mockPodLister) Pods(namespace string) corelisters.PodNamespaceLister { + return mockPodNamespaceLister{pods: m.pods, namespace: namespace} +} + +type mockPodNamespaceLister struct { + pods []*v1.Pod + namespace string +} + +func (m mockPodNamespaceLister) List(selector labels.Selector) ([]*v1.Pod, error) { + var filteredPods []*v1.Pod + for _, pod := range m.pods { + if pod.Namespace == m.namespace { + filteredPods = append(filteredPods, pod) + } + } + return filteredPods, nil +} + +func (m mockPodNamespaceLister) Get(name string) (*v1.Pod, error) { + for _, pod := range m.pods { + if pod.Namespace == m.namespace && pod.Name == name { + return pod, nil + } + } + return nil, k8serrors.NewNotFound(corev1.Resource("pod"), name) +} + +// unit test for deletePod function +func Test_deletePod(t *testing.T) { + t.Run("delete existing pod successfully", func(t *testing.T) { + // Create a fake client + fakeClient := kubernetesfake.NewSimpleClientset() + + // Create a test pod + testPod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pod", + Namespace: "default", + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "test-container", + Image: "test-image", + }, + }, + }, + } + + // Create the pod using the fake client + _, err := fakeClient.CoreV1().Pods("default").Create(context.TODO(), testPod, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create test pod: %v", err) + } + + // Create a controller with the fake client + controller := &Controller{ + kubeClient: fakeClient, + } + + // Call deletePod function + controller.deletePod(testPod) + + // Verify that the pod was recreated + _, err = fakeClient.CoreV1().Pods("default").Get(context.TODO(), "test-pod", metav1.GetOptions{}) + if err != nil { + t.Fatalf("Pod was not recreated") + } + }) + + t.Run("delete non-existent pod", func(t *testing.T) { + // Create a fake client + fakeClient := kubernetesfake.NewSimpleClientset() + + // Create a controller with the fake client + controller := &Controller{ + kubeClient: fakeClient, + } + + // Call deletePod with a non-existent pod + nonExistentPod := corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "non-existent-container", + Image: "non-existent-image", + }, + }, + } + controller.deletePod(nonExistentPod) + // No error should occur, and the function should complete + // verify if the pod is deleted + _, err := fakeClient.CoreV1().Pods("default").Get(context.TODO(), "test-pod", metav1.GetOptions{}) + if err == nil { + t.Fatalf("Pod was not deleted") + } + }) +} + +// unit test for updateJob function +func Test_updateJob(t *testing.T) { + t.Run("update correct job parameter successfully", func(t *testing.T) { + // Create fake clients + fakeSednaClient := fakeseednaclientset.NewSimpleClientset() + + // Create a test job + oldJob := &sednav1.FederatedLearningJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-fl-job", + Namespace: "default", + }, + Spec: sednav1.FLJobSpec{ + AggregationWorker: sednav1.AggregationWorker{ + Model: sednav1.TrainModel{ + Name: "test-model", + }, + Template: v1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-fl-job-aggregation-worker", + }, + Spec: v1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "test-container", + Image: "test-image", + }, + }, + }, + }, + }, + TrainingWorkers: []sednav1.TrainingWorker{ + { + Dataset: sednav1.TrainDataset{ + Name: "test-dataset1", + }, + Template: v1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-fl-job-training-worker-0", + }, + Spec: v1.PodSpec{ + NodeName: "test-node1", + Containers: []corev1.Container{ + { + Name: "test-container", + Image: "test-image", + ImagePullPolicy: corev1.PullIfNotPresent, + Env: []corev1.EnvVar{ + { + Name: "batch_size", + Value: "32", + }, + { + Name: "learning_rate", + Value: "0.001", + }, + { + Name: "epochs", + Value: "2", + }, + }, + Resources: corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceMemory: resource.MustParse("2Gi"), + }, + }, + }, + }, + }, + }, + }, + { + Dataset: sednav1.TrainDataset{ + Name: "test-dataset2", + }, + Template: v1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-fl-job-training-worker-1", + }, + Spec: v1.PodSpec{ + NodeName: "test-node2", + Containers: []corev1.Container{ + { + Name: "test-container", + Image: "test-image", + ImagePullPolicy: corev1.PullIfNotPresent, + Env: []corev1.EnvVar{ + { + Name: "batch_size", + Value: "32", + }, + { + Name: "learning_rate", + Value: "0.001", + }, + { + Name: "epochs", + Value: "2", + }, + }, + Resources: corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceMemory: resource.MustParse("2Gi"), + }, + }, + }, + }, + }, + }, + }, + }, + PretrainedModel: sednav1.PretrainedModel{ + Name: "test-pretrained-model", + }, + }, + } + oldJob.ResourceVersion = "1" + // Create the job using the fake client + _, err := fakeSednaClient.SednaV1alpha1().FederatedLearningJobs("default").Create(context.TODO(), oldJob, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create test job: %v", err) + } + fakeKubeClient := kubernetesfake.NewSimpleClientset() + + // Create test pods + testPods := []*v1.Pod{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "test-fl-job-aggregation-worker", + Namespace: "default", + }, + Spec: oldJob.Spec.AggregationWorker.Template.Spec, + }, + { + ObjectMeta: metav1.ObjectMeta{ + Name: "test-fl-job-training-worker-0", + Namespace: "default", + }, + Spec: oldJob.Spec.TrainingWorkers[0].Template.Spec, + }, + { + ObjectMeta: metav1.ObjectMeta{ + Name: "test-fl-job-training-worker-1", + Namespace: "default", + }, + Spec: oldJob.Spec.TrainingWorkers[1].Template.Spec, + }, + } + + // create pretrained model resource + pretrainedModel := &sednav1.Model{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pretrained-model", + Namespace: "default", + }, + } + _, err = fakeSednaClient.SednaV1alpha1().Models("default").Create(context.TODO(), pretrainedModel, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create pretrained model: %v", err) + } + // create model resource + model := &sednav1.Model{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-model", + Namespace: "default", + }, + } + _, err = fakeSednaClient.SednaV1alpha1().Models("default").Create(context.TODO(), model, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create model: %v", err) + } + + // create dataset1 resource + dataset1 := &sednav1.Dataset{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dataset1", + Namespace: "default", + }, + } + _, err = fakeSednaClient.SednaV1alpha1().Datasets("default").Create(context.TODO(), dataset1, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create dataset: %v", err) + } + + // create dataset2 resource + dataset2 := &sednav1.Dataset{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dataset2", + Namespace: "default", + }, + } + _, err = fakeSednaClient.SednaV1alpha1().Datasets("default").Create(context.TODO(), dataset2, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create dataset: %v", err) + } + cfg := &config.ControllerConfig{ + LC: config.LCConfig{ + Server: "http://test-lc-server:8080", + }, + } + eventBroadcaster := record.NewBroadcaster() + eventBroadcaster.StartRecordingToSink(&v1core.EventSinkImpl{Interface: fakeKubeClient.CoreV1().Events("")}) + // Create a controller with the fake clients + c := &Controller{ + kubeClient: fakeKubeClient, + client: fakeSednaClient.SednaV1alpha1(), + podStore: &mockPodLister{pods: testPods}, + flSelector: labels.SelectorFromSet(labels.Set{"federatedlearningjob.sedna.io/job-name": "test-fl-job"}), + queue: workqueue.NewNamedRateLimitingQueue(workqueue.NewItemExponentialFailureRateLimiter(runtime.DefaultBackOff, runtime.MaxBackOff), "test-fl-job"), + recorder: eventBroadcaster.NewRecorder(scheme.Scheme, v1.EventSource{Component: "test-fl-job"}), + cfg: cfg, + sendToEdgeFunc: func(nodeName string, eventType watch.EventType, job interface{}) error { + return nil + }, + } + c.aggServiceHost = "test-fl-job-aggregation.default" + + // Update the job + newJob := oldJob.DeepCopy() + newJob.Spec.TrainingWorkers[0].Template.Spec.Containers[0].Env[0].Value = "16" + newJob.Generation = 2 + newJob.ResourceVersion = "2" + + c.updateJob(oldJob, newJob) + + // Verify that the job was updated + updatedJob, err := fakeSednaClient.SednaV1alpha1().FederatedLearningJobs("default").Get(context.TODO(), "test-fl-job", metav1.GetOptions{}) + if err != nil { + t.Fatalf("Failed to get updated job: %v", err) + } + if updatedJob.Spec.TrainingWorkers[0].Template.Spec.Containers[0].Env[0].Value != "16" { + t.Fatalf("Job was not updated correctly") + } + }) +} diff --git a/pkg/globalmanager/controllers/jointinference/jointinferenceservice_test.go b/pkg/globalmanager/controllers/jointinference/jointinferenceservice_test.go index 8c60cbe41..8fc24ed2d 100644 --- a/pkg/globalmanager/controllers/jointinference/jointinferenceservice_test.go +++ b/pkg/globalmanager/controllers/jointinference/jointinferenceservice_test.go @@ -1,307 +1,307 @@ -package jointinference - -import ( - "context" - "testing" - - sednav1 "github.com/kubeedge/sedna/pkg/apis/sedna/v1alpha1" - fakeseednaclientset "github.com/kubeedge/sedna/pkg/client/clientset/versioned/fake" - "github.com/kubeedge/sedna/pkg/globalmanager/config" - "github.com/kubeedge/sedna/pkg/globalmanager/runtime" - appsv1 "k8s.io/api/apps/v1" - v1 "k8s.io/api/core/v1" - k8serrors "k8s.io/apimachinery/pkg/api/errors" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/labels" - "k8s.io/apimachinery/pkg/watch" - kubernetesfake "k8s.io/client-go/kubernetes/fake" - "k8s.io/client-go/kubernetes/scheme" - v1core "k8s.io/client-go/kubernetes/typed/core/v1" - corelisters "k8s.io/client-go/listers/apps/v1" - corelistersv1 "k8s.io/client-go/listers/core/v1" - "k8s.io/client-go/tools/record" - "k8s.io/client-go/util/workqueue" -) - -type mockPodLister struct { - pods []*v1.Pod -} - -type mockPodNamespaceLister struct { - pods []*v1.Pod - namespace string -} - -func (m *mockPodLister) Pods(namespace string) corelistersv1.PodNamespaceLister { - return mockPodNamespaceLister{pods: m.pods, namespace: namespace} -} - -func (m *mockPodLister) List(selector labels.Selector) (ret []*v1.Pod, err error) { - return m.pods, nil -} - -func (m mockPodNamespaceLister) List(selector labels.Selector) ([]*v1.Pod, error) { - var filteredPods []*v1.Pod - for _, pod := range m.pods { - if pod.Namespace == m.namespace { - filteredPods = append(filteredPods, pod) - } - } - return filteredPods, nil -} - -type mockDeploymentLister struct { - deployments []*appsv1.Deployment -} - -func (m *mockDeploymentLister) List(selector labels.Selector) (ret []*appsv1.Deployment, err error) { - return m.deployments, nil -} - -func (m *mockDeploymentLister) Deployments(namespace string) corelisters.DeploymentNamespaceLister { - return mockDeploymentNamespaceLister{deployments: m.deployments, namespace: namespace} -} - -func (m mockPodNamespaceLister) Get(name string) (*v1.Pod, error) { - for _, pod := range m.pods { - if pod.Namespace == m.namespace && pod.Name == name { - return pod, nil - } - } - return nil, k8serrors.NewNotFound(v1.Resource("pod"), name) -} - -type mockDeploymentNamespaceLister struct { - deployments []*appsv1.Deployment - namespace string -} - -func (m mockDeploymentNamespaceLister) List(selector labels.Selector) ([]*appsv1.Deployment, error) { - var filteredDeployments []*appsv1.Deployment - for _, deployment := range m.deployments { - if deployment.Namespace == m.namespace { - filteredDeployments = append(filteredDeployments, deployment) - } - } - return filteredDeployments, nil -} - -func (m mockDeploymentNamespaceLister) Get(name string) (*appsv1.Deployment, error) { - for _, deployment := range m.deployments { - if deployment.Namespace == m.namespace && deployment.Name == name { - return deployment, nil - } - } - return nil, k8serrors.NewNotFound(v1.Resource("deployment"), name) -} - -func Test_updateService(t *testing.T) { - t.Run("update joint inference service successfully", func(t *testing.T) { - // Create fake clients - fakeSednaClient := fakeseednaclientset.NewSimpleClientset() - fakeKubeClient := kubernetesfake.NewSimpleClientset() - - // Create a test joint inference service - oldService := &sednav1.JointInferenceService{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-ji-service", - Namespace: "default", - Generation: 1, - ResourceVersion: "1", - }, - Spec: sednav1.JointInferenceServiceSpec{ - EdgeWorker: sednav1.EdgeWorker{ - Model: sednav1.SmallModel{ - Name: "test-edge-model", - }, - Template: v1.PodTemplateSpec{ - Spec: v1.PodSpec{ - Containers: []v1.Container{ - { - Name: "edge-container", - Image: "edge-image:v1", - }, - }, - }, - }, - HardExampleMining: sednav1.HardExampleMining{ - Name: "test-hem", - Parameters: []sednav1.ParaSpec{ - { - Key: "param1", - Value: "value1", - }, - }, - }, - }, - CloudWorker: sednav1.CloudWorker{ - Model: sednav1.BigModel{ - Name: "test-cloud-model", - }, - Template: v1.PodTemplateSpec{ - Spec: v1.PodSpec{ - Containers: []v1.Container{ - { - Name: "cloud-container", - Image: "cloud-image:v1", - }, - }, - }, - }, - }, - }, - } - - //Create Big Model Resource Object for Cloud - bigModel := &sednav1.Model{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-cloud-model", - Namespace: "default", - }, - } - _, err := fakeSednaClient.SednaV1alpha1().Models("default").Create(context.TODO(), bigModel, metav1.CreateOptions{}) - if err != nil { - t.Fatalf("Failed to create test big model: %v", err) - } - - // Create Small Model Resource Object for Edge - smallModel := &sednav1.Model{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-edge-model", - Namespace: "default", - }, - } - _, err = fakeSednaClient.SednaV1alpha1().Models("default").Create(context.TODO(), smallModel, metav1.CreateOptions{}) - if err != nil { - t.Fatalf("Failed to create test small model: %v", err) - } - - // Create the service using the fake client - _, err = fakeSednaClient.SednaV1alpha1().JointInferenceServices("default").Create(context.TODO(), oldService, metav1.CreateOptions{}) - if err != nil { - t.Fatalf("Failed to create test service: %v", err) - } - - // Create test deployments - edgeDeployment := &appsv1.Deployment{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-ji-deployment-edge", - Namespace: "default", - }, - Spec: appsv1.DeploymentSpec{ - Template: oldService.Spec.EdgeWorker.Template, - }, - } - cloudDeployment := &appsv1.Deployment{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-ji-deployment-cloud", - Namespace: "default", - }, - Spec: appsv1.DeploymentSpec{ - Template: oldService.Spec.CloudWorker.Template, - }, - } - - _, err = fakeKubeClient.AppsV1().Deployments("default").Create(context.TODO(), edgeDeployment, metav1.CreateOptions{}) - if err != nil { - t.Fatalf("Failed to create edge deployment: %v", err) - } - _, err = fakeKubeClient.AppsV1().Deployments("default").Create(context.TODO(), cloudDeployment, metav1.CreateOptions{}) - if err != nil { - t.Fatalf("Failed to create cloud deployment: %v", err) - } - - // Manually create pods for the deployments - edgePod := &v1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-ji-service-edge-pod", - Namespace: "default", - Labels: map[string]string{ - "jointinferenceservice.sedna.io/service-name": "test-ji-service", - }, - OwnerReferences: []metav1.OwnerReference{ - { - APIVersion: "apps/v1", - Kind: "Deployment", - Name: edgeDeployment.Name, - UID: edgeDeployment.UID, - }, - }, - }, - Spec: edgeDeployment.Spec.Template.Spec, - } - cloudPod := &v1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-ji-service-cloud-pod", - Namespace: "default", - Labels: map[string]string{ - "jointinferenceservice.sedna.io/service-name": "test-ji-service", - }, - OwnerReferences: []metav1.OwnerReference{ - { - APIVersion: "apps/v1", - Kind: "Deployment", - Name: cloudDeployment.Name, - UID: cloudDeployment.UID, - }, - }, - }, - Spec: cloudDeployment.Spec.Template.Spec, - } - - // Add pods to the fake client - _, err = fakeKubeClient.CoreV1().Pods("default").Create(context.TODO(), edgePod, metav1.CreateOptions{}) - if err != nil { - t.Fatalf("Failed to create edge pod: %v", err) - } - _, err = fakeKubeClient.CoreV1().Pods("default").Create(context.TODO(), cloudPod, metav1.CreateOptions{}) - if err != nil { - t.Fatalf("Failed to create cloud pod: %v", err) - } - - cfg := &config.ControllerConfig{ - LC: config.LCConfig{ - Server: "http://test-lc-server:8080", - }, - } - - eventBroadcaster := record.NewBroadcaster() - eventBroadcaster.StartRecordingToSink(&v1core.EventSinkImpl{Interface: fakeKubeClient.CoreV1().Events("")}) - - // Create a controller with the fake clients - c := &Controller{ - kubeClient: fakeKubeClient, - client: fakeSednaClient.SednaV1alpha1(), - queue: workqueue.NewNamedRateLimitingQueue(workqueue.NewItemExponentialFailureRateLimiter(runtime.DefaultBackOff, runtime.MaxBackOff), "test-ji-service"), - recorder: eventBroadcaster.NewRecorder(scheme.Scheme, v1.EventSource{Component: "test-ji-service"}), - cfg: cfg, - deploymentsLister: &mockDeploymentLister{deployments: []*appsv1.Deployment{edgeDeployment, cloudDeployment}}, - sendToEdgeFunc: func(nodeName string, eventType watch.EventType, job interface{}) error { - return nil - }, - } - - // Update the service - newService := oldService.DeepCopy() - // change parameter of hard example mining - newService.Spec.EdgeWorker.HardExampleMining.Parameters[0].Value = "value2" - newService.Generation = 2 - newService.ResourceVersion = "2" - // Call updateService function - c.createOrUpdateWorker(newService, jointInferenceForCloud, "test-ji-service.default", 8080, true) - c.createOrUpdateWorker(newService, jointInferenceForEdge, "test-ji-service.default", 8080, true) - // update service in fakeSednaClient - _, err = fakeSednaClient.SednaV1alpha1().JointInferenceServices("default").Update(context.TODO(), newService, metav1.UpdateOptions{}) - if err != nil { - t.Fatalf("Failed to update service: %v", err) - } - // Verify that the services were deleted and recreated - updatedService, err := fakeSednaClient.SednaV1alpha1().JointInferenceServices("default").Get(context.TODO(), "test-ji-service", metav1.GetOptions{}) - if err != nil { - t.Fatalf("Failed to get updated deployment: %v", err) - } - if updatedService.Spec.EdgeWorker.HardExampleMining.Parameters[0].Value != "value2" { - t.Fatalf("Service was not updated correctly") - } - }) -} +package jointinference + +import ( + "context" + "testing" + + sednav1 "github.com/kubeedge/sedna/pkg/apis/sedna/v1alpha1" + fakeseednaclientset "github.com/kubeedge/sedna/pkg/client/clientset/versioned/fake" + "github.com/kubeedge/sedna/pkg/globalmanager/config" + "github.com/kubeedge/sedna/pkg/globalmanager/runtime" + appsv1 "k8s.io/api/apps/v1" + v1 "k8s.io/api/core/v1" + k8serrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/watch" + kubernetesfake "k8s.io/client-go/kubernetes/fake" + "k8s.io/client-go/kubernetes/scheme" + v1core "k8s.io/client-go/kubernetes/typed/core/v1" + corelisters "k8s.io/client-go/listers/apps/v1" + corelistersv1 "k8s.io/client-go/listers/core/v1" + "k8s.io/client-go/tools/record" + "k8s.io/client-go/util/workqueue" +) + +type mockPodLister struct { + pods []*v1.Pod +} + +type mockPodNamespaceLister struct { + pods []*v1.Pod + namespace string +} + +func (m *mockPodLister) Pods(namespace string) corelistersv1.PodNamespaceLister { + return mockPodNamespaceLister{pods: m.pods, namespace: namespace} +} + +func (m *mockPodLister) List(selector labels.Selector) (ret []*v1.Pod, err error) { + return m.pods, nil +} + +func (m mockPodNamespaceLister) List(selector labels.Selector) ([]*v1.Pod, error) { + var filteredPods []*v1.Pod + for _, pod := range m.pods { + if pod.Namespace == m.namespace { + filteredPods = append(filteredPods, pod) + } + } + return filteredPods, nil +} + +type mockDeploymentLister struct { + deployments []*appsv1.Deployment +} + +func (m *mockDeploymentLister) List(selector labels.Selector) (ret []*appsv1.Deployment, err error) { + return m.deployments, nil +} + +func (m *mockDeploymentLister) Deployments(namespace string) corelisters.DeploymentNamespaceLister { + return mockDeploymentNamespaceLister{deployments: m.deployments, namespace: namespace} +} + +func (m mockPodNamespaceLister) Get(name string) (*v1.Pod, error) { + for _, pod := range m.pods { + if pod.Namespace == m.namespace && pod.Name == name { + return pod, nil + } + } + return nil, k8serrors.NewNotFound(v1.Resource("pod"), name) +} + +type mockDeploymentNamespaceLister struct { + deployments []*appsv1.Deployment + namespace string +} + +func (m mockDeploymentNamespaceLister) List(selector labels.Selector) ([]*appsv1.Deployment, error) { + var filteredDeployments []*appsv1.Deployment + for _, deployment := range m.deployments { + if deployment.Namespace == m.namespace { + filteredDeployments = append(filteredDeployments, deployment) + } + } + return filteredDeployments, nil +} + +func (m mockDeploymentNamespaceLister) Get(name string) (*appsv1.Deployment, error) { + for _, deployment := range m.deployments { + if deployment.Namespace == m.namespace && deployment.Name == name { + return deployment, nil + } + } + return nil, k8serrors.NewNotFound(v1.Resource("deployment"), name) +} + +func Test_updateService(t *testing.T) { + t.Run("update joint inference service successfully", func(t *testing.T) { + // Create fake clients + fakeSednaClient := fakeseednaclientset.NewSimpleClientset() + fakeKubeClient := kubernetesfake.NewSimpleClientset() + + // Create a test joint inference service + oldService := &sednav1.JointInferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ji-service", + Namespace: "default", + Generation: 1, + ResourceVersion: "1", + }, + Spec: sednav1.JointInferenceServiceSpec{ + EdgeWorker: sednav1.EdgeWorker{ + Model: sednav1.SmallModel{ + Name: "test-edge-model", + }, + Template: v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "edge-container", + Image: "edge-image:v1", + }, + }, + }, + }, + HardExampleMining: sednav1.HardExampleMining{ + Name: "test-hem", + Parameters: []sednav1.ParaSpec{ + { + Key: "param1", + Value: "value1", + }, + }, + }, + }, + CloudWorker: sednav1.CloudWorker{ + Model: sednav1.BigModel{ + Name: "test-cloud-model", + }, + Template: v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "cloud-container", + Image: "cloud-image:v1", + }, + }, + }, + }, + }, + }, + } + + //Create Big Model Resource Object for Cloud + bigModel := &sednav1.Model{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cloud-model", + Namespace: "default", + }, + } + _, err := fakeSednaClient.SednaV1alpha1().Models("default").Create(context.TODO(), bigModel, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create test big model: %v", err) + } + + // Create Small Model Resource Object for Edge + smallModel := &sednav1.Model{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-edge-model", + Namespace: "default", + }, + } + _, err = fakeSednaClient.SednaV1alpha1().Models("default").Create(context.TODO(), smallModel, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create test small model: %v", err) + } + + // Create the service using the fake client + _, err = fakeSednaClient.SednaV1alpha1().JointInferenceServices("default").Create(context.TODO(), oldService, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create test service: %v", err) + } + + // Create test deployments + edgeDeployment := &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ji-deployment-edge", + Namespace: "default", + }, + Spec: appsv1.DeploymentSpec{ + Template: oldService.Spec.EdgeWorker.Template, + }, + } + cloudDeployment := &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ji-deployment-cloud", + Namespace: "default", + }, + Spec: appsv1.DeploymentSpec{ + Template: oldService.Spec.CloudWorker.Template, + }, + } + + _, err = fakeKubeClient.AppsV1().Deployments("default").Create(context.TODO(), edgeDeployment, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create edge deployment: %v", err) + } + _, err = fakeKubeClient.AppsV1().Deployments("default").Create(context.TODO(), cloudDeployment, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create cloud deployment: %v", err) + } + + // Manually create pods for the deployments + edgePod := &v1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ji-service-edge-pod", + Namespace: "default", + Labels: map[string]string{ + "jointinferenceservice.sedna.io/service-name": "test-ji-service", + }, + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: "apps/v1", + Kind: "Deployment", + Name: edgeDeployment.Name, + UID: edgeDeployment.UID, + }, + }, + }, + Spec: edgeDeployment.Spec.Template.Spec, + } + cloudPod := &v1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ji-service-cloud-pod", + Namespace: "default", + Labels: map[string]string{ + "jointinferenceservice.sedna.io/service-name": "test-ji-service", + }, + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: "apps/v1", + Kind: "Deployment", + Name: cloudDeployment.Name, + UID: cloudDeployment.UID, + }, + }, + }, + Spec: cloudDeployment.Spec.Template.Spec, + } + + // Add pods to the fake client + _, err = fakeKubeClient.CoreV1().Pods("default").Create(context.TODO(), edgePod, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create edge pod: %v", err) + } + _, err = fakeKubeClient.CoreV1().Pods("default").Create(context.TODO(), cloudPod, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create cloud pod: %v", err) + } + + cfg := &config.ControllerConfig{ + LC: config.LCConfig{ + Server: "http://test-lc-server:8080", + }, + } + + eventBroadcaster := record.NewBroadcaster() + eventBroadcaster.StartRecordingToSink(&v1core.EventSinkImpl{Interface: fakeKubeClient.CoreV1().Events("")}) + + // Create a controller with the fake clients + c := &Controller{ + kubeClient: fakeKubeClient, + client: fakeSednaClient.SednaV1alpha1(), + queue: workqueue.NewNamedRateLimitingQueue(workqueue.NewItemExponentialFailureRateLimiter(runtime.DefaultBackOff, runtime.MaxBackOff), "test-ji-service"), + recorder: eventBroadcaster.NewRecorder(scheme.Scheme, v1.EventSource{Component: "test-ji-service"}), + cfg: cfg, + deploymentsLister: &mockDeploymentLister{deployments: []*appsv1.Deployment{edgeDeployment, cloudDeployment}}, + sendToEdgeFunc: func(nodeName string, eventType watch.EventType, job interface{}) error { + return nil + }, + } + + // Update the service + newService := oldService.DeepCopy() + // change parameter of hard example mining + newService.Spec.EdgeWorker.HardExampleMining.Parameters[0].Value = "value2" + newService.Generation = 2 + newService.ResourceVersion = "2" + // Call updateService function + c.createOrUpdateWorker(newService, jointInferenceForCloud, "test-ji-service.default", 8080, true) + c.createOrUpdateWorker(newService, jointInferenceForEdge, "test-ji-service.default", 8080, true) + // update service in fakeSednaClient + _, err = fakeSednaClient.SednaV1alpha1().JointInferenceServices("default").Update(context.TODO(), newService, metav1.UpdateOptions{}) + if err != nil { + t.Fatalf("Failed to update service: %v", err) + } + // Verify that the services were deleted and recreated + updatedService, err := fakeSednaClient.SednaV1alpha1().JointInferenceServices("default").Get(context.TODO(), "test-ji-service", metav1.GetOptions{}) + if err != nil { + t.Fatalf("Failed to get updated deployment: %v", err) + } + if updatedService.Spec.EdgeWorker.HardExampleMining.Parameters[0].Value != "value2" { + t.Fatalf("Service was not updated correctly") + } + }) +}