Skip to content

Commit

Permalink
test: test custom podSpec in driver and executor
Browse files Browse the repository at this point in the history
Signed-off-by: machichima <[email protected]>
  • Loading branch information
machichima committed Dec 15, 2024
1 parent 88e3b6c commit 394c269
Showing 1 changed file with 36 additions and 23 deletions.
59 changes: 36 additions & 23 deletions flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ import (

sj "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2"
sparkOp "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2"
// NOTE: this import also use things inside google.golang structpb one
// NOTE: this import also use things inside google.golang structpb one
// structpb "github.com/golang/protobuf/ptypes/struct"
"google.golang.org/protobuf/types/known/structpb"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/protobuf/types/known/structpb"
corev1 "k8s.io/api/core/v1"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"

Expand Down Expand Up @@ -301,6 +301,7 @@ func dummySparkCustomObjDriverExecutor(sparkConf map[string]string, driverPod *c

func dummyPodSpec() *corev1.PodSpec {
return &corev1.PodSpec{
// TODO: test adding custom resources
InitContainers: []corev1.Container{
{
Name: "init",
Expand Down Expand Up @@ -367,8 +368,6 @@ func dummySparkTaskTemplateDriverExecutor(id string, sparkConf map[string]string
Target: &core.TaskTemplate_Container{
Container: &core.Container{
Image: testImage,
Args: testArgs,
Env: dummyEnvVars,
},
},
Config: map[string]string{
Expand Down Expand Up @@ -974,13 +973,6 @@ func TestGetPropertiesSpark(t *testing.T) {
}

func TestBuildResourceCustomK8SPod(t *testing.T) {
// TODO: edit below tests for custom driver and executor
// the TestBuildResourcePodTemplate test whether the custom Toleration is displayed

// create dummy driver and executor pod
// dummy sparkJob that takes in dummy driver and executor pod
// see whether the driver and worker podSpec is what we set
// what properties to test

defaultConfig := defaultPluginConfig()
assert.NoError(t, config.SetK8sPluginConfig(defaultConfig))
Expand All @@ -1007,9 +999,17 @@ func TestBuildResourceCustomK8SPod(t *testing.T) {

driverK8SPod := &core.K8SPod{
PodSpec: transformStructToStructPB(t, driverPodSpec),
Metadata: &core.K8SObjectMetadata{
Annotations: map[string]string{"annotation-driver": "val-driver"},
Labels: map[string]string{"label-driver": "val-driver"},
},
}
executorK8SPod := &core.K8SPod{
PodSpec: transformStructToStructPB(t, executorPodSpec),
Metadata: &core.K8SObjectMetadata{
Annotations: map[string]string{"annotation-executor": "val-executor"},
Labels: map[string]string{"label-executor": "val-executor"},
},
}
// put the driver/executor podspec (add custom tolerations) to below function
taskTemplate := dummySparkTaskTemplateDriverExecutor("blah-1", dummySparkConf, driverK8SPod, executorK8SPod)
Expand Down Expand Up @@ -1042,19 +1042,27 @@ func TestBuildResourceCustomK8SPod(t *testing.T) {
assert.Equal(t, sparkApplicationFile, *sparkApp.Spec.MainApplicationFile)

// Driver
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1"}), sparkApp.Spec.Driver.Annotations)
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1"}), sparkApp.Spec.Driver.Labels)
assert.Equal(t, utils.UnionMaps(
defaultConfig.DefaultAnnotations, map[string]string{
"annotation-1": "val1",
"annotation-driver": "val-driver",
},
), sparkApp.Spec.Driver.Annotations)
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{
"label-1": "val1",
"label-driver": "val-driver",
}), sparkApp.Spec.Driver.Labels)
assert.Equal(t, len(findEnvVarByName(sparkApp.Spec.Driver.Env, "FLYTE_MAX_ATTEMPTS").Value), 1)
// assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Driver.Env, "foo").Value)
// assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Driver.Env, "fooEnv").Value)
assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Driver.Env, "foo").Value)
assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Driver.Env, "fooEnv").Value)
assert.Equal(t, findEnvVarByName(dummyEnvVarsWithSecretRef, "SECRET"), findEnvVarByName(sparkApp.Spec.Driver.Env, "SECRET"))
// assert.Equal(t, 9, len(sparkApp.Spec.Driver.Env))
assert.Equal(t, 9, len(sparkApp.Spec.Driver.Env))
assert.Equal(t, testImage, *sparkApp.Spec.Driver.Image)
assert.Equal(t, flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()), *sparkApp.Spec.Driver.ServiceAccount)
// assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Driver.SecurityContenxt)
// assert.Equal(t, defaultConfig.DefaultPodDNSConfig, sparkApp.Spec.Driver.DNSConfig)
// assert.Equal(t, defaultConfig.EnableHostNetworkingPod, sparkApp.Spec.Driver.HostNetwork)
// assert.Equal(t, defaultConfig.SchedulerName, *sparkApp.Spec.Driver.SchedulerName)
assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Driver.SecurityContenxt)
assert.Equal(t, defaultConfig.DefaultPodDNSConfig, sparkApp.Spec.Driver.DNSConfig)
assert.Equal(t, defaultConfig.EnableHostNetworkingPod, sparkApp.Spec.Driver.HostNetwork)
assert.Equal(t, defaultConfig.SchedulerName, *sparkApp.Spec.Driver.SchedulerName)
assert.Equal(t, []corev1.Toleration{
defaultConfig.DefaultTolerations[0],
driverExtraToleration,
Expand All @@ -1080,8 +1088,14 @@ func TestBuildResourceCustomK8SPod(t *testing.T) {
assert.Equal(t, dummySparkConf["spark.driver.memory"], *sparkApp.Spec.Driver.Memory)

// // Executor
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1"}), sparkApp.Spec.Executor.Annotations)
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1"}), sparkApp.Spec.Executor.Labels)
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{
"annotation-1": "val1",
"annotation-executor": "val-executor",
}), sparkApp.Spec.Executor.Annotations)
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{
"label-1": "val1",
"label-executor": "val-executor",
}), sparkApp.Spec.Executor.Labels)
assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Executor.Env, "foo").Value)
assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Executor.Env, "fooEnv").Value)
assert.Equal(t, findEnvVarByName(dummyEnvVarsWithSecretRef, "SECRET"), findEnvVarByName(sparkApp.Spec.Executor.Env, "SECRET"))
Expand Down Expand Up @@ -1120,7 +1134,6 @@ func TestBuildResourceCustomK8SPod(t *testing.T) {
assert.Equal(t, dummySparkConf["spark.executor.memory"], *sparkApp.Spec.Executor.Memory)
}


func transformStructToStructPB(t *testing.T, obj interface{}) *structpb.Struct {
data, err := json.Marshal(obj)
assert.Nil(t, err)
Expand Down

0 comments on commit 394c269

Please sign in to comment.