From 88e3b6cf9eaaa828bf7d3d83e2c034d4377e2e56 Mon Sep 17 00:00:00 2001 From: machichima Date: Sun, 15 Dec 2024 19:39:52 +0800 Subject: [PATCH] feat: merge default and custom podSpec Signed-off-by: machichima --- .../pluginmachinery/flytek8s/pod_helper.go | 22 ++--- .../flytek8s/pod_helper_test.go | 8 +- .../go/tasks/plugins/k8s/spark/spark.go | 96 +++++++++---------- 3 files changed, 60 insertions(+), 66 deletions(-) diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go index d7c1f333fa..63d0971cac 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -280,14 +280,14 @@ func BuildRawPod(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (*v return nil, nil, "", err } - // If primaryContainerName is set in taskTemplate config, use it instead - // of c.Name - if val, ok := taskTemplate.Config[PrimaryContainerKey]; ok{ - primaryContainerName = val - c.Name = primaryContainerName - } else { - primaryContainerName = c.Name - } + // If primaryContainerName is set in taskTemplate config, use it instead + // of c.Name + if val, ok := taskTemplate.Config[PrimaryContainerKey]; ok { + primaryContainerName = val + c.Name = primaryContainerName + } else { + primaryContainerName = c.Name + } podSpec = &v1.PodSpec{ Containers: []v1.Container{ *c, @@ -570,7 +570,7 @@ func MergeWithBasePodTemplate(ctx context.Context, tCtx pluginsCore.TaskExecutio } // merge podSpec with podTemplate - mergedPodSpec, err := mergePodSpecs(&podTemplate.Template.Spec, podSpec, primaryContainerName, primaryInitContainerName) + mergedPodSpec, err := MergePodSpecs(&podTemplate.Template.Spec, podSpec, primaryContainerName, primaryInitContainerName) if err != nil { return nil, nil, err } @@ -584,10 +584,10 @@ func MergeWithBasePodTemplate(ctx context.Context, tCtx pluginsCore.TaskExecutio return mergedPodSpec, mergedObjectMeta, nil } -// mergePodSpecs merges the two provided PodSpecs. This process uses the first as the base configuration, where values +// MergePodSpecs merges the two provided PodSpecs. This process uses the first as the base configuration, where values // set by the first PodSpec are overwritten by the second in the return value. Additionally, this function applies // container-level configuration from the basePodSpec. -func mergePodSpecs(basePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryContainerName string, primaryInitContainerName string) (*v1.PodSpec, error) { +func MergePodSpecs(basePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryContainerName string, primaryInitContainerName string) (*v1.PodSpec, error) { if basePodSpec == nil || podSpec == nil { return nil, errors.New("neither the basePodSpec or the podSpec can be nil") } diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go index 0a70cdd895..139ee583dc 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go @@ -2047,13 +2047,13 @@ func TestMergeWithBasePodTemplate(t *testing.T) { func TestMergePodSpecs(t *testing.T) { var priority int32 = 1 - podSpec1, _ := mergePodSpecs(nil, nil, "foo", "foo-init") + podSpec1, _ := MergePodSpecs(nil, nil, "foo", "foo-init") assert.Nil(t, podSpec1) - podSpec2, _ := mergePodSpecs(&v1.PodSpec{}, nil, "foo", "foo-init") + podSpec2, _ := MergePodSpecs(&v1.PodSpec{}, nil, "foo", "foo-init") assert.Nil(t, podSpec2) - podSpec3, _ := mergePodSpecs(nil, &v1.PodSpec{}, "foo", "foo-init") + podSpec3, _ := MergePodSpecs(nil, &v1.PodSpec{}, "foo", "foo-init") assert.Nil(t, podSpec3) podSpec := v1.PodSpec{ @@ -2141,7 +2141,7 @@ func TestMergePodSpecs(t *testing.T) { }, } - mergedPodSpec, err := mergePodSpecs(&podTemplateSpec, &podSpec, "primary", "primary-init") + mergedPodSpec, err := MergePodSpecs(&podTemplateSpec, &podSpec, "primary", "primary-init") assert.Nil(t, err) // validate a PodTemplate-only field diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go index 9bced79f1c..df7b1c228c 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go @@ -25,9 +25,8 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" - pluginsUtils "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" + pluginsUtils "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" "github.com/flyteorg/flyte/flytestdlib/utils" - ) const KindSparkApplication = "SparkApplication" @@ -143,18 +142,30 @@ func serviceAccountName(metadata pluginsCore.TaskExecutionMetadata) string { return name } -func createSparkPodSpec(taskCtx pluginsCore.TaskExecutionContext, podSpec *v1.PodSpec, container *v1.Container, k8sPod *core.K8SPod) *sparkOp.SparkPodSpec { +func createSparkPodSpec( + taskCtx pluginsCore.TaskExecutionContext, + podSpec *v1.PodSpec, + container *v1.Container, + k8sPod *core.K8SPod, +) *sparkOp.SparkPodSpec { + // TODO: check whether merge annotations/labels together or other ways? - annotations := pluginsUtils.UnionMaps(config.GetK8sPluginConfig().DefaultAnnotations, pluginsUtils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations())) - labels := pluginsUtils.UnionMaps(config.GetK8sPluginConfig().DefaultLabels, pluginsUtils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels())) - if k8sPod != nil && k8sPod.Metadata != nil{ - if k8sPod.Metadata.Annotations != nil { - annotations = pluginsUtils.UnionMaps(annotations, k8sPod.Metadata.Annotations) - } - if k8sPod.Metadata.Labels != nil { - labels = pluginsUtils.UnionMaps(labels, k8sPod.Metadata.Labels) - } - } + annotations := pluginsUtils.UnionMaps( + config.GetK8sPluginConfig().DefaultAnnotations, + pluginsUtils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()), + ) + labels := pluginsUtils.UnionMaps( + config.GetK8sPluginConfig().DefaultLabels, + pluginsUtils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()), + ) + if k8sPod != nil && k8sPod.Metadata != nil { + if k8sPod.Metadata.Annotations != nil { + annotations = pluginsUtils.UnionMaps(annotations, k8sPod.Metadata.Annotations) + } + if k8sPod.Metadata.Labels != nil { + labels = pluginsUtils.UnionMaps(labels, k8sPod.Metadata.Labels) + } + } sparkEnv := make([]v1.EnvVar, 0) for _, envVar := range container.Env { @@ -190,10 +201,17 @@ func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCont return nil, err } - // If DriverPod exist in sparkJob and is primary, use it instead - driverPod := sparkJob.GetDriverPod() + driverPod := sparkJob.GetDriverPod() if driverPod != nil { - podSpec, err = unmarshalK8sPod(podSpec, driverPod, primaryContainerName) + var customPodSpec *v1.PodSpec + + err = utils.UnmarshalStructToObj(driverPod.PodSpec, &customPodSpec) + if err != nil { + return nil, errors.Errorf(errors.BadTaskSpecification, + "Unable to unmarshal pod spec [%v], Err: [%v]", driverPod.PodSpec, err.Error()) + } + + podSpec, err = flytek8s.MergePodSpecs(podSpec, customPodSpec, primaryContainerName, "") if err != nil { return nil, err } @@ -218,38 +236,6 @@ func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCont return &spec, nil } -// Unmarshal pod spec from K8SPod -// -// Return task's generated pod spec if K8SPod PodSpec is not available -func unmarshalK8sPod(podSpec *v1.PodSpec, k8sPod *core.K8SPod, primaryContainerName string) (*v1.PodSpec, error) { - if k8sPod == nil { - return podSpec, nil - } - - if k8sPod.PodSpec == nil { - return podSpec, nil - } - - var customPodSpec *v1.PodSpec - - err := utils.UnmarshalStructToObj(k8sPod.PodSpec, &customPodSpec) - if err != nil { - return nil, errors.Errorf(errors.BadTaskSpecification, - "Unable to unmarshal pod spec [%v], Err: [%v]", k8sPod.PodSpec, err.Error()) - } - - primaryContainers := []v1.Container{} - for _, container := range customPodSpec.Containers { - // Only support the primary container for now - if container.Name == primaryContainerName { - primaryContainers = append(primaryContainers, container) - } - } - customPodSpec.Containers = primaryContainers - - return customPodSpec, nil -} - type executorSpec struct { container *v1.Container sparkSpec *sparkOp.ExecutorSpec @@ -262,9 +248,17 @@ func createExecutorSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCo return nil, err } - // If ExecutorPod exist in sparkJob and is primary, use it instead - if sparkJob.ExecutorPod != nil { - podSpec, err = unmarshalK8sPod(podSpec, sparkJob.ExecutorPod, primaryContainerName) + executorPod := sparkJob.ExecutorPod + if executorPod != nil { + var customPodSpec *v1.PodSpec + + err = utils.UnmarshalStructToObj(executorPod.PodSpec, &customPodSpec) + if err != nil { + return nil, errors.Errorf(errors.BadTaskSpecification, + "Unable to unmarshal pod spec [%v], Err: [%v]", executorPod.PodSpec, err.Error()) + } + + podSpec, err = flytek8s.MergePodSpecs(podSpec, customPodSpec, primaryContainerName, "") if err != nil { return nil, err }