diff --git a/flyteidl/gen/pb-go/flyteidl/plugins/spark.pb.go b/flyteidl/gen/pb-go/flyteidl/plugins/spark.pb.go index 316d1f1d7b..47063ecb9f 100644 --- a/flyteidl/gen/pb-go/flyteidl/plugins/spark.pb.go +++ b/flyteidl/gen/pb-go/flyteidl/plugins/spark.pb.go @@ -7,6 +7,7 @@ package plugins import ( + core "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" structpb "google.golang.org/protobuf/types/known/structpb" @@ -132,6 +133,9 @@ type SparkJob struct { // Domain name of your deployment. Use the form .cloud.databricks.com. // This instance name can be set in either flytepropeller or flytekit. DatabricksInstance string `protobuf:"bytes,9,opt,name=databricksInstance,proto3" json:"databricksInstance,omitempty"` + + DriverPod *core.K8SPod `protobuf:"bytes,2,opt,name=driverPod,json=driverPod,proto3" json:"driverPod,omitempty"` + ExecutorPod *core.K8SPod `protobuf:"bytes,2,opt,name=executorPod,json=executorPod,proto3" json:"executorPod,omitempty"` } func (x *SparkJob) Reset() { diff --git a/flyteidl/protos/flyteidl/plugins/spark.proto b/flyteidl/protos/flyteidl/plugins/spark.proto index 666ea311b2..308c96c302 100644 --- a/flyteidl/protos/flyteidl/plugins/spark.proto +++ b/flyteidl/protos/flyteidl/plugins/spark.proto @@ -2,6 +2,7 @@ syntax = "proto3"; package flyteidl.plugins; import "google/protobuf/struct.proto"; +import "flyteidl/core/tasks.proto"; option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins"; @@ -31,4 +32,10 @@ message SparkJob { // Domain name of your deployment. Use the form .cloud.databricks.com. // This instance name can be set in either flytepropeller or flytekit. string databricksInstance = 9; + + // Pod Spec for the Spark driver pod + core.K8sPod driverPod = 10; + + // Pod Spec for the Spark executor pod + core.K8sPod executorPod = 11; } diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go index 6873fc2257..58f12382a3 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go @@ -75,11 +75,11 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo } sparkConfig := getSparkConfig(taskCtx, &sparkJob) - driverSpec, err := createDriverSpec(ctx, taskCtx, sparkConfig) + driverSpec, err := createDriverSpec(ctx, taskCtx, sparkConfig, &sparkJob) if err != nil { return nil, err } - executorSpec, err := createExecutorSpec(ctx, taskCtx, sparkConfig) + executorSpec, err := createExecutorSpec(ctx, taskCtx, sparkConfig, &sparkJob) if err != nil { return nil, err } @@ -141,9 +141,10 @@ func serviceAccountName(metadata pluginsCore.TaskExecutionMetadata) string { return name } -func createSparkPodSpec(taskCtx pluginsCore.TaskExecutionContext, podSpec *v1.PodSpec, container *v1.Container) *sparkOp.SparkPodSpec { - annotations := utils.UnionMaps(config.GetK8sPluginConfig().DefaultAnnotations, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations())) - labels := utils.UnionMaps(config.GetK8sPluginConfig().DefaultLabels, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels())) +func createSparkPodSpec(taskCtx pluginsCore.TaskExecutionContext, podSpec *v1.PodSpec, container *v1.Container, podAnnotations map[string]string, podLabels map[string]string) *sparkOp.SparkPodSpec { + // TODO: check whether merge annotations/labels together or other ways? + annotations := utils.UnionMaps(config.GetK8sPluginConfig().DefaultAnnotations, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()), podAnnotations) + labels := utils.UnionMaps(config.GetK8sPluginConfig().DefaultLabels, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()), podLabels) sparkEnv := make([]v1.EnvVar, 0) for _, envVar := range container.Env { @@ -171,18 +172,33 @@ type driverSpec struct { sparkSpec *sparkOp.DriverSpec } -func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, sparkConfig map[string]string) (*driverSpec, error) { +func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, sparkConfig map[string]string, sparkJob *plugins.SparkJob) (*driverSpec, error) { // Spark driver pods should always run as non-interruptible + // NOTE: This line change task to non-interruptible, but seems not to affect the podSpec things nonInterruptibleTaskCtx := flytek8s.NewPluginTaskExecutionContext(taskCtx, flytek8s.WithInterruptible(false)) podSpec, _, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, nonInterruptibleTaskCtx) if err != nil { return nil, err } + + // TODO: Validate whether the following function is correct + // If DriverPod exist in sparkJob and is primary, use it instead + var podAnnotations map[string]string + var podLabels map[string]string + if sparkJob.DriverPod != nil { + podSpec, err = unmarshalK8sPod(podSpec, sparkJob.DriverPod, primaryContainerName) + if err != nil { + return nil, err + } + podAnnotations = sparkJob.DriverPod.Metadata.Annotations + podLabels = sparkJob.DriverPod.Metadata.Labels + } + primaryContainer, err := flytek8s.GetContainer(podSpec, primaryContainerName) if err != nil { return nil, err } - sparkPodSpec := createSparkPodSpec(nonInterruptibleTaskCtx, podSpec, primaryContainer) + sparkPodSpec := createSparkPodSpec(nonInterruptibleTaskCtx, podSpec, primaryContainer, podAnnotations, podLabels) serviceAccountName := serviceAccountName(nonInterruptibleTaskCtx.TaskExecutionMetadata()) spec := driverSpec{ &sparkOp.DriverSpec{ @@ -197,22 +213,69 @@ func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCont return &spec, nil } +// Unmarshal pod spec from K8SPod +// +// Return task's generated pod spec if K8SPod PodSpec is not available +func unmarshalK8sPod(podSpec *v1.PodSpec, k8sPod *core.K8SPod, primaryContainerName string) (*v1.PodSpec, error) { + if k8sPod == nil { + return podSpec, nil + } + + if k8sPod.PodSpec == nil { + return podSpec, nil + } + + var customPodSpec *v1.PodSpec + + err := utils.UnmarshalStructToObj(k8sPod.PodSpec, &customPodSpec) + if err != nil { + return nil, errors.Errorf(errors.BadTaskSpecification, + "Unable to unmarshal pod spec [%v], Err: [%v]", k8sPod.PodSpec, err.Error()) + } + + primaryContainers := []v1.Container{} + for _, container := range customPodSpec.Containers { + // Only support the primary container for now + if container.Name == primaryContainerName { + primaryContainers = append(primaryContainers, container) + } + } + customPodSpec.Containers = primaryContainers + + return customPodSpec, nil +} + + type executorSpec struct { container *v1.Container sparkSpec *sparkOp.ExecutorSpec serviceAccountName string } -func createExecutorSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, sparkConfig map[string]string) (*executorSpec, error) { +func createExecutorSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, sparkConfig map[string]string, sparkJob *plugins.SparkJob) (*executorSpec, error) { podSpec, _, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) if err != nil { return nil, err } + + // TODO: Validate whether the following function is correct + // If DriverPod exist in sparkJob and is primary, use it instead + var podAnnotations map[string]string + var podLabels map[string]string + if sparkJob.ExecutorPod != nil { + podSpec, err = unmarshalK8sPod(podSpec, sparkJob.ExecutorPod, primaryContainerName) + if err != nil { + return nil, err + } + podAnnotations = sparkJob.ExecutorPod.Metadata.Annotations + podLabels = sparkJob.ExecutorPod.Metadata.Labels + } + primaryContainer, err := flytek8s.GetContainer(podSpec, primaryContainerName) if err != nil { return nil, err } - sparkPodSpec := createSparkPodSpec(taskCtx, podSpec, primaryContainer) + sparkPodSpec := createSparkPodSpec(taskCtx, podSpec, primaryContainer, podAnnotations, podLabels) serviceAccountName := serviceAccountName(taskCtx.TaskExecutionMetadata()) spec := executorSpec{ primaryContainer,