diff --git a/docs/resources/task.md b/docs/resources/task.md index 2889ff0a42..43e07c6ac7 100644 --- a/docs/resources/task.md +++ b/docs/resources/task.md @@ -88,6 +88,7 @@ resource "snowflake_task" "test_task" { - `error_integration` (String) Specifies the name of the notification integration used for error notifications. - `schedule` (String) The schedule for periodically running the task. This can be a cron or interval in minutes. (Conflict with after) - `session_parameters` (Map of String) Specifies session parameters to set for the session when the task runs. A task supports all session parameters. +- `suspend_task_after_num_failures` (Number) Specifies the number of consecutive failed task runs after which the current task is suspended automatically. The default is 0 (no automatic suspension). - `user_task_managed_initial_warehouse_size` (String) Specifies the size of the compute resources to provision for the first run of the task, before a task history is available for Snowflake to determine an ideal size. Once a task has successfully completed a few runs, Snowflake ignores this parameter setting. (Conflicts with warehouse) - `user_task_timeout_ms` (Number) Specifies the time limit on a single run of the task before it times out (in milliseconds). - `warehouse` (String) The warehouse the task will use. Omit this parameter to use Snowflake-managed compute resources for runs of this task. (Conflicts with user_task_managed_initial_warehouse_size) diff --git a/pkg/resources/task.go b/pkg/resources/task.go index 72e42fec09..2134398f29 100644 --- a/pkg/resources/task.go +++ b/pkg/resources/task.go @@ -16,7 +16,6 @@ import ( "golang.org/x/exp/slices" ) -// TODO [SNOW-884987]: add missing SUSPEND_TASK_AFTER_NUM_FAILURES attribute. var taskSchema = map[string]*schema.Schema{ "enabled": { Type: schema.TypeBool, @@ -67,6 +66,13 @@ var taskSchema = map[string]*schema.Schema{ ValidateFunc: validation.IntBetween(0, 86400000), Description: "Specifies the time limit on a single run of the task before it times out (in milliseconds).", }, + "suspend_task_after_num_failures": { + Type: schema.TypeInt, + Optional: true, + Default: 0, + ValidateFunc: validation.IntAtLeast(0), + Description: "Specifies the number of consecutive failed task runs after which the current task is suspended automatically. The default is 0 (no automatic suspension).", + }, "comment": { Type: schema.TypeString, Optional: true, @@ -124,6 +130,19 @@ func difference(a, b map[string]any) map[string]any { return diff } +// differentValue find keys present both in 'a' and 'b' but having different values. +func differentValue(a, b map[string]any) map[string]any { + diff := make(map[string]any) + for k, va := range a { + if vb, ok := b[k]; ok { + if vb != va { + diff[k] = vb + } + } + } + return diff +} + // Task returns a pointer to the resource representing a task. func Task() *schema.Resource { return &schema.Resource{ @@ -214,7 +233,7 @@ func ReadTask(d *schema.ResourceData, meta interface{}) error { } if len(params) > 0 { - sessionParameters := map[string]interface{}{} + sessionParameters := make(map[string]any) fieldParameters := map[string]interface{}{ "user_task_managed_initial_warehouse_size": "", } @@ -233,6 +252,13 @@ func ReadTask(d *schema.ResourceData, meta interface{}) error { } fieldParameters["user_task_timeout_ms"] = timeout + case "SUSPEND_TASK_AFTER_NUM_FAILURES": + num, err := strconv.ParseInt(param.Value, 10, 64) + if err != nil { + return err + } + + fieldParameters["suspend_task_after_num_failures"] = num default: sessionParameters[param.Key] = param.Value } @@ -299,6 +325,10 @@ func CreateTask(d *schema.ResourceData, meta interface{}) error { createRequest.WithUserTaskTimeoutMs(sdk.Int(v.(int))) } + if v, ok := d.GetOk("suspend_task_after_num_failures"); ok { + createRequest.WithSuspendTaskAfterNumFailures(sdk.Int(v.(int))) + } + if v, ok := d.GetOk("comment"); ok { createRequest.WithComment(sdk.String(v.(string))) } @@ -558,6 +588,20 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error { } } + if d.HasChange("suspend_task_after_num_failures") { + o, n := d.GetChange("suspend_task_after_num_failures") + alterRequest := sdk.NewAlterTaskRequest(taskId) + if o.(int) > 0 && n.(int) == 0 { + alterRequest.WithUnset(sdk.NewTaskUnsetRequest().WithSuspendTaskAfterNumFailures(sdk.Bool(true))) + } else { + alterRequest.WithSet(sdk.NewTaskSetRequest().WithSuspendTaskAfterNumFailures(sdk.Int(n.(int)))) + } + err := client.Tasks.Alter(ctx, alterRequest) + if err != nil { + return fmt.Errorf("error updating suspend task after num failures on task %s", taskId.FullyQualifiedName()) + } + } + if d.HasChange("comment") { newComment := d.Get("comment") alterRequest := sdk.NewAlterTaskRequest(taskId) @@ -586,7 +630,6 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error { } } - // TODO [SNOW-884987]: old implementation does not handle changing parameter value correctly (only finds for parameters to add od remove, not change) if d.HasChange("session_parameters") { o, n := d.GetChange("session_parameters") @@ -601,6 +644,7 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error { remove := difference(os, ns) add := difference(ns, os) + change := differentValue(os, ns) if len(remove) > 0 { sessionParametersUnset, err := sdk.GetSessionParametersUnsetFrom(remove) @@ -608,7 +652,7 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error { return err } if err := client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(taskId).WithUnset(sdk.NewTaskUnsetRequest().WithSessionParametersUnset(sessionParametersUnset))); err != nil { - return fmt.Errorf("error removing session_parameters on task %v", d.Id()) + return fmt.Errorf("error removing session_parameters on task %v err = %w", d.Id(), err) } } @@ -618,7 +662,17 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error { return err } if err := client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(taskId).WithSet(sdk.NewTaskSetRequest().WithSessionParameters(sessionParameters))); err != nil { - return fmt.Errorf("error adding session_parameters to task %v", d.Id()) + return fmt.Errorf("error adding session_parameters to task %v err = %w", d.Id(), err) + } + } + + if len(change) > 0 { + sessionParameters, err := sdk.GetSessionParametersFrom(change) + if err != nil { + return err + } + if err := client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(taskId).WithSet(sdk.NewTaskSetRequest().WithSessionParameters(sessionParameters))); err != nil { + return fmt.Errorf("error updating session_parameters in task %v err = %w", d.Id(), err) } } } diff --git a/pkg/resources/task_acceptance_test.go b/pkg/resources/task_acceptance_test.go index fa7af5e679..91f32876ec 100644 --- a/pkg/resources/task_acceptance_test.go +++ b/pkg/resources/task_acceptance_test.go @@ -8,6 +8,7 @@ import ( "text/template" acc "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/hashicorp/terraform-plugin-testing/helper/acctest" "github.com/hashicorp/terraform-plugin-testing/helper/resource" "github.com/hashicorp/terraform-plugin-testing/terraform" @@ -50,6 +51,10 @@ var ( Enabled: true, Schedule: "5 MINUTE", UserTaskTimeoutMs: 1800000, + SessionParams: map[string]string{ + string(sdk.SessionParameterLockTimeout): "1000", + string(sdk.SessionParameterStrictJSONOutput): "true", + }, }, ChildTask: &TaskSettings{ @@ -79,6 +84,10 @@ var ( Enabled: true, Schedule: "5 MINUTE", UserTaskTimeoutMs: 1800000, + SessionParams: map[string]string{ + string(sdk.SessionParameterLockTimeout): "1000", + string(sdk.SessionParameterStrictJSONOutput): "true", + }, }, ChildTask: &TaskSettings{ @@ -95,7 +104,7 @@ var ( When: "TRUE", Enabled: true, SessionParams: map[string]string{ - "TIMESTAMP_INPUT_FORMAT": "YYYY-MM-DD HH24", + string(sdk.SessionParameterTimestampInputFormat): "YYYY-MM-DD HH24", }, Schedule: "5 MINUTE", UserTaskTimeoutMs: 1800000, @@ -113,6 +122,10 @@ var ( Enabled: true, Schedule: "15 MINUTE", UserTaskTimeoutMs: 1800000, + SessionParams: map[string]string{ + string(sdk.SessionParameterLockTimeout): "1000", + string(sdk.SessionParameterStrictJSONOutput): "true", + }, }, ChildTask: &TaskSettings{ @@ -144,6 +157,11 @@ var ( Enabled: false, Schedule: "5 MINUTE", UserTaskTimeoutMs: 1800000, + // Changes session params: one is updated, one is removed, one is added + SessionParams: map[string]string{ + string(sdk.SessionParameterLockTimeout): "2000", + string(sdk.SessionParameterMultiStatementCount): "5", + }, }, ChildTask: &TaskSettings{ @@ -160,7 +178,7 @@ var ( When: "TRUE", Enabled: true, SessionParams: map[string]string{ - "TIMESTAMP_INPUT_FORMAT": "YYYY-MM-DD HH24", + string(sdk.SessionParameterTimestampInputFormat): "YYYY-MM-DD HH24", }, Schedule: "5 MINUTE", UserTaskTimeoutMs: 0, @@ -193,6 +211,9 @@ func TestAcc_Task(t *testing.T) { resource.TestCheckResourceAttr("snowflake_task.child_task", "schedule", initialState.ChildTask.Schedule), checkInt64("snowflake_task.root_task", "user_task_timeout_ms", initialState.RootTask.UserTaskTimeoutMs), resource.TestCheckNoResourceAttr("snowflake_task.solo_task", "user_task_timeout_ms"), + checkInt64("snowflake_task.root_task", "session_parameters.LOCK_TIMEOUT", 1000), + checkBool("snowflake_task.root_task", "session_parameters.STRICT_JSON_OUTPUT", true), + resource.TestCheckNoResourceAttr("snowflake_task.root_task", "session_parameters.MULTI_STATEMENT_COUNT"), ), }, { @@ -213,6 +234,9 @@ func TestAcc_Task(t *testing.T) { resource.TestCheckResourceAttr("snowflake_task.child_task", "schedule", stepOne.ChildTask.Schedule), checkInt64("snowflake_task.root_task", "user_task_timeout_ms", stepOne.RootTask.UserTaskTimeoutMs), checkInt64("snowflake_task.solo_task", "user_task_timeout_ms", stepOne.SoloTask.UserTaskTimeoutMs), + checkInt64("snowflake_task.root_task", "session_parameters.LOCK_TIMEOUT", 1000), + checkBool("snowflake_task.root_task", "session_parameters.STRICT_JSON_OUTPUT", true), + resource.TestCheckNoResourceAttr("snowflake_task.root_task", "session_parameters.MULTI_STATEMENT_COUNT"), ), }, { @@ -233,6 +257,9 @@ func TestAcc_Task(t *testing.T) { resource.TestCheckResourceAttr("snowflake_task.child_task", "schedule", stepTwo.ChildTask.Schedule), checkInt64("snowflake_task.root_task", "user_task_timeout_ms", stepTwo.RootTask.UserTaskTimeoutMs), checkInt64("snowflake_task.solo_task", "user_task_timeout_ms", stepTwo.SoloTask.UserTaskTimeoutMs), + checkInt64("snowflake_task.root_task", "session_parameters.LOCK_TIMEOUT", 1000), + checkBool("snowflake_task.root_task", "session_parameters.STRICT_JSON_OUTPUT", true), + resource.TestCheckNoResourceAttr("snowflake_task.root_task", "session_parameters.MULTI_STATEMENT_COUNT"), ), }, { @@ -253,6 +280,9 @@ func TestAcc_Task(t *testing.T) { resource.TestCheckResourceAttr("snowflake_task.child_task", "schedule", stepThree.ChildTask.Schedule), checkInt64("snowflake_task.root_task", "user_task_timeout_ms", stepThree.RootTask.UserTaskTimeoutMs), checkInt64("snowflake_task.solo_task", "user_task_timeout_ms", stepThree.SoloTask.UserTaskTimeoutMs), + checkInt64("snowflake_task.root_task", "session_parameters.LOCK_TIMEOUT", 2000), + resource.TestCheckNoResourceAttr("snowflake_task.root_task", "session_parameters.STRICT_JSON_OUTPUT"), + checkInt64("snowflake_task.root_task", "session_parameters.MULTI_STATEMENT_COUNT", 5), ), }, { @@ -279,6 +309,9 @@ func TestAcc_Task(t *testing.T) { // `user_task_timeout_ms` by unsetting the // USER_TASK_TIMEOUT_MS session variable. checkInt64("snowflake_task.solo_task", "user_task_timeout_ms", initialState.ChildTask.UserTaskTimeoutMs), + checkInt64("snowflake_task.root_task", "session_parameters.LOCK_TIMEOUT", 1000), + checkBool("snowflake_task.root_task", "session_parameters.STRICT_JSON_OUTPUT", true), + resource.TestCheckNoResourceAttr("snowflake_task.root_task", "session_parameters.MULTI_STATEMENT_COUNT"), ), }, }, @@ -302,12 +335,12 @@ resource "snowflake_task" "root_task" { user_task_timeout_ms = {{ .RootTask.UserTaskTimeoutMs }} {{- end }} - {{ if .ChildTask.SessionParams }} + {{ if .RootTask.SessionParams }} session_parameters = { - {{ range $key, $value := .RootTask.SessionParams}} + {{ range $key, $value := .RootTask.SessionParams}} {{ $key }} = "{{ $value }}", - } {{- end }} + } {{- end }} } resource "snowflake_task" "child_task" { @@ -325,10 +358,10 @@ resource "snowflake_task" "child_task" { {{ if .ChildTask.SessionParams }} session_parameters = { - {{ range $key, $value := .ChildTask.SessionParams}} + {{ range $key, $value := .ChildTask.SessionParams}} {{ $key }} = "{{ $value }}", - } {{- end }} + } {{- end }} } resource "snowflake_task" "solo_task" { @@ -351,8 +384,8 @@ resource "snowflake_task" "solo_task" { session_parameters = { {{ range $key, $value := .SoloTask.SessionParams}} {{ $key }} = "{{ $value }}", - } {{- end }} + } {{- end }} } `) @@ -519,6 +552,7 @@ func TestAcc_Task_SwitchScheduled(t *testing.T) { resource.TestCheckResourceAttr("snowflake_task.test_task", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_task.test_task", "sql_statement", "SELECT 1"), resource.TestCheckResourceAttr("snowflake_task.test_task", "schedule", "5 MINUTE"), + resource.TestCheckResourceAttr("snowflake_task.test_task_root", "suspend_task_after_num_failures", "1"), ), }, { @@ -529,6 +563,7 @@ func TestAcc_Task_SwitchScheduled(t *testing.T) { resource.TestCheckResourceAttr("snowflake_task.test_task", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_task.test_task", "sql_statement", "SELECT 1"), resource.TestCheckResourceAttr("snowflake_task.test_task", "schedule", ""), + resource.TestCheckResourceAttr("snowflake_task.test_task_root", "suspend_task_after_num_failures", "2"), ), }, { @@ -539,6 +574,7 @@ func TestAcc_Task_SwitchScheduled(t *testing.T) { resource.TestCheckResourceAttr("snowflake_task.test_task", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_task.test_task", "sql_statement", "SELECT 1"), resource.TestCheckResourceAttr("snowflake_task.test_task", "schedule", "5 MINUTE"), + resource.TestCheckResourceAttr("snowflake_task.test_task_root", "suspend_task_after_num_failures", "1"), ), }, { @@ -549,6 +585,7 @@ func TestAcc_Task_SwitchScheduled(t *testing.T) { resource.TestCheckResourceAttr("snowflake_task.test_task", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_task.test_task", "sql_statement", "SELECT 1"), resource.TestCheckResourceAttr("snowflake_task.test_task", "schedule", ""), + resource.TestCheckResourceAttr("snowflake_task.test_task_root", "suspend_task_after_num_failures", "0"), ), }, }, @@ -558,12 +595,13 @@ func TestAcc_Task_SwitchScheduled(t *testing.T) { func taskConfigManagedScheduled(name string, taskRootName string, databaseName string, schemaName string) string { s := ` resource "snowflake_task" "test_task_root" { - name = "%s" - database = "%s" - schema = "%s" - sql_statement = "SELECT 1" - enabled = true - schedule = "5 MINUTE" + name = "%s" + database = "%s" + schema = "%s" + sql_statement = "SELECT 1" + enabled = true + schedule = "5 MINUTE" + suspend_task_after_num_failures = 1 } resource "snowflake_task" "test_task" { @@ -581,12 +619,13 @@ resource "snowflake_task" "test_task" { func taskConfigManagedScheduled2(name string, taskRootName string, databaseName string, schemaName string) string { s := ` resource "snowflake_task" "test_task_root" { - name = "%s" - database = "%s" - schema = "%s" - sql_statement = "SELECT 1" - enabled = true - schedule = "5 MINUTE" + name = "%s" + database = "%s" + schema = "%s" + sql_statement = "SELECT 1" + enabled = true + schedule = "5 MINUTE" + suspend_task_after_num_failures = 2 } resource "snowflake_task" "test_task" { diff --git a/pkg/sdk/internal/collections/queue.go b/pkg/sdk/internal/collections/queue.go new file mode 100644 index 0000000000..3749f1bc3e --- /dev/null +++ b/pkg/sdk/internal/collections/queue.go @@ -0,0 +1,30 @@ +package collections + +type Queue[T any] struct { + data []T +} + +func (s *Queue[T]) Head() *T { + if len(s.data) == 0 { + return nil + } + return &s.data[0] +} + +func (s *Queue[T]) Pop() *T { + elem := s.Head() + if elem != nil { + s.data = s.data[1:] + } + return elem +} + +func (s *Queue[T]) Push(elem T) { + s.data = append(s.data, elem) +} + +func NewQueue[T any]() Queue[T] { + return Queue[T]{ + data: make([]T, 0), + } +} diff --git a/pkg/sdk/internal/collections/queue_test.go b/pkg/sdk/internal/collections/queue_test.go new file mode 100644 index 0000000000..05387df658 --- /dev/null +++ b/pkg/sdk/internal/collections/queue_test.go @@ -0,0 +1,51 @@ +package collections + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestQueue(t *testing.T) { + t.Run("empty queue initialization", func(t *testing.T) { + queue := NewQueue[int]() + + require.Nil(t, queue.Head()) + require.Nil(t, queue.Pop()) + }) + + t.Run("returns head multiple times", func(t *testing.T) { + queue := NewQueue[int]() + + queue.Push(1) + + require.Equal(t, 1, *queue.Head()) + require.Equal(t, 1, *queue.Head()) + }) + + t.Run("returns empty head after pop", func(t *testing.T) { + queue := NewQueue[int]() + + queue.Pop() + + require.Nil(t, queue.Head()) + }) + + t.Run("multiple operations", func(t *testing.T) { + queue := NewQueue[int]() + + queue.Push(1) + require.Equal(t, 1, *queue.Head()) + + queue.Push(2) + require.Equal(t, 1, *queue.Head()) + + elem := queue.Pop() + require.Equal(t, 1, *elem) + require.Equal(t, 2, *queue.Head()) + + elem = queue.Pop() + require.Equal(t, 2, *elem) + require.Nil(t, queue.Head()) + }) +} diff --git a/pkg/sdk/parameters.go b/pkg/sdk/parameters.go index 6a7318bd1d..25304c5851 100644 --- a/pkg/sdk/parameters.go +++ b/pkg/sdk/parameters.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "strconv" + "strings" ) var ( @@ -150,327 +151,31 @@ func (parameters *parameters) SetAccountParameter(ctx context.Context, parameter } func (parameters *parameters) SetSessionParameterOnAccount(ctx context.Context, parameter SessionParameter, value string) error { - opts := AlterAccountOptions{Set: &AccountSet{Parameters: &AccountLevelParameters{SessionParameters: &SessionParameters{}}}} - switch parameter { - case SessionParameterAbortDetachedQuery: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.Parameters.SessionParameters.AbortDetachedQuery = b - case SessionParameterAutocommit: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.Parameters.SessionParameters.Autocommit = b - case SessionParameterBinaryInputFormat: - opts.Set.Parameters.SessionParameters.BinaryInputFormat = Pointer(BinaryInputFormat(value)) - case SessionParameterBinaryOutputFormat: - opts.Set.Parameters.SessionParameters.BinaryOutputFormat = Pointer(BinaryOutputFormat(value)) - case SessionParameterClientMetadataRequestUseConnectionCtx: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.Parameters.SessionParameters.ClientMetadataRequestUseConnectionCtx = b - case SessionParameterClientMetadataUseSessionDatabase: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.Parameters.SessionParameters.ClientMetadataUseSessionDatabase = b - case SessionParameterClientResultColumnCaseInsensitive: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.Parameters.SessionParameters.ClientResultColumnCaseInsensitive = b - case SessionParameterDateInputFormat: - opts.Set.Parameters.SessionParameters.DateInputFormat = &value - case SessionParameterDateOutputFormat: - opts.Set.Parameters.SessionParameters.DateOutputFormat = &value - case SessionParameterErrorOnNondeterministicMerge: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.Parameters.SessionParameters.ErrorOnNondeterministicMerge = b - case SessionParameterErrorOnNondeterministicUpdate: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.Parameters.SessionParameters.ErrorOnNondeterministicUpdate = b - case SessionParameterGeographyOutputFormat: - opts.Set.Parameters.SessionParameters.GeographyOutputFormat = Pointer(GeographyOutputFormat(value)) - case SessionParameterJSONIndent: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("JSON_INDENT session parameter is an integer, got %v", value) - } - opts.Set.Parameters.SessionParameters.JSONIndent = Pointer(v) - case SessionParameterLockTimeout: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("LOCK_TIMEOUT session parameter is an integer, got %v", value) - } - opts.Set.Parameters.SessionParameters.LockTimeout = Pointer(v) - case SessionParameterMultiStatementCount: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("MULTI_STATEMENT_COUNT session parameter is an integer, got %v", value) - } - opts.Set.Parameters.SessionParameters.MultiStatementCount = Pointer(v) - - case SessionParameterQueryTag: - opts.Set.Parameters.SessionParameters.QueryTag = &value - case SessionParameterQuotedIdentifiersIgnoreCase: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.Parameters.SessionParameters.QuotedIdentifiersIgnoreCase = b - case SessionParameterRowsPerResultset: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("ROWS_PER_RESULTSET session parameter is an integer, got %v", value) - } - opts.Set.Parameters.SessionParameters.RowsPerResultset = Pointer(v) - case SessionParameterSimulatedDataSharingConsumer: - opts.Set.Parameters.SessionParameters.SimulatedDataSharingConsumer = &value - case SessionParameterStatementTimeoutInSeconds: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("STATEMENT_TIMEOUT_IN_SECONDS session parameter is an integer, got %v", value) - } - opts.Set.Parameters.SessionParameters.StatementTimeoutInSeconds = Pointer(v) - case SessionParameterStrictJSONOutput: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.Parameters.SessionParameters.StrictJSONOutput = b - case SessionParameterTimestampDayIsAlways24h: - b, err := parseBooleanParameter(string(parameter), value) + sp := &SessionParameters{} + err := sp.setParam(parameter, value) + if err == nil { + opts := AlterAccountOptions{Set: &AccountSet{Parameters: &AccountLevelParameters{SessionParameters: sp}}} + err = parameters.client.Accounts.Alter(ctx, &opts) if err != nil { return err } - opts.Set.Parameters.SessionParameters.TimestampDayIsAlways24h = b - case SessionParameterTimestampInputFormat: - opts.Set.Parameters.SessionParameters.TimestampInputFormat = &value - case SessionParameterTimestampLTZOutputFormat: - opts.Set.Parameters.SessionParameters.TimestampLTZOutputFormat = &value - case SessionParameterTimestampNTZOutputFormat: - opts.Set.Parameters.SessionParameters.TimestampNTZOutputFormat = &value - case SessionParameterTimestampOutputFormat: - opts.Set.Parameters.SessionParameters.TimestampOutputFormat = &value - case SessionParameterTimestampTypeMapping: - opts.Set.Parameters.SessionParameters.TimestampTypeMapping = &value - case SessionParameterTimestampTZOutputFormat: - opts.Set.Parameters.SessionParameters.TimestampTZOutputFormat = &value - case SessionParameterTimezone: - opts.Set.Parameters.SessionParameters.Timezone = &value - case SessionParameterTimeInputFormat: - opts.Set.Parameters.SessionParameters.TimeInputFormat = &value - case SessionParameterTimeOutputFormat: - opts.Set.Parameters.SessionParameters.TimeOutputFormat = &value - case SessionParameterTransactionDefaultIsolationLevel: - opts.Set.Parameters.SessionParameters.TransactionDefaultIsolationLevel = Pointer(TransactionDefaultIsolationLevel(value)) - case SessionParameterTwoDigitCenturyStart: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("TWO_DIGIT_CENTURY_START session parameter is an integer, got %v", value) - } - opts.Set.Parameters.SessionParameters.TwoDigitCenturyStart = Pointer(v) - case SessionParameterUnsupportedDDLAction: - opts.Set.Parameters.SessionParameters.UnsupportedDDLAction = Pointer(UnsupportedDDLAction(value)) - case SessionParameterUseCachedResult: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.Parameters.SessionParameters.UseCachedResult = b - case SessionParameterWeekOfYearPolicy: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("WEEK_OF_YEAR_POLICY session parameter is an integer, got %v", value) + return nil + } else { + if strings.Contains(err.Error(), "session parameter is not supported") { + return parameters.SetObjectParameterOnAccount(ctx, ObjectParameter(parameter), value) } - opts.Set.Parameters.SessionParameters.WeekOfYearPolicy = Pointer(v) - case SessionParameterWeekStart: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("WEEK_START session parameter is an integer, got %v", value) - } - opts.Set.Parameters.SessionParameters.WeekStart = Pointer(v) - default: - return parameters.SetObjectParameterOnAccount(ctx, ObjectParameter(parameter), value) - } - err := parameters.client.Accounts.Alter(ctx, &opts) - if err != nil { return err } - return nil } func (parameters *parameters) SetSessionParameterOnUser(ctx context.Context, userId AccountObjectIdentifier, parameter SessionParameter, value string) error { - opts := AlterUserOptions{Set: &UserSet{SessionParameters: &SessionParameters{}}} - switch parameter { - case SessionParameterAbortDetachedQuery: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.SessionParameters.AbortDetachedQuery = b - case SessionParameterAutocommit: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.SessionParameters.Autocommit = b - case SessionParameterBinaryInputFormat: - opts.Set.SessionParameters.BinaryInputFormat = Pointer(BinaryInputFormat(value)) - case SessionParameterBinaryOutputFormat: - opts.Set.SessionParameters.BinaryOutputFormat = Pointer(BinaryOutputFormat(value)) - case SessionParameterClientMetadataRequestUseConnectionCtx: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.SessionParameters.ClientMetadataRequestUseConnectionCtx = b - case SessionParameterClientMetadataUseSessionDatabase: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.SessionParameters.ClientMetadataUseSessionDatabase = b - case SessionParameterClientResultColumnCaseInsensitive: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.SessionParameters.ClientResultColumnCaseInsensitive = b - case SessionParameterDateInputFormat: - opts.Set.SessionParameters.DateInputFormat = &value - case SessionParameterDateOutputFormat: - opts.Set.SessionParameters.DateOutputFormat = &value - case SessionParameterErrorOnNondeterministicMerge: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.SessionParameters.ErrorOnNondeterministicMerge = b - case SessionParameterErrorOnNondeterministicUpdate: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.SessionParameters.ErrorOnNondeterministicUpdate = b - case SessionParameterGeographyOutputFormat: - opts.Set.SessionParameters.GeographyOutputFormat = Pointer(GeographyOutputFormat(value)) - case SessionParameterJSONIndent: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("JSON_INDENT session parameter is an integer, got %v", value) - } - opts.Set.SessionParameters.JSONIndent = Pointer(v) - case SessionParameterLockTimeout: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("LOCK_TIMEOUT session parameter is an integer, got %v", value) - } - opts.Set.SessionParameters.LockTimeout = Pointer(v) - case SessionParameterMultiStatementCount: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("MULTI_STATEMENT_COUNT session parameter is an integer, got %v", value) - } - opts.Set.SessionParameters.MultiStatementCount = Pointer(v) - - case SessionParameterQueryTag: - opts.Set.SessionParameters.QueryTag = &value - case SessionParameterQuotedIdentifiersIgnoreCase: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.SessionParameters.QuotedIdentifiersIgnoreCase = b - case SessionParameterRowsPerResultset: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("ROWS_PER_RESULTSET session parameter is an integer, got %v", value) - } - opts.Set.SessionParameters.RowsPerResultset = Pointer(v) - case SessionParameterSimulatedDataSharingConsumer: - opts.Set.SessionParameters.SimulatedDataSharingConsumer = &value - case SessionParameterStatementTimeoutInSeconds: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("STATEMENT_TIMEOUT_IN_SECONDS session parameter is an integer, got %v", value) - } - opts.Set.SessionParameters.StatementTimeoutInSeconds = Pointer(v) - case SessionParameterStrictJSONOutput: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.SessionParameters.StrictJSONOutput = b - case SessionParameterTimestampDayIsAlways24h: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.SessionParameters.TimestampDayIsAlways24h = b - case SessionParameterTimestampInputFormat: - opts.Set.SessionParameters.TimestampInputFormat = &value - case SessionParameterTimestampLTZOutputFormat: - opts.Set.SessionParameters.TimestampLTZOutputFormat = &value - case SessionParameterTimestampNTZOutputFormat: - opts.Set.SessionParameters.TimestampNTZOutputFormat = &value - case SessionParameterTimestampOutputFormat: - opts.Set.SessionParameters.TimestampOutputFormat = &value - case SessionParameterTimestampTypeMapping: - opts.Set.SessionParameters.TimestampTypeMapping = &value - case SessionParameterTimestampTZOutputFormat: - opts.Set.SessionParameters.TimestampTZOutputFormat = &value - case SessionParameterTimezone: - opts.Set.SessionParameters.Timezone = &value - case SessionParameterTimeInputFormat: - opts.Set.SessionParameters.TimeInputFormat = &value - case SessionParameterTimeOutputFormat: - opts.Set.SessionParameters.TimeOutputFormat = &value - case SessionParameterTransactionDefaultIsolationLevel: - opts.Set.SessionParameters.TransactionDefaultIsolationLevel = Pointer(TransactionDefaultIsolationLevel(value)) - case SessionParameterTwoDigitCenturyStart: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("TWO_DIGIT_CENTURY_START session parameter is an integer, got %v", value) - } - opts.Set.SessionParameters.TwoDigitCenturyStart = Pointer(v) - case SessionParameterUnsupportedDDLAction: - opts.Set.SessionParameters.UnsupportedDDLAction = Pointer(UnsupportedDDLAction(value)) - case SessionParameterUseCachedResult: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.SessionParameters.UseCachedResult = b - case SessionParameterWeekOfYearPolicy: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("WEEK_OF_YEAR_POLICY session parameter is an integer, got %v", value) - } - opts.Set.SessionParameters.WeekOfYearPolicy = Pointer(v) - case SessionParameterWeekStart: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("WEEK_START session parameter is an integer, got %v", value) - } - opts.Set.SessionParameters.WeekStart = Pointer(v) - default: - return fmt.Errorf("Invalid session parameter: %v", string(parameter)) + sp := &SessionParameters{} + err := sp.setParam(parameter, value) + if err != nil { + return err } - err := parameters.client.Users.Alter(ctx, userId, &opts) + opts := AlterUserOptions{Set: &UserSet{SessionParameters: sp}} + err = parameters.client.Users.Alter(ctx, userId, &opts) if err != nil { return err } @@ -1013,7 +718,7 @@ type SessionParametersUnset struct { } func (v *SessionParametersUnset) validate() error { - if !anyValueSet(v.AbortDetachedQuery, v.Autocommit, v.BinaryInputFormat, v.BinaryOutputFormat, v.DateInputFormat, v.DateOutputFormat, v.ErrorOnNondeterministicMerge, v.ErrorOnNondeterministicUpdate, v.GeographyOutputFormat, v.JSONIndent, v.LockTimeout, v.QueryTag, v.RowsPerResultset, v.SimulatedDataSharingConsumer, v.StatementTimeoutInSeconds, v.StrictJSONOutput, v.TimestampDayIsAlways24h, v.TimestampInputFormat, v.TimestampLTZOutputFormat, v.TimestampNTZOutputFormat, v.TimestampOutputFormat, v.TimestampTypeMapping, v.TimestampTZOutputFormat, v.Timezone, v.TimeInputFormat, v.TimeOutputFormat, v.TransactionDefaultIsolationLevel, v.TwoDigitCenturyStart, v.UnsupportedDDLAction, v.UseCachedResult, v.WeekOfYearPolicy, v.WeekStart) { + if !anyValueSet(v.AbortDetachedQuery, v.Autocommit, v.BinaryInputFormat, v.BinaryOutputFormat, v.ClientMetadataRequestUseConnectionCtx, v.ClientMetadataUseSessionDatabase, v.ClientResultColumnCaseInsensitive, v.DateInputFormat, v.DateOutputFormat, v.ErrorOnNondeterministicMerge, v.ErrorOnNondeterministicUpdate, v.GeographyOutputFormat, v.JSONIndent, v.LockTimeout, v.MultiStatementCount, v.QueryTag, v.QuotedIdentifiersIgnoreCase, v.RowsPerResultset, v.SimulatedDataSharingConsumer, v.StatementTimeoutInSeconds, v.StrictJSONOutput, v.TimestampDayIsAlways24h, v.TimestampInputFormat, v.TimestampLTZOutputFormat, v.TimestampNTZOutputFormat, v.TimestampOutputFormat, v.TimestampTypeMapping, v.TimestampTZOutputFormat, v.Timezone, v.TimeInputFormat, v.TimeOutputFormat, v.TransactionDefaultIsolationLevel, v.TwoDigitCenturyStart, v.UnsupportedDDLAction, v.UseCachedResult, v.WeekOfYearPolicy, v.WeekStart) { return errors.Join(errAtLeastOneOf("SessionParametersUnset", "AbortDetachedQuery", "Autocommit", "BinaryInputFormat", "BinaryOutputFormat", "DateInputFormat", "DateOutputFormat", "ErrorOnNondeterministicMerge", "ErrorOnNondeterministicUpdate", "GeographyOutputFormat", "JSONIndent", "LockTimeout", "QueryTag", "RowsPerResultset", "SimulatedDataSharingConsumer", "StatementTimeoutInSeconds", "StrictJSONOutput", "TimestampDayIsAlways24h", "TimestampInputFormat", "TimestampLTZOutputFormat", "TimestampNTZOutputFormat", "TimestampOutputFormat", "TimestampTypeMapping", "TimestampTZOutputFormat", "Timezone", "TimeInputFormat", "TimeOutputFormat", "TransactionDefaultIsolationLevel", "TwoDigitCenturyStart", "UnsupportedDDLAction", "UseCachedResult", "WeekOfYearPolicy", "WeekStart")) } return nil diff --git a/pkg/sdk/parameters_impl.go b/pkg/sdk/parameters_impl.go index 110efbeff3..235f88c4b8 100644 --- a/pkg/sdk/parameters_impl.go +++ b/pkg/sdk/parameters_impl.go @@ -20,8 +20,6 @@ func GetSessionParametersFrom(params map[string]any) (*SessionParameters, error) return sessionParameters, nil } -// TODO [SNOW-884987]: use this method in SetSessionParameterOnAccount and in SetSessionParameterOnUser -// TODO [SNOW-884987]: unit test this method func (sessionParameters *SessionParameters) setParam(parameter SessionParameter, value string) error { switch parameter { case SessionParameterAbortDetachedQuery: diff --git a/pkg/sdk/parameters_impl_test.go b/pkg/sdk/parameters_impl_test.go new file mode 100644 index 0000000000..59a6bdc04b --- /dev/null +++ b/pkg/sdk/parameters_impl_test.go @@ -0,0 +1,117 @@ +package sdk + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSessionParameters_setParam(t *testing.T) { + tests := []struct { + parameter SessionParameter + value string + expectedValue any + accessor func(*SessionParameters) any + }{ + {parameter: SessionParameterAbortDetachedQuery, value: "true", expectedValue: true, accessor: func(sp *SessionParameters) any { return *sp.AbortDetachedQuery }}, + {parameter: SessionParameterAutocommit, value: "true", expectedValue: true, accessor: func(sp *SessionParameters) any { return *sp.Autocommit }}, + {parameter: SessionParameterBinaryInputFormat, value: "some", expectedValue: BinaryInputFormat("some"), accessor: func(sp *SessionParameters) any { return *sp.BinaryInputFormat }}, + {parameter: SessionParameterBinaryOutputFormat, value: "some", expectedValue: BinaryOutputFormat("some"), accessor: func(sp *SessionParameters) any { return *sp.BinaryOutputFormat }}, + {parameter: SessionParameterClientMetadataRequestUseConnectionCtx, value: "true", expectedValue: true, accessor: func(sp *SessionParameters) any { return *sp.ClientMetadataRequestUseConnectionCtx }}, + {parameter: SessionParameterClientMetadataUseSessionDatabase, value: "true", expectedValue: true, accessor: func(sp *SessionParameters) any { return *sp.ClientMetadataUseSessionDatabase }}, + {parameter: SessionParameterClientResultColumnCaseInsensitive, value: "true", expectedValue: true, accessor: func(sp *SessionParameters) any { return *sp.ClientResultColumnCaseInsensitive }}, + {parameter: SessionParameterDateInputFormat, value: "some", expectedValue: "some", accessor: func(sp *SessionParameters) any { return *sp.DateInputFormat }}, + {parameter: SessionParameterGeographyOutputFormat, value: "some", expectedValue: GeographyOutputFormat("some"), accessor: func(sp *SessionParameters) any { return *sp.GeographyOutputFormat }}, + {parameter: SessionParameterDateOutputFormat, value: "some", expectedValue: "some", accessor: func(sp *SessionParameters) any { return *sp.DateOutputFormat }}, + {parameter: SessionParameterErrorOnNondeterministicMerge, value: "true", expectedValue: true, accessor: func(sp *SessionParameters) any { return *sp.ErrorOnNondeterministicMerge }}, + {parameter: SessionParameterErrorOnNondeterministicUpdate, value: "true", expectedValue: true, accessor: func(sp *SessionParameters) any { return *sp.ErrorOnNondeterministicUpdate }}, + {parameter: SessionParameterJSONIndent, value: "1", expectedValue: 1, accessor: func(sp *SessionParameters) any { return *sp.JSONIndent }}, + {parameter: SessionParameterLockTimeout, value: "1", expectedValue: 1, accessor: func(sp *SessionParameters) any { return *sp.LockTimeout }}, + {parameter: SessionParameterMultiStatementCount, value: "1", expectedValue: 1, accessor: func(sp *SessionParameters) any { return *sp.MultiStatementCount }}, + {parameter: SessionParameterQueryTag, value: "some", expectedValue: "some", accessor: func(sp *SessionParameters) any { return *sp.QueryTag }}, + {parameter: SessionParameterQuotedIdentifiersIgnoreCase, value: "true", expectedValue: true, accessor: func(sp *SessionParameters) any { return *sp.QuotedIdentifiersIgnoreCase }}, + {parameter: SessionParameterRowsPerResultset, value: "1", expectedValue: 1, accessor: func(sp *SessionParameters) any { return *sp.RowsPerResultset }}, + {parameter: SessionParameterSimulatedDataSharingConsumer, value: "some", expectedValue: "some", accessor: func(sp *SessionParameters) any { return *sp.SimulatedDataSharingConsumer }}, + {parameter: SessionParameterStatementTimeoutInSeconds, value: "1", expectedValue: 1, accessor: func(sp *SessionParameters) any { return *sp.StatementTimeoutInSeconds }}, + {parameter: SessionParameterStrictJSONOutput, value: "true", expectedValue: true, accessor: func(sp *SessionParameters) any { return *sp.StrictJSONOutput }}, + {parameter: SessionParameterTimeInputFormat, value: "some", expectedValue: "some", accessor: func(sp *SessionParameters) any { return *sp.TimeInputFormat }}, + {parameter: SessionParameterTimeOutputFormat, value: "some", expectedValue: "some", accessor: func(sp *SessionParameters) any { return *sp.TimeOutputFormat }}, + {parameter: SessionParameterTimestampDayIsAlways24h, value: "true", expectedValue: true, accessor: func(sp *SessionParameters) any { return *sp.TimestampDayIsAlways24h }}, + {parameter: SessionParameterTimestampInputFormat, value: "some", expectedValue: "some", accessor: func(sp *SessionParameters) any { return *sp.TimestampInputFormat }}, + {parameter: SessionParameterTimestampLTZOutputFormat, value: "some", expectedValue: "some", accessor: func(sp *SessionParameters) any { return *sp.TimestampLTZOutputFormat }}, + {parameter: SessionParameterTimestampNTZOutputFormat, value: "some", expectedValue: "some", accessor: func(sp *SessionParameters) any { return *sp.TimestampNTZOutputFormat }}, + {parameter: SessionParameterTimestampOutputFormat, value: "some", expectedValue: "some", accessor: func(sp *SessionParameters) any { return *sp.TimestampOutputFormat }}, + {parameter: SessionParameterTimestampTypeMapping, value: "some", expectedValue: "some", accessor: func(sp *SessionParameters) any { return *sp.TimestampTypeMapping }}, + {parameter: SessionParameterTimestampTZOutputFormat, value: "some", expectedValue: "some", accessor: func(sp *SessionParameters) any { return *sp.TimestampTZOutputFormat }}, + {parameter: SessionParameterTimezone, value: "some", expectedValue: "some", accessor: func(sp *SessionParameters) any { return *sp.Timezone }}, + {parameter: SessionParameterTransactionDefaultIsolationLevel, value: "some", expectedValue: TransactionDefaultIsolationLevel("some"), accessor: func(sp *SessionParameters) any { return *sp.TransactionDefaultIsolationLevel }}, + {parameter: SessionParameterTwoDigitCenturyStart, value: "1", expectedValue: 1, accessor: func(sp *SessionParameters) any { return *sp.TwoDigitCenturyStart }}, + {parameter: SessionParameterUnsupportedDDLAction, value: "some", expectedValue: UnsupportedDDLAction("some"), accessor: func(sp *SessionParameters) any { return *sp.UnsupportedDDLAction }}, + {parameter: SessionParameterUseCachedResult, value: "true", expectedValue: true, accessor: func(sp *SessionParameters) any { return *sp.UseCachedResult }}, + {parameter: SessionParameterWeekOfYearPolicy, value: "1", expectedValue: 1, accessor: func(sp *SessionParameters) any { return *sp.WeekOfYearPolicy }}, + {parameter: SessionParameterWeekStart, value: "1", expectedValue: 1, accessor: func(sp *SessionParameters) any { return *sp.WeekStart }}, + } + for _, tt := range tests { + t.Run(fmt.Sprintf("test valid value '%s' for parameter %s", tt.value, tt.parameter), func(t *testing.T) { + sessionParameters := &SessionParameters{} + + err := sessionParameters.setParam(tt.parameter, tt.value) + + require.NoError(t, err) + require.Equal(t, tt.expectedValue, tt.accessor(sessionParameters)) + }) + } + + invalidCases := []struct { + parameter SessionParameter + value string + }{ + {parameter: SessionParameterAbortDetachedQuery, value: "true123"}, + {parameter: SessionParameterAutocommit, value: "true123"}, + // {parameter: SessionParameterBinaryInputFormat, value: "some"}, // add validation + // {parameter: SessionParameterBinaryOutputFormat, value: "some"}, // add validation + {parameter: SessionParameterClientMetadataRequestUseConnectionCtx, value: "true123"}, + {parameter: SessionParameterClientMetadataUseSessionDatabase, value: "true123"}, + {parameter: SessionParameterClientResultColumnCaseInsensitive, value: "true123"}, + // {parameter: SessionParameterDateInputFormat, value: "some"}, // add validation + // {parameter: SessionParameterGeographyOutputFormat, value: "some"}, // add validation + // {parameter: SessionParameterDateOutputFormat, value: "some"}, // add validation + {parameter: SessionParameterErrorOnNondeterministicMerge, value: "true123"}, + {parameter: SessionParameterErrorOnNondeterministicUpdate, value: "true123"}, + {parameter: SessionParameterJSONIndent, value: "aaa"}, + {parameter: SessionParameterLockTimeout, value: "aaa"}, + {parameter: SessionParameterMultiStatementCount, value: "aaa"}, + // {parameter: SessionParameterQueryTag, value: "some"}, // add validation + {parameter: SessionParameterQuotedIdentifiersIgnoreCase, value: "true123"}, + {parameter: SessionParameterRowsPerResultset, value: "aaa"}, + // {parameter: SessionParameterSimulatedDataSharingConsumer, value: "some"}, // add validation + {parameter: SessionParameterStatementTimeoutInSeconds, value: "aaa"}, + {parameter: SessionParameterStrictJSONOutput, value: "true123"}, + // {parameter: SessionParameterTimeInputFormat, value: "some"}, // add validation + // {parameter: SessionParameterTimeOutputFormat, value: "some"}, // add validation + {parameter: SessionParameterTimestampDayIsAlways24h, value: "true123"}, + // {parameter: SessionParameterTimestampInputFormat, value: "some"}, // add validation + // {parameter: SessionParameterTimestampLTZOutputFormat, value: "some"}, // add validation + // {parameter: SessionParameterTimestampNTZOutputFormat, value: "some"}, // add validation + // {parameter: SessionParameterTimestampOutputFormat, value: "some"}, // add validation + // {parameter: SessionParameterTimestampTypeMapping, value: "some"}, // add validation + // {parameter: SessionParameterTimestampTZOutputFormat, value: "some"}, // add validation + // {parameter: SessionParameterTimezone, value: "some"}, // add validation + // {parameter: SessionParameterTransactionDefaultIsolationLevel, value: "some"}, // add validation + {parameter: SessionParameterTwoDigitCenturyStart, value: "aaa"}, + // {parameter: SessionParameterUnsupportedDDLAction, value: "some"}, // add validation + {parameter: SessionParameterUseCachedResult, value: "true123"}, + {parameter: SessionParameterWeekOfYearPolicy, value: "aaa"}, + {parameter: SessionParameterWeekStart, value: "aaa"}, + } + for _, tt := range invalidCases { + t.Run(fmt.Sprintf("test invalid value '%s' for parameter %s", tt.value, tt.parameter), func(t *testing.T) { + sessionParameters := &SessionParameters{} + + err := sessionParameters.setParam(tt.parameter, tt.value) + + require.Error(t, err) + }) + } +} diff --git a/pkg/sdk/tasks_impl_gen.go b/pkg/sdk/tasks_impl_gen.go index dd6f3070aa..6a4cb38d4c 100644 --- a/pkg/sdk/tasks_impl_gen.go +++ b/pkg/sdk/tasks_impl_gen.go @@ -3,8 +3,10 @@ package sdk import ( "context" "encoding/json" - "fmt" "strings" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/internal/collections" + "golang.org/x/exp/slices" ) var _ Tasks = (*tasks)(nil) @@ -66,40 +68,38 @@ func (v *tasks) Execute(ctx context.Context, request *ExecuteTaskRequest) error // GetRootTasks is a way to get all root tasks for the given tasks. // Snowflake does not have (yet) a method to do it without traversing the task graph manually. // Task DAG should have a single root but this is checked when the root task is being resumed; that's why we return here multiple roots. -// Cycles should not be possible in a task DAG but it is checked when the root task is being resumed; that's why this method has to be cycle-proof. -// TODO [SNOW-884987]: handle cycles +// Cycles should not be possible in a task DAG, but it is checked when the root task is being resumed; that's why this method has to be cycle-proof. func GetRootTasks(v Tasks, ctx context.Context, id SchemaObjectIdentifier) ([]Task, error) { - task, err := v.ShowByID(ctx, id) - if err != nil { - return nil, err - } + tasksToExamine := collections.NewQueue[SchemaObjectIdentifier]() + alreadyExaminedTasksNames := make([]string, 0) + rootTasks := make([]Task, 0) - predecessors := task.Predecessors - // no predecessors mean this is a root task - if len(predecessors) == 0 { - return []Task{*task}, nil - } + tasksToExamine.Push(id) + + for tasksToExamine.Head() != nil { + current := tasksToExamine.Pop() - rootTasks := make([]Task, 0, len(predecessors)) - for _, predecessor := range predecessors { - predecessorTasks, err := GetRootTasks(v, ctx, predecessor) + if slices.Contains(alreadyExaminedTasksNames, current.Name()) { + continue + } + + task, err := v.ShowByID(ctx, *current) if err != nil { - return nil, fmt.Errorf("unable to get predecessors for task %s err = %w", predecessor.FullyQualifiedName(), err) + return nil, err } - rootTasks = append(rootTasks, predecessorTasks...) - } - // TODO [SNOW-884987]: extract unique function in our collection helper (if cycle-proof algorithm still needs it) - keys := make(map[string]bool) - uniqueRootTasks := make([]Task, 0, len(rootTasks)) - for _, rootTask := range rootTasks { - if _, exists := keys[rootTask.ID().FullyQualifiedName()]; !exists { - keys[rootTask.ID().FullyQualifiedName()] = true - uniqueRootTasks = append(uniqueRootTasks, rootTask) + predecessors := task.Predecessors + if len(predecessors) == 0 { + rootTasks = append(rootTasks, *task) + } else { + for _, p := range predecessors { + tasksToExamine.Push(p) + } } + alreadyExaminedTasksNames = append(alreadyExaminedTasksNames, current.Name()) } - return uniqueRootTasks, nil + return rootTasks, nil } func (r *CreateTaskRequest) toOpts() *CreateTaskOptions { diff --git a/pkg/sdk/tasks_test.go b/pkg/sdk/tasks_test.go index 9d9f27d7d5..0faf8bf20e 100644 --- a/pkg/sdk/tasks_test.go +++ b/pkg/sdk/tasks_test.go @@ -58,7 +58,8 @@ func TestTasks_GetRootTasks(t *testing.T) { {"t1": {}, "t2": {}, "initial": {"t1"}, "expected": {"t1"}}, {"t1": {"t2", "t3", "t4"}, "t2": {}, "t3": {}, "t4": {}, "initial": {"t1"}, "expected": {"t2", "t3", "t4"}}, {"t1": {"t2", "t3", "t4"}, "t2": {}, "t3": {"t2"}, "t4": {"t3"}, "initial": {"t1"}, "expected": {"t2"}}, - // {"r": {}, "t1": {"t2", "r"}, "t2": {"t3"}, "t3": {"t1"}, "initial": {"t1"}, "expected": {"r"}}, // cycle -> failing for current (old) implementation + {"r": {}, "t1": {"t2", "r"}, "t2": {"t3"}, "t3": {"t1"}, "initial": {"t1"}, "expected": {"r"}}, // cycle -> failing for the old implementation + {"r": {}, "t1": {"t2", "r"}, "t2": {"t3"}, "t3": {"t1"}, "initial": {"t3"}, "expected": {"r"}}, // cycle -> failing for the old implementation } for i, tt := range tests { t.Run(fmt.Sprintf("test case [%v]", i), func(t *testing.T) { diff --git a/pkg/sdk/testint/tasks_gen_integration_test.go b/pkg/sdk/testint/tasks_gen_integration_test.go index f6dd6d9e5e..f68847fa47 100644 --- a/pkg/sdk/testint/tasks_gen_integration_test.go +++ b/pkg/sdk/testint/tasks_gen_integration_test.go @@ -257,6 +257,12 @@ func TestInt_Tasks(t *testing.T) { err = client.Tasks.Alter(ctx, alterRequest) require.NoError(t, err) + // can get the root task even with cycle + rootTasks, err = sdk.GetRootTasks(client.Tasks, ctx, t3Id) + require.NoError(t, err) + require.Len(t, rootTasks, 1) + require.Equal(t, rootId, rootTasks[0].ID()) + // we get an error when trying to start alterRequest = sdk.NewAlterTaskRequest(rootId).WithResume(sdk.Bool(true)) err = client.Tasks.Alter(ctx, alterRequest)