Skip to content

Commit

Permalink
feat: merge default and custom podSpec
Browse files Browse the repository at this point in the history
Signed-off-by: machichima <[email protected]>
  • Loading branch information
machichima committed Dec 15, 2024
1 parent 1faaa00 commit 88e3b6c
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 66 deletions.
22 changes: 11 additions & 11 deletions flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,14 +280,14 @@ func BuildRawPod(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (*v
return nil, nil, "", err
}

// If primaryContainerName is set in taskTemplate config, use it instead
// of c.Name
if val, ok := taskTemplate.Config[PrimaryContainerKey]; ok{
primaryContainerName = val
c.Name = primaryContainerName
} else {
primaryContainerName = c.Name
}
// If primaryContainerName is set in taskTemplate config, use it instead
// of c.Name
if val, ok := taskTemplate.Config[PrimaryContainerKey]; ok {
primaryContainerName = val
c.Name = primaryContainerName
} else {
primaryContainerName = c.Name
}
podSpec = &v1.PodSpec{
Containers: []v1.Container{
*c,
Expand Down Expand Up @@ -570,7 +570,7 @@ func MergeWithBasePodTemplate(ctx context.Context, tCtx pluginsCore.TaskExecutio
}

// merge podSpec with podTemplate
mergedPodSpec, err := mergePodSpecs(&podTemplate.Template.Spec, podSpec, primaryContainerName, primaryInitContainerName)
mergedPodSpec, err := MergePodSpecs(&podTemplate.Template.Spec, podSpec, primaryContainerName, primaryInitContainerName)
if err != nil {
return nil, nil, err
}
Expand All @@ -584,10 +584,10 @@ func MergeWithBasePodTemplate(ctx context.Context, tCtx pluginsCore.TaskExecutio
return mergedPodSpec, mergedObjectMeta, nil
}

// mergePodSpecs merges the two provided PodSpecs. This process uses the first as the base configuration, where values
// MergePodSpecs merges the two provided PodSpecs. This process uses the first as the base configuration, where values
// set by the first PodSpec are overwritten by the second in the return value. Additionally, this function applies
// container-level configuration from the basePodSpec.
func mergePodSpecs(basePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryContainerName string, primaryInitContainerName string) (*v1.PodSpec, error) {
func MergePodSpecs(basePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryContainerName string, primaryInitContainerName string) (*v1.PodSpec, error) {
if basePodSpec == nil || podSpec == nil {
return nil, errors.New("neither the basePodSpec or the podSpec can be nil")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2047,13 +2047,13 @@ func TestMergeWithBasePodTemplate(t *testing.T) {
func TestMergePodSpecs(t *testing.T) {
var priority int32 = 1

podSpec1, _ := mergePodSpecs(nil, nil, "foo", "foo-init")
podSpec1, _ := MergePodSpecs(nil, nil, "foo", "foo-init")
assert.Nil(t, podSpec1)

podSpec2, _ := mergePodSpecs(&v1.PodSpec{}, nil, "foo", "foo-init")
podSpec2, _ := MergePodSpecs(&v1.PodSpec{}, nil, "foo", "foo-init")
assert.Nil(t, podSpec2)

podSpec3, _ := mergePodSpecs(nil, &v1.PodSpec{}, "foo", "foo-init")
podSpec3, _ := MergePodSpecs(nil, &v1.PodSpec{}, "foo", "foo-init")
assert.Nil(t, podSpec3)

podSpec := v1.PodSpec{
Expand Down Expand Up @@ -2141,7 +2141,7 @@ func TestMergePodSpecs(t *testing.T) {
},
}

mergedPodSpec, err := mergePodSpecs(&podTemplateSpec, &podSpec, "primary", "primary-init")
mergedPodSpec, err := MergePodSpecs(&podTemplateSpec, &podSpec, "primary", "primary-init")
assert.Nil(t, err)

// validate a PodTemplate-only field
Expand Down
96 changes: 45 additions & 51 deletions flyteplugins/go/tasks/plugins/k8s/spark/spark.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@ import (
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog"
pluginsUtils "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils"
pluginsUtils "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils"
"github.com/flyteorg/flyte/flytestdlib/utils"

)

const KindSparkApplication = "SparkApplication"
Expand Down Expand Up @@ -143,18 +142,30 @@ func serviceAccountName(metadata pluginsCore.TaskExecutionMetadata) string {
return name
}

func createSparkPodSpec(taskCtx pluginsCore.TaskExecutionContext, podSpec *v1.PodSpec, container *v1.Container, k8sPod *core.K8SPod) *sparkOp.SparkPodSpec {
func createSparkPodSpec(
taskCtx pluginsCore.TaskExecutionContext,
podSpec *v1.PodSpec,
container *v1.Container,
k8sPod *core.K8SPod,
) *sparkOp.SparkPodSpec {

// TODO: check whether merge annotations/labels together or other ways?
annotations := pluginsUtils.UnionMaps(config.GetK8sPluginConfig().DefaultAnnotations, pluginsUtils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))
labels := pluginsUtils.UnionMaps(config.GetK8sPluginConfig().DefaultLabels, pluginsUtils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))
if k8sPod != nil && k8sPod.Metadata != nil{
if k8sPod.Metadata.Annotations != nil {
annotations = pluginsUtils.UnionMaps(annotations, k8sPod.Metadata.Annotations)
}
if k8sPod.Metadata.Labels != nil {
labels = pluginsUtils.UnionMaps(labels, k8sPod.Metadata.Labels)
}
}
annotations := pluginsUtils.UnionMaps(
config.GetK8sPluginConfig().DefaultAnnotations,
pluginsUtils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()),
)
labels := pluginsUtils.UnionMaps(
config.GetK8sPluginConfig().DefaultLabels,
pluginsUtils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()),
)
if k8sPod != nil && k8sPod.Metadata != nil {
if k8sPod.Metadata.Annotations != nil {
annotations = pluginsUtils.UnionMaps(annotations, k8sPod.Metadata.Annotations)
}
if k8sPod.Metadata.Labels != nil {
labels = pluginsUtils.UnionMaps(labels, k8sPod.Metadata.Labels)
}
}

sparkEnv := make([]v1.EnvVar, 0)
for _, envVar := range container.Env {
Expand Down Expand Up @@ -190,10 +201,17 @@ func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCont
return nil, err
}

// If DriverPod exist in sparkJob and is primary, use it instead
driverPod := sparkJob.GetDriverPod()
driverPod := sparkJob.GetDriverPod()
if driverPod != nil {
podSpec, err = unmarshalK8sPod(podSpec, driverPod, primaryContainerName)
var customPodSpec *v1.PodSpec

err = utils.UnmarshalStructToObj(driverPod.PodSpec, &customPodSpec)
if err != nil {
return nil, errors.Errorf(errors.BadTaskSpecification,
"Unable to unmarshal pod spec [%v], Err: [%v]", driverPod.PodSpec, err.Error())
}

podSpec, err = flytek8s.MergePodSpecs(podSpec, customPodSpec, primaryContainerName, "")
if err != nil {
return nil, err
}
Expand All @@ -218,38 +236,6 @@ func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCont
return &spec, nil
}

// Unmarshal pod spec from K8SPod
//
// Return task's generated pod spec if K8SPod PodSpec is not available
func unmarshalK8sPod(podSpec *v1.PodSpec, k8sPod *core.K8SPod, primaryContainerName string) (*v1.PodSpec, error) {
if k8sPod == nil {
return podSpec, nil
}

if k8sPod.PodSpec == nil {
return podSpec, nil
}

var customPodSpec *v1.PodSpec

err := utils.UnmarshalStructToObj(k8sPod.PodSpec, &customPodSpec)
if err != nil {
return nil, errors.Errorf(errors.BadTaskSpecification,
"Unable to unmarshal pod spec [%v], Err: [%v]", k8sPod.PodSpec, err.Error())
}

primaryContainers := []v1.Container{}
for _, container := range customPodSpec.Containers {
// Only support the primary container for now
if container.Name == primaryContainerName {
primaryContainers = append(primaryContainers, container)
}
}
customPodSpec.Containers = primaryContainers

return customPodSpec, nil
}

type executorSpec struct {
container *v1.Container
sparkSpec *sparkOp.ExecutorSpec
Expand All @@ -262,9 +248,17 @@ func createExecutorSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCo
return nil, err
}

// If ExecutorPod exist in sparkJob and is primary, use it instead
if sparkJob.ExecutorPod != nil {
podSpec, err = unmarshalK8sPod(podSpec, sparkJob.ExecutorPod, primaryContainerName)
executorPod := sparkJob.ExecutorPod
if executorPod != nil {
var customPodSpec *v1.PodSpec

err = utils.UnmarshalStructToObj(executorPod.PodSpec, &customPodSpec)
if err != nil {
return nil, errors.Errorf(errors.BadTaskSpecification,
"Unable to unmarshal pod spec [%v], Err: [%v]", executorPod.PodSpec, err.Error())
}

podSpec, err = flytek8s.MergePodSpecs(podSpec, customPodSpec, primaryContainerName, "")
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 88e3b6c

Please sign in to comment.