Skip to content

Commit

Permalink
feat: add driverPod/executorPod in Spark
Browse files Browse the repository at this point in the history
Add driverPod/executorPod field in SparkJob class and use them as Spark
driver and executor

Signed-off-by: machichima <[email protected]>
  • Loading branch information
machichima committed Dec 5, 2024
1 parent ba331fd commit d847a63
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 9 deletions.
4 changes: 4 additions & 0 deletions flyteidl/gen/pb-go/flyteidl/plugins/spark.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions flyteidl/protos/flyteidl/plugins/spark.proto
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ syntax = "proto3";

package flyteidl.plugins;
import "google/protobuf/struct.proto";
import "flyteidl/core/tasks.proto";

option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins";

Expand Down Expand Up @@ -31,4 +32,10 @@ message SparkJob {
// Domain name of your deployment. Use the form <account>.cloud.databricks.com.
// This instance name can be set in either flytepropeller or flytekit.
string databricksInstance = 9;

// Pod Spec for the Spark driver pod
core.K8sPod driverPod = 10;

// Pod Spec for the Spark executor pod
core.K8sPod executorPod = 11;
}
81 changes: 72 additions & 9 deletions flyteplugins/go/tasks/plugins/k8s/spark/spark.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
}

sparkConfig := getSparkConfig(taskCtx, &sparkJob)
driverSpec, err := createDriverSpec(ctx, taskCtx, sparkConfig)
driverSpec, err := createDriverSpec(ctx, taskCtx, sparkConfig, &sparkJob)
if err != nil {
return nil, err
}
executorSpec, err := createExecutorSpec(ctx, taskCtx, sparkConfig)
executorSpec, err := createExecutorSpec(ctx, taskCtx, sparkConfig, &sparkJob)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -141,9 +141,10 @@ func serviceAccountName(metadata pluginsCore.TaskExecutionMetadata) string {
return name
}

func createSparkPodSpec(taskCtx pluginsCore.TaskExecutionContext, podSpec *v1.PodSpec, container *v1.Container) *sparkOp.SparkPodSpec {
annotations := utils.UnionMaps(config.GetK8sPluginConfig().DefaultAnnotations, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))
labels := utils.UnionMaps(config.GetK8sPluginConfig().DefaultLabels, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))
func createSparkPodSpec(taskCtx pluginsCore.TaskExecutionContext, podSpec *v1.PodSpec, container *v1.Container, podAnnotations map[string]string, podLabels map[string]string) *sparkOp.SparkPodSpec {
// TODO: check whether merge annotations/labels together or other ways?
annotations := utils.UnionMaps(config.GetK8sPluginConfig().DefaultAnnotations, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()), podAnnotations)
labels := utils.UnionMaps(config.GetK8sPluginConfig().DefaultLabels, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()), podLabels)

sparkEnv := make([]v1.EnvVar, 0)
for _, envVar := range container.Env {
Expand Down Expand Up @@ -171,18 +172,33 @@ type driverSpec struct {
sparkSpec *sparkOp.DriverSpec
}

func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, sparkConfig map[string]string) (*driverSpec, error) {
func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, sparkConfig map[string]string, sparkJob *plugins.SparkJob) (*driverSpec, error) {
// Spark driver pods should always run as non-interruptible
// NOTE: This line change task to non-interruptible, but seems not to affect the podSpec things
nonInterruptibleTaskCtx := flytek8s.NewPluginTaskExecutionContext(taskCtx, flytek8s.WithInterruptible(false))
podSpec, _, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, nonInterruptibleTaskCtx)
if err != nil {
return nil, err
}

// TODO: Validate whether the following function is correct
// If DriverPod exist in sparkJob and is primary, use it instead
var podAnnotations map[string]string
var podLabels map[string]string
if sparkJob.DriverPod != nil {
podSpec, err = unmarshalK8sPod(podSpec, sparkJob.DriverPod, primaryContainerName)
if err != nil {
return nil, err
}
podAnnotations = sparkJob.DriverPod.Metadata.Annotations
podLabels = sparkJob.DriverPod.Metadata.Labels
}

primaryContainer, err := flytek8s.GetContainer(podSpec, primaryContainerName)
if err != nil {
return nil, err
}
sparkPodSpec := createSparkPodSpec(nonInterruptibleTaskCtx, podSpec, primaryContainer)
sparkPodSpec := createSparkPodSpec(nonInterruptibleTaskCtx, podSpec, primaryContainer, podAnnotations, podLabels)
serviceAccountName := serviceAccountName(nonInterruptibleTaskCtx.TaskExecutionMetadata())
spec := driverSpec{
&sparkOp.DriverSpec{
Expand All @@ -197,22 +213,69 @@ func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCont
return &spec, nil
}

// Unmarshal pod spec from K8SPod
//
// Return task's generated pod spec if K8SPod PodSpec is not available
func unmarshalK8sPod(podSpec *v1.PodSpec, k8sPod *core.K8SPod, primaryContainerName string) (*v1.PodSpec, error) {
if k8sPod == nil {
return podSpec, nil
}

if k8sPod.PodSpec == nil {
return podSpec, nil
}

var customPodSpec *v1.PodSpec

err := utils.UnmarshalStructToObj(k8sPod.PodSpec, &customPodSpec)
if err != nil {
return nil, errors.Errorf(errors.BadTaskSpecification,
"Unable to unmarshal pod spec [%v], Err: [%v]", k8sPod.PodSpec, err.Error())
}

primaryContainers := []v1.Container{}
for _, container := range customPodSpec.Containers {
// Only support the primary container for now
if container.Name == primaryContainerName {
primaryContainers = append(primaryContainers, container)
}
}
customPodSpec.Containers = primaryContainers

return customPodSpec, nil
}


type executorSpec struct {
container *v1.Container
sparkSpec *sparkOp.ExecutorSpec
serviceAccountName string
}

func createExecutorSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, sparkConfig map[string]string) (*executorSpec, error) {
func createExecutorSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, sparkConfig map[string]string, sparkJob *plugins.SparkJob) (*executorSpec, error) {
podSpec, _, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx)
if err != nil {
return nil, err
}

// TODO: Validate whether the following function is correct
// If DriverPod exist in sparkJob and is primary, use it instead
var podAnnotations map[string]string
var podLabels map[string]string
if sparkJob.ExecutorPod != nil {
podSpec, err = unmarshalK8sPod(podSpec, sparkJob.ExecutorPod, primaryContainerName)
if err != nil {
return nil, err
}
podAnnotations = sparkJob.ExecutorPod.Metadata.Annotations
podLabels = sparkJob.ExecutorPod.Metadata.Labels
}

primaryContainer, err := flytek8s.GetContainer(podSpec, primaryContainerName)
if err != nil {
return nil, err
}
sparkPodSpec := createSparkPodSpec(taskCtx, podSpec, primaryContainer)
sparkPodSpec := createSparkPodSpec(taskCtx, podSpec, primaryContainer, podAnnotations, podLabels)
serviceAccountName := serviceAccountName(taskCtx.TaskExecutionMetadata())
spec := executorSpec{
primaryContainer,
Expand Down

0 comments on commit d847a63

Please sign in to comment.