diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/config/config.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/config/config.go index adb2d655bb..95be69699d 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/config/config.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/config/config.go @@ -218,6 +218,8 @@ type K8sPluginConfig struct { // Extended resources that should be added to the tolerations automatically. AddTolerationsForExtendedResources []string `json:"add-tolerations-for-extended-resources" pflag:",Name of the extended resources for which tolerations should be added."` + + EnableDistributedErrorAggregation bool `json:"enable-distributed-error-aggregation" pflag:",If true, will aggregate errors of different worker pods for distributed tasks."` } // FlyteCoPilotConfig specifies configuration for the Flyte CoPilot system. FlyteCoPilot, allows running flytekit-less containers diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index 0ee3f3570f..3dc81ce41a 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -18,6 +18,7 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery" pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" pluginsK8s "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s" + k8sConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/kfoperators/common" @@ -30,8 +31,14 @@ type pytorchOperatorResourceHandler struct { var _ k8s.Plugin = pytorchOperatorResourceHandler{} func (pytorchOperatorResourceHandler) GetProperties() k8s.PluginProperties { - return k8s.PluginProperties{ - ErrorAggregationStrategy: k8s.EarliestErrorAggregationStrategy, + config := k8sConfig.GetK8sPluginConfig() + + if config.EnableDistributedErrorAggregation { + return k8s.PluginProperties{ + ErrorAggregationStrategy: k8s.EarliestErrorAggregationStrategy, + } + } else { + return k8s.PluginProperties{} } } @@ -47,7 +54,7 @@ func (pytorchOperatorResourceHandler) BuildIdentityResource(ctx context.Context, } // Defines a func to create the full resource object that will be posted to k8s. -func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) { +func (p pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) { taskTemplate, err := taskCtx.TaskReader().Read(ctx) if err != nil { @@ -115,10 +122,13 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx }, }, }) - container.Env = append(container.Env, apiv1.EnvVar{ - Name: pluginsK8s.FlyteInternalDistErrorStrategyEnvVarKey, - Value: k8s.EarliestErrorAggregationStrategy.String(), - }) + + if p.GetProperties().ErrorAggregationStrategy == k8s.EarliestErrorAggregationStrategy { + container.Env = append(container.Env, apiv1.EnvVar{ + Name: pluginsK8s.FlyteInternalDistErrorStrategyEnvVarKey, + Value: k8s.EarliestErrorAggregationStrategy.String(), + }) + } } updateEnvVars(&workerReplicaSpec.Template.Spec.Containers[0]) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index 0f38bb2851..ab319561cf 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -26,6 +26,7 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s" pluginsK8s "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s" flytek8sConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" + k8sConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" pluginIOMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" @@ -724,8 +725,14 @@ func TestGetLogsElastic(t *testing.T) { } func TestGetProperties(t *testing.T) { + config := k8sConfig.GetK8sPluginConfig() pytorchResourceHandler := pytorchOperatorResourceHandler{} - expected := k8s.PluginProperties{ + + expected := k8s.PluginProperties{} + assert.Equal(t, expected, pytorchResourceHandler.GetProperties()) + + config.EnableDistributedErrorAggregation = true + expected = k8s.PluginProperties{ ErrorAggregationStrategy: k8s.EarliestErrorAggregationStrategy, } assert.Equal(t, expected, pytorchResourceHandler.GetProperties()) @@ -861,6 +868,8 @@ func TestBuildResourcePytorchV1(t *testing.T) { }, } + config := k8sConfig.GetK8sPluginConfig() + config.EnableDistributedErrorAggregation = true pytorchResourceHandler := pytorchOperatorResourceHandler{} taskTemplate := dummyPytorchTaskTemplate("job4", taskConfig)