diff --git a/flyteplugins/go/tasks/plugins/k8s/dask/config.go b/flyteplugins/go/tasks/plugins/k8s/dask/config.go new file mode 100644 index 00000000000..aac388e116e --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/dask/config.go @@ -0,0 +1,29 @@ +package dask + +import ( + pluginsConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/config" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/logs" +) + +//go:generate pflags Config --default-var=defaultConfig + +var ( + defaultConfig = Config{ + Logs: logs.DefaultConfig, + } + + configSection = pluginsConfig.MustRegisterSubSection("dask", &defaultConfig) +) + +// Config is config for 'dask' plugin +type Config struct { + Logs logs.LogConfig `json:"logs,omitempty"` +} + +func GetConfig() *Config { + return configSection.GetConfig().(*Config) +} + +func SetConfig(cfg *Config) error { + return configSection.SetConfig(cfg) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/dask/config_flags.go b/flyteplugins/go/tasks/plugins/k8s/dask/config_flags.go new file mode 100755 index 00000000000..03774b772bd --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/dask/config_flags.go @@ -0,0 +1,65 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package dask + +import ( + "encoding/json" + "reflect" + + "fmt" + + "github.com/spf13/pflag" +) + +// If v is a pointer, it will get its element value or the zero value of the element type. +// If v is not a pointer, it will return it as is. +func (Config) elemValueOrNil(v interface{}) interface{} { + if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { + if reflect.ValueOf(v).IsNil() { + return reflect.Zero(t.Elem()).Interface() + } else { + return reflect.ValueOf(v).Interface() + } + } else if v == nil { + return reflect.Zero(t).Interface() + } + + return v +} + +func (Config) mustJsonMarshal(v interface{}) string { + raw, err := json.Marshal(v) + if err != nil { + panic(err) + } + + return string(raw) +} + +func (Config) mustMarshalJSON(v json.Marshaler) string { + raw, err := v.MarshalJSON() + if err != nil { + panic(err) + } + + return string(raw) +} + +// GetPFlagSet will return strongly types pflags for all fields in Config and its nested types. The format of the +// flags is json-name.json-sub-name... etc. +func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "logs.cloudwatch-enabled"), defaultConfig.Logs.IsCloudwatchEnabled, "Enable Cloudwatch Logging") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "logs.cloudwatch-region"), defaultConfig.Logs.CloudwatchRegion, "AWS region in which Cloudwatch logs are stored.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "logs.cloudwatch-log-group"), defaultConfig.Logs.CloudwatchLogGroup, "Log group to which streams are associated.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "logs.cloudwatch-template-uri"), defaultConfig.Logs.CloudwatchTemplateURI, "Template Uri to use when building cloudwatch log links") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "logs.kubernetes-enabled"), defaultConfig.Logs.IsKubernetesEnabled, "Enable Kubernetes Logging") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "logs.kubernetes-url"), defaultConfig.Logs.KubernetesURL, "Console URL for Kubernetes logs") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "logs.kubernetes-template-uri"), defaultConfig.Logs.KubernetesTemplateURI, "Template Uri to use when building kubernetes log links") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "logs.stackdriver-enabled"), defaultConfig.Logs.IsStackDriverEnabled, "Enable Log-links to stackdriver") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "logs.gcp-project"), defaultConfig.Logs.GCPProjectName, "Name of the project in GCP") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "logs.stackdriver-logresourcename"), defaultConfig.Logs.StackdriverLogResourceName, "Name of the logresource in stackdriver") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "logs.stackdriver-template-uri"), defaultConfig.Logs.StackDriverTemplateURI, "Template Uri to use when building stackdriver log links") + return cmdFlags +} diff --git a/flyteplugins/go/tasks/plugins/k8s/dask/config_flags_test.go b/flyteplugins/go/tasks/plugins/k8s/dask/config_flags_test.go new file mode 100755 index 00000000000..4cd2be2b44e --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/dask/config_flags_test.go @@ -0,0 +1,256 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package dask + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +) + +var dereferencableKindsConfig = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElementConfig(t reflect.Kind) bool { + _, exists := dereferencableKindsConfig[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshalerHookConfig(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElementConfig(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + raw, err := json.Marshal(data) + if err != nil { + fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +func decode_Config(input, result interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + Result: result, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + jsonUnmarshalerHookConfig, + ), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +func join_Config(arr interface{}, sep string) string { + listValue := reflect.ValueOf(arr) + strs := make([]string, 0, listValue.Len()) + for i := 0; i < listValue.Len(); i++ { + strs = append(strs, fmt.Sprintf("%v", listValue.Index(i))) + } + + return strings.Join(strs, sep) +} + +func testDecodeJson_Config(t *testing.T, val, result interface{}) { + assert.NoError(t, decode_Config(val, result)) +} + +func testDecodeRaw_Config(t *testing.T, vStringSlice, result interface{}) { + assert.NoError(t, decode_Config(vStringSlice, result)) +} + +func TestConfig_GetPFlagSet(t *testing.T) { + val := Config{} + cmdFlags := val.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) +} + +func TestConfig_SetFlags(t *testing.T) { + actual := Config{} + cmdFlags := actual.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) + + t.Run("Test_logs.cloudwatch-enabled", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("logs.cloudwatch-enabled", testValue) + if vBool, err := cmdFlags.GetBool("logs.cloudwatch-enabled"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.Logs.IsCloudwatchEnabled) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_logs.cloudwatch-region", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("logs.cloudwatch-region", testValue) + if vString, err := cmdFlags.GetString("logs.cloudwatch-region"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Logs.CloudwatchRegion) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_logs.cloudwatch-log-group", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("logs.cloudwatch-log-group", testValue) + if vString, err := cmdFlags.GetString("logs.cloudwatch-log-group"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Logs.CloudwatchLogGroup) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_logs.cloudwatch-template-uri", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("logs.cloudwatch-template-uri", testValue) + if vString, err := cmdFlags.GetString("logs.cloudwatch-template-uri"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Logs.CloudwatchTemplateURI) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_logs.kubernetes-enabled", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("logs.kubernetes-enabled", testValue) + if vBool, err := cmdFlags.GetBool("logs.kubernetes-enabled"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.Logs.IsKubernetesEnabled) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_logs.kubernetes-url", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("logs.kubernetes-url", testValue) + if vString, err := cmdFlags.GetString("logs.kubernetes-url"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Logs.KubernetesURL) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_logs.kubernetes-template-uri", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("logs.kubernetes-template-uri", testValue) + if vString, err := cmdFlags.GetString("logs.kubernetes-template-uri"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Logs.KubernetesTemplateURI) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_logs.stackdriver-enabled", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("logs.stackdriver-enabled", testValue) + if vBool, err := cmdFlags.GetBool("logs.stackdriver-enabled"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.Logs.IsStackDriverEnabled) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_logs.gcp-project", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("logs.gcp-project", testValue) + if vString, err := cmdFlags.GetString("logs.gcp-project"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Logs.GCPProjectName) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_logs.stackdriver-logresourcename", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("logs.stackdriver-logresourcename", testValue) + if vString, err := cmdFlags.GetString("logs.stackdriver-logresourcename"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Logs.StackdriverLogResourceName) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_logs.stackdriver-template-uri", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("logs.stackdriver-template-uri", testValue) + if vString, err := cmdFlags.GetString("logs.stackdriver-template-uri"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Logs.StackDriverTemplateURI) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/dask/dask.go b/flyteplugins/go/tasks/plugins/k8s/dask/dask.go index d3b4ab32f14..4c9a551cbae 100644 --- a/flyteplugins/go/tasks/plugins/k8s/dask/dask.go +++ b/flyteplugins/go/tasks/plugins/k8s/dask/dask.go @@ -279,7 +279,7 @@ func createJobSpec(workerSpec daskAPI.WorkerSpec, schedulerSpec daskAPI.Schedule } func (p daskResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, r client.Object) (pluginsCore.PhaseInfo, error) { - logPlugin, err := logs.InitializeLogPlugins(logs.GetLogConfig()) + logPlugin, err := logs.InitializeLogPlugins(&GetConfig().Logs) if err != nil { return pluginsCore.PhaseInfoUndefined, err }