diff --git a/go/tasks/plugins/array/array_tests_base.go b/go/tasks/plugins/array/array_tests_base.go index dd8551bc0..90ed5f531 100644 --- a/go/tasks/plugins/array/array_tests_base.go +++ b/go/tasks/plugins/array/array_tests_base.go @@ -6,10 +6,6 @@ import ( "github.com/flyteorg/flyteplugins/tests" idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" - "github.com/flyteorg/flytestdlib/utils" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" "context" @@ -49,13 +45,11 @@ func RunArrayTestsEndToEnd(t *testing.T, executor core.Plugin, iter AdvanceItera } var err error - template.Custom, err = utils.MarshalPbToStruct(&plugins.ArrayJob{ - Parallelism: 10, - Size: 1, - SuccessCriteria: &plugins.ArrayJob_MinSuccesses{ - MinSuccesses: 1, - }, - }) + template.Config = map[string]string{ + "Parallelism": "10", + "Size": "1", + "MinSuccesses": "1", + } assert.NoError(t, err) @@ -83,16 +77,11 @@ func RunArrayTestsEndToEnd(t *testing.T, executor core.Plugin, iter AdvanceItera }, } - var err error - template.Custom, err = utils.MarshalPbToStruct(&plugins.ArrayJob{ - Parallelism: 10, - Size: 2, - SuccessCriteria: &plugins.ArrayJob_MinSuccesses{ - MinSuccesses: 1, - }, - }) - - assert.NoError(t, err) + template.Config = map[string]string{ + "Parallelism": "10", + "Size": "2", + "MinSuccesses": "1", + } expectedOutputs := coreutils.MustMakeLiteral(map[string]interface{}{ "x": []interface{}{5, 5}, diff --git a/go/tasks/plugins/array/awsbatch/executor.go b/go/tasks/plugins/array/awsbatch/executor.go index ec4b5bc3c..7f7c0e64b 100644 --- a/go/tasks/plugins/array/awsbatch/executor.go +++ b/go/tasks/plugins/array/awsbatch/executor.go @@ -189,10 +189,10 @@ func init() { pluginmachinery.PluginRegistry().RegisterCorePlugin( core.PluginEntry{ ID: executorName, - RegisteredTaskTypes: []core.TaskType{arrayTaskType, array.AwsBatchTaskType}, + RegisteredTaskTypes: []core.TaskType{arrayTaskType, arrayCore.AwsBatchTaskType}, LoadPlugin: createNewExecutorPlugin, IsDefault: false, - DefaultForTaskTypes: []core.TaskType{arrayTaskType, array.AwsBatchTaskType}, + DefaultForTaskTypes: []core.TaskType{arrayTaskType, arrayCore.AwsBatchTaskType}, }) } diff --git a/go/tasks/plugins/array/awsbatch/transformer_test.go b/go/tasks/plugins/array/awsbatch/transformer_test.go index 29fc8022c..ebbbbb1bd 100644 --- a/go/tasks/plugins/array/awsbatch/transformer_test.go +++ b/go/tasks/plugins/array/awsbatch/transformer_test.go @@ -22,7 +22,6 @@ import ( "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" - "github.com/flyteorg/flyteplugins/go/tasks/plugins/array" "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/awsbatch/config" v12 "k8s.io/api/core/v1" @@ -30,7 +29,7 @@ import ( "github.com/aws/aws-sdk-go/service/batch" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" + arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core" "github.com/stretchr/testify/assert" ) @@ -137,7 +136,7 @@ func TestArrayJobToBatchInput(t *testing.T) { }, } - input := &plugins.ArrayJob{ + input := &arrayCore.ArrayJob{ Size: 10, Parallelism: 5, } @@ -207,7 +206,7 @@ func TestArrayJobToBatchInput(t *testing.T) { assert.NotNil(t, batchInput) assert.Equal(t, *expectedBatchInput, *batchInput) - taskTemplate.Type = array.AwsBatchTaskType + taskTemplate.Type = arrayCore.AwsBatchTaskType tr.OnReadMatch(mock.Anything).Return(taskTemplate, nil) taskCtx.OnTaskReader().Return(tr) diff --git a/go/tasks/plugins/array/catalog.go b/go/tasks/plugins/array/catalog.go index 3695304b3..4712b1b40 100644 --- a/go/tasks/plugins/array/catalog.go +++ b/go/tasks/plugins/array/catalog.go @@ -6,8 +6,6 @@ import ( "math" "strconv" - idlPlugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" - arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core" "github.com/flyteorg/flytestdlib/bitarray" @@ -23,8 +21,6 @@ import ( idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" ) -const AwsBatchTaskType = "aws-batch" - // DetermineDiscoverability checks if there are any previously cached tasks. If there are we will only submit an // ArrayJob for the non-cached tasks. The ArrayJob is now a different size, and each task will get a new index location // which is different than their original location. To find the original index we construct an indexLookup array. @@ -42,18 +38,7 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex } // Extract the custom plugin pb - var arrayJob *idlPlugins.ArrayJob - if taskTemplate.Type == AwsBatchTaskType { - arrayJob = &idlPlugins.ArrayJob{ - Parallelism: 1, - Size: 1, - SuccessCriteria: &idlPlugins.ArrayJob_MinSuccesses{ - MinSuccesses: 1, - }, - } - } else { - arrayJob, err = arrayCore.ToArrayJob(taskTemplate.GetCustom(), taskTemplate.TaskTypeVersion) - } + arrayJob, err := arrayCore.ToArrayJob(taskTemplate, taskTemplate.TaskTypeVersion) if err != nil { return state, err } @@ -96,7 +81,7 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex return state, errors.Errorf(errors.BadTaskSpecification, "Unable to determine array size from inputs") } - minSuccesses := math.Ceil(float64(arrayJob.GetMinSuccessRatio()) * float64(size)) + minSuccesses := math.Ceil(arrayJob.GetMinSuccessRatio() * float64(size)) logger.Debugf(ctx, "Computed state: size [%d] and minSuccesses [%d]", int64(size), int64(minSuccesses)) state = state.SetOriginalArraySize(int64(size)) diff --git a/go/tasks/plugins/array/catalog_test.go b/go/tasks/plugins/array/catalog_test.go index 3edb82056..d4ccb3858 100644 --- a/go/tasks/plugins/array/catalog_test.go +++ b/go/tasks/plugins/array/catalog_test.go @@ -177,7 +177,7 @@ func TestDetermineDiscoverability(t *testing.T) { t.Run("Run AWS Batch single job", func(t *testing.T) { toCache := arrayCore.InvertBitSet(bitarray.NewBitSet(1), 1) - template.Type = AwsBatchTaskType + template.Type = arrayCore.AwsBatchTaskType runDetermineDiscoverabilityTest(t, template, f, &arrayCore.State{ CurrentPhase: arrayCore.PhasePreLaunch, PhaseVersion: core2.DefaultPhaseVersion, @@ -258,14 +258,9 @@ func TestDiscoverabilityTaskType1(t *testing.T) { download.OnGetCachedResults().Return(bitarray.NewBitSet(1)).Once() toCache := arrayCore.InvertBitSet(bitarray.NewBitSet(uint(3)), uint(3)) - arrayJob := &plugins.ArrayJob{ - SuccessCriteria: &plugins.ArrayJob_MinSuccessRatio{ - MinSuccessRatio: 0.5, - }, + arrayJob := map[string]string{ + "MinSuccessRatio": "0.5", } - var arrayJobCustom structpb.Struct - err := utils.MarshalStruct(arrayJob, &arrayJobCustom) - assert.NoError(t, err) templateType1 := &core.TaskTemplate{ Id: &core.Identifier{ ResourceType: core.ResourceType_TASK, @@ -290,8 +285,30 @@ func TestDiscoverabilityTaskType1(t *testing.T) { }, }, TaskTypeVersion: 1, - Custom: &arrayJobCustom, + Config: arrayJob, + } + + runDetermineDiscoverabilityTest(t, templateType1, f, &arrayCore.State{ + CurrentPhase: arrayCore.PhasePreLaunch, + PhaseVersion: core2.DefaultPhaseVersion, + ExecutionArraySize: 3, + OriginalArraySize: 3, + OriginalMinSuccesses: 2, + IndexesToCache: toCache, + Reason: "Task is not discoverable.", + }, nil) + + // Get ArrayJob information from taskTemplate.config + arrayJobProto := &plugins.ArrayJob{ + SuccessCriteria: &plugins.ArrayJob_MinSuccessRatio{ + MinSuccessRatio: 0.5, + }, } + var arrayJobCustom structpb.Struct + err := utils.MarshalStruct(arrayJobProto, &arrayJobCustom) + assert.NoError(t, err) + templateType1.Config = nil + templateType1.Custom = &arrayJobCustom runDetermineDiscoverabilityTest(t, templateType1, f, &arrayCore.State{ CurrentPhase: arrayCore.PhasePreLaunch, diff --git a/go/tasks/plugins/array/core/array_job.go b/go/tasks/plugins/array/core/array_job.go new file mode 100644 index 000000000..3289a1ce7 --- /dev/null +++ b/go/tasks/plugins/array/core/array_job.go @@ -0,0 +1,24 @@ +package core + +type ArrayJob struct { + Parallelism int64 + Size int64 + MinSuccesses int64 + MinSuccessRatio float64 +} + +func (a ArrayJob) GetParallelism() int64 { + return a.Parallelism +} + +func (a ArrayJob) GetSize() int64 { + return a.Size +} + +func (a ArrayJob) GetMinSuccesses() int64 { + return a.MinSuccesses +} + +func (a ArrayJob) GetMinSuccessRatio() float64 { + return a.MinSuccessRatio +} diff --git a/go/tasks/plugins/array/core/state.go b/go/tasks/plugins/array/core/state.go index 67fea6514..787634ace 100644 --- a/go/tasks/plugins/array/core/state.go +++ b/go/tasks/plugins/array/core/state.go @@ -3,8 +3,11 @@ package core import ( "context" "fmt" + "strconv" "time" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" + "github.com/flyteorg/flytestdlib/errors" "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/arraystatus" @@ -13,9 +16,7 @@ import ( idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" idlPlugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" "github.com/flyteorg/flytestdlib/logger" - structpb "github.com/golang/protobuf/ptypes/struct" ) //go:generate mockery -all -case=underscore @@ -38,6 +39,8 @@ const ( PhasePermanentFailure ) +const AwsBatchTaskType = "aws-batch" + type State struct { CurrentPhase Phase `json:"phase"` PhaseVersion uint32 `json:"phaseVersion"` @@ -139,30 +142,53 @@ const ( ErrorK8sArrayGeneric errors.ErrorCode = "ARRAY_JOB_GENERIC_FAILURE" ) -func ToArrayJob(structObj *structpb.Struct, taskTypeVersion int32) (*idlPlugins.ArrayJob, error) { - if structObj == nil { - if taskTypeVersion == 0 { - - return &idlPlugins.ArrayJob{ - Parallelism: 1, - Size: 1, - SuccessCriteria: &idlPlugins.ArrayJob_MinSuccesses{ - MinSuccesses: 1, - }, - }, nil +func ToArrayJob(taskTemplate *idlCore.TaskTemplate, taskTypeVersion int32) (*ArrayJob, error) { + if taskTemplate != nil && taskTemplate.GetConfig() != nil { + config := taskTemplate.GetConfig() + arrayJob := &ArrayJob{} + var err error + if len(config["Parallelism"]) != 0 { + arrayJob.Parallelism, err = strconv.ParseInt(config["Parallelism"], 10, 64) + } + if len(config["Size"]) != 0 { + arrayJob.Size, err = strconv.ParseInt(config["Size"], 10, 64) + } + if len(config["MinSuccesses"]) != 0 { + arrayJob.MinSuccesses, err = strconv.ParseInt(config["MinSuccesses"], 10, 64) + } + if len(config["MinSuccessRatio"]) != 0 { + arrayJob.MinSuccessRatio, err = strconv.ParseFloat(config["MinSuccessRatio"], 64) + } + return arrayJob, err + } + + // Keep backward compatibility for those who use arrayJob proto + if taskTemplate != nil && taskTemplate.GetCustom() != nil { + arrayJob := &idlPlugins.ArrayJob{} + err := utils.UnmarshalStruct(taskTemplate.GetCustom(), arrayJob) + if err != nil { + return nil, err } - return &idlPlugins.ArrayJob{ - Parallelism: 1, - Size: 1, - SuccessCriteria: &idlPlugins.ArrayJob_MinSuccessRatio{ - MinSuccessRatio: 1.0, - }, + return &ArrayJob{ + Parallelism: arrayJob.GetParallelism(), + Size: arrayJob.GetSize(), + MinSuccessRatio: float64(arrayJob.GetMinSuccessRatio()), + MinSuccesses: arrayJob.GetMinSuccesses(), }, nil } - arrayJob := &idlPlugins.ArrayJob{} - err := utils.UnmarshalStruct(structObj, arrayJob) - return arrayJob, err + if taskTypeVersion == 0 || (taskTemplate != nil && taskTemplate.Type == AwsBatchTaskType) { + return &ArrayJob{ + Parallelism: 1, + Size: 1, + MinSuccesses: 1, + }, nil + } + return &ArrayJob{ + Parallelism: 1, + Size: 1, + MinSuccessRatio: 1.0, + }, nil } func GetPhaseVersionOffset(currentPhase Phase, length int64) uint32 { diff --git a/go/tasks/plugins/array/core/state_test.go b/go/tasks/plugins/array/core/state_test.go index 879e02d45..6a3423c59 100644 --- a/go/tasks/plugins/array/core/state_test.go +++ b/go/tasks/plugins/array/core/state_test.go @@ -5,13 +5,14 @@ import ( "fmt" "testing" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" - "github.com/golang/protobuf/proto" + "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/arraystatus" - "github.com/flyteorg/flytestdlib/bitarray" + idlPlugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" + "github.com/flyteorg/flytestdlib/utils" + idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/arraystatus" + "github.com/flyteorg/flytestdlib/bitarray" "github.com/stretchr/testify/assert" ) @@ -294,25 +295,67 @@ func TestToArrayJob(t *testing.T) { t.Run("task_type_version == 0", func(t *testing.T) { arrayJob, err := ToArrayJob(nil, 0) assert.NoError(t, err) - assert.True(t, proto.Equal(arrayJob, &plugins.ArrayJob{ - Parallelism: 1, - Size: 1, - SuccessCriteria: &plugins.ArrayJob_MinSuccesses{ - MinSuccesses: 1, - }, - })) + assert.True(t, *arrayJob == ArrayJob{ + Parallelism: 1, + Size: 1, + MinSuccesses: 1, + }) }) t.Run("task_type_version == 1", func(t *testing.T) { arrayJob, err := ToArrayJob(nil, 1) assert.NoError(t, err) - assert.True(t, proto.Equal(arrayJob, &plugins.ArrayJob{ - Parallelism: 1, - Size: 1, - SuccessCriteria: &plugins.ArrayJob_MinSuccessRatio{ + assert.True(t, *arrayJob == ArrayJob{ + Parallelism: 1, + Size: 1, + MinSuccessRatio: 1.0, + }) + }) + + t.Run("task_type_version == AwsBatchTaskType", func(t *testing.T) { + taskTemplate := &idlCore.TaskTemplate{Type: AwsBatchTaskType} + arrayJob, err := ToArrayJob(taskTemplate, 1) + assert.NoError(t, err) + assert.True(t, *arrayJob == ArrayJob{ + Parallelism: 1, + Size: 1, + MinSuccesses: 1, + }) + }) + + t.Run("ToArrayJob with config", func(t *testing.T) { + config := map[string]string{ + "Parallelism": "10", + "Size": "10", + "MinSuccesses": "1", + "MinSuccessRatio": "1.0", + } + taskTemplate := &idlCore.TaskTemplate{Config: config} + arrayJob, err := ToArrayJob(taskTemplate, 0) + assert.NoError(t, err) + assert.Equal(t, arrayJob.GetParallelism(), int64(10)) + assert.Equal(t, arrayJob.GetSize(), int64(10)) + assert.Equal(t, arrayJob.GetMinSuccesses(), int64(1)) + assert.Equal(t, arrayJob.GetMinSuccessRatio(), 1.0) + }) + + t.Run("ToArrayJob with custom", func(t *testing.T) { + arrayJobProto := &idlPlugins.ArrayJob{ + Parallelism: 10, + Size: 10, + SuccessCriteria: &idlPlugins.ArrayJob_MinSuccessRatio{ MinSuccessRatio: 1.0, }, - })) + } + custom, err := utils.MarshalPbToStruct(arrayJobProto) + assert.NoError(t, err) + taskTemplate := &idlCore.TaskTemplate{Custom: custom} + arrayJob, err := ToArrayJob(taskTemplate, 0) + assert.NoError(t, err) + assert.Equal(t, arrayJob.GetParallelism(), int64(10)) + assert.Equal(t, arrayJob.GetSize(), int64(10)) + assert.Equal(t, arrayJob.GetMinSuccesses(), int64(0)) + assert.Equal(t, arrayJob.GetMinSuccessRatio(), 1.0) }) } diff --git a/go/tasks/plugins/array/inputs.go b/go/tasks/plugins/array/inputs.go index ac0c2acb5..aa525ae6c 100644 --- a/go/tasks/plugins/array/inputs.go +++ b/go/tasks/plugins/array/inputs.go @@ -6,6 +6,7 @@ import ( idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" + arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core" "github.com/flyteorg/flytestdlib/storage" ) @@ -20,7 +21,7 @@ func (i arrayJobInputReader) GetInputPath() storage.DataReference { } func GetInputReader(tCtx core.TaskExecutionContext, taskTemplate *idlCore.TaskTemplate) io.InputReader { - if taskTemplate.GetTaskTypeVersion() == 0 && taskTemplate.Type != AwsBatchTaskType { + if taskTemplate.GetTaskTypeVersion() == 0 && taskTemplate.Type != arrayCore.AwsBatchTaskType { // Prior to task type version == 1, dynamic type tasks (including array tasks) would write input files for each // individual array task instance. In this case we use a modified input reader to only pass in the parent input // directory. diff --git a/go/tasks/plugins/array/inputs_test.go b/go/tasks/plugins/array/inputs_test.go index 42b0f8703..1026c3aa8 100644 --- a/go/tasks/plugins/array/inputs_test.go +++ b/go/tasks/plugins/array/inputs_test.go @@ -6,6 +6,7 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" pluginsCoreMock "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" pluginsIOMock "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" + arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -38,4 +39,14 @@ func TestGetInputReader(t *testing.T) { }) assert.Equal(t, inputReader.GetInputPath().String(), "test-data-reference") }) + + t.Run("task_type_version == AwsBatchTaskType", func(t *testing.T) { + taskCtx := &pluginsCoreMock.TaskExecutionContext{} + taskCtx.On("InputReader").Return(inputReader) + + inputReader := GetInputReader(taskCtx, &core.TaskTemplate{ + Type: arrayCore.AwsBatchTaskType, + }) + assert.Equal(t, inputReader.GetInputPath().String(), "test-data-reference") + }) } diff --git a/go/tasks/plugins/array/k8s/transformer.go b/go/tasks/plugins/array/k8s/transformer.go index 8b10c54d3..9eb5819ec 100644 --- a/go/tasks/plugins/array/k8s/transformer.go +++ b/go/tasks/plugins/array/k8s/transformer.go @@ -14,10 +14,10 @@ import ( "github.com/flyteorg/flyteplugins/go/tasks/plugins/array" idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - idlPlugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" "github.com/flyteorg/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" + arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core" core2 "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -97,7 +97,7 @@ func buildPodMapTask(task *idlCore.TaskTemplate, metadata core.TaskExecutionMeta // FlyteArrayJobToK8sPodTemplate returns a pod template for the given task context. Note that Name is not set on the // result object. It's up to the caller to set the Name before creating the object in K8s. func FlyteArrayJobToK8sPodTemplate(ctx context.Context, tCtx core.TaskExecutionContext, namespaceTemplate string) ( - podTemplate v1.Pod, job *idlPlugins.ArrayJob, err error) { + podTemplate v1.Pod, job *arrayCore.ArrayJob, err error) { // Check that the taskTemplate is valid taskTemplate, err := tCtx.TaskReader().Read(ctx) @@ -117,12 +117,9 @@ func FlyteArrayJobToK8sPodTemplate(ctx context.Context, tCtx core.TaskExecutionC arrayInputReader: array.GetInputReader(tCtx, taskTemplate), } - var arrayJob *idlPlugins.ArrayJob - if taskTemplate.GetCustom() != nil { - arrayJob, err = core2.ToArrayJob(taskTemplate.GetCustom(), taskTemplate.TaskTypeVersion) - if err != nil { - return v1.Pod{}, nil, err - } + arrayJob, err := core2.ToArrayJob(taskTemplate, taskTemplate.TaskTypeVersion) + if err != nil { + return v1.Pod{}, nil, err } annotations := utils.UnionMaps(config.GetK8sPluginConfig().DefaultAnnotations, tCtx.TaskExecutionMetadata().GetAnnotations()) diff --git a/go/tasks/plugins/array/k8s/transformer_test.go b/go/tasks/plugins/array/k8s/transformer_test.go index 822e3c544..b3a97a368 100644 --- a/go/tasks/plugins/array/k8s/transformer_test.go +++ b/go/tasks/plugins/array/k8s/transformer_test.go @@ -11,10 +11,9 @@ import ( "github.com/flyteorg/flytestdlib/storage" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - idlPlugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" mocks2 "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" + arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core" structpb "github.com/golang/protobuf/ptypes/struct" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -42,7 +41,7 @@ var podSpec = v1.PodSpec{ }, } -var arrayJob = idlPlugins.ArrayJob{ +var arrayJob = arrayCore.ArrayJob{ Size: 100, } @@ -57,15 +56,11 @@ func getK8sPodTask(t *testing.T, annotations map[string]string) *core.TaskTempla t.Fatal(err) } - custom := &structpb.Struct{} - if err := utils.MarshalStruct(&arrayJob, custom); err != nil { - t.Fatal(err) - } - return &core.TaskTemplate{ TaskTypeVersion: 2, Config: map[string]string{ primaryContainerKey: testPrimaryContainerName, + "Size": "100", }, Target: &core.TaskTemplate_K8SPod{ K8SPod: &core.K8SPod{ @@ -78,7 +73,6 @@ func getK8sPodTask(t *testing.T, annotations map[string]string) *core.TaskTempla }, }, }, - Custom: custom, } } diff --git a/go/tasks/plugins/array/outputs.go b/go/tasks/plugins/array/outputs.go index 4177fcb61..0d7b2ece1 100644 --- a/go/tasks/plugins/array/outputs.go +++ b/go/tasks/plugins/array/outputs.go @@ -198,7 +198,7 @@ func AssembleFinalOutputs(ctx context.Context, assemblyQueue OutputAssembler, tC finalPhases: finalPhases, outputPaths: tCtx.OutputWriter(), dataStore: tCtx.DataStore(), - isAwsSingleJob: taskTemplate.Type == AwsBatchTaskType, + isAwsSingleJob: taskTemplate.Type == arrayCore.AwsBatchTaskType, }) if err != nil {