From 07b9b4fee800fe3f34890783cc463d4fc5904717 Mon Sep 17 00:00:00 2001 From: Scott Winkler Date: Thu, 19 Oct 2023 07:18:25 -0700 Subject: [PATCH 01/20] fix: provider config (#2136) * provider config fix * fix * fix deprecated error * update docs --- docs/index.md | 2 +- pkg/provider/provider.go | 36 +++++++++++++++++------------------- 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/docs/index.md b/docs/index.md index 4041d491f2..2063a87819 100644 --- a/docs/index.md +++ b/docs/index.md @@ -92,7 +92,7 @@ provider "snowflake" { - `token` (String, Sensitive) Token to use for OAuth and other forms of token based auth. Can also be sourced from the `SNOWFLAKE_TOKEN` environment variable. - `token_accessor` (Block List, Max: 1) (see [below for nested schema](#nestedblock--token_accessor)) - `user` (String) Username. Can also be sourced from the `SNOWFLAKE_USER` environment variable. Required unless using `profile`. -- `username` (String, Deprecated) Username for username+password authentication. Can also be sourced from the `SNOWFLAKE_USER` environment variable. Required unless using `profile`. +- `username` (String, Deprecated) Username for username+password authentication. Can also be sourced from the `SNOWFLAKE_USERNAME` environment variable. Required unless using `profile`. - `validate_default_parameters` (Boolean) If true, disables the validation checks for Database, Schema, Warehouse and Role at the time a connection is established. Can also be sourced from the `SNOWFLAKE_VALIDATE_DEFAULT_PARAMETERS` environment variable. - `warehouse` (String) Specifies the virtual warehouse to use by default for queries, loading, etc. in the client session. Can also be sourced from the `SNOWFLAKE_WAREHOUSE` environment variable. diff --git a/pkg/provider/provider.go b/pkg/provider/provider.go index 15385e7316..226ff9c5d2 100644 --- a/pkg/provider/provider.go +++ b/pkg/provider/provider.go @@ -33,10 +33,10 @@ func Provider() *schema.Provider { }, "username": { Type: schema.TypeString, - Description: "Username for username+password authentication. Can also be sourced from the `SNOWFLAKE_USER` environment variable. Required unless using `profile`.", + Description: "Username for username+password authentication. Can also be sourced from the `SNOWFLAKE_USERNAME` environment variable. Required unless using `profile`.", Optional: true, - DefaultFunc: schema.EnvDefaultFunc("SNOWFLAKE_USER", nil), - Deprecated: "Use `user` instead", + DefaultFunc: schema.EnvDefaultFunc("SNOWFLAKE_USERNAME", nil), + Deprecated: "Use `user` instead of `username`", }, "password": { Type: schema.TypeString, @@ -540,17 +540,18 @@ func ConfigureProvider(s *schema.ResourceData) (interface{}, error) { Application: "terraform-provider-snowflake", } - if v, ok := s.GetOk("account"); ok { + if v, ok := s.GetOk("account"); ok && v.(string) != "" { config.Account = v.(string) } - if v, ok := s.GetOk("user"); ok { + // backwards compatibility until we can remove this + if v, ok := s.GetOk("username"); ok && v.(string) != "" { config.User = v.(string) } - // backwards compatibility until we can remove this - if v, ok := s.GetOk("username"); ok { + if v, ok := s.GetOk("user"); ok && v.(string) != "" { config.User = v.(string) } - if v, ok := s.GetOk("password"); ok { + + if v, ok := s.GetOk("password"); ok && v.(string) != "" { config.Password = v.(string) } @@ -591,7 +592,7 @@ func ConfigureProvider(s *schema.ResourceData) (interface{}, error) { config.ClientIP = net.ParseIP(v.(string)) } - if v, ok := s.GetOk("protocol"); ok { + if v, ok := s.GetOk("protocol"); ok && v.(string) != "" { config.Protocol = v.(string) } @@ -599,7 +600,7 @@ func ConfigureProvider(s *schema.ResourceData) (interface{}, error) { config.Host = v.(string) } - if v, ok := s.GetOk("port"); ok { + if v, ok := s.GetOk("port"); ok && v.(int) > 0 { config.Port = v.(int) } @@ -645,27 +646,27 @@ func ConfigureProvider(s *schema.ResourceData) (interface{}, error) { config.OktaURL = oktaURL } - if v, ok := s.GetOk("login_timeout"); ok { + if v, ok := s.GetOk("login_timeout"); ok && v.(int) > 0 { config.LoginTimeout = time.Second * time.Duration(int64(v.(int))) } - if v, ok := s.GetOk("request_timeout"); ok { + if v, ok := s.GetOk("request_timeout"); ok && v.(int) > 0 { config.RequestTimeout = time.Second * time.Duration(int64(v.(int))) } - if v, ok := s.GetOk("jwt_expire_timeout"); ok { + if v, ok := s.GetOk("jwt_expire_timeout"); ok && v.(int) > 0 { config.JWTExpireTimeout = time.Second * time.Duration(int64(v.(int))) } - if v, ok := s.GetOk("client_timeout"); ok { + if v, ok := s.GetOk("client_timeout"); ok && v.(int) > 0 { config.ClientTimeout = time.Second * time.Duration(int64(v.(int))) } - if v, ok := s.GetOk("jwt_client_timeout"); ok { + if v, ok := s.GetOk("jwt_client_timeout"); ok && v.(int) > 0 { config.JWTClientTimeout = time.Second * time.Duration(int64(v.(int))) } - if v, ok := s.GetOk("external_browser_timeout"); ok { + if v, ok := s.GetOk("external_browser_timeout"); ok && v.(int) > 0 { config.ExternalBrowserTimeout = time.Second * time.Duration(int64(v.(int))) } @@ -737,9 +738,6 @@ func ConfigureProvider(s *schema.ResourceData) (interface{}, error) { profile := v.(string) if profile == "default" { defaultConfig := sdk.DefaultConfig() - if defaultConfig.Account == "" || defaultConfig.User == "" { - return "", errors.New("account and User must be set in provider config, ~/.snowflake/config, or as an environment variable") - } config = sdk.MergeConfig(config, defaultConfig) } else { profileConfig, err := sdk.ProfileConfig(profile) From 1e6e54f828efa60edd258b316709fc4dfd370f93 Mon Sep 17 00:00:00 2001 From: Dane Rieber Date: Fri, 20 Oct 2023 05:08:45 -0500 Subject: [PATCH 02/20] feat: add parse_header option to file format resource (#2132) Co-authored-by: Artur Sawicki --- docs/resources/file_format.md | 1 + pkg/resources/file_format.go | 14 ++++++++++++++ pkg/sdk/file_format.go | 2 +- 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/docs/resources/file_format.md b/docs/resources/file_format.md index 3d80e3ac59..466db3c425 100644 --- a/docs/resources/file_format.md +++ b/docs/resources/file_format.md @@ -52,6 +52,7 @@ resource "snowflake_file_format" "example_file_format" { - `file_extension` (String) Specifies the extension for files unloaded to a stage. - `ignore_utf8_errors` (Boolean) Boolean that specifies whether UTF-8 encoding errors produce error conditions. - `null_if` (List of String) String used to convert to and from SQL NULL. +- `parse_header` (Boolean) Boolean that specifies whether to use the first row headers in the data files to determine column names. - `preserve_space` (Boolean) Boolean that specifies whether the XML parser preserves leading and trailing spaces in element content. - `record_delimiter` (String) Specifies one or more singlebyte or multibyte characters that separate records in an input file (data loading) or unloaded file (data unloading). - `replace_invalid_characters` (Boolean) Boolean that specifies whether to replace invalid UTF-8 characters with the Unicode replacement character (�). diff --git a/pkg/resources/file_format.go b/pkg/resources/file_format.go index d03e1d7583..9caab6e0d1 100644 --- a/pkg/resources/file_format.go +++ b/pkg/resources/file_format.go @@ -26,6 +26,7 @@ var formatTypeOptions = map[string][]string{ "record_delimiter", "field_delimiter", "file_extension", + "parse_header", "skip_header", "skip_blank_lines", "date_format", @@ -135,6 +136,11 @@ var fileFormatSchema = map[string]*schema.Schema{ Optional: true, Description: "Specifies the extension for files unloaded to a stage.", }, + "parse_header": { + Type: schema.TypeBool, + Optional: true, + Description: "Boolean that specifies whether to use the first row headers in the data files to determine column names.", + }, "skip_header": { Type: schema.TypeInt, Optional: true, @@ -345,6 +351,7 @@ func CreateFileFormat(d *schema.ResourceData, meta interface{}) error { if v, ok := d.GetOk("file_extension"); ok { opts.CSVFileExtension = sdk.String(v.(string)) } + opts.CSVParseHeader = sdk.Bool(d.Get("parse_header").(bool)) opts.CSVSkipHeader = sdk.Int(d.Get("skip_header").(int)) opts.CSVSkipBlankLines = sdk.Bool(d.Get("skip_blank_lines").(bool)) if v, ok := d.GetOk("date_format"); ok { @@ -565,6 +572,9 @@ func ReadFileFormat(d *schema.ResourceData, meta interface{}) error { if err := d.Set("file_extension", fileFormat.Options.CSVFileExtension); err != nil { return err } + if err := d.Set("parse_header", fileFormat.Options.CSVParseHeader); err != nil { + return err + } if err := d.Set("skip_header", fileFormat.Options.CSVSkipHeader); err != nil { return err } @@ -788,6 +798,10 @@ func UpdateFileFormat(d *schema.ResourceData, meta interface{}) error { v := d.Get("file_extension").(string) opts.Set.CSVFileExtension = &v } + if d.HasChange("parse_header") { + v := d.Get("parse_header").(bool) + opts.Set.CSVParseHeader = &v + } if d.HasChange("skip_header") { v := d.Get("skip_header").(int) opts.Set.CSVSkipHeader = &v diff --git a/pkg/sdk/file_format.go b/pkg/sdk/file_format.go index 67cd16558b..c3faf62920 100644 --- a/pkg/sdk/file_format.go +++ b/pkg/sdk/file_format.go @@ -737,7 +737,7 @@ func (v *fileFormats) Describe(ctx context.Context, id SchemaObjectIdentifier) ( case "PARSE_HEADER": b, err := strconv.ParseBool(v) if err != nil { - return nil, fmt.Errorf(`cannot cast SKIP_HEADER value "%s" to bool: %w`, v, err) + return nil, fmt.Errorf(`cannot cast PARSE_HEADER value "%s" to bool: %w`, v, err) } details.Options.CSVParseHeader = &b case "DATE_FORMAT": From de23f2ba939eb368d9734217e1bb2d4ebc75eef4 Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Fri, 20 Oct 2023 14:34:06 +0200 Subject: [PATCH 03/20] feat: Use task from SDK in resource and data source (#2140) * Get list of predecessors in tasks * Move get root tasks implementation to SDK * Using sdk in tasks resource (WIP) * Using sdk in tasks resource 2 (WIP) * Using sdk in tasks resource 3 (WIP) * Using sdk in tasks resource 4 (WIP) * Remove old taskId * Use task sdk in datasource * Add Jira issue number to TODO comments * Fix some linter comments * Add error integration to task alter * Add initial warehouse size to task alter * Removed todo for get root tasks operation * Remove old implementation for tasks predecessors * Introduce task state and fix unmarshalling predecessors * Read task parameters * Use session parameters for create * Update session parameters * Add TODO * Add test for getting root tasks * Add test for cycle * Test getting predecessor name correctly * Fix lint * Add integration test for getting predecessors and task cycles * Add integration test for multiple roots * Add missing TODO * Add missing comment * Fix terraform maps bs * Fix slice initialization * Fix slice initialization 2 * Fix getting task predecessor name * Change getting task predecessor name logic * Change getting task predecessor name logic again * Fix linter * Fix warehouse alter * Fix after review * Fix after review 2 * Add comment in test * Adjust go.mod --- go.mod | 2 - go.sum | 11 - pkg/datasources/tasks.go | 22 +- pkg/resources/task.go | 521 ++++++++--------- pkg/resources/task_internal_test.go | 47 -- pkg/sdk/parameters.go | 69 ++- pkg/sdk/parameters_impl.go | 275 +++++++++ pkg/sdk/tasks_def.go | 10 +- pkg/sdk/tasks_dto_builders_gen.go | 15 + pkg/sdk/tasks_dto_gen.go | 19 +- pkg/sdk/tasks_gen.go | 34 +- pkg/sdk/tasks_gen_test.go | 26 +- pkg/sdk/tasks_impl_gen.go | 96 ++- pkg/sdk/tasks_test.go | 113 ++++ pkg/sdk/tasks_validations_gen.go | 11 +- pkg/sdk/testint/tasks_gen_integration_test.go | 145 ++++- pkg/snowflake/task.go | 552 ------------------ pkg/snowflake/task_test.go | 174 ------ 18 files changed, 965 insertions(+), 1177 deletions(-) delete mode 100644 pkg/resources/task_internal_test.go create mode 100644 pkg/sdk/parameters_impl.go create mode 100644 pkg/sdk/tasks_test.go delete mode 100644 pkg/snowflake/task.go delete mode 100644 pkg/snowflake/task_test.go diff --git a/go.mod b/go.mod index 04664aa47e..f5c52822c6 100644 --- a/go.mod +++ b/go.mod @@ -85,8 +85,6 @@ require ( github.com/hashicorp/logutils v1.0.0 // indirect github.com/hashicorp/terraform-exec v0.19.0 // indirect github.com/hashicorp/terraform-json v0.17.1 // indirect - github.com/hashicorp/terraform-plugin-framework v1.4.1 // indirect - github.com/hashicorp/terraform-plugin-framework-validators v0.12.0 // indirect github.com/hashicorp/terraform-plugin-log v0.9.0 // indirect github.com/hashicorp/terraform-registry-address v0.2.2 // indirect github.com/hashicorp/terraform-svchost v0.1.1 // indirect diff --git a/go.sum b/go.sum index 34bc1a5ed2..8f217b4c9b 100644 --- a/go.sum +++ b/go.sum @@ -37,7 +37,6 @@ github.com/apache/arrow/go/v12 v12.0.1/go.mod h1:weuTY7JvTG/HDPtMQxEUp7pU73vkLWM github.com/apache/thrift v0.19.0 h1:sOqkWPzMj7w6XaYbJQG7m4sGqVolaW/0D28Ln7yPzMk= github.com/apache/thrift v0.19.0/go.mod h1:SUALL216IiaOw2Oy+5Vs9lboJ/t9g40C+G07Dc0QC1I= github.com/apparentlymart/go-textseg/v12 v12.0.0/go.mod h1:S/4uRK2UtaQttw1GenVJEynmyUenKwP++x/+DdGV/Ec= -github.com/apparentlymart/go-textseg/v13 v13.0.0/go.mod h1:ZK2fH7c4NqDTLtiYLvIkEghdlcqw7yxLeM89kiTRPUo= github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY= github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= @@ -177,10 +176,6 @@ github.com/hashicorp/terraform-json v0.17.1 h1:eMfvh/uWggKmY7Pmb3T85u86E2EQg6EQH github.com/hashicorp/terraform-json v0.17.1/go.mod h1:Huy6zt6euxaY9knPAFKjUITn8QxUFIe9VuSzb4zn/0o= github.com/hashicorp/terraform-plugin-docs v0.16.0 h1:UmxFr3AScl6Wged84jndJIfFccGyBZn52KtMNsS12dI= github.com/hashicorp/terraform-plugin-docs v0.16.0/go.mod h1:M3ZrlKBJAbPMtNOPwHicGi1c+hZUh7/g0ifT/z7TVfA= -github.com/hashicorp/terraform-plugin-framework v1.4.1 h1:ZC29MoB3Nbov6axHdgPbMz7799pT5H8kIrM8YAsaVrs= -github.com/hashicorp/terraform-plugin-framework v1.4.1/go.mod h1:XC0hPcQbBvlbxwmjxuV/8sn8SbZRg4XwGMs22f+kqV0= -github.com/hashicorp/terraform-plugin-framework-validators v0.12.0 h1:HOjBuMbOEzl7snOdOoUfE2Jgeto6JOjLVQ39Ls2nksc= -github.com/hashicorp/terraform-plugin-framework-validators v0.12.0/go.mod h1:jfHGE/gzjxYz6XoUwi/aYiiKrJDeutQNUtGQXkaHklg= github.com/hashicorp/terraform-plugin-go v0.19.0 h1:BuZx/6Cp+lkmiG0cOBk6Zps0Cb2tmqQpDM3iAtnhDQU= github.com/hashicorp/terraform-plugin-go v0.19.0/go.mod h1:EhRSkEPNoylLQntYsk5KrDHTZJh9HQoumZXbOGOXmec= github.com/hashicorp/terraform-plugin-log v0.9.0 h1:i7hOA+vdAItN1/7UrfBqBwvYPQ9TFvymaRGZED3FCV0= @@ -206,7 +201,6 @@ github.com/imdario/mergo v0.3.11/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH github.com/imdario/mergo v0.3.15 h1:M8XP7IuFNsqUx6VPK2P9OSmsYsI/YFaGil0uD21V3dM= github.com/imdario/mergo v0.3.15/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY= github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A= -github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo= github.com/jhump/protoreflect v1.15.1 h1:HUMERORf3I3ZdX05WaQ6MIpd/NJ434hTp5YiKgfCL6c= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= @@ -226,7 +220,6 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= @@ -285,7 +278,6 @@ github.com/posener/complete v1.2.3/go.mod h1:WZIdtGGp+qx0sLrYKtIRAruyNpv6hFCicSg github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/russross/blackfriday v1.6.0 h1:KqfZb0pUVN2lYqZUYRddxF4OR8ZMURnJIG5Y3VRLtww= github.com/russross/blackfriday v1.6.0/go.mod h1:ti0ldHuxg49ri4ksnFxlkCfN+hvslNlmVHqNRXXJNAY= -github.com/sebdah/goldie v1.0.0/go.mod h1:jXP4hmWywNEwZzhMuv2ccnqTSFpuq8iyQhtQdkkZBH4= github.com/sergi/go-diff v1.2.0 h1:XU+rvMAioB0UC3q1MFrIQy4Vo5/4VsRDQQXHsEya6xQ= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= @@ -325,7 +317,6 @@ github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a/go.mod h1:ul22v+Nro/ github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/zclconf/go-cty v1.14.0 h1:/Xrd39K7DXbHzlisFP9c4pHao4yyf+/Ug9LEz+Y/yhc= github.com/zclconf/go-cty v1.14.0/go.mod h1:VvMs5i0vgZdhYawQNq5kePSpLAoz8u1xvZgrPIxfnZE= -github.com/zclconf/go-cty-debug v0.0.0-20191215020915-b22d67c1ba0b/go.mod h1:ZRKQfBXbGkpdV6QMzT3rU1kSTAnfu1dO8dPKjYprgj8= github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= @@ -423,9 +414,7 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME= -gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/pkg/datasources/tasks.go b/pkg/datasources/tasks.go index 327108d78e..1dc0f8d1c6 100644 --- a/pkg/datasources/tasks.go +++ b/pkg/datasources/tasks.go @@ -1,12 +1,12 @@ package datasources import ( + "context" "database/sql" - "errors" "fmt" "log" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" ) @@ -63,25 +63,23 @@ func Tasks() *schema.Resource { func ReadTasks(d *schema.ResourceData, meta interface{}) error { db := meta.(*sql.DB) + client := sdk.NewClientFromDB(db) + ctx := context.Background() + databaseName := d.Get("database").(string) schemaName := d.Get("schema").(string) - currentTasks, err := snowflake.ListTasks(databaseName, schemaName, db) - if errors.Is(err, sql.ErrNoRows) { + extractedTasks, err := client.Tasks.Show(ctx, sdk.NewShowTaskRequest().WithIn(&sdk.In{Schema: sdk.NewDatabaseObjectIdentifier(databaseName, schemaName)})) + if err != nil { // If not found, mark resource to be removed from state file during apply or refresh log.Printf("[DEBUG] tasks in schema (%s) not found", d.Id()) d.SetId("") return nil - } else if err != nil { - log.Printf("[DEBUG] unable to parse tasks in schema (%s)", d.Id()) - d.SetId("") - return nil } - tasks := []map[string]interface{}{} - - for _, task := range currentTasks { - taskMap := map[string]interface{}{} + tasks := make([]map[string]any, 0, len(extractedTasks)) + for _, task := range extractedTasks { + taskMap := map[string]any{} taskMap["name"] = task.Name taskMap["database"] = task.DatabaseName diff --git a/pkg/resources/task.go b/pkg/resources/task.go index 148fcb87b8..72e42fec09 100644 --- a/pkg/resources/task.go +++ b/pkg/resources/task.go @@ -1,26 +1,22 @@ package resources import ( - "bytes" + "context" "database/sql" - "encoding/csv" - "errors" "fmt" "log" "strconv" - "strings" + "time" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" "golang.org/x/exp/slices" ) -const ( - taskIDDelimiter = '|' -) - +// TODO [SNOW-884987]: add missing SUSPEND_TASK_AFTER_NUM_FAILURES attribute. var taskSchema = map[string]*schema.Schema{ "enabled": { Type: schema.TypeBool, @@ -117,29 +113,9 @@ var taskSchema = map[string]*schema.Schema{ }, } -type taskID struct { - DatabaseName string - SchemaName string - TaskName string -} - -// String() takes in a taskID object and returns a pipe-delimited string: -// DatabaseName|SchemaName|TaskName. -func (t *taskID) String() (string, error) { - var buf bytes.Buffer - csvWriter := csv.NewWriter(&buf) - csvWriter.Comma = taskIDDelimiter - dataIdentifiers := [][]string{{t.DatabaseName, t.SchemaName, t.TaskName}} - if err := csvWriter.WriteAll(dataIdentifiers); err != nil { - return "", err - } - strTaskID := strings.TrimSpace(buf.String()) - return strTaskID, nil -} - // difference find keys in 'a' but not in 'b'. -func difference(a, b map[string]interface{}) map[string]interface{} { - diff := make(map[string]interface{}) +func difference(a, b map[string]any) map[string]any { + diff := make(map[string]any) for k := range a { if _, ok := b[k]; !ok { diff[k] = a[k] @@ -148,31 +124,6 @@ func difference(a, b map[string]interface{}) map[string]interface{} { return diff } -// taskIDFromString() takes in a pipe-delimited string: DatabaseName|SchemaName|TaskName -// and returns a taskID object. -func taskIDFromString(stringID string) (*taskID, error) { - reader := csv.NewReader(strings.NewReader(stringID)) - reader.Comma = pipeIDDelimiter - lines, err := reader.ReadAll() - if err != nil { - return nil, fmt.Errorf("not CSV compatible") - } - - if len(lines) != 1 { - return nil, fmt.Errorf("1 line per task") - } - if len(lines[0]) != 3 { - return nil, fmt.Errorf("3 fields allowed") - } - - taskResult := &taskID{ - DatabaseName: lines[0][0], - SchemaName: lines[0][1], - TaskName: lines[0][2], - } - return taskResult, nil -} - // Task returns a pointer to the resource representing a task. func Task() *schema.Resource { return &schema.Resource{ @@ -191,110 +142,73 @@ func Task() *schema.Resource { // ReadTask implements schema.ReadFunc. func ReadTask(d *schema.ResourceData, meta interface{}) error { db := meta.(*sql.DB) - taskID, err := taskIDFromString(d.Id()) - if err != nil { - return err - } + client := sdk.NewClientFromDB(db) + ctx := context.Background() - database := taskID.DatabaseName - schema := taskID.SchemaName - name := taskID.TaskName + taskId := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) - builder := snowflake.NewTaskBuilder(name, database, schema) - q := builder.Show() - row := snowflake.QueryRow(db, q) - t, err := snowflake.ScanTask(row) - if errors.Is(err, sql.ErrNoRows) { + task, err := client.Tasks.ShowByID(ctx, taskId) + if err != nil { // If not found, mark resource to be removed from state file during apply or refresh log.Printf("[DEBUG] task (%s) not found", d.Id()) d.SetId("") return nil } - if err != nil { - return err - } - if err := d.Set("enabled", t.IsEnabled()); err != nil { + if err := d.Set("enabled", task.IsStarted()); err != nil { return err } - if err := d.Set("name", t.Name); err != nil { + if err := d.Set("name", task.Name); err != nil { return err } - if err := d.Set("database", t.DatabaseName); err != nil { + if err := d.Set("database", task.DatabaseName); err != nil { return err } - if err := d.Set("schema", t.SchemaName); err != nil { + if err := d.Set("schema", task.SchemaName); err != nil { return err } - if err := d.Set("warehouse", t.Warehouse); err != nil { + if err := d.Set("warehouse", task.Warehouse); err != nil { return err } - if err := d.Set("schedule", t.Schedule); err != nil { + if err := d.Set("schedule", task.Schedule); err != nil { return err } - if err := d.Set("comment", t.Comment); err != nil { + if err := d.Set("comment", task.Comment); err != nil { return err } - allowOverlappingExecutionValue, err := t.AllowOverlappingExecution.Value() - if err != nil { + if err := d.Set("allow_overlapping_execution", task.AllowOverlappingExecution); err != nil { return err } - if allowOverlappingExecutionValue != nil && allowOverlappingExecutionValue.(string) != "null" { - allowOverlappingExecution, err := strconv.ParseBool(allowOverlappingExecutionValue.(string)) - if err != nil { - return err - } - - if err := d.Set("allow_overlapping_execution", allowOverlappingExecution); err != nil { - return err - } - } else { - if err := d.Set("allow_overlapping_execution", false); err != nil { - return err - } - } - - // The "DESCRIBE TASK ..." command returns the string "null" for error_integration - if t.ErrorIntegration.String == "null" { - t.ErrorIntegration.Valid = false - t.ErrorIntegration.String = "" - } - - if err := d.Set("error_integration", t.ErrorIntegration.String); err != nil { + if err := d.Set("error_integration", task.ErrorIntegration); err != nil { return err } - predecessors, err := t.GetPredecessors() - if err != nil { - return err + predecessors := make([]string, len(task.Predecessors)) + for i, p := range task.Predecessors { + predecessors[i] = p.Name() } - if err := d.Set("after", predecessors); err != nil { return err } - if err := d.Set("when", t.Condition); err != nil { + if err := d.Set("when", task.Condition); err != nil { return err } - if err := d.Set("sql_statement", t.Definition); err != nil { + if err := d.Set("sql_statement", task.Definition); err != nil { return err } - q = builder.ShowParameters() - paramRows, err := snowflake.Query(db, q) - if err != nil { - return err - } - params, err := snowflake.ScanTaskParameters(paramRows) + opts := &sdk.ShowParametersOptions{In: &sdk.ParametersIn{Task: taskId}} + params, err := client.Parameters.ShowParameters(ctx, opts) if err != nil { return err } @@ -306,8 +220,6 @@ func ReadTask(d *schema.ResourceData, meta interface{}) error { } for _, param := range params { - log.Printf("[TRACE] %+v\n", param) - if param.Level != "TASK" { continue } @@ -344,98 +256,102 @@ func ReadTask(d *schema.ResourceData, meta interface{}) error { // CreateTask implements schema.CreateFunc. func CreateTask(d *schema.ResourceData, meta interface{}) error { - var err error db := meta.(*sql.DB) - database := d.Get("database").(string) - schema := d.Get("schema").(string) + client := sdk.NewClientFromDB(db) + ctx := context.Background() + + databaseName := d.Get("database").(string) + schemaName := d.Get("schema").(string) name := d.Get("name").(string) - sql := d.Get("sql_statement").(string) - enabled := d.Get("enabled").(bool) - builder := snowflake.NewTaskBuilder(name, database, schema) - builder.WithStatement(sql) + sqlStatement := d.Get("sql_statement").(string) + + taskId := sdk.NewSchemaObjectIdentifier(databaseName, schemaName, name) + createRequest := sdk.NewCreateTaskRequest(taskId, sqlStatement) // Set optionals if v, ok := d.GetOk("warehouse"); ok { - builder.WithWarehouse(v.(string)) + warehouseId := sdk.NewAccountObjectIdentifier(v.(string)) + createRequest.WithWarehouse(sdk.NewCreateTaskWarehouseRequest().WithWarehouse(&warehouseId)) } if v, ok := d.GetOk("user_task_managed_initial_warehouse_size"); ok { - builder.WithInitialWarehouseSize(v.(string)) + size, err := sdk.ToWarehouseSize(v.(string)) + if err != nil { + return err + } + createRequest.WithWarehouse(sdk.NewCreateTaskWarehouseRequest().WithUserTaskManagedInitialWarehouseSize(&size)) } if v, ok := d.GetOk("schedule"); ok { - builder.WithSchedule(v.(string)) + createRequest.WithSchedule(sdk.String(v.(string))) } if v, ok := d.GetOk("session_parameters"); ok { - builder.WithSessionParameters(v.(map[string]interface{})) + sessionParameters, err := sdk.GetSessionParametersFrom(v.(map[string]any)) + if err != nil { + return err + } + createRequest.WithSessionParameters(sessionParameters) } if v, ok := d.GetOk("user_task_timeout_ms"); ok { - builder.WithTimeout(v.(int)) + createRequest.WithUserTaskTimeoutMs(sdk.Int(v.(int))) } if v, ok := d.GetOk("comment"); ok { - builder.WithComment(v.(string)) + createRequest.WithComment(sdk.String(v.(string))) } if v, ok := d.GetOk("allow_overlapping_execution"); ok { - builder.WithAllowOverlappingExecution(v.(bool)) + createRequest.WithAllowOverlappingExecution(sdk.Bool(v.(bool))) } if v, ok := d.GetOk("error_integration"); ok { - builder.WithErrorIntegration(v.(string)) + createRequest.WithErrorIntegration(sdk.String(v.(string))) } if v, ok := d.GetOk("after"); ok { after := expandStringList(v.([]interface{})) + precedingTasks := make([]sdk.SchemaObjectIdentifier, 0) for _, dep := range after { - rootTasks, err := snowflake.GetRootTasks(dep, database, schema, db) + precedingTaskId := sdk.NewSchemaObjectIdentifier(databaseName, schemaName, dep) + rootTasks, err := sdk.GetRootTasks(client.Tasks, ctx, precedingTaskId) if err != nil { return err } for _, rootTask := range rootTasks { - // if a root task is enabled, then it needs to be suspended before the child tasks can be created - if rootTask.IsEnabled() { - q := rootTask.Suspend() - if err := snowflake.Exec(db, q); err != nil { + // if a root task is started, then it needs to be suspended before the child tasks can be created + if rootTask.IsStarted() { + err := suspendTask(ctx, client, rootTask.ID()) + if err != nil { return err } // resume the task after modifications are complete as long as it is not a standalone task if !(rootTask.Name == name) { - defer resumeTask(db, rootTask) + defer func(identifier sdk.SchemaObjectIdentifier) { _ = resumeTask(ctx, client, identifier) }(rootTask.ID()) } } } - - builder.WithAfter(after) + precedingTasks = append(precedingTasks, precedingTaskId) } + createRequest.WithAfter(precedingTasks) } if v, ok := d.GetOk("when"); ok { - builder.WithCondition(v.(string)) + createRequest.WithWhen(sdk.String(v.(string))) } - q := builder.Create() - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error creating task %v err = %w", name, err) + if err := client.Tasks.Create(ctx, createRequest); err != nil { + return fmt.Errorf("error creating task %s err = %w", taskId.FullyQualifiedName(), err) } - taskID := &taskID{ - DatabaseName: database, - SchemaName: schema, - TaskName: name, - } - dataIDInput, err := taskID.String() - if err != nil { - return err - } - d.SetId(dataIDInput) + d.SetId(helpers.EncodeSnowflakeID(taskId)) + enabled := d.Get("enabled").(bool) if enabled { - if err := snowflake.WaitResumeTask(db, name, database, schema); err != nil { + if err := waitForTaskStart(ctx, client, taskId); err != nil { log.Printf("[WARN] failed to resume task %s", name) } } @@ -443,57 +359,77 @@ func CreateTask(d *schema.ResourceData, meta interface{}) error { return ReadTask(d, meta) } -func resumeTask(db *sql.DB, rootTask *snowflake.Task) { - q := rootTask.Resume() - if err := snowflake.Exec(db, q); err != nil { - log.Printf("[WARN] failed to resume task %s", rootTask.Name) +func waitForTaskStart(ctx context.Context, client *sdk.Client, id sdk.SchemaObjectIdentifier) error { + err := resumeTask(ctx, client, id) + if err != nil { + return fmt.Errorf("error starting task %s err = %w", id.FullyQualifiedName(), err) + } + return helpers.Retry(5, 5*time.Second, func() (error, bool) { + task, err := client.Tasks.ShowByID(ctx, id) + if err != nil { + return fmt.Errorf("error starting task %s err = %w", id.FullyQualifiedName(), err), false + } + if !task.IsStarted() { + return nil, false + } + return nil, true + }) +} + +func suspendTask(ctx context.Context, client *sdk.Client, id sdk.SchemaObjectIdentifier) error { + err := client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(id).WithSuspend(sdk.Bool(true))) + if err != nil { + log.Printf("[WARN] failed to suspend task %s", id.FullyQualifiedName()) } + return err } -// UpdateTask implements schema.UpdateFunc. -func UpdateTask(d *schema.ResourceData, meta interface{}) error { - taskID, err := taskIDFromString(d.Id()) +func resumeTask(ctx context.Context, client *sdk.Client, id sdk.SchemaObjectIdentifier) error { + err := client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(id).WithResume(sdk.Bool(true))) if err != nil { - return err + log.Printf("[WARN] failed to resume task %s", id.FullyQualifiedName()) } + return err +} +// UpdateTask implements schema.UpdateFunc. +func UpdateTask(d *schema.ResourceData, meta interface{}) error { db := meta.(*sql.DB) - database := taskID.DatabaseName - schema := taskID.SchemaName - name := taskID.TaskName - builder := snowflake.NewTaskBuilder(name, database, schema) + client := sdk.NewClientFromDB(db) + ctx := context.Background() + + taskId := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) - rootTasks, err := snowflake.GetRootTasks(name, database, schema, db) + rootTasks, err := sdk.GetRootTasks(client.Tasks, ctx, taskId) if err != nil { return err } for _, rootTask := range rootTasks { - // if a root task is enabled, then it needs to be suspended before the child tasks can be created - if rootTask.IsEnabled() { - q := rootTask.Suspend() - if err := snowflake.Exec(db, q); err != nil { + // if a root task is started, then it needs to be suspended before the child tasks can be created + if rootTask.IsStarted() { + err := suspendTask(ctx, client, rootTask.ID()) + if err != nil { return err } - if !(rootTask.Name == name) { - // resume the task after modifications are complete, as long as it is not a standalone task - defer resumeTask(db, rootTask) + // resume the task after modifications are complete as long as it is not a standalone task + if !(rootTask.Name == taskId.Name()) { + defer func(identifier sdk.SchemaObjectIdentifier) { _ = resumeTask(ctx, client, identifier) }(rootTask.ID()) } } } if d.HasChange("warehouse") { - var q string newWarehouse := d.Get("warehouse") - + alterRequest := sdk.NewAlterTaskRequest(taskId) if newWarehouse == "" { - q = builder.SwitchWarehouseToManaged() + alterRequest.WithUnset(sdk.NewTaskUnsetRequest().WithWarehouse(sdk.Bool(true))) } else { - q = builder.ChangeWarehouse(newWarehouse.(string)) + alterRequest.WithSet(sdk.NewTaskSetRequest().WithWarehouse(sdk.Pointer(sdk.NewAccountObjectIdentifier(newWarehouse.(string))))) } - - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error updating warehouse on task %v", d.Id()) + err := client.Tasks.Alter(ctx, alterRequest) + if err != nil { + return fmt.Errorf("error updating warehouse on task %s err = %w", taskId.FullyQualifiedName(), err) } } @@ -502,158 +438,156 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error { warehouse := d.Get("warehouse") if warehouse == "" && newSize != "" { - q := builder.SwitchManagedWithInitialSize(newSize.(string)) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error updating user_task_managed_initial_warehouse_size on task %v", d.Id()) + size, err := sdk.ToWarehouseSize(newSize.(string)) + if err != nil { + return err + } + alterRequest := sdk.NewAlterTaskRequest(taskId).WithSet(sdk.NewTaskSetRequest().WithUserTaskManagedInitialWarehouseSize(&size)) + err = client.Tasks.Alter(ctx, alterRequest) + if err != nil { + return fmt.Errorf("error updating user_task_managed_initial_warehouse_size on task %s", taskId.FullyQualifiedName()) } } } if d.HasChange("error_integration") { - var q string - if errorIntegration, ok := d.GetOk("error_integration"); ok { - q = builder.ChangeErrorIntegration(errorIntegration.(string)) + newErrorIntegration := d.Get("error_integration") + alterRequest := sdk.NewAlterTaskRequest(taskId) + if newErrorIntegration == "" { + alterRequest.WithUnset(sdk.NewTaskUnsetRequest().WithErrorIntegration(sdk.Bool(true))) } else { - q = builder.RemoveErrorIntegration() + alterRequest.WithSet(sdk.NewTaskSetRequest().WithErrorIntegration(sdk.String(newErrorIntegration.(string)))) } - - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error updating task error_integration on %v", d.Id()) + err := client.Tasks.Alter(ctx, alterRequest) + if err != nil { + return fmt.Errorf("error updating error integration on task %s", taskId.FullyQualifiedName()) } } if d.HasChange("after") { - // preemitvely removing schedule because a task cannot have both after and schedule - q := builder.RemoveSchedule() - - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error updating schedule on task %v", d.Id()) - } - // making changes to after require suspending the current task - q = builder.Suspend() - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error suspending task %v", d.Id()) + if err := suspendTask(ctx, client, taskId); err != nil { + return fmt.Errorf("error suspending task %s", taskId.FullyQualifiedName()) } o, n := d.GetChange("after") - var oldAfter []string - if len(o.([]interface{})) > 0 { - oldAfter = expandStringList(o.([]interface{})) - } + oldAfter := expandStringList(o.([]interface{})) + newAfter := expandStringList(n.([]interface{})) - var newAfter []string - if len(n.([]interface{})) > 0 { - newAfter = expandStringList(n.([]interface{})) + if len(newAfter) > 0 { + // preemptively removing schedule because a task cannot have both after and schedule + if err := client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(taskId).WithUnset(sdk.NewTaskUnsetRequest().WithSchedule(sdk.Bool(true)))); err != nil { + return fmt.Errorf("error updating schedule on task %s", taskId.FullyQualifiedName()) + } } // Remove old dependencies that are not in new dependencies - var toRemove []string + toRemove := make([]sdk.SchemaObjectIdentifier, 0) for _, dep := range oldAfter { if !slices.Contains(newAfter, dep) { - toRemove = append(toRemove, dep) + toRemove = append(toRemove, sdk.NewSchemaObjectIdentifier(taskId.DatabaseName(), taskId.SchemaName(), dep)) } } if len(toRemove) > 0 { - q := builder.RemoveAfter(toRemove) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error removing after dependencies from task %v", d.Id()) + if err := client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(taskId).WithRemoveAfter(toRemove)); err != nil { + return fmt.Errorf("error removing after dependencies from task %s", taskId.FullyQualifiedName()) } } // Add new dependencies that are not in old dependencies - var toAdd []string + toAdd := make([]sdk.SchemaObjectIdentifier, 0) for _, dep := range newAfter { if !slices.Contains(oldAfter, dep) { - toAdd = append(toAdd, dep) + toAdd = append(toAdd, sdk.NewSchemaObjectIdentifier(taskId.DatabaseName(), taskId.SchemaName(), dep)) } } + // TODO [SNOW-884987]: for now leaving old copy-pasted implementation; extract function for task suspension in following change if len(toAdd) > 0 { // need to suspend any new root tasks from dependencies before adding them for _, dep := range toAdd { - rootTasks, err := snowflake.GetRootTasks(dep, database, schema, db) + rootTasks, err := sdk.GetRootTasks(client.Tasks, ctx, dep) if err != nil { return err } for _, rootTask := range rootTasks { - if rootTask.IsEnabled() { - q := rootTask.Suspend() - if err := snowflake.Exec(db, q); err != nil { + // if a root task is started, then it needs to be suspended before the child tasks can be created + if rootTask.IsStarted() { + err := suspendTask(ctx, client, rootTask.ID()) + if err != nil { return err } - if !(rootTask.Name == name) { - // resume the task after modifications are complete, as long as it is not a standalone task - defer resumeTask(db, rootTask) + // resume the task after modifications are complete as long as it is not a standalone task + if !(rootTask.Name == taskId.Name()) { + defer func(identifier sdk.SchemaObjectIdentifier) { _ = resumeTask(ctx, client, identifier) }(rootTask.ID()) } } } } - q := builder.AddAfter(toAdd) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error adding after dependencies to task %v", d.Id()) + if err := client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(taskId).WithAddAfter(toAdd)); err != nil { + return fmt.Errorf("error adding after dependencies from task %s", taskId.FullyQualifiedName()) } } } if d.HasChange("schedule") { - var q string - o, n := d.GetChange("schedule") - if o != "" && n == "" { - q = builder.RemoveSchedule() + newSchedule := d.Get("schedule") + alterRequest := sdk.NewAlterTaskRequest(taskId) + if newSchedule == "" { + alterRequest.WithUnset(sdk.NewTaskUnsetRequest().WithSchedule(sdk.Bool(true))) } else { - q = builder.ChangeSchedule(n.(string)) + alterRequest.WithSet(sdk.NewTaskSetRequest().WithSchedule(sdk.String(newSchedule.(string)))) } - - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error updating schedule on task %v", d.Id()) + err := client.Tasks.Alter(ctx, alterRequest) + if err != nil { + return fmt.Errorf("error updating schedule on task %s", taskId.FullyQualifiedName()) } } if d.HasChange("user_task_timeout_ms") { - var q string o, n := d.GetChange("user_task_timeout_ms") + alterRequest := sdk.NewAlterTaskRequest(taskId) if o.(int) > 0 && n.(int) == 0 { - q = builder.RemoveTimeout() + alterRequest.WithUnset(sdk.NewTaskUnsetRequest().WithUserTaskTimeoutMs(sdk.Bool(true))) } else { - q = builder.ChangeTimeout(n.(int)) + alterRequest.WithSet(sdk.NewTaskSetRequest().WithUserTaskTimeoutMs(sdk.Int(n.(int)))) } - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error updating user task timeout on task %v", d.Id()) + err := client.Tasks.Alter(ctx, alterRequest) + if err != nil { + return fmt.Errorf("error updating user task timeout on task %s", taskId.FullyQualifiedName()) } } if d.HasChange("comment") { - var q string - o, n := d.GetChange("comment") - if o != "" && n == "" { - q = builder.RemoveComment() + newComment := d.Get("comment") + alterRequest := sdk.NewAlterTaskRequest(taskId) + if newComment == "" { + alterRequest.WithUnset(sdk.NewTaskUnsetRequest().WithComment(sdk.Bool(true))) } else { - q = builder.ChangeComment(n.(string)) + alterRequest.WithSet(sdk.NewTaskSetRequest().WithComment(sdk.String(newComment.(string)))) } - - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error updating comment on task %v", d.Id()) + err := client.Tasks.Alter(ctx, alterRequest) + if err != nil { + return fmt.Errorf("error updating comment on task %s", taskId.FullyQualifiedName()) } } if d.HasChange("allow_overlapping_execution") { - var q string - _, n := d.GetChange("allow_overlapping_execution") - flag := n.(bool) - if flag { - q = builder.SetAllowOverlappingExecutionParameter() + n := d.Get("allow_overlapping_execution") + alterRequest := sdk.NewAlterTaskRequest(taskId) + if n == "" { + alterRequest.WithUnset(sdk.NewTaskUnsetRequest().WithAllowOverlappingExecution(sdk.Bool(true))) } else { - q = builder.UnsetAllowOverlappingExecutionParameter() + alterRequest.WithSet(sdk.NewTaskSetRequest().WithAllowOverlappingExecution(sdk.Bool(n.(bool)))) } - - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error updating task %v", d.Id()) + err := client.Tasks.Alter(ctx, alterRequest) + if err != nil { + return fmt.Errorf("error updating allow overlapping execution on task %s", taskId.FullyQualifiedName()) } } + // TODO [SNOW-884987]: old implementation does not handle changing parameter value correctly (only finds for parameters to add od remove, not change) if d.HasChange("session_parameters") { - var q string o, n := d.GetChange("session_parameters") if o == nil { @@ -662,22 +596,28 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error { if n == nil { n = make(map[string]interface{}) } - os := o.(map[string]interface{}) - ns := n.(map[string]interface{}) + os := o.(map[string]any) + ns := n.(map[string]any) remove := difference(os, ns) add := difference(ns, os) if len(remove) > 0 { - q = builder.RemoveSessionParameters(remove) - if err := snowflake.Exec(db, q); err != nil { + sessionParametersUnset, err := sdk.GetSessionParametersUnsetFrom(remove) + if err != nil { + return err + } + if err := client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(taskId).WithUnset(sdk.NewTaskUnsetRequest().WithSessionParametersUnset(sessionParametersUnset))); err != nil { return fmt.Errorf("error removing session_parameters on task %v", d.Id()) } } if len(add) > 0 { - q = builder.AddSessionParameters(add) - if err := snowflake.Exec(db, q); err != nil { + sessionParameters, err := sdk.GetSessionParametersFrom(add) + if err != nil { + return err + } + if err := client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(taskId).WithSet(sdk.NewTaskSetRequest().WithSessionParameters(sessionParameters))); err != nil { return fmt.Errorf("error adding session_parameters to task %v", d.Id()) } } @@ -685,29 +625,30 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error { if d.HasChange("when") { n := d.Get("when") - q := builder.ChangeCondition(n.(string)) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error updating when condition on task %v", d.Id()) + alterRequest := sdk.NewAlterTaskRequest(taskId).WithModifyWhen(sdk.String(n.(string))) + err := client.Tasks.Alter(ctx, alterRequest) + if err != nil { + return fmt.Errorf("error updating when condition on task %s", taskId.FullyQualifiedName()) } } if d.HasChange("sql_statement") { n := d.Get("sql_statement") - q := builder.ChangeSQLStatement(n.(string)) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error updating sql statement on task %v", d.Id()) + alterRequest := sdk.NewAlterTaskRequest(taskId).WithModifyAs(sdk.String(n.(string))) + err := client.Tasks.Alter(ctx, alterRequest) + if err != nil { + return fmt.Errorf("error updating sql statement on task %s", taskId.FullyQualifiedName()) } } enabled := d.Get("enabled").(bool) if enabled { - if err := snowflake.WaitResumeTask(db, name, database, schema); err != nil { - log.Printf("[WARN] failed to resume task %s", name) + if waitForTaskStart(ctx, client, taskId) != nil { + log.Printf("[WARN] failed to resume task %s", taskId.FullyQualifiedName()) } } else { - q := builder.Suspend() - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error updating task state %v", d.Id()) + if suspendTask(ctx, client, taskId) != nil { + return fmt.Errorf("[WARN] failed to suspend task %s", taskId.FullyQualifiedName()) } } return ReadTask(d, meta) @@ -716,40 +657,36 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error { // DeleteTask implements schema.DeleteFunc. func DeleteTask(d *schema.ResourceData, meta interface{}) error { db := meta.(*sql.DB) - taskID, err := taskIDFromString(d.Id()) - if err != nil { - return err - } + client := sdk.NewClientFromDB(db) + ctx := context.Background() - database := taskID.DatabaseName - schema := taskID.SchemaName - name := taskID.TaskName + taskId := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) - rootTasks, err := snowflake.GetRootTasks(name, database, schema, db) + rootTasks, err := sdk.GetRootTasks(client.Tasks, ctx, taskId) if err != nil { return err } for _, rootTask := range rootTasks { - // if a root task is enabled, then it needs to be suspended before the child tasks can be deleted - if rootTask.IsEnabled() { - q := rootTask.Suspend() - if err := snowflake.Exec(db, q); err != nil { + // if a root task is started, then it needs to be suspended before the child tasks can be created + if rootTask.IsStarted() { + err := suspendTask(ctx, client, rootTask.ID()) + if err != nil { return err } - if !(rootTask.Name == name) { - // resume the task after modifications are complete, as long as it is not a standalone task - defer resumeTask(db, rootTask) + // resume the task after modifications are complete as long as it is not a standalone task + if !(rootTask.Name == taskId.Name()) { + defer func(identifier sdk.SchemaObjectIdentifier) { _ = resumeTask(ctx, client, identifier) }(rootTask.ID()) } } } - q := snowflake.NewTaskBuilder(name, database, schema).Drop() - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error deleting task %v err = %w", d.Id(), err) + dropRequest := sdk.NewDropTaskRequest(taskId) + err = client.Tasks.Drop(ctx, dropRequest) + if err != nil { + return fmt.Errorf("error deleting task %s err = %w", taskId.FullyQualifiedName(), err) } d.SetId("") - return nil } diff --git a/pkg/resources/task_internal_test.go b/pkg/resources/task_internal_test.go deleted file mode 100644 index f433665048..0000000000 --- a/pkg/resources/task_internal_test.go +++ /dev/null @@ -1,47 +0,0 @@ -package resources - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestStringFromTaskID(t *testing.T) { - r := require.New(t) - task := taskID{DatabaseName: "test_db", SchemaName: "test_schema", TaskName: "test_task"} - id, err := task.String() - r.NoError(err) - r.Equal("test_db|test_schema|test_task", id) -} - -func TestTaskIDFromString(t *testing.T) { - r := require.New(t) - - id := "test_db|test_schema|test_task" - task, err := taskIDFromString(id) - r.NoError(err) - r.Equal("test_db", task.DatabaseName) - r.Equal("test_schema", task.SchemaName) - r.Equal("test_task", task.TaskName) - - id = "test_db" - _, err = taskIDFromString(id) - r.Equal(fmt.Errorf("3 fields allowed"), err) - - // Bad ID - id = "|" - _, err = taskIDFromString(id) - r.Equal(fmt.Errorf("3 fields allowed"), err) - - // 0 lines - id = "" - _, err = taskIDFromString(id) - r.Equal(fmt.Errorf("1 line per task"), err) - - // 2 lines - id = `database|schema|task - database|schema|task` - _, err = taskIDFromString(id) - r.Equal(fmt.Errorf("1 line per task"), err) -} diff --git a/pkg/sdk/parameters.go b/pkg/sdk/parameters.go index 0c6911f3bb..7cff37644a 100644 --- a/pkg/sdk/parameters.go +++ b/pkg/sdk/parameters.go @@ -971,38 +971,43 @@ func (v *SessionParameters) validate() error { } type SessionParametersUnset struct { - AbortDetachedQuery *bool `ddl:"keyword" sql:"ABORT_DETACHED_QUERY"` - Autocommit *bool `ddl:"keyword" sql:"AUTOCOMMIT"` - BinaryInputFormat *bool `ddl:"keyword" sql:"BINARY_INPUT_FORMAT"` - BinaryOutputFormat *bool `ddl:"keyword" sql:"BINARY_OUTPUT_FORMAT"` - DateInputFormat *bool `ddl:"keyword" sql:"DATE_INPUT_FORMAT"` - DateOutputFormat *bool `ddl:"keyword" sql:"DATE_OUTPUT_FORMAT"` - ErrorOnNondeterministicMerge *bool `ddl:"keyword" sql:"ERROR_ON_NONDETERMINISTIC_MERGE"` - ErrorOnNondeterministicUpdate *bool `ddl:"keyword" sql:"ERROR_ON_NONDETERMINISTIC_UPDATE"` - GeographyOutputFormat *bool `ddl:"keyword" sql:"GEOGRAPHY_OUTPUT_FORMAT"` - JSONIndent *bool `ddl:"keyword" sql:"JSON_INDENT"` - LockTimeout *bool `ddl:"keyword" sql:"LOCK_TIMEOUT"` - QueryTag *bool `ddl:"keyword" sql:"QUERY_TAG"` - RowsPerResultset *bool `ddl:"keyword" sql:"ROWS_PER_RESULTSET"` - SimulatedDataSharingConsumer *bool `ddl:"keyword" sql:"SIMULATED_DATA_SHARING_CONSUMER"` - StatementTimeoutInSeconds *bool `ddl:"keyword" sql:"STATEMENT_TIMEOUT_IN_SECONDS"` - StrictJSONOutput *bool `ddl:"keyword" sql:"STRICT_JSON_OUTPUT"` - TimestampDayIsAlways24h *bool `ddl:"keyword" sql:"TIMESTAMP_DAY_IS_ALWAYS_24H"` - TimestampInputFormat *bool `ddl:"keyword" sql:"TIMESTAMP_INPUT_FORMAT"` - TimestampLTZOutputFormat *bool `ddl:"keyword" sql:"TIMESTAMP_LTZ_OUTPUT_FORMAT"` - TimestampNTZOutputFormat *bool `ddl:"keyword" sql:"TIMESTAMP_NTZ_OUTPUT_FORMAT"` - TimestampOutputFormat *bool `ddl:"keyword" sql:"TIMESTAMP_OUTPUT_FORMAT"` - TimestampTypeMapping *bool `ddl:"keyword" sql:"TIMESTAMP_TYPE_MAPPING"` - TimestampTZOutputFormat *bool `ddl:"keyword" sql:"TIMESTAMP_TZ_OUTPUT_FORMAT"` - Timezone *bool `ddl:"keyword" sql:"TIMEZONE"` - TimeInputFormat *bool `ddl:"keyword" sql:"TIME_INPUT_FORMAT"` - TimeOutputFormat *bool `ddl:"keyword" sql:"TIME_OUTPUT_FORMAT"` - TransactionDefaultIsolationLevel *bool `ddl:"keyword" sql:"TRANSACTION_DEFAULT_ISOLATION_LEVEL"` - TwoDigitCenturyStart *bool `ddl:"keyword" sql:"TWO_DIGIT_CENTURY_START"` - UnsupportedDDLAction *bool `ddl:"keyword" sql:"UNSUPPORTED_DDL_ACTION"` - UseCachedResult *bool `ddl:"keyword" sql:"USE_CACHED_RESULT"` - WeekOfYearPolicy *bool `ddl:"keyword" sql:"WEEK_OF_YEAR_POLICY"` - WeekStart *bool `ddl:"keyword" sql:"WEEK_START"` + AbortDetachedQuery *bool `ddl:"keyword" sql:"ABORT_DETACHED_QUERY"` + Autocommit *bool `ddl:"keyword" sql:"AUTOCOMMIT"` + BinaryInputFormat *bool `ddl:"keyword" sql:"BINARY_INPUT_FORMAT"` + BinaryOutputFormat *bool `ddl:"keyword" sql:"BINARY_OUTPUT_FORMAT"` + ClientMetadataRequestUseConnectionCtx *bool `ddl:"keyword" sql:"CLIENT_METADATA_REQUEST_USE_CONNECTION_CTX"` + ClientMetadataUseSessionDatabase *bool `ddl:"keyword" sql:"CLIENT_METADATA_USE_SESSION_DATABASE"` + ClientResultColumnCaseInsensitive *bool `ddl:"keyword" sql:"CLIENT_RESULT_COLUMN_CASE_INSENSITIVE"` + DateInputFormat *bool `ddl:"keyword" sql:"DATE_INPUT_FORMAT"` + DateOutputFormat *bool `ddl:"keyword" sql:"DATE_OUTPUT_FORMAT"` + ErrorOnNondeterministicMerge *bool `ddl:"keyword" sql:"ERROR_ON_NONDETERMINISTIC_MERGE"` + ErrorOnNondeterministicUpdate *bool `ddl:"keyword" sql:"ERROR_ON_NONDETERMINISTIC_UPDATE"` + GeographyOutputFormat *bool `ddl:"keyword" sql:"GEOGRAPHY_OUTPUT_FORMAT"` + JSONIndent *bool `ddl:"keyword" sql:"JSON_INDENT"` + LockTimeout *bool `ddl:"keyword" sql:"LOCK_TIMEOUT"` + MultiStatementCount *bool `ddl:"keyword" sql:"MULTI_STATEMENT_COUNT"` + QueryTag *bool `ddl:"keyword" sql:"QUERY_TAG"` + QuotedIdentifiersIgnoreCase *bool `ddl:"keyword" sql:"QUOTED_IDENTIFIERS_IGNORE_CASE"` + RowsPerResultset *bool `ddl:"keyword" sql:"ROWS_PER_RESULTSET"` + SimulatedDataSharingConsumer *bool `ddl:"keyword" sql:"SIMULATED_DATA_SHARING_CONSUMER"` + StatementTimeoutInSeconds *bool `ddl:"keyword" sql:"STATEMENT_TIMEOUT_IN_SECONDS"` + StrictJSONOutput *bool `ddl:"keyword" sql:"STRICT_JSON_OUTPUT"` + TimestampDayIsAlways24h *bool `ddl:"keyword" sql:"TIMESTAMP_DAY_IS_ALWAYS_24H"` + TimestampInputFormat *bool `ddl:"keyword" sql:"TIMESTAMP_INPUT_FORMAT"` + TimestampLTZOutputFormat *bool `ddl:"keyword" sql:"TIMESTAMP_LTZ_OUTPUT_FORMAT"` + TimestampNTZOutputFormat *bool `ddl:"keyword" sql:"TIMESTAMP_NTZ_OUTPUT_FORMAT"` + TimestampOutputFormat *bool `ddl:"keyword" sql:"TIMESTAMP_OUTPUT_FORMAT"` + TimestampTypeMapping *bool `ddl:"keyword" sql:"TIMESTAMP_TYPE_MAPPING"` + TimestampTZOutputFormat *bool `ddl:"keyword" sql:"TIMESTAMP_TZ_OUTPUT_FORMAT"` + Timezone *bool `ddl:"keyword" sql:"TIMEZONE"` + TimeInputFormat *bool `ddl:"keyword" sql:"TIME_INPUT_FORMAT"` + TimeOutputFormat *bool `ddl:"keyword" sql:"TIME_OUTPUT_FORMAT"` + TransactionDefaultIsolationLevel *bool `ddl:"keyword" sql:"TRANSACTION_DEFAULT_ISOLATION_LEVEL"` + TwoDigitCenturyStart *bool `ddl:"keyword" sql:"TWO_DIGIT_CENTURY_START"` + UnsupportedDDLAction *bool `ddl:"keyword" sql:"UNSUPPORTED_DDL_ACTION"` + UseCachedResult *bool `ddl:"keyword" sql:"USE_CACHED_RESULT"` + WeekOfYearPolicy *bool `ddl:"keyword" sql:"WEEK_OF_YEAR_POLICY"` + WeekStart *bool `ddl:"keyword" sql:"WEEK_START"` } func (v *SessionParametersUnset) validate() error { diff --git a/pkg/sdk/parameters_impl.go b/pkg/sdk/parameters_impl.go new file mode 100644 index 0000000000..110efbeff3 --- /dev/null +++ b/pkg/sdk/parameters_impl.go @@ -0,0 +1,275 @@ +package sdk + +import ( + "fmt" + "strconv" +) + +func GetSessionParametersFrom(params map[string]any) (*SessionParameters, error) { + sessionParameters := &SessionParameters{} + for k, v := range params { + s, ok := v.(string) + if !ok { + return nil, fmt.Errorf("expecting string value for parameter %s (current value: %v)", k, v) + } + err := sessionParameters.setParam(SessionParameter(k), s) + if err != nil { + return nil, err + } + } + return sessionParameters, nil +} + +// TODO [SNOW-884987]: use this method in SetSessionParameterOnAccount and in SetSessionParameterOnUser +// TODO [SNOW-884987]: unit test this method +func (sessionParameters *SessionParameters) setParam(parameter SessionParameter, value string) error { + switch parameter { + case SessionParameterAbortDetachedQuery: + b, err := parseBooleanParameter(string(parameter), value) + if err != nil { + return err + } + sessionParameters.AbortDetachedQuery = b + case SessionParameterAutocommit: + b, err := parseBooleanParameter(string(parameter), value) + if err != nil { + return err + } + sessionParameters.Autocommit = b + case SessionParameterBinaryInputFormat: + sessionParameters.BinaryInputFormat = Pointer(BinaryInputFormat(value)) + case SessionParameterBinaryOutputFormat: + sessionParameters.BinaryOutputFormat = Pointer(BinaryOutputFormat(value)) + case SessionParameterClientMetadataRequestUseConnectionCtx: + b, err := parseBooleanParameter(string(parameter), value) + if err != nil { + return err + } + sessionParameters.ClientMetadataRequestUseConnectionCtx = b + case SessionParameterClientMetadataUseSessionDatabase: + b, err := parseBooleanParameter(string(parameter), value) + if err != nil { + return err + } + sessionParameters.ClientMetadataUseSessionDatabase = b + case SessionParameterClientResultColumnCaseInsensitive: + b, err := parseBooleanParameter(string(parameter), value) + if err != nil { + return err + } + sessionParameters.ClientResultColumnCaseInsensitive = b + case SessionParameterDateInputFormat: + sessionParameters.DateInputFormat = &value + case SessionParameterDateOutputFormat: + sessionParameters.DateOutputFormat = &value + case SessionParameterErrorOnNondeterministicMerge: + b, err := parseBooleanParameter(string(parameter), value) + if err != nil { + return err + } + sessionParameters.ErrorOnNondeterministicMerge = b + case SessionParameterErrorOnNondeterministicUpdate: + b, err := parseBooleanParameter(string(parameter), value) + if err != nil { + return err + } + sessionParameters.ErrorOnNondeterministicUpdate = b + case SessionParameterGeographyOutputFormat: + sessionParameters.GeographyOutputFormat = Pointer(GeographyOutputFormat(value)) + case SessionParameterJSONIndent: + v, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("JSON_INDENT session parameter is an integer, got %v", value) + } + sessionParameters.JSONIndent = Pointer(v) + case SessionParameterLockTimeout: + v, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("LOCK_TIMEOUT session parameter is an integer, got %v", value) + } + sessionParameters.LockTimeout = Pointer(v) + case SessionParameterMultiStatementCount: + v, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("MULTI_STATEMENT_COUNT session parameter is an integer, got %v", value) + } + sessionParameters.MultiStatementCount = Pointer(v) + + case SessionParameterQueryTag: + sessionParameters.QueryTag = &value + case SessionParameterQuotedIdentifiersIgnoreCase: + b, err := parseBooleanParameter(string(parameter), value) + if err != nil { + return err + } + sessionParameters.QuotedIdentifiersIgnoreCase = b + case SessionParameterRowsPerResultset: + v, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("ROWS_PER_RESULTSET session parameter is an integer, got %v", value) + } + sessionParameters.RowsPerResultset = Pointer(v) + case SessionParameterSimulatedDataSharingConsumer: + sessionParameters.SimulatedDataSharingConsumer = &value + case SessionParameterStatementTimeoutInSeconds: + v, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("STATEMENT_TIMEOUT_IN_SECONDS session parameter is an integer, got %v", value) + } + sessionParameters.StatementTimeoutInSeconds = Pointer(v) + case SessionParameterStrictJSONOutput: + b, err := parseBooleanParameter(string(parameter), value) + if err != nil { + return err + } + sessionParameters.StrictJSONOutput = b + case SessionParameterTimestampDayIsAlways24h: + b, err := parseBooleanParameter(string(parameter), value) + if err != nil { + return err + } + sessionParameters.TimestampDayIsAlways24h = b + case SessionParameterTimestampInputFormat: + sessionParameters.TimestampInputFormat = &value + case SessionParameterTimestampLTZOutputFormat: + sessionParameters.TimestampLTZOutputFormat = &value + case SessionParameterTimestampNTZOutputFormat: + sessionParameters.TimestampNTZOutputFormat = &value + case SessionParameterTimestampOutputFormat: + sessionParameters.TimestampOutputFormat = &value + case SessionParameterTimestampTypeMapping: + sessionParameters.TimestampTypeMapping = &value + case SessionParameterTimestampTZOutputFormat: + sessionParameters.TimestampTZOutputFormat = &value + case SessionParameterTimezone: + sessionParameters.Timezone = &value + case SessionParameterTimeInputFormat: + sessionParameters.TimeInputFormat = &value + case SessionParameterTimeOutputFormat: + sessionParameters.TimeOutputFormat = &value + case SessionParameterTransactionDefaultIsolationLevel: + sessionParameters.TransactionDefaultIsolationLevel = Pointer(TransactionDefaultIsolationLevel(value)) + case SessionParameterTwoDigitCenturyStart: + v, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("TWO_DIGIT_CENTURY_START session parameter is an integer, got %v", value) + } + sessionParameters.TwoDigitCenturyStart = Pointer(v) + case SessionParameterUnsupportedDDLAction: + sessionParameters.UnsupportedDDLAction = Pointer(UnsupportedDDLAction(value)) + case SessionParameterUseCachedResult: + b, err := parseBooleanParameter(string(parameter), value) + if err != nil { + return err + } + sessionParameters.UseCachedResult = b + case SessionParameterWeekOfYearPolicy: + v, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("WEEK_OF_YEAR_POLICY session parameter is an integer, got %v", value) + } + sessionParameters.WeekOfYearPolicy = Pointer(v) + case SessionParameterWeekStart: + v, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("WEEK_START session parameter is an integer, got %v", value) + } + sessionParameters.WeekStart = Pointer(v) + default: + return fmt.Errorf("%s session parameter is not supported", string(parameter)) + } + return nil +} + +func GetSessionParametersUnsetFrom(params map[string]any) (*SessionParametersUnset, error) { + sessionParametersUnset := &SessionParametersUnset{} + for k := range params { + err := sessionParametersUnset.setParam(SessionParameter(k)) + if err != nil { + return nil, err + } + } + return sessionParametersUnset, nil +} + +func (sessionParametersUnset *SessionParametersUnset) setParam(parameter SessionParameter) error { + switch parameter { + case SessionParameterAbortDetachedQuery: + sessionParametersUnset.AbortDetachedQuery = Bool(true) + case SessionParameterAutocommit: + sessionParametersUnset.Autocommit = Bool(true) + case SessionParameterBinaryInputFormat: + sessionParametersUnset.BinaryInputFormat = Bool(true) + case SessionParameterBinaryOutputFormat: + sessionParametersUnset.BinaryOutputFormat = Bool(true) + case SessionParameterClientMetadataRequestUseConnectionCtx: + sessionParametersUnset.ClientMetadataRequestUseConnectionCtx = Bool(true) + case SessionParameterClientMetadataUseSessionDatabase: + sessionParametersUnset.ClientMetadataUseSessionDatabase = Bool(true) + case SessionParameterClientResultColumnCaseInsensitive: + sessionParametersUnset.ClientResultColumnCaseInsensitive = Bool(true) + case SessionParameterDateInputFormat: + sessionParametersUnset.DateInputFormat = Bool(true) + case SessionParameterDateOutputFormat: + sessionParametersUnset.DateOutputFormat = Bool(true) + case SessionParameterErrorOnNondeterministicMerge: + sessionParametersUnset.ErrorOnNondeterministicMerge = Bool(true) + case SessionParameterErrorOnNondeterministicUpdate: + sessionParametersUnset.ErrorOnNondeterministicUpdate = Bool(true) + case SessionParameterGeographyOutputFormat: + sessionParametersUnset.GeographyOutputFormat = Bool(true) + case SessionParameterJSONIndent: + sessionParametersUnset.JSONIndent = Bool(true) + case SessionParameterLockTimeout: + sessionParametersUnset.LockTimeout = Bool(true) + case SessionParameterMultiStatementCount: + sessionParametersUnset.MultiStatementCount = Bool(true) + case SessionParameterQueryTag: + sessionParametersUnset.QueryTag = Bool(true) + case SessionParameterQuotedIdentifiersIgnoreCase: + sessionParametersUnset.QuotedIdentifiersIgnoreCase = Bool(true) + case SessionParameterRowsPerResultset: + sessionParametersUnset.RowsPerResultset = Bool(true) + case SessionParameterSimulatedDataSharingConsumer: + sessionParametersUnset.SimulatedDataSharingConsumer = Bool(true) + case SessionParameterStatementTimeoutInSeconds: + sessionParametersUnset.StatementTimeoutInSeconds = Bool(true) + case SessionParameterStrictJSONOutput: + sessionParametersUnset.StrictJSONOutput = Bool(true) + case SessionParameterTimestampDayIsAlways24h: + sessionParametersUnset.TimestampDayIsAlways24h = Bool(true) + case SessionParameterTimestampInputFormat: + sessionParametersUnset.TimestampInputFormat = Bool(true) + case SessionParameterTimestampLTZOutputFormat: + sessionParametersUnset.TimestampLTZOutputFormat = Bool(true) + case SessionParameterTimestampNTZOutputFormat: + sessionParametersUnset.TimestampNTZOutputFormat = Bool(true) + case SessionParameterTimestampOutputFormat: + sessionParametersUnset.TimestampOutputFormat = Bool(true) + case SessionParameterTimestampTypeMapping: + sessionParametersUnset.TimestampTypeMapping = Bool(true) + case SessionParameterTimestampTZOutputFormat: + sessionParametersUnset.TimestampTZOutputFormat = Bool(true) + case SessionParameterTimezone: + sessionParametersUnset.Timezone = Bool(true) + case SessionParameterTimeInputFormat: + sessionParametersUnset.TimeInputFormat = Bool(true) + case SessionParameterTimeOutputFormat: + sessionParametersUnset.TimeOutputFormat = Bool(true) + case SessionParameterTransactionDefaultIsolationLevel: + sessionParametersUnset.TransactionDefaultIsolationLevel = Bool(true) + case SessionParameterTwoDigitCenturyStart: + sessionParametersUnset.TwoDigitCenturyStart = Bool(true) + case SessionParameterUnsupportedDDLAction: + sessionParametersUnset.UnsupportedDDLAction = Bool(true) + case SessionParameterUseCachedResult: + sessionParametersUnset.UseCachedResult = Bool(true) + case SessionParameterWeekOfYearPolicy: + sessionParametersUnset.WeekOfYearPolicy = Bool(true) + case SessionParameterWeekStart: + sessionParametersUnset.WeekStart = Bool(true) + default: + return fmt.Errorf("%s session parameter is not supported", string(parameter)) + } + return nil +} diff --git a/pkg/sdk/tasks_def.go b/pkg/sdk/tasks_def.go index 560c48c7f2..11e6840967 100644 --- a/pkg/sdk/tasks_def.go +++ b/pkg/sdk/tasks_def.go @@ -114,15 +114,18 @@ var TasksDef = g.NewInterface( OptionalQueryStructField( "Set", g.QueryStruct("TaskSet"). - OptionalIdentifier("Warehouse", g.KindOfT[AccountObjectIdentifier](), g.IdentifierOptions().SQL("WAREHOUSE")). + OptionalIdentifier("Warehouse", g.KindOfT[AccountObjectIdentifier](), g.IdentifierOptions().Equals().SQL("WAREHOUSE")). + OptionalAssignment("USER_TASK_MANAGED_INITIAL_WAREHOUSE_SIZE", "WarehouseSize", g.ParameterOptions().SingleQuotes()). OptionalTextAssignment("SCHEDULE", g.ParameterOptions().SingleQuotes()). OptionalTextAssignment("CONFIG", g.ParameterOptions().NoQuotes()). OptionalBooleanAssignment("ALLOW_OVERLAPPING_EXECUTION", nil). OptionalNumberAssignment("USER_TASK_TIMEOUT_MS", nil). OptionalNumberAssignment("SUSPEND_TASK_AFTER_NUM_FAILURES", nil). + OptionalTextAssignment("ERROR_INTEGRATION", g.ParameterOptions().NoQuotes()). OptionalTextAssignment("COMMENT", g.ParameterOptions().SingleQuotes()). OptionalSessionParameters(). - WithValidation(g.AtLeastOneValueSet, "Warehouse", "Schedule", "Config", "AllowOverlappingExecution", "UserTaskTimeoutMs", "SuspendTaskAfterNumFailures", "Comment", "SessionParameters"), + WithValidation(g.AtLeastOneValueSet, "Warehouse", "UserTaskManagedInitialWarehouseSize", "Schedule", "Config", "AllowOverlappingExecution", "UserTaskTimeoutMs", "SuspendTaskAfterNumFailures", "ErrorIntegration", "Comment", "SessionParameters"). + WithValidation(g.ConflictingFields, "Warehouse", "UserTaskManagedInitialWarehouseSize"), g.KeywordOptions().SQL("SET"), ). OptionalQueryStructField( @@ -134,9 +137,10 @@ var TasksDef = g.NewInterface( OptionalSQL("ALLOW_OVERLAPPING_EXECUTION"). OptionalSQL("USER_TASK_TIMEOUT_MS"). OptionalSQL("SUSPEND_TASK_AFTER_NUM_FAILURES"). + OptionalSQL("ERROR_INTEGRATION"). OptionalSQL("COMMENT"). OptionalSessionParametersUnset(). - WithValidation(g.AtLeastOneValueSet, "Warehouse", "Schedule", "Config", "AllowOverlappingExecution", "UserTaskTimeoutMs", "SuspendTaskAfterNumFailures", "Comment", "SessionParametersUnset"), + WithValidation(g.AtLeastOneValueSet, "Warehouse", "Schedule", "Config", "AllowOverlappingExecution", "UserTaskTimeoutMs", "SuspendTaskAfterNumFailures", "ErrorIntegration", "Comment", "SessionParametersUnset"), g.KeywordOptions().SQL("UNSET"), ). SetTags(). diff --git a/pkg/sdk/tasks_dto_builders_gen.go b/pkg/sdk/tasks_dto_builders_gen.go index 15fadaa1c1..fd02c08b2e 100644 --- a/pkg/sdk/tasks_dto_builders_gen.go +++ b/pkg/sdk/tasks_dto_builders_gen.go @@ -195,6 +195,11 @@ func (s *TaskSetRequest) WithWarehouse(Warehouse *AccountObjectIdentifier) *Task return s } +func (s *TaskSetRequest) WithUserTaskManagedInitialWarehouseSize(UserTaskManagedInitialWarehouseSize *WarehouseSize) *TaskSetRequest { + s.UserTaskManagedInitialWarehouseSize = UserTaskManagedInitialWarehouseSize + return s +} + func (s *TaskSetRequest) WithSchedule(Schedule *string) *TaskSetRequest { s.Schedule = Schedule return s @@ -220,6 +225,11 @@ func (s *TaskSetRequest) WithSuspendTaskAfterNumFailures(SuspendTaskAfterNumFail return s } +func (s *TaskSetRequest) WithErrorIntegration(ErrorIntegration *string) *TaskSetRequest { + s.ErrorIntegration = ErrorIntegration + return s +} + func (s *TaskSetRequest) WithComment(Comment *string) *TaskSetRequest { s.Comment = Comment return s @@ -264,6 +274,11 @@ func (s *TaskUnsetRequest) WithSuspendTaskAfterNumFailures(SuspendTaskAfterNumFa return s } +func (s *TaskUnsetRequest) WithErrorIntegration(ErrorIntegration *bool) *TaskUnsetRequest { + s.ErrorIntegration = ErrorIntegration + return s +} + func (s *TaskUnsetRequest) WithComment(Comment *bool) *TaskUnsetRequest { s.Comment = Comment return s diff --git a/pkg/sdk/tasks_dto_gen.go b/pkg/sdk/tasks_dto_gen.go index 6307db4c7c..aa58a0c7aa 100644 --- a/pkg/sdk/tasks_dto_gen.go +++ b/pkg/sdk/tasks_dto_gen.go @@ -68,14 +68,16 @@ type AlterTaskRequest struct { } type TaskSetRequest struct { - Warehouse *AccountObjectIdentifier - Schedule *string - Config *string - AllowOverlappingExecution *bool - UserTaskTimeoutMs *int - SuspendTaskAfterNumFailures *int - Comment *string - SessionParameters *SessionParameters + Warehouse *AccountObjectIdentifier + UserTaskManagedInitialWarehouseSize *WarehouseSize + Schedule *string + Config *string + AllowOverlappingExecution *bool + UserTaskTimeoutMs *int + SuspendTaskAfterNumFailures *int + ErrorIntegration *string + Comment *string + SessionParameters *SessionParameters } type TaskUnsetRequest struct { @@ -85,6 +87,7 @@ type TaskUnsetRequest struct { AllowOverlappingExecution *bool UserTaskTimeoutMs *bool SuspendTaskAfterNumFailures *bool + ErrorIntegration *bool Comment *bool SessionParametersUnset *SessionParametersUnset } diff --git a/pkg/sdk/tasks_gen.go b/pkg/sdk/tasks_gen.go index a3527a9ae0..c7e31aeb11 100644 --- a/pkg/sdk/tasks_gen.go +++ b/pkg/sdk/tasks_gen.go @@ -75,14 +75,16 @@ type AlterTaskOptions struct { } type TaskSet struct { - Warehouse *AccountObjectIdentifier `ddl:"identifier" sql:"WAREHOUSE"` - Schedule *string `ddl:"parameter,single_quotes" sql:"SCHEDULE"` - Config *string `ddl:"parameter,no_quotes" sql:"CONFIG"` - AllowOverlappingExecution *bool `ddl:"parameter" sql:"ALLOW_OVERLAPPING_EXECUTION"` - UserTaskTimeoutMs *int `ddl:"parameter" sql:"USER_TASK_TIMEOUT_MS"` - SuspendTaskAfterNumFailures *int `ddl:"parameter" sql:"SUSPEND_TASK_AFTER_NUM_FAILURES"` - Comment *string `ddl:"parameter,single_quotes" sql:"COMMENT"` - SessionParameters *SessionParameters `ddl:"list,no_parentheses"` + Warehouse *AccountObjectIdentifier `ddl:"identifier,equals" sql:"WAREHOUSE"` + UserTaskManagedInitialWarehouseSize *WarehouseSize `ddl:"parameter,single_quotes" sql:"USER_TASK_MANAGED_INITIAL_WAREHOUSE_SIZE"` + Schedule *string `ddl:"parameter,single_quotes" sql:"SCHEDULE"` + Config *string `ddl:"parameter,no_quotes" sql:"CONFIG"` + AllowOverlappingExecution *bool `ddl:"parameter" sql:"ALLOW_OVERLAPPING_EXECUTION"` + UserTaskTimeoutMs *int `ddl:"parameter" sql:"USER_TASK_TIMEOUT_MS"` + SuspendTaskAfterNumFailures *int `ddl:"parameter" sql:"SUSPEND_TASK_AFTER_NUM_FAILURES"` + ErrorIntegration *string `ddl:"parameter,no_quotes" sql:"ERROR_INTEGRATION"` + Comment *string `ddl:"parameter,single_quotes" sql:"COMMENT"` + SessionParameters *SessionParameters `ddl:"list,no_parentheses"` } type TaskUnset struct { @@ -92,6 +94,7 @@ type TaskUnset struct { AllowOverlappingExecution *bool `ddl:"keyword" sql:"ALLOW_OVERLAPPING_EXECUTION"` UserTaskTimeoutMs *bool `ddl:"keyword" sql:"USER_TASK_TIMEOUT_MS"` SuspendTaskAfterNumFailures *bool `ddl:"keyword" sql:"SUSPEND_TASK_AFTER_NUM_FAILURES"` + ErrorIntegration *bool `ddl:"keyword" sql:"ERROR_INTEGRATION"` Comment *bool `ddl:"keyword" sql:"COMMENT"` SessionParametersUnset *SessionParametersUnset `ddl:"list,no_parentheses"` } @@ -149,8 +152,8 @@ type Task struct { Comment string Warehouse string Schedule string - Predecessors string - State string + Predecessors []SchemaObjectIdentifier + State TaskState Definition string Condition string AllowOverlappingExecution bool @@ -180,3 +183,14 @@ type ExecuteTaskOptions struct { func (v *Task) ID() SchemaObjectIdentifier { return NewSchemaObjectIdentifier(v.DatabaseName, v.SchemaName, v.Name) } + +type TaskState string + +const ( + TaskStateStarted TaskState = "started" + TaskStateSuspended TaskState = "suspended" +) + +func (v *Task) IsStarted() bool { + return v.State == TaskStateStarted +} diff --git a/pkg/sdk/tasks_gen_test.go b/pkg/sdk/tasks_gen_test.go index 12b7eeabb8..c167a0998d 100644 --- a/pkg/sdk/tasks_gen_test.go +++ b/pkg/sdk/tasks_gen_test.go @@ -166,10 +166,19 @@ func TestTasks_Alter(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("Resume", "Suspend", "RemoveAfter", "AddAfter", "Set", "Unset", "SetTags", "UnsetTags", "ModifyAs", "ModifyWhen")) }) - t.Run("validation: at least one of the fields [opts.Set.Warehouse opts.Set.Schedule opts.Set.Config opts.Set.AllowOverlappingExecution opts.Set.UserTaskTimeoutMs opts.Set.SuspendTaskAfterNumFailures opts.Set.Comment opts.Set.SessionParameters] should be set", func(t *testing.T) { + t.Run("validation: at least one of the fields [opts.Set.Warehouse opts.Set.UserTaskManagedInitialWarehouseSize opts.Set.Schedule opts.Set.Config opts.Set.AllowOverlappingExecution opts.Set.UserTaskTimeoutMs opts.Set.SuspendTaskAfterNumFailures opts.Set.ErrorIntegration opts.Set.Comment opts.Set.SessionParameters] should be set", func(t *testing.T) { opts := defaultOpts() opts.Set = &TaskSet{} - assertOptsInvalidJoinedErrors(t, opts, errAtLeastOneOf("Warehouse", "Schedule", "Config", "AllowOverlappingExecution", "UserTaskTimeoutMs", "SuspendTaskAfterNumFailures", "Comment", "SessionParameters")) + assertOptsInvalidJoinedErrors(t, opts, errAtLeastOneOf("Warehouse", "UserTaskManagedInitialWarehouseSize", "Schedule", "Config", "AllowOverlappingExecution", "UserTaskTimeoutMs", "SuspendTaskAfterNumFailures", "ErrorIntegration", "Comment", "SessionParameters")) + }) + + t.Run("validation: conflicting fields for [opts.Set.Warehouse opts.Set.UserTaskManagedInitialWarehouseSize]", func(t *testing.T) { + warehouseId := RandomAccountObjectIdentifier() + opts := defaultOpts() + opts.Set = &TaskSet{} + opts.Set.Warehouse = &warehouseId + opts.Set.UserTaskManagedInitialWarehouseSize = &WarehouseSizeXSmall + assertOptsInvalidJoinedErrors(t, opts, errOneOf("Set", "Warehouse", "UserTaskManagedInitialWarehouseSize")) }) t.Run("validation: opts.Set.SessionParameters.SessionParameters should be valid", func(t *testing.T) { @@ -181,10 +190,10 @@ func TestTasks_Alter(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, fmt.Errorf("JSON_INDENT must be between 0 and 16")) }) - t.Run("validation: at least one of the fields [opts.Unset.Warehouse opts.Unset.Schedule opts.Unset.Config opts.Unset.AllowOverlappingExecution opts.Unset.UserTaskTimeoutMs opts.Unset.SuspendTaskAfterNumFailures opts.Unset.Comment opts.Unset.SessionParametersUnset] should be set", func(t *testing.T) { + t.Run("validation: at least one of the fields [opts.Unset.Warehouse opts.Unset.Schedule opts.Unset.Config opts.Unset.AllowOverlappingExecution opts.Unset.UserTaskTimeoutMs opts.Unset.SuspendTaskAfterNumFailures opts.Unset.ErrorIntegration opts.Unset.Comment opts.Unset.SessionParametersUnset] should be set", func(t *testing.T) { opts := defaultOpts() opts.Unset = &TaskUnset{} - assertOptsInvalidJoinedErrors(t, opts, errAtLeastOneOf("Warehouse", "Schedule", "Config", "AllowOverlappingExecution", "UserTaskTimeoutMs", "SuspendTaskAfterNumFailures", "Comment", "SessionParametersUnset")) + assertOptsInvalidJoinedErrors(t, opts, errAtLeastOneOf("Warehouse", "Schedule", "Config", "AllowOverlappingExecution", "UserTaskTimeoutMs", "SuspendTaskAfterNumFailures", "ErrorIntegration", "Comment", "SessionParametersUnset")) }) t.Run("validation: opts.Unset.SessionParametersUnset.SessionParametersUnset should be valid", func(t *testing.T) { @@ -226,6 +235,15 @@ func TestTasks_Alter(t *testing.T) { assertOptsValidAndSQLEquals(t, opts, "ALTER TASK %s SET COMMENT = 'some comment'", id.FullyQualifiedName()) }) + t.Run("alter set warehouse", func(t *testing.T) { + warehouseId := RandomAccountObjectIdentifier() + opts := defaultOpts() + opts.Set = &TaskSet{ + Warehouse: &warehouseId, + } + assertOptsValidAndSQLEquals(t, opts, "ALTER TASK %s SET WAREHOUSE = %s", id.FullyQualifiedName(), warehouseId.FullyQualifiedName()) + }) + t.Run("alter set session parameter", func(t *testing.T) { opts := defaultOpts() opts.Set = &TaskSet{ diff --git a/pkg/sdk/tasks_impl_gen.go b/pkg/sdk/tasks_impl_gen.go index 612e0f8e6e..dd6f3070aa 100644 --- a/pkg/sdk/tasks_impl_gen.go +++ b/pkg/sdk/tasks_impl_gen.go @@ -2,6 +2,9 @@ package sdk import ( "context" + "encoding/json" + "fmt" + "strings" ) var _ Tasks = (*tasks)(nil) @@ -60,6 +63,45 @@ func (v *tasks) Execute(ctx context.Context, request *ExecuteTaskRequest) error return validateAndExec(v.client, ctx, opts) } +// GetRootTasks is a way to get all root tasks for the given tasks. +// Snowflake does not have (yet) a method to do it without traversing the task graph manually. +// Task DAG should have a single root but this is checked when the root task is being resumed; that's why we return here multiple roots. +// Cycles should not be possible in a task DAG but it is checked when the root task is being resumed; that's why this method has to be cycle-proof. +// TODO [SNOW-884987]: handle cycles +func GetRootTasks(v Tasks, ctx context.Context, id SchemaObjectIdentifier) ([]Task, error) { + task, err := v.ShowByID(ctx, id) + if err != nil { + return nil, err + } + + predecessors := task.Predecessors + // no predecessors mean this is a root task + if len(predecessors) == 0 { + return []Task{*task}, nil + } + + rootTasks := make([]Task, 0, len(predecessors)) + for _, predecessor := range predecessors { + predecessorTasks, err := GetRootTasks(v, ctx, predecessor) + if err != nil { + return nil, fmt.Errorf("unable to get predecessors for task %s err = %w", predecessor.FullyQualifiedName(), err) + } + rootTasks = append(rootTasks, predecessorTasks...) + } + + // TODO [SNOW-884987]: extract unique function in our collection helper (if cycle-proof algorithm still needs it) + keys := make(map[string]bool) + uniqueRootTasks := make([]Task, 0, len(rootTasks)) + for _, rootTask := range rootTasks { + if _, exists := keys[rootTask.ID().FullyQualifiedName()]; !exists { + keys[rootTask.ID().FullyQualifiedName()] = true + uniqueRootTasks = append(uniqueRootTasks, rootTask) + } + } + + return uniqueRootTasks, nil +} + func (r *CreateTaskRequest) toOpts() *CreateTaskOptions { opts := &CreateTaskOptions{ OrReplace: r.OrReplace, @@ -115,14 +157,16 @@ func (r *AlterTaskRequest) toOpts() *AlterTaskOptions { } if r.Set != nil { opts.Set = &TaskSet{ - Warehouse: r.Set.Warehouse, - Schedule: r.Set.Schedule, - Config: r.Set.Config, - AllowOverlappingExecution: r.Set.AllowOverlappingExecution, - UserTaskTimeoutMs: r.Set.UserTaskTimeoutMs, - SuspendTaskAfterNumFailures: r.Set.SuspendTaskAfterNumFailures, - Comment: r.Set.Comment, - SessionParameters: r.Set.SessionParameters, + Warehouse: r.Set.Warehouse, + UserTaskManagedInitialWarehouseSize: r.Set.UserTaskManagedInitialWarehouseSize, + Schedule: r.Set.Schedule, + Config: r.Set.Config, + AllowOverlappingExecution: r.Set.AllowOverlappingExecution, + UserTaskTimeoutMs: r.Set.UserTaskTimeoutMs, + SuspendTaskAfterNumFailures: r.Set.SuspendTaskAfterNumFailures, + ErrorIntegration: r.Set.ErrorIntegration, + Comment: r.Set.Comment, + SessionParameters: r.Set.SessionParameters, } } if r.Unset != nil { @@ -133,6 +177,7 @@ func (r *AlterTaskRequest) toOpts() *AlterTaskOptions { AllowOverlappingExecution: r.Unset.AllowOverlappingExecution, UserTaskTimeoutMs: r.Unset.UserTaskTimeoutMs, SuspendTaskAfterNumFailures: r.Unset.SuspendTaskAfterNumFailures, + ErrorIntegration: r.Unset.ErrorIntegration, Comment: r.Unset.Comment, SessionParametersUnset: r.Unset.SessionParametersUnset, } @@ -183,10 +228,21 @@ func (r taskDBRow) convert() *Task { task.Schedule = r.Schedule.String } if r.Predecessors.Valid { - task.Predecessors = r.Predecessors.String + names, err := getPredecessors(r.Predecessors.String) + ids := make([]SchemaObjectIdentifier, len(names)) + if err == nil { + for i, name := range names { + ids[i] = NewSchemaObjectIdentifier(r.DatabaseName, r.SchemaName, name) + } + } + task.Predecessors = ids } if r.State.Valid { - task.State = r.State.String + if strings.ToLower(r.State.String) == string(TaskStateStarted) { + task.State = TaskStateStarted + } else { + task.State = TaskStateSuspended + } } if r.Definition.Valid { task.Definition = r.Definition.String @@ -218,6 +274,26 @@ func (r taskDBRow) convert() *Task { return &task } +func getPredecessors(predecessors string) ([]string, error) { + // Since 2022_03, Snowflake returns this as a JSON array (even empty) + // The list is formatted, e.g.: + // e.g. `[\n \"\\\"qgb)Z1KcNWJ(\\\".\\\"glN@JtR=7dzP$7\\\".\\\"_XEL(7N_F?@frgT5>dQS>V|vSy,J\\\"\"\n]`. + predecessorNames := make([]string, 0) + err := json.Unmarshal([]byte(predecessors), &predecessorNames) + if err == nil { + for i, predecessorName := range predecessorNames { + formattedName := strings.Trim(predecessorName, "\\\"") + idx := strings.LastIndex(formattedName, "\"") + 1 + if strings.LastIndex(formattedName, ".\"")+2 < idx { + idx++ + } + formattedName = formattedName[idx:] + predecessorNames[i] = formattedName + } + } + return predecessorNames, err +} + func (r *DescribeTaskRequest) toOpts() *DescribeTaskOptions { opts := &DescribeTaskOptions{ name: r.name, diff --git a/pkg/sdk/tasks_test.go b/pkg/sdk/tasks_test.go new file mode 100644 index 0000000000..9d9f27d7d5 --- /dev/null +++ b/pkg/sdk/tasks_test.go @@ -0,0 +1,113 @@ +package sdk + +import ( + "context" + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type testTasks struct { + tasks + + stubbedTasks map[string]*Task +} + +func (v *testTasks) ShowByID(ctx context.Context, id SchemaObjectIdentifier) (*Task, error) { + t, ok := v.stubbedTasks[id.Name()] + if !ok { + return nil, errors.New("no task configured, check test config") + } + return t, nil +} + +func TestTasks_GetRootTasks(t *testing.T) { + db := "database" + sc := "schema" + + setUpTasks := func(p map[string][]string) map[string]*Task { + r := make(map[string]*Task) + for k, v := range p { + if k == "initial" || k == "expected" { + continue + } + t := Task{DatabaseName: db, SchemaName: sc, Name: k} + predecessors := make([]SchemaObjectIdentifier, len(v)) + for i, predecessor := range v { + predecessors[i] = NewSchemaObjectIdentifier(db, sc, predecessor) + } + t.Predecessors = predecessors + r[k] = &t + } + return r + } + + // To increase readability tests are defined as maps (anonymous structs looked much worse in this case). + // What are the contents of the map: + // - key "initial" -> one element list with the task for which we will be getting root tasks + // - key "expected" -> list of the expected root tasks + // - any other key is considered as a task definition, that contains all direct predecessors (empty list for root task). + tests := []map[string][]string{ + {"t1": {}, "initial": {"t1"}, "expected": {"t1"}}, + {"t1": {"t2"}, "t2": {"t3"}, "t3": {}, "initial": {"t1"}, "expected": {"t3"}}, + {"t1": {"t2", "t3"}, "t2": {"t3"}, "t3": {}, "initial": {"t1"}, "expected": {"t3"}}, + {"t1": {"t2", "t3"}, "t2": {}, "t3": {}, "initial": {"t1"}, "expected": {"t2", "t3"}}, + {"t1": {}, "t2": {}, "initial": {"t1"}, "expected": {"t1"}}, + {"t1": {"t2", "t3", "t4"}, "t2": {}, "t3": {}, "t4": {}, "initial": {"t1"}, "expected": {"t2", "t3", "t4"}}, + {"t1": {"t2", "t3", "t4"}, "t2": {}, "t3": {"t2"}, "t4": {"t3"}, "initial": {"t1"}, "expected": {"t2"}}, + // {"r": {}, "t1": {"t2", "r"}, "t2": {"t3"}, "t3": {"t1"}, "initial": {"t1"}, "expected": {"r"}}, // cycle -> failing for current (old) implementation + } + for i, tt := range tests { + t.Run(fmt.Sprintf("test case [%v]", i), func(t *testing.T) { + ctx := context.Background() + initial, ok := tt["initial"] + if !ok { + t.FailNow() + } + expected, ok := tt["expected"] + if !ok { + t.FailNow() + } + client := new(testTasks) + client.stubbedTasks = setUpTasks(tt) + + rootTasks, err := GetRootTasks(client, ctx, NewSchemaObjectIdentifier(db, sc, initial[0])) + require.NoError(t, err) + for _, v := range rootTasks { + assert.Contains(t, expected, v.Name) + } + require.Len(t, rootTasks, len(expected)) + }) + } +} + +func Test_getPredecessors(t *testing.T) { + special := "!@#$%&*+-_=?:;,.|(){}<>" + + tests := []struct { + predecessorsRaw string + expectedPredecessors []string + }{ + {predecessorsRaw: "[]", expectedPredecessors: []string{}}, + {predecessorsRaw: "[\n \"\\\"qgb)Z1KcNWJ(\\\".\\\"glN@JtR=7dzP$7\\\".\\\"Ls.T7-(bt{.lWd@DRWkyA6<6hNdh\\\"\"\n]", expectedPredecessors: []string{"Ls.T7-(bt{.lWd@DRWkyA6<6hNdh"}}, + {predecessorsRaw: "[\n \"\\\"qgb)Z1KcNWJ(\\\".\\\"glN@JtR=7dzP$7\\\".Ls.T7-(bt{.lWd@DRWkyA6<6hNdh\"\n]", expectedPredecessors: []string{"Ls.T7-(bt{.lWd@DRWkyA6<6hNdh"}}, + {predecessorsRaw: fmt.Sprintf("[\n \"\\\"a\\\".\\\"b\\\".\\\"%s\\\"\"\n]", special), expectedPredecessors: []string{special}}, + {predecessorsRaw: "[\n \"\\\"a\\\".\\\"b\\\".\\\"c\\\"\",\"\\\"a\\\".\\\"b\\\".\\\"d\\\"\",\"\\\"a\\\".\\\"b\\\".\\\"e\\\"\"\n]", expectedPredecessors: []string{"c", "d", "e"}}, + {predecessorsRaw: "[\"\\\"a\\\".\\\"b\\\".\\\".PHo,k:%Sz8tdx,9?23xTsgHLYxe\\\"\"]", expectedPredecessors: []string{".PHo,k:%Sz8tdx,9?23xTsgHLYxe"}}, + } + for i, tt := range tests { + t.Run(fmt.Sprintf("test number %d for input: [%s]", i, tt.predecessorsRaw), func(t *testing.T) { + got, err := getPredecessors(tt.predecessorsRaw) + require.NoError(t, err) + require.Equal(t, tt.expectedPredecessors, got) + }) + } + + t.Run("incorrect json", func(t *testing.T) { + _, err := getPredecessors("[{]") + require.ErrorContains(t, err, "invalid character ']'") + }) +} diff --git a/pkg/sdk/tasks_validations_gen.go b/pkg/sdk/tasks_validations_gen.go index 4148d375a3..b3b5317b53 100644 --- a/pkg/sdk/tasks_validations_gen.go +++ b/pkg/sdk/tasks_validations_gen.go @@ -62,8 +62,11 @@ func (opts *AlterTaskOptions) validate() error { errs = append(errs, errExactlyOneOf("Resume", "Suspend", "RemoveAfter", "AddAfter", "Set", "Unset", "SetTags", "UnsetTags", "ModifyAs", "ModifyWhen")) } if valueSet(opts.Set) { - if ok := anyValueSet(opts.Set.Warehouse, opts.Set.Schedule, opts.Set.Config, opts.Set.AllowOverlappingExecution, opts.Set.UserTaskTimeoutMs, opts.Set.SuspendTaskAfterNumFailures, opts.Set.Comment, opts.Set.SessionParameters); !ok { - errs = append(errs, errAtLeastOneOf("Warehouse", "Schedule", "Config", "AllowOverlappingExecution", "UserTaskTimeoutMs", "SuspendTaskAfterNumFailures", "Comment", "SessionParameters")) + if ok := anyValueSet(opts.Set.Warehouse, opts.Set.UserTaskManagedInitialWarehouseSize, opts.Set.Schedule, opts.Set.Config, opts.Set.AllowOverlappingExecution, opts.Set.UserTaskTimeoutMs, opts.Set.SuspendTaskAfterNumFailures, opts.Set.ErrorIntegration, opts.Set.Comment, opts.Set.SessionParameters); !ok { + errs = append(errs, errAtLeastOneOf("Warehouse", "UserTaskManagedInitialWarehouseSize", "Schedule", "Config", "AllowOverlappingExecution", "UserTaskTimeoutMs", "SuspendTaskAfterNumFailures", "ErrorIntegration", "Comment", "SessionParameters")) + } + if everyValueSet(opts.Set.Warehouse, opts.Set.UserTaskManagedInitialWarehouseSize) { + errs = append(errs, errOneOf("Set", "Warehouse", "UserTaskManagedInitialWarehouseSize")) } if valueSet(opts.Set.SessionParameters) { if err := opts.Set.SessionParameters.validate(); err != nil { @@ -72,8 +75,8 @@ func (opts *AlterTaskOptions) validate() error { } } if valueSet(opts.Unset) { - if ok := anyValueSet(opts.Unset.Warehouse, opts.Unset.Schedule, opts.Unset.Config, opts.Unset.AllowOverlappingExecution, opts.Unset.UserTaskTimeoutMs, opts.Unset.SuspendTaskAfterNumFailures, opts.Unset.Comment, opts.Unset.SessionParametersUnset); !ok { - errs = append(errs, errAtLeastOneOf("Warehouse", "Schedule", "Config", "AllowOverlappingExecution", "UserTaskTimeoutMs", "SuspendTaskAfterNumFailures", "Comment", "SessionParametersUnset")) + if ok := anyValueSet(opts.Unset.Warehouse, opts.Unset.Schedule, opts.Unset.Config, opts.Unset.AllowOverlappingExecution, opts.Unset.UserTaskTimeoutMs, opts.Unset.SuspendTaskAfterNumFailures, opts.Unset.ErrorIntegration, opts.Unset.Comment, opts.Unset.SessionParametersUnset); !ok { + errs = append(errs, errAtLeastOneOf("Warehouse", "Schedule", "Config", "AllowOverlappingExecution", "UserTaskTimeoutMs", "SuspendTaskAfterNumFailures", "ErrorIntegration", "Comment", "SessionParametersUnset")) } if valueSet(opts.Unset.SessionParametersUnset) { if err := opts.Unset.SessionParametersUnset.validate(); err != nil { diff --git a/pkg/sdk/testint/tasks_gen_integration_test.go b/pkg/sdk/testint/tasks_gen_integration_test.go index beb0da02de..d03415764f 100644 --- a/pkg/sdk/testint/tasks_gen_integration_test.go +++ b/pkg/sdk/testint/tasks_gen_integration_test.go @@ -27,8 +27,8 @@ func TestInt_Tasks(t *testing.T) { assert.Equal(t, "", task.Comment) assert.Equal(t, "", task.Warehouse) assert.Equal(t, "", task.Schedule) - assert.Equal(t, "[]", task.Predecessors) - assert.Equal(t, "suspended", task.State) + assert.Empty(t, task.Predecessors) + assert.Equal(t, sdk.TaskStateSuspended, task.State) assert.Equal(t, sql, task.Definition) assert.Equal(t, "", task.Condition) assert.Equal(t, false, task.AllowOverlappingExecution) @@ -52,7 +52,7 @@ func TestInt_Tasks(t *testing.T) { assert.Equal(t, comment, task.Comment) assert.Equal(t, warehouse, task.Warehouse) assert.Equal(t, schedule, task.Schedule) - assert.Equal(t, "suspended", task.State) + assert.Equal(t, sdk.TaskStateSuspended, task.State) assert.Equal(t, sql, task.Definition) assert.Equal(t, condition, task.Condition) assert.Equal(t, allowOverlappingExecution, task.AllowOverlappingExecution) @@ -63,14 +63,10 @@ func TestInt_Tasks(t *testing.T) { assert.Equal(t, config, task.Config) assert.Empty(t, task.Budget) if predecessor != nil { - // Predecessors list is formatted, so matching it is unnecessarily complicated: - // e.g. `[\n \"\\\"qgb)Z1KcNWJ(\\\".\\\"glN@JtR=7dzP$7\\\".\\\"_XEL(7N_F?@frgT5>dQS>V|vSy,J\\\"\"\n]`. - // We just match parts of the expected predecessor. Later we can parse the output while constructing Task object. - assert.Contains(t, task.Predecessors, predecessor.DatabaseName()) - assert.Contains(t, task.Predecessors, predecessor.SchemaName()) - assert.Contains(t, task.Predecessors, predecessor.Name()) + assert.Len(t, task.Predecessors, 1) + assert.Contains(t, task.Predecessors, *predecessor) } else { - assert.Equal(t, "[]", task.Predecessors) + assert.Empty(t, task.Predecessors) } } @@ -192,6 +188,123 @@ func TestInt_Tasks(t *testing.T) { assertTaskWithOptions(t, task, request.GetName(), "", "", "", "", false, "", &otherId) }) + t.Run("create dag of tasks", func(t *testing.T) { + rootName := random.String() + rootId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, rootName) + + request := sdk.NewCreateTaskRequest(rootId, sql).WithSchedule(sdk.String("10 MINUTE")) + root := createTaskWithRequest(t, request) + + require.Empty(t, root.Predecessors) + + t1Name := random.String() + t1Id := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, t1Name) + + request = sdk.NewCreateTaskRequest(t1Id, sql).WithAfter([]sdk.SchemaObjectIdentifier{rootId}) + t1 := createTaskWithRequest(t, request) + + require.Equal(t, []sdk.SchemaObjectIdentifier{rootId}, t1.Predecessors) + + t2Name := random.String() + t2Id := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, t2Name) + + request = sdk.NewCreateTaskRequest(t2Id, sql).WithAfter([]sdk.SchemaObjectIdentifier{t1Id, rootId}) + t2 := createTaskWithRequest(t, request) + + require.Contains(t, t2.Predecessors, rootId) + require.Contains(t, t2.Predecessors, t1Id) + require.Len(t, t2.Predecessors, 2) + + t3Name := random.String() + t3Id := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, t3Name) + + request = sdk.NewCreateTaskRequest(t3Id, sql).WithAfter([]sdk.SchemaObjectIdentifier{t2Id, t1Id}) + t3 := createTaskWithRequest(t, request) + + require.Contains(t, t3.Predecessors, t2Id) + require.Contains(t, t3.Predecessors, t1Id) + require.Len(t, t3.Predecessors, 2) + + rootTasks, err := sdk.GetRootTasks(client.Tasks, ctx, rootId) + require.NoError(t, err) + require.Len(t, rootTasks, 1) + require.Equal(t, rootId, rootTasks[0].ID()) + + rootTasks, err = sdk.GetRootTasks(client.Tasks, ctx, t1Id) + require.NoError(t, err) + require.Len(t, rootTasks, 1) + require.Equal(t, rootId, rootTasks[0].ID()) + + rootTasks, err = sdk.GetRootTasks(client.Tasks, ctx, t2Id) + require.NoError(t, err) + require.Len(t, rootTasks, 1) + require.Equal(t, rootId, rootTasks[0].ID()) + + rootTasks, err = sdk.GetRootTasks(client.Tasks, ctx, t3Id) + require.NoError(t, err) + require.Len(t, rootTasks, 1) + require.Equal(t, rootId, rootTasks[0].ID()) + + // cannot set ALLOW_OVERLAPPING_EXECUTION on child task + alterRequest := sdk.NewAlterTaskRequest(t1Id).WithSet(sdk.NewTaskSetRequest().WithAllowOverlappingExecution(sdk.Bool(true))) + err = client.Tasks.Alter(ctx, alterRequest) + require.ErrorContains(t, err, "Cannot set allow_overlapping_execution on non-root task") + + // can set ALLOW_OVERLAPPING_EXECUTION on root task + alterRequest = sdk.NewAlterTaskRequest(rootId).WithSet(sdk.NewTaskSetRequest().WithAllowOverlappingExecution(sdk.Bool(true))) + err = client.Tasks.Alter(ctx, alterRequest) + require.NoError(t, err) + + // can create cycle, because DAG is suspended + alterRequest = sdk.NewAlterTaskRequest(t1Id).WithAddAfter([]sdk.SchemaObjectIdentifier{t3Id}) + err = client.Tasks.Alter(ctx, alterRequest) + require.NoError(t, err) + + // we get an error when trying to start + alterRequest = sdk.NewAlterTaskRequest(rootId).WithResume(sdk.Bool(true)) + err = client.Tasks.Alter(ctx, alterRequest) + require.ErrorContains(t, err, "Graph has at least one cycle containing task") + }) + + t.Run("create dag of tasks - multiple roots", func(t *testing.T) { + root1Name := random.String() + root1Id := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, root1Name) + + request := sdk.NewCreateTaskRequest(root1Id, sql).WithSchedule(sdk.String("10 MINUTE")) + root1 := createTaskWithRequest(t, request) + + require.Empty(t, root1.Predecessors) + + root2Name := random.String() + root2Id := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, root2Name) + + request = sdk.NewCreateTaskRequest(root2Id, sql).WithSchedule(sdk.String("10 MINUTE")) + root2 := createTaskWithRequest(t, request) + + require.Empty(t, root2.Predecessors) + + t1Name := random.String() + t1Id := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, t1Name) + + request = sdk.NewCreateTaskRequest(t1Id, sql).WithAfter([]sdk.SchemaObjectIdentifier{root1Id, root2Id}) + t1 := createTaskWithRequest(t, request) + + require.Contains(t, t1.Predecessors, root1Id) + require.Contains(t, t1.Predecessors, root2Id) + require.Len(t, t1.Predecessors, 2) + + rootTasks, err := sdk.GetRootTasks(client.Tasks, ctx, t1Id) + require.NoError(t, err) + require.Len(t, rootTasks, 2) + require.Contains(t, []sdk.SchemaObjectIdentifier{root1Id, root2Id}, rootTasks[0].ID()) + require.Contains(t, []sdk.SchemaObjectIdentifier{root1Id, root2Id}, rootTasks[1].ID()) + + // we get an error when trying to start + alterRequest := sdk.NewAlterTaskRequest(root1Id).WithResume(sdk.Bool(true)) + err = client.Tasks.Alter(ctx, alterRequest) + require.ErrorContains(t, err, "The graph has more than one root task (one without predecessors)") + }) + // TODO: this fails with `syntax error line 1 at position 89 unexpected 'GRANTS'`. // The reason is that in the documentation there is a note: "This parameter is not supported currently.". // t.Run("create task: with grants", func(t *testing.T) { @@ -335,7 +448,7 @@ func TestInt_Tasks(t *testing.T) { task := createTaskWithRequest(t, request) id := task.ID() - assert.Equal(t, "suspended", task.State) + assert.Equal(t, sdk.TaskStateSuspended, task.State) alterRequest := sdk.NewAlterTaskRequest(id).WithResume(sdk.Bool(true)) err := client.Tasks.Alter(ctx, alterRequest) @@ -344,7 +457,7 @@ func TestInt_Tasks(t *testing.T) { alteredTask, err := client.Tasks.ShowByID(ctx, id) require.NoError(t, err) - assert.Equal(t, "started", alteredTask.State) + assert.Equal(t, sdk.TaskStateStarted, alteredTask.State) alterRequest = sdk.NewAlterTaskRequest(id).WithSuspend(sdk.Bool(true)) err = client.Tasks.Alter(ctx, alterRequest) @@ -353,7 +466,7 @@ func TestInt_Tasks(t *testing.T) { alteredTask, err = client.Tasks.ShowByID(ctx, id) require.NoError(t, err) - assert.Equal(t, "suspended", alteredTask.State) + assert.Equal(t, sdk.TaskStateSuspended, alteredTask.State) }) t.Run("alter task: remove after and add after", func(t *testing.T) { @@ -369,7 +482,7 @@ func TestInt_Tasks(t *testing.T) { task := createTaskWithRequest(t, request) id := task.ID() - assert.Contains(t, task.Predecessors, otherId.Name()) + assert.Contains(t, task.Predecessors, otherId) alterRequest := sdk.NewAlterTaskRequest(id).WithRemoveAfter([]sdk.SchemaObjectIdentifier{otherId}) @@ -379,7 +492,7 @@ func TestInt_Tasks(t *testing.T) { task, err = client.Tasks.ShowByID(ctx, id) require.NoError(t, err) - assert.Equal(t, "[]", task.Predecessors) + assert.Empty(t, task.Predecessors) alterRequest = sdk.NewAlterTaskRequest(id).WithAddAfter([]sdk.SchemaObjectIdentifier{otherId}) @@ -389,7 +502,7 @@ func TestInt_Tasks(t *testing.T) { task, err = client.Tasks.ShowByID(ctx, id) require.NoError(t, err) - assert.Contains(t, task.Predecessors, otherId.Name()) + assert.Contains(t, task.Predecessors, otherId) }) t.Run("alter task: modify when and as", func(t *testing.T) { diff --git a/pkg/snowflake/task.go b/pkg/snowflake/task.go deleted file mode 100644 index 5d7c9e0ab4..0000000000 --- a/pkg/snowflake/task.go +++ /dev/null @@ -1,552 +0,0 @@ -package snowflake - -import ( - "database/sql" - "encoding/json" - "errors" - "fmt" - "log" - "sort" - "strconv" - "strings" - "time" - - "github.com/jmoiron/sqlx" -) - -// TaskBuilder abstracts the creation of sql queries for a snowflake task. -type TaskBuilder struct { - name string - db string - schema string - warehouse string - schedule string - sessionParameters map[string]interface{} - userTaskTimeoutMS int - comment string - after []string - when string - SQLStatement string - disabled bool - userTaskManagedInitialWarehouseSize string - errorIntegration string - allowOverlappingExecution bool -} - -// GetFullName prepends db and schema to in parameter. -func (tb *TaskBuilder) GetFullName(name string) string { - var n strings.Builder - - n.WriteString(fmt.Sprintf(`"%v"."%v"."%v"`, tb.db, tb.schema, name)) - - return n.String() -} - -// QualifiedName prepends the db and schema and escapes everything nicely. -func (tb *TaskBuilder) QualifiedName() string { - return tb.GetFullName(tb.name) -} - -// Name returns the name of the task. -func (tb *TaskBuilder) Name() string { - return tb.name -} - -// WithWarehouse adds a warehouse to the TaskBuilder. -func (tb *TaskBuilder) WithWarehouse(s string) *TaskBuilder { - tb.warehouse = s - return tb -} - -// WithSchedule adds a schedule to the TaskBuilder. -func (tb *TaskBuilder) WithSchedule(s string) *TaskBuilder { - tb.schedule = s - return tb -} - -// WithSessionParameters adds session parameters to the TaskBuilder. -func (tb *TaskBuilder) WithSessionParameters(params map[string]interface{}) *TaskBuilder { - tb.sessionParameters = params - return tb -} - -// WithComment adds a comment to the TaskBuilder. -func (tb *TaskBuilder) WithComment(c string) *TaskBuilder { - tb.comment = c - return tb -} - -// WithAllowOverlappingExecution set the ALLOW_OVERLAPPING_EXECUTION on the TaskBuilder. -func (tb *TaskBuilder) WithAllowOverlappingExecution(flag bool) *TaskBuilder { - tb.allowOverlappingExecution = flag - return tb -} - -// WithTimeout adds a timeout to the TaskBuilder. -func (tb *TaskBuilder) WithTimeout(t int) *TaskBuilder { - tb.userTaskTimeoutMS = t - return tb -} - -// WithAfter adds after task dependencies to the TaskBuilder. -func (tb *TaskBuilder) WithAfter(after []string) *TaskBuilder { - tb.after = after - return tb -} - -// WithCondition adds a WHEN condition to the TaskBuilder. -func (tb *TaskBuilder) WithCondition(when string) *TaskBuilder { - tb.when = when - return tb -} - -// WithStatement adds a sql statement to the TaskBuilder. -func (tb *TaskBuilder) WithStatement(sql string) *TaskBuilder { - tb.SQLStatement = sql - return tb -} - -// WithInitialWarehouseSize adds an initial warehouse size to the TaskBuilder. -func (tb *TaskBuilder) WithInitialWarehouseSize(initialWarehouseSize string) *TaskBuilder { - tb.userTaskManagedInitialWarehouseSize = initialWarehouseSize - return tb -} - -// WithErrorIntegration adds ErrorIntegration specification to the TaskBuilder. -func (tb *TaskBuilder) WithErrorIntegration(s string) *TaskBuilder { - tb.errorIntegration = s - return tb -} - -// Task returns a pointer to a Builder that abstracts the DDL operations for a task. -// -// Supported DDL operations are: -// - CREATE TASK -// - ALTER TASK -// - DROP TASK -// - DESCRIBE TASK -// -// [Snowflake Reference](https://docs.snowflake.com/en/user-guide/tasks-intro.html#task-ddl) -func NewTaskBuilder(name, db, schema string) *TaskBuilder { - return &TaskBuilder{ - name: name, - db: db, - schema: schema, - disabled: false, // helper for when started root or standalone task gets suspended - } -} - -// Create returns the SQL that will create a new task. -func (tb *TaskBuilder) Create() string { - q := strings.Builder{} - q.WriteString(`CREATE`) - - q.WriteString(fmt.Sprintf(` TASK %v`, tb.QualifiedName())) - - if tb.warehouse != "" { - q.WriteString(fmt.Sprintf(` WAREHOUSE = "%v"`, EscapeString(tb.warehouse))) - } else if tb.userTaskManagedInitialWarehouseSize != "" { - q.WriteString(fmt.Sprintf(` USER_TASK_MANAGED_INITIAL_WAREHOUSE_SIZE = '%v'`, EscapeString(tb.userTaskManagedInitialWarehouseSize))) - } - - if tb.schedule != "" { - q.WriteString(fmt.Sprintf(` SCHEDULE = '%v'`, EscapeString(tb.schedule))) - } - - if len(tb.sessionParameters) > 0 { - sp := make([]string, 0) - sortedKeys := make([]string, 0) - for k := range tb.sessionParameters { - sortedKeys = append(sortedKeys, k) - } - sort.Strings(sortedKeys) - - for _, k := range sortedKeys { - sp = append(sp, EscapeString(fmt.Sprintf(`%v = "%v"`, k, tb.sessionParameters[k]))) - } - q.WriteString(fmt.Sprintf(` %v`, strings.Join(sp, ", "))) - } - - if tb.comment != "" { - q.WriteString(fmt.Sprintf(` COMMENT = '%v'`, EscapeString(tb.comment))) - } - - if tb.allowOverlappingExecution { - q.WriteString(` ALLOW_OVERLAPPING_EXECUTION = TRUE`) - } - - if tb.errorIntegration != "" { - q.WriteString(fmt.Sprintf(` ERROR_INTEGRATION = '%v'`, EscapeString(tb.errorIntegration))) - } - - if tb.userTaskTimeoutMS > 0 { - q.WriteString(fmt.Sprintf(` USER_TASK_TIMEOUT_MS = %v`, tb.userTaskTimeoutMS)) - } - - if len(tb.after) > 0 { - after := make([]string, 0) - for _, a := range tb.after { - after = append(after, tb.GetFullName(a)) - } - q.WriteString(fmt.Sprintf(` AFTER %v`, strings.Join(after, ", "))) - } - - if tb.when != "" { - q.WriteString(fmt.Sprintf(` WHEN %v`, tb.when)) - } - - if tb.SQLStatement != "" { - q.WriteString(fmt.Sprintf(` AS %v`, UnescapeString(tb.SQLStatement))) - } - - return q.String() -} - -// ChangeWarehouse returns the sql that will change the warehouse for the task. -func (tb *TaskBuilder) ChangeWarehouse(newWh string) string { - return fmt.Sprintf(`ALTER TASK %v SET WAREHOUSE = "%v"`, tb.QualifiedName(), EscapeString(newWh)) -} - -// SwitchWarehouseToManaged returns the sql that will switch to managed warehouse. -func (tb *TaskBuilder) SwitchWarehouseToManaged() string { - return fmt.Sprintf(`ALTER TASK %v SET WAREHOUSE = null`, tb.QualifiedName()) -} - -// SwitchManagedWithInitialSize returns the sql that will update warehouse initial size . -func (tb *TaskBuilder) SwitchManagedWithInitialSize(initialSize string) string { - return fmt.Sprintf(`ALTER TASK %v SET USER_TASK_MANAGED_INITIAL_WAREHOUSE_SIZE = '%v'`, tb.QualifiedName(), EscapeString(initialSize)) -} - -// ChangeSchedule returns the sql that will change the schedule for the task. -func (tb *TaskBuilder) ChangeSchedule(newSchedule string) string { - return fmt.Sprintf(`ALTER TASK %v SET SCHEDULE = '%v'`, tb.QualifiedName(), EscapeString(newSchedule)) -} - -// RemoveSchedule returns the sql that will remove the schedule for the task. -func (tb *TaskBuilder) RemoveSchedule() string { - return fmt.Sprintf(`ALTER TASK %v UNSET SCHEDULE`, tb.QualifiedName()) -} - -// ChangeTimeout returns the sql that will change the user task timeout for the task. -func (tb *TaskBuilder) ChangeTimeout(newTimeout int) string { - return fmt.Sprintf(`ALTER TASK %v SET USER_TASK_TIMEOUT_MS = %v`, tb.QualifiedName(), newTimeout) -} - -// RemoveTimeout returns the sql that will remove the user task timeout for the task. -func (tb *TaskBuilder) RemoveTimeout() string { - return fmt.Sprintf(`ALTER TASK %v UNSET USER_TASK_TIMEOUT_MS`, tb.QualifiedName()) -} - -// ChangeComment returns the sql that will change the comment for the task. -func (tb *TaskBuilder) ChangeComment(newComment string) string { - return fmt.Sprintf(`ALTER TASK %v SET COMMENT = '%v'`, tb.QualifiedName(), EscapeString(newComment)) -} - -// RemoveComment returns the sql that will remove the comment for the task. -func (tb *TaskBuilder) RemoveComment() string { - return fmt.Sprintf(`ALTER TASK %v UNSET COMMENT`, tb.QualifiedName()) -} - -// SetAllowOverlappingExecutionParameter returns the sql that will change the ALLOW_OVERLAPPING_EXECUTION for the task. -func (tb *TaskBuilder) SetAllowOverlappingExecutionParameter() string { - return fmt.Sprintf(`ALTER TASK %v SET ALLOW_OVERLAPPING_EXECUTION = TRUE`, tb.QualifiedName()) -} - -// UnsetAllowOverlappingExecutionParameter returns the sql that will unset the ALLOW_OVERLAPPING_EXECUTION for the task. -func (tb *TaskBuilder) UnsetAllowOverlappingExecutionParameter() string { - return fmt.Sprintf(`ALTER TASK %v UNSET ALLOW_OVERLAPPING_EXECUTION`, tb.QualifiedName()) -} - -// AddAfter returns the sql that will add the after dependency for the task. -func (tb *TaskBuilder) AddAfter(after []string) string { - afterTasks := make([]string, 0) - for _, a := range after { - afterTasks = append(afterTasks, tb.GetFullName(a)) - } - return fmt.Sprintf(`ALTER TASK %v ADD AFTER %v`, tb.QualifiedName(), strings.Join(afterTasks, ", ")) -} - -// RemoveAfter returns the sql that will remove the after dependency for the task. -func (tb *TaskBuilder) RemoveAfter(after []string) string { - afterTasks := make([]string, 0) - for _, a := range after { - afterTasks = append(afterTasks, tb.GetFullName(a)) - } - return fmt.Sprintf(`ALTER TASK %v REMOVE AFTER %v`, tb.QualifiedName(), strings.Join(afterTasks, ", ")) -} - -// AddSessionParameters returns the sql that will remove the session parameters for the task. -func (tb *TaskBuilder) AddSessionParameters(params map[string]interface{}) string { - p := make([]string, 0) - sortedKeys := make([]string, 0) - for k := range params { - sortedKeys = append(sortedKeys, k) - } - sort.Strings(sortedKeys) - - for _, k := range sortedKeys { - p = append(p, EscapeString(fmt.Sprintf(`%v = "%v"`, k, params[k]))) - } - - return fmt.Sprintf(`ALTER TASK %v SET %v`, tb.QualifiedName(), strings.Join(p, ", ")) -} - -// RemoveSessionParameters returns the sql that will remove the session parameters for the task. -func (tb *TaskBuilder) RemoveSessionParameters(params map[string]interface{}) string { - sortedKeys := make([]string, 0) - for k := range params { - sortedKeys = append(sortedKeys, k) - } - sort.Strings(sortedKeys) - - return fmt.Sprintf(`ALTER TASK %v UNSET %v`, tb.QualifiedName(), strings.Join(sortedKeys, ", ")) -} - -// ChangeCondition returns the sql that will update the WHEN condition for the task. -func (tb *TaskBuilder) ChangeCondition(newCondition string) string { - return fmt.Sprintf(`ALTER TASK %v MODIFY WHEN %v`, tb.QualifiedName(), newCondition) -} - -// ChangeSQLStatement returns the sql that will update the sql the task executes. -func (tb *TaskBuilder) ChangeSQLStatement(newStatement string) string { - return fmt.Sprintf(`ALTER TASK %v MODIFY AS %v`, tb.QualifiedName(), UnescapeString(newStatement)) -} - -// Suspend returns the sql that will suspend the task. -func (tb *TaskBuilder) Suspend() string { - return fmt.Sprintf(`ALTER TASK %v SUSPEND`, tb.QualifiedName()) -} - -// Resume returns the sql that will resume the task. -func (tb *TaskBuilder) Resume() string { - return fmt.Sprintf(`ALTER TASK %v RESUME`, tb.QualifiedName()) -} - -// Drop returns the sql that will remove the task. -func (tb *TaskBuilder) Drop() string { - return fmt.Sprintf(`DROP TASK %v`, tb.QualifiedName()) -} - -// Describe returns the sql that will describe a task. -func (tb *TaskBuilder) Describe() string { - return fmt.Sprintf(`DESCRIBE TASK %v`, tb.QualifiedName()) -} - -// Show returns the sql that will show a task. -func (tb *TaskBuilder) Show() string { - return fmt.Sprintf(`SHOW TASKS LIKE '%v' IN SCHEMA "%v"."%v"`, EscapeString(tb.name), EscapeString(tb.db), EscapeString(tb.schema)) -} - -// ShowParameters returns the query to show the session parameters for the task. -func (tb *TaskBuilder) ShowParameters() string { - return fmt.Sprintf(`SHOW PARAMETERS IN TASK %v`, tb.QualifiedName()) -} - -// SetDisabled disables the task builder. -func (tb *TaskBuilder) SetDisabled() *TaskBuilder { - tb.disabled = true - return tb -} - -// IsDisabled returns if the task builder is disabled. -func (tb *TaskBuilder) IsDisabled() bool { - return tb.disabled -} - -// ChangeErrorIntegration return SQL query that will update the error_integration on the task. -func (tb *TaskBuilder) ChangeErrorIntegration(c string) string { - return fmt.Sprintf(`ALTER TASK %v SET ERROR_INTEGRATION = %v`, tb.QualifiedName(), EscapeString(c)) -} - -// RemoveErrorIntegration returns the SQL query that will remove the error_integration on the task. -func (tb *TaskBuilder) RemoveErrorIntegration() string { - return fmt.Sprintf(`ALTER TASK %v UNSET ERROR_INTEGRATION`, tb.QualifiedName()) -} - -func (tb *TaskBuilder) SetAllowOverlappingExecution() *TaskBuilder { - tb.allowOverlappingExecution = true - return tb -} - -func (tb *TaskBuilder) IsAllowOverlappingExecution() bool { - return tb.allowOverlappingExecution -} - -type Task struct { - ID string `db:"id"` - CreatedOn string `db:"created_on"` - Name string `db:"name"` - DatabaseName string `db:"database_name"` - SchemaName string `db:"schema_name"` - Owner string `db:"owner"` - Comment *string `db:"comment"` - Warehouse *string `db:"warehouse"` - Schedule *string `db:"schedule"` - Predecessors *string `db:"predecessors"` - State string `db:"state"` - Definition string `db:"definition"` - Condition *string `db:"condition"` - ErrorIntegration sql.NullString `db:"error_integration"` - AllowOverlappingExecution sql.NullString `db:"allow_overlapping_execution"` -} - -func (t *Task) QualifiedName() string { - return fmt.Sprintf(`"%v"."%v"."%v"`, EscapeString(t.DatabaseName), EscapeString(t.SchemaName), EscapeString(t.Name)) -} - -func (t *Task) Suspend() string { - return fmt.Sprintf(`ALTER TASK %v SUSPEND`, t.QualifiedName()) -} - -func (t *Task) Resume() string { - return fmt.Sprintf(`ALTER TASK %v RESUME`, t.QualifiedName()) -} - -func (t *Task) IsEnabled() bool { - return strings.ToLower(t.State) == "started" -} - -func (t *Task) GetPredecessors() ([]string, error) { - if t.Predecessors == nil { - return []string{}, nil - } - - // Since 2022_03, Snowflake returns this as a JSON array (even empty) - var predecessorNames []string // nolint: prealloc //todo: fixme - if err := json.Unmarshal([]byte(*t.Predecessors), &predecessorNames); err == nil { - for i, predecessorName := range predecessorNames { - formattedName := predecessorName[strings.LastIndex(predecessorName, ".")+1:] - formattedName = strings.Trim(formattedName, "\\\"") - predecessorNames[i] = formattedName - } - return predecessorNames, nil - } - - pre := strings.Split(*t.Predecessors, ".") - for _, p := range pre { - predecessorName, err := strconv.Unquote(p) - if err != nil { - return nil, err - } - predecessorNames = append(predecessorNames, predecessorName) - } - return predecessorNames, nil -} - -// ScanTask turns a sql row into a task object. -func ScanTask(row *sqlx.Row) (*Task, error) { - t := &Task{} - e := row.StructScan(t) - return t, e -} - -// TaskParams struct to represent a row of parameters. -type TaskParams struct { - Key string `db:"key"` - Value string `db:"value"` - DefaultValue string `db:"default"` - Level string `db:"level"` - Description string `db:"description"` -} - -// ScanTaskParameters takes a database row and converts it to a task parameter pointer. -func ScanTaskParameters(rows *sqlx.Rows) ([]*TaskParams, error) { - t := []*TaskParams{} - - for rows.Next() { - r := &TaskParams{} - if err := rows.StructScan(r); err != nil { - return nil, err - } - t = append(t, r) - } - return t, nil -} - -func ListTasks(databaseName string, schemaName string, db *sql.DB) ([]Task, error) { - stmt := fmt.Sprintf(`SHOW TASKS IN SCHEMA "%s"."%v"`, databaseName, schemaName) - rows, err := Query(db, stmt) - if err != nil { - return nil, err - } - defer rows.Close() - - dbs := []Task{} - if err := sqlx.StructScan(rows, &dbs); err != nil { - if errors.Is(err, sql.ErrNoRows) { - log.Println("[DEBUG] no tasks found") - return nil, nil - } - return dbs, fmt.Errorf("unable to scan row for %s err = %w", stmt, err) - } - return dbs, nil -} - -// GetRootTasks tries to retrieve the root of current task or returns the current (standalone) task. -func GetRootTasks(name string, databaseName string, schemaName string, db *sql.DB) ([]*Task, error) { - builder := NewTaskBuilder(name, databaseName, schemaName) - log.Printf("[DEBUG] retrieving predecessors for task %s\n", builder.QualifiedName()) - q := builder.Show() - row := QueryRow(db, q) - t, err := ScanTask(row) - if err != nil { - return nil, err - } - - predecessors, err := t.GetPredecessors() - if err != nil { - return nil, fmt.Errorf("unable to get predecessors for task %s err = %w", builder.QualifiedName(), err) - } - - // no predecessors mean this is a root task - if len(predecessors) == 0 { - return []*Task{t}, nil - } - - tasks := make([]*Task, 0, len(predecessors)) - // get the root tasks for each predecessor and append them all together - for _, predecessor := range predecessors { - predecessorTasks, err := GetRootTasks(predecessor, databaseName, schemaName, db) - if err != nil { - return nil, fmt.Errorf("unable to get predecessors for task %s err = %w", builder.QualifiedName(), err) - } - tasks = append(tasks, predecessorTasks...) - } - - // remove duplicate root tasks - uniqueTasks := make(map[string]*Task) - for _, task := range tasks { - uniqueTasks[task.QualifiedName()] = task - } - tasks = []*Task{} - for _, task := range uniqueTasks { - tasks = append(tasks, task) - } - - return tasks, nil -} - -func WaitResumeTask(db *sql.DB, name string, database string, schema string) error { - builder := NewTaskBuilder(name, database, schema) - - // try to resume the task, and verify that it was resumed. - // if its not resumed then try again up until a maximum of 5 times - for i := 0; i < 5; i++ { - q := builder.Resume() - if err := Exec(db, q); err != nil { - return fmt.Errorf("error resuming task %v err = %w", name, err) - } - - q = builder.Show() - row := QueryRow(db, q) - t, err := ScanTask(row) - if err != nil { - return err - } - if t.IsEnabled() { - return nil - } - time.Sleep(10 * time.Second) - } - return fmt.Errorf("unable to resume task %v after 5 attempts", name) -} diff --git a/pkg/snowflake/task_test.go b/pkg/snowflake/task_test.go deleted file mode 100644 index 97401cc083..0000000000 --- a/pkg/snowflake/task_test.go +++ /dev/null @@ -1,174 +0,0 @@ -package snowflake - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestTaskCreate(t *testing.T) { - r := require.New(t) - st := NewTaskBuilder("test_task", "test_db", "test_schema") - r.Equal(`"test_db"."test_schema"."test_task"`, st.QualifiedName()) - - st.WithWarehouse("test_wh") - r.Equal(`CREATE TASK "test_db"."test_schema"."test_task" WAREHOUSE = "test_wh"`, st.Create()) - - st.WithSchedule("USING CRON 0 9-17 * * SUN America/Los_Angeles") - r.Equal(`CREATE TASK "test_db"."test_schema"."test_task" WAREHOUSE = "test_wh" SCHEDULE = 'USING CRON 0 9-17 * * SUN America/Los_Angeles'`, st.Create()) - - st.WithSessionParameters(map[string]interface{}{"TIMESTAMP_INPUT_FORMAT": "YYYY-MM-DD HH24"}) - r.Equal(`CREATE TASK "test_db"."test_schema"."test_task" WAREHOUSE = "test_wh" SCHEDULE = 'USING CRON 0 9-17 * * SUN America/Los_Angeles' TIMESTAMP_INPUT_FORMAT = "YYYY-MM-DD HH24"`, st.Create()) - - st.WithComment("test comment") - r.Equal(`CREATE TASK "test_db"."test_schema"."test_task" WAREHOUSE = "test_wh" SCHEDULE = 'USING CRON 0 9-17 * * SUN America/Los_Angeles' TIMESTAMP_INPUT_FORMAT = "YYYY-MM-DD HH24" COMMENT = 'test comment'`, st.Create()) - - st.WithTimeout(12) - r.Equal(`CREATE TASK "test_db"."test_schema"."test_task" WAREHOUSE = "test_wh" SCHEDULE = 'USING CRON 0 9-17 * * SUN America/Los_Angeles' TIMESTAMP_INPUT_FORMAT = "YYYY-MM-DD HH24" COMMENT = 'test comment' USER_TASK_TIMEOUT_MS = 12`, st.Create()) - - st.WithAfter([]string{"other_task"}) - r.Equal(`CREATE TASK "test_db"."test_schema"."test_task" WAREHOUSE = "test_wh" SCHEDULE = 'USING CRON 0 9-17 * * SUN America/Los_Angeles' TIMESTAMP_INPUT_FORMAT = "YYYY-MM-DD HH24" COMMENT = 'test comment' USER_TASK_TIMEOUT_MS = 12 AFTER "test_db"."test_schema"."other_task"`, st.Create()) - - st.WithCondition("SYSTEM$STREAM_HAS_DATA('MYSTREAM')") - r.Equal(`CREATE TASK "test_db"."test_schema"."test_task" WAREHOUSE = "test_wh" SCHEDULE = 'USING CRON 0 9-17 * * SUN America/Los_Angeles' TIMESTAMP_INPUT_FORMAT = "YYYY-MM-DD HH24" COMMENT = 'test comment' USER_TASK_TIMEOUT_MS = 12 AFTER "test_db"."test_schema"."other_task" WHEN SYSTEM$STREAM_HAS_DATA('MYSTREAM')`, st.Create()) - - st.WithStatement("SELECT * FROM table WHERE column = 'name'") - r.Equal(`CREATE TASK "test_db"."test_schema"."test_task" WAREHOUSE = "test_wh" SCHEDULE = 'USING CRON 0 9-17 * * SUN America/Los_Angeles' TIMESTAMP_INPUT_FORMAT = "YYYY-MM-DD HH24" COMMENT = 'test comment' USER_TASK_TIMEOUT_MS = 12 AFTER "test_db"."test_schema"."other_task" WHEN SYSTEM$STREAM_HAS_DATA('MYSTREAM') AS SELECT * FROM table WHERE column = 'name'`, st.Create()) - - st.WithAllowOverlappingExecution(true) - r.Equal(`CREATE TASK "test_db"."test_schema"."test_task" WAREHOUSE = "test_wh" SCHEDULE = 'USING CRON 0 9-17 * * SUN America/Los_Angeles' TIMESTAMP_INPUT_FORMAT = "YYYY-MM-DD HH24" COMMENT = 'test comment' ALLOW_OVERLAPPING_EXECUTION = TRUE USER_TASK_TIMEOUT_MS = 12 AFTER "test_db"."test_schema"."other_task" WHEN SYSTEM$STREAM_HAS_DATA('MYSTREAM') AS SELECT * FROM table WHERE column = 'name'`, st.Create()) -} - -func TestChangeWarehouse(t *testing.T) { - r := require.New(t) - st := NewTaskBuilder("test_task", "test_db", "test_schema") - r.Equal(`ALTER TASK "test_db"."test_schema"."test_task" SET WAREHOUSE = "much_wh"`, st.ChangeWarehouse("much_wh")) -} - -func TestSwitchWarehouseToManaged(t *testing.T) { - r := require.New(t) - st := NewTaskBuilder("test_task", "test_db", "test_schema") - r.Equal(`ALTER TASK "test_db"."test_schema"."test_task" SET WAREHOUSE = null`, st.SwitchWarehouseToManaged()) -} - -func TestSwitchManagedWithInitialSize(t *testing.T) { - r := require.New(t) - st := NewTaskBuilder("test_task", "test_db", "test_schema") - r.Equal(`ALTER TASK "test_db"."test_schema"."test_task" SET USER_TASK_MANAGED_INITIAL_WAREHOUSE_SIZE = 'SMALL'`, st.SwitchManagedWithInitialSize("SMALL")) -} - -func TestChangeSchedule(t *testing.T) { - r := require.New(t) - st := NewTaskBuilder("test_task", "test_db", "test_schema") - r.Equal(`ALTER TASK "test_db"."test_schema"."test_task" SET SCHEDULE = 'USING CRON 0 9-17 * * SUN America/New_York'`, st.ChangeSchedule("USING CRON 0 9-17 * * SUN America/New_York")) -} - -func TestRemoveSchedule(t *testing.T) { - r := require.New(t) - st := NewTaskBuilder("test_task", "test_db", "test_schema") - r.Equal(`ALTER TASK "test_db"."test_schema"."test_task" UNSET SCHEDULE`, st.RemoveSchedule()) -} - -func TestChangeTimeout(t *testing.T) { - r := require.New(t) - st := NewTaskBuilder("test_task", "test_db", "test_schema") - r.Equal(`ALTER TASK "test_db"."test_schema"."test_task" SET USER_TASK_TIMEOUT_MS = 100`, st.ChangeTimeout(100)) -} - -func TestRemoveTimeout(t *testing.T) { - r := require.New(t) - st := NewTaskBuilder("test_task", "test_db", "test_schema") - r.Equal(`ALTER TASK "test_db"."test_schema"."test_task" UNSET USER_TASK_TIMEOUT_MS`, st.RemoveTimeout()) -} - -func TestChangeComment(t *testing.T) { - r := require.New(t) - st := NewTaskBuilder("test_task", "test_db", "test_schema") - r.Equal(`ALTER TASK "test_db"."test_schema"."test_task" SET COMMENT = 'much comment wow'`, st.ChangeComment("much comment wow")) -} - -func TestRemoveComment(t *testing.T) { - r := require.New(t) - st := NewTaskBuilder("test_task", "test_db", "test_schema") - r.Equal(`ALTER TASK "test_db"."test_schema"."test_task" UNSET COMMENT`, st.RemoveComment()) -} - -func TestAddAfter(t *testing.T) { - r := require.New(t) - st := NewTaskBuilder("test_task", "test_db", "test_schema") - r.Equal(`ALTER TASK "test_db"."test_schema"."test_task" ADD AFTER "test_db"."test_schema"."other_task"`, st.AddAfter([]string{"other_task"})) -} - -func TestRemoveAfter(t *testing.T) { - r := require.New(t) - st := NewTaskBuilder("test_task", "test_db", "test_schema") - r.Equal(`ALTER TASK "test_db"."test_schema"."test_task" REMOVE AFTER "test_db"."test_schema"."first_me_task"`, st.RemoveAfter([]string{"first_me_task"})) -} - -func TestAddSessionParameters(t *testing.T) { - r := require.New(t) - st := NewTaskBuilder("test_task", "test_db", "test_schema") - params := map[string]interface{}{"TIMESTAMP_INPUT_FORMAT": "YYYY-MM-DD HH24", "CLIENT_TIMESTAMP_TYPE_MAPPING": "TIMESTAMP_LTZ"} - r.Equal(`ALTER TASK "test_db"."test_schema"."test_task" SET CLIENT_TIMESTAMP_TYPE_MAPPING = "TIMESTAMP_LTZ", TIMESTAMP_INPUT_FORMAT = "YYYY-MM-DD HH24"`, st.AddSessionParameters(params)) -} - -func TestRemoveSessionParameters(t *testing.T) { - r := require.New(t) - st := NewTaskBuilder("test_task", "test_db", "test_schema") - params := map[string]interface{}{"TIMESTAMP_INPUT_FORMAT": "YYYY-MM-DD HH24", "CLIENT_TIMESTAMP_TYPE_MAPPING": "TIMESTAMP_LTZ"} - r.Equal(`ALTER TASK "test_db"."test_schema"."test_task" UNSET CLIENT_TIMESTAMP_TYPE_MAPPING, TIMESTAMP_INPUT_FORMAT`, st.RemoveSessionParameters(params)) -} - -func TestChangeCondition(t *testing.T) { - r := require.New(t) - st := NewTaskBuilder("test_task", "test_db", "test_schema") - r.Equal(`ALTER TASK "test_db"."test_schema"."test_task" MODIFY WHEN TRUE = TRUE`, st.ChangeCondition("TRUE = TRUE")) -} - -func TestChangeSqlStatement(t *testing.T) { - r := require.New(t) - st := NewTaskBuilder("test_task", "test_db", "test_schema") - r.Equal(`ALTER TASK "test_db"."test_schema"."test_task" MODIFY AS SELECT * FROM table`, st.ChangeSQLStatement("SELECT * FROM table")) -} - -func TestSuspend(t *testing.T) { - r := require.New(t) - st := NewTaskBuilder("test_task", "test_db", "test_schema") - r.Equal(`ALTER TASK "test_db"."test_schema"."test_task" SUSPEND`, st.Suspend()) -} - -func TestResume(t *testing.T) { - r := require.New(t) - st := NewTaskBuilder("test_task", "test_db", "test_schema") - r.Equal(`ALTER TASK "test_db"."test_schema"."test_task" RESUME`, st.Resume()) -} - -func TestShowParameters(t *testing.T) { - r := require.New(t) - st := NewTaskBuilder("test_task", "test_db", "test_schema") - r.Equal(`SHOW PARAMETERS IN TASK "test_db"."test_schema"."test_task"`, st.ShowParameters()) -} - -func TestDrop(t *testing.T) { - r := require.New(t) - st := NewTaskBuilder("test_task", "test_db", "test_schema") - r.Equal(`DROP TASK "test_db"."test_schema"."test_task"`, st.Drop()) -} - -func TestDescribe(t *testing.T) { - r := require.New(t) - st := NewTaskBuilder("test_task", "test_db", "test_schema") - r.Equal(`DESCRIBE TASK "test_db"."test_schema"."test_task"`, st.Describe()) -} - -func TestShow(t *testing.T) { - r := require.New(t) - st := NewTaskBuilder("test_task", "test_db", "test_schema") - r.Equal(`SHOW TASKS LIKE 'test_task' IN SCHEMA "test_db"."test_schema"`, st.Show()) -} - -func TestSetAllowOverlappingExecution(t *testing.T) { - r := require.New(t) - st := NewTaskBuilder("test_task", "test_db", "test_schema") - r.Equal(`ALTER TASK "test_db"."test_schema"."test_task" SET ALLOW_OVERLAPPING_EXECUTION = TRUE`, st.SetAllowOverlappingExecutionParameter()) -} From 16022ef4171e7dccf2932ae6e8d451b51c93291c Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Fri, 20 Oct 2023 15:01:27 +0200 Subject: [PATCH 04/20] chore: Set up a single warehouse for the SDK integration tests (#2141) * WH: Prepare common warehouse for integration tests * WH: Use created warehouse in almost every test --- pkg/sdk/testint/alerts_integration_test.go | 45 +++++++------------ pkg/sdk/testint/comments_integration_test.go | 7 +-- .../context_functions_integration_test.go | 1 + .../conversion_functions_integration_test.go | 2 + .../testint/dynamic_table_integration_test.go | 10 ++--- .../external_tables_integration_test.go | 4 +- pkg/sdk/testint/grants_integration_test.go | 5 +-- pkg/sdk/testint/helpers_test.go | 8 ++-- pkg/sdk/testint/sessions_integration_test.go | 32 +++++++------ pkg/sdk/testint/setup_integration_test.go | 35 +++++++++++++-- pkg/sdk/testint/tasks_gen_integration_test.go | 7 +-- .../testint/warehouses_integration_test.go | 12 +++++ 12 files changed, 93 insertions(+), 75 deletions(-) diff --git a/pkg/sdk/testint/alerts_integration_test.go b/pkg/sdk/testint/alerts_integration_test.go index d12dbdf839..3d7c8ad6ce 100644 --- a/pkg/sdk/testint/alerts_integration_test.go +++ b/pkg/sdk/testint/alerts_integration_test.go @@ -14,13 +14,10 @@ func TestInt_AlertsShow(t *testing.T) { client := testClient(t) ctx := testContext(t) - testWarehouse, warehouseCleanup := createWarehouse(t, client) - t.Cleanup(warehouseCleanup) - - alertTest, alertCleanup := createAlert(t, client, testDb(t), testSchema(t), testWarehouse) + alertTest, alertCleanup := createAlert(t, client, testDb(t), testSchema(t), testWarehouse(t)) t.Cleanup(alertCleanup) - alert2Test, alert2Cleanup := createAlert(t, client, testDb(t), testSchema(t), testWarehouse) + alert2Test, alert2Cleanup := createAlert(t, client, testDb(t), testSchema(t), testWarehouse(t)) t.Cleanup(alert2Cleanup) t.Run("without show options", func(t *testing.T) { @@ -85,9 +82,6 @@ func TestInt_AlertCreate(t *testing.T) { client := testClient(t) ctx := testContext(t) - testWarehouse, warehouseCleanup := createWarehouse(t, client) - t.Cleanup(warehouseCleanup) - t.Run("test complete case", func(t *testing.T) { name := random.String() schedule := "USING CRON * * * * TUE,THU UTC" @@ -95,7 +89,7 @@ func TestInt_AlertCreate(t *testing.T) { action := "SELECT 1" comment := random.Comment() id := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, name) - err := client.Alerts.Create(ctx, id, testWarehouse.ID(), schedule, condition, action, &sdk.CreateAlertOptions{ + err := client.Alerts.Create(ctx, id, testWarehouse(t).ID(), schedule, condition, action, &sdk.CreateAlertOptions{ OrReplace: sdk.Bool(true), IfNotExists: sdk.Bool(false), Comment: sdk.String(comment), @@ -104,7 +98,7 @@ func TestInt_AlertCreate(t *testing.T) { alertDetails, err := client.Alerts.Describe(ctx, id) require.NoError(t, err) assert.Equal(t, name, alertDetails.Name) - assert.Equal(t, testWarehouse.Name, alertDetails.Warehouse) + assert.Equal(t, testWarehouse(t).Name, alertDetails.Warehouse) assert.Equal(t, schedule, alertDetails.Schedule) assert.Equal(t, comment, *alertDetails.Comment) assert.Equal(t, condition, alertDetails.Condition) @@ -131,7 +125,7 @@ func TestInt_AlertCreate(t *testing.T) { action := "SELECT 1" comment := random.Comment() id := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, name) - err := client.Alerts.Create(ctx, id, testWarehouse.ID(), schedule, condition, action, &sdk.CreateAlertOptions{ + err := client.Alerts.Create(ctx, id, testWarehouse(t).ID(), schedule, condition, action, &sdk.CreateAlertOptions{ OrReplace: sdk.Bool(false), IfNotExists: sdk.Bool(true), Comment: sdk.String(comment), @@ -140,7 +134,7 @@ func TestInt_AlertCreate(t *testing.T) { alertDetails, err := client.Alerts.Describe(ctx, id) require.NoError(t, err) assert.Equal(t, name, alertDetails.Name) - assert.Equal(t, testWarehouse.Name, alertDetails.Warehouse) + assert.Equal(t, testWarehouse(t).Name, alertDetails.Warehouse) assert.Equal(t, schedule, alertDetails.Schedule) assert.Equal(t, comment, *alertDetails.Comment) assert.Equal(t, condition, alertDetails.Condition) @@ -166,12 +160,12 @@ func TestInt_AlertCreate(t *testing.T) { condition := "SELECT 1" action := "SELECT 1" id := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, name) - err := client.Alerts.Create(ctx, id, testWarehouse.ID(), schedule, condition, action, nil) + err := client.Alerts.Create(ctx, id, testWarehouse(t).ID(), schedule, condition, action, nil) require.NoError(t, err) alertDetails, err := client.Alerts.Describe(ctx, id) require.NoError(t, err) assert.Equal(t, name, alertDetails.Name) - assert.Equal(t, testWarehouse.Name, alertDetails.Warehouse) + assert.Equal(t, testWarehouse(t).Name, alertDetails.Warehouse) assert.Equal(t, schedule, alertDetails.Schedule) assert.Equal(t, condition, alertDetails.Condition) assert.Equal(t, action, alertDetails.Action) @@ -204,12 +198,12 @@ func TestInt_AlertCreate(t *testing.T) { end ` id := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, name) - err := client.Alerts.Create(ctx, id, testWarehouse.ID(), schedule, condition, action, nil) + err := client.Alerts.Create(ctx, id, testWarehouse(t).ID(), schedule, condition, action, nil) require.NoError(t, err) alertDetails, err := client.Alerts.Describe(ctx, id) require.NoError(t, err) assert.Equal(t, name, alertDetails.Name) - assert.Equal(t, testWarehouse.Name, alertDetails.Warehouse) + assert.Equal(t, testWarehouse(t).Name, alertDetails.Warehouse) assert.Equal(t, schedule, alertDetails.Schedule) assert.Equal(t, condition, alertDetails.Condition) assert.Equal(t, strings.TrimSpace(action), alertDetails.Action) @@ -233,10 +227,7 @@ func TestInt_AlertDescribe(t *testing.T) { client := testClient(t) ctx := testContext(t) - warehouseTest, warehouseCleanup := createWarehouse(t, client) - t.Cleanup(warehouseCleanup) - - alert, alertCleanup := createAlert(t, client, testDb(t), testSchema(t), warehouseTest) + alert, alertCleanup := createAlert(t, client, testDb(t), testSchema(t), testWarehouse(t)) t.Cleanup(alertCleanup) t.Run("when alert exists", func(t *testing.T) { @@ -256,11 +247,8 @@ func TestInt_AlertAlter(t *testing.T) { client := testClient(t) ctx := testContext(t) - warehouseTest, warehouseCleanup := createWarehouse(t, client) - t.Cleanup(warehouseCleanup) - t.Run("when setting and unsetting a value", func(t *testing.T) { - alert, alertCleanup := createAlert(t, client, testDb(t), testSchema(t), warehouseTest) + alert, alertCleanup := createAlert(t, client, testDb(t), testSchema(t), testWarehouse(t)) t.Cleanup(alertCleanup) newSchedule := "USING CRON * * * * TUE,FRI GMT" @@ -286,7 +274,7 @@ func TestInt_AlertAlter(t *testing.T) { }) t.Run("when modifying condition and action", func(t *testing.T) { - alert, alertCleanup := createAlert(t, client, testDb(t), testSchema(t), warehouseTest) + alert, alertCleanup := createAlert(t, client, testDb(t), testSchema(t), testWarehouse(t)) t.Cleanup(alertCleanup) newCondition := "select * from DUAL where false" @@ -330,7 +318,7 @@ func TestInt_AlertAlter(t *testing.T) { }) t.Run("resume and then suspend", func(t *testing.T) { - alert, alertCleanup := createAlert(t, client, testDb(t), testSchema(t), warehouseTest) + alert, alertCleanup := createAlert(t, client, testDb(t), testSchema(t), testWarehouse(t)) t.Cleanup(alertCleanup) alterOptions := &sdk.AlterAlertOptions{ @@ -375,11 +363,8 @@ func TestInt_AlertDrop(t *testing.T) { client := testClient(t) ctx := testContext(t) - warehouseTest, warehouseCleanup := createWarehouse(t, client) - t.Cleanup(warehouseCleanup) - t.Run("when alert exists", func(t *testing.T) { - alert, _ := createAlert(t, client, testDb(t), testSchema(t), warehouseTest) + alert, _ := createAlert(t, client, testDb(t), testSchema(t), testWarehouse(t)) id := alert.ID() err := client.Alerts.Drop(ctx, id) require.NoError(t, err) diff --git a/pkg/sdk/testint/comments_integration_test.go b/pkg/sdk/testint/comments_integration_test.go index 6c5b29d0b5..18afc97d52 100644 --- a/pkg/sdk/testint/comments_integration_test.go +++ b/pkg/sdk/testint/comments_integration_test.go @@ -13,18 +13,15 @@ func TestInt_Comment(t *testing.T) { client := testClient(t) ctx := testContext(t) - testWarehouse, warehouseCleanup := createWarehouse(t, client) - t.Cleanup(warehouseCleanup) - t.Run("set", func(t *testing.T) { comment := random.Comment() err := client.Comments.Set(ctx, &sdk.SetCommentOptions{ ObjectType: sdk.ObjectTypeWarehouse, - ObjectName: testWarehouse.ID(), + ObjectName: testWarehouse(t).ID(), Value: sdk.String(comment), }) require.NoError(t, err) - wh, err := client.Warehouses.ShowByID(ctx, testWarehouse.ID()) + wh, err := client.Warehouses.ShowByID(ctx, testWarehouse(t).ID()) require.NoError(t, err) assert.Equal(t, comment, wh.Comment) }) diff --git a/pkg/sdk/testint/context_functions_integration_test.go b/pkg/sdk/testint/context_functions_integration_test.go index aa5b130cc5..12906bae8c 100644 --- a/pkg/sdk/testint/context_functions_integration_test.go +++ b/pkg/sdk/testint/context_functions_integration_test.go @@ -81,6 +81,7 @@ func TestInt_CurrentWarehouse(t *testing.T) { client := testClient(t) ctx := testContext(t) + // new warehouse created on purpose warehouseTest, warehouseCleanup := createWarehouse(t, client) t.Cleanup(warehouseCleanup) err := client.Sessions.UseWarehouse(ctx, warehouseTest.ID()) diff --git a/pkg/sdk/testint/conversion_functions_integration_test.go b/pkg/sdk/testint/conversion_functions_integration_test.go index 7192a94160..303e25e4bb 100644 --- a/pkg/sdk/testint/conversion_functions_integration_test.go +++ b/pkg/sdk/testint/conversion_functions_integration_test.go @@ -36,6 +36,7 @@ func TestInt_ToTimestampLTZ(t *testing.T) { }) require.NoError(t, err) }) + // new warehouse created on purpose warehouseTest, warehouseCleanup := createWarehouse(t, client) t.Cleanup(warehouseCleanup) err = client.Sessions.UseWarehouse(ctx, warehouseTest.ID()) @@ -74,6 +75,7 @@ func TestInt_ToTimestampNTZ(t *testing.T) { }) require.NoError(t, err) }) + // new warehouse created on purpose warehouseTest, warehouseCleanup := createWarehouse(t, client) t.Cleanup(warehouseCleanup) err = client.Sessions.UseWarehouse(ctx, warehouseTest.ID()) diff --git a/pkg/sdk/testint/dynamic_table_integration_test.go b/pkg/sdk/testint/dynamic_table_integration_test.go index cd8ec42227..0c6edea858 100644 --- a/pkg/sdk/testint/dynamic_table_integration_test.go +++ b/pkg/sdk/testint/dynamic_table_integration_test.go @@ -13,8 +13,6 @@ import ( func TestInt_DynamicTableCreateAndDrop(t *testing.T) { client := testClient(t) - warehouseTest, warehouseCleanup := createWarehouse(t, client) - t.Cleanup(warehouseCleanup) tableTest, tableCleanup := createTable(t, client, testDb(t), testSchema(t)) t.Cleanup(tableCleanup) @@ -26,7 +24,7 @@ func TestInt_DynamicTableCreateAndDrop(t *testing.T) { } query := "select id from " + tableTest.ID().FullyQualifiedName() comment := random.Comment() - err := client.DynamicTables.Create(ctx, sdk.NewCreateDynamicTableRequest(name, warehouseTest.ID(), targetLag, query).WithOrReplace(true).WithComment(&comment)) + err := client.DynamicTables.Create(ctx, sdk.NewCreateDynamicTableRequest(name, testWarehouse(t).ID(), targetLag, query).WithOrReplace(true).WithComment(&comment)) require.NoError(t, err) t.Cleanup(func() { err = client.DynamicTables.Drop(ctx, sdk.NewDropDynamicTableRequest(name)) @@ -38,7 +36,7 @@ func TestInt_DynamicTableCreateAndDrop(t *testing.T) { entity := entities[0] require.Equal(t, name.Name(), entity.Name) - require.Equal(t, warehouseTest.ID().Name(), entity.Warehouse) + require.Equal(t, testWarehouse(t).ID().Name(), entity.Warehouse) require.Equal(t, *targetLag.MaximumDuration, entity.TargetLag) }) @@ -49,7 +47,7 @@ func TestInt_DynamicTableCreateAndDrop(t *testing.T) { } query := "select id from " + tableTest.ID().FullyQualifiedName() comment := random.Comment() - err := client.DynamicTables.Create(ctx, sdk.NewCreateDynamicTableRequest(name, warehouseTest.ID(), targetLag, query).WithOrReplace(true).WithComment(&comment)) + err := client.DynamicTables.Create(ctx, sdk.NewCreateDynamicTableRequest(name, testWarehouse(t).ID(), targetLag, query).WithOrReplace(true).WithComment(&comment)) require.NoError(t, err) t.Cleanup(func() { err = client.DynamicTables.Drop(ctx, sdk.NewDropDynamicTableRequest(name)) @@ -61,7 +59,7 @@ func TestInt_DynamicTableCreateAndDrop(t *testing.T) { entity := entities[0] require.Equal(t, name.Name(), entity.Name) - require.Equal(t, warehouseTest.ID().Name(), entity.Warehouse) + require.Equal(t, testWarehouse(t).ID().Name(), entity.Warehouse) require.Equal(t, "DOWNSTREAM", entity.TargetLag) }) } diff --git a/pkg/sdk/testint/external_tables_integration_test.go b/pkg/sdk/testint/external_tables_integration_test.go index 13ef6d55ac..a3eb9e639b 100644 --- a/pkg/sdk/testint/external_tables_integration_test.go +++ b/pkg/sdk/testint/external_tables_integration_test.go @@ -99,10 +99,8 @@ func TestInt_ExternalTables(t *testing.T) { t.Run("Create: infer schema", func(t *testing.T) { fileFormat, _ := createFileFormat(t, client, testSchema(t).ID()) - warehouse, warehouseCleanup := createWarehouse(t, client) - t.Cleanup(warehouseCleanup) - err := client.Sessions.UseWarehouse(ctx, warehouse.ID()) + err := client.Sessions.UseWarehouse(ctx, testWarehouse(t).ID()) require.NoError(t, err) name := random.AlphanumericN(32) diff --git a/pkg/sdk/testint/grants_integration_test.go b/pkg/sdk/testint/grants_integration_test.go index e550f1d791..a4c6cb7063 100644 --- a/pkg/sdk/testint/grants_integration_test.go +++ b/pkg/sdk/testint/grants_integration_test.go @@ -544,9 +544,6 @@ func TestInt_GrantOwnership(t *testing.T) { }) t.Run("on account level object to role", func(t *testing.T) { - warehouse, warehouseCleanup := createWarehouse(t, client) - t.Cleanup(warehouseCleanup) - // role is deliberately created after warehouse, so that cleanup is done in reverse // because after ownership grant we lose privilege to drop object // with first dropping the role, we reacquire rights to do it - a little hacky trick @@ -557,7 +554,7 @@ func TestInt_GrantOwnership(t *testing.T) { on := sdk.OwnershipGrantOn{ Object: &sdk.Object{ ObjectType: sdk.ObjectTypeWarehouse, - Name: warehouse.ID(), + Name: testWarehouse(t).ID(), }, } to := sdk.OwnershipGrantTo{ diff --git a/pkg/sdk/testint/helpers_test.go b/pkg/sdk/testint/helpers_test.go index ce25e7f157..97bd3fbb44 100644 --- a/pkg/sdk/testint/helpers_test.go +++ b/pkg/sdk/testint/helpers_test.go @@ -66,12 +66,10 @@ func testClientFromProfile(t *testing.T, profile string) (*sdk.Client, error) { func useWarehouse(t *testing.T, client *sdk.Client, warehouseID sdk.AccountObjectIdentifier) func() { t.Helper() ctx := context.Background() - orgWarehouse, err := client.ContextFunctions.CurrentWarehouse(ctx) - require.NoError(t, err) - err = client.Sessions.UseWarehouse(ctx, warehouseID) + err := client.Sessions.UseWarehouse(ctx, warehouseID) require.NoError(t, err) return func() { - err := client.Sessions.UseWarehouse(ctx, sdk.NewAccountObjectIdentifier(orgWarehouse)) + err = client.Sessions.UseWarehouse(ctx, testWarehouse(t).ID()) require.NoError(t, err) } } @@ -137,6 +135,8 @@ func createWarehouseWithOptions(t *testing.T, client *sdk.Client, opts *sdk.Crea }, func() { err := client.Warehouses.Drop(ctx, id, nil) require.NoError(t, err) + err = client.Sessions.UseWarehouse(ctx, testWarehouse(t).ID()) + require.NoError(t, err) } } diff --git a/pkg/sdk/testint/sessions_integration_test.go b/pkg/sdk/testint/sessions_integration_test.go index fa7c283503..f14211fac4 100644 --- a/pkg/sdk/testint/sessions_integration_test.go +++ b/pkg/sdk/testint/sessions_integration_test.go @@ -99,23 +99,19 @@ func TestInt_ShowUserParameter(t *testing.T) { func TestInt_UseWarehouse(t *testing.T) { client := testClient(t) ctx := testContext(t) - originalWH, err := client.ContextFunctions.CurrentWarehouse(ctx) - require.NoError(t, err) + t.Cleanup(func() { - originalWHIdentifier := sdk.NewAccountObjectIdentifier(originalWH) - if !sdk.ValidObjectIdentifier(originalWHIdentifier) { - return - } - err := client.Sessions.UseWarehouse(ctx, originalWHIdentifier) + err := client.Sessions.UseWarehouse(ctx, testWarehouse(t).ID()) require.NoError(t, err) }) - warehouseTest, warehouseCleanup := createWarehouse(t, client) + // new warehouse created on purpose + warehouse, warehouseCleanup := createWarehouse(t, client) t.Cleanup(warehouseCleanup) - err = client.Sessions.UseWarehouse(ctx, warehouseTest.ID()) + err := client.Sessions.UseWarehouse(ctx, warehouse.ID()) require.NoError(t, err) actual, err := client.ContextFunctions.CurrentWarehouse(ctx) require.NoError(t, err) - expected := warehouseTest.Name + expected := warehouse.Name assert.Equal(t, expected, actual) } @@ -127,11 +123,14 @@ func TestInt_UseDatabase(t *testing.T) { err := client.Sessions.UseSchema(ctx, testSchema(t).ID()) require.NoError(t, err) }) - err := client.Sessions.UseDatabase(ctx, testDb(t).ID()) + // new database created on purpose + database, databaseCleanup := createDatabase(t, client) + t.Cleanup(databaseCleanup) + err := client.Sessions.UseDatabase(ctx, database.ID()) require.NoError(t, err) actual, err := client.ContextFunctions.CurrentDatabase(ctx) require.NoError(t, err) - expected := testDb(t).Name + expected := database.Name assert.Equal(t, expected, actual) } @@ -143,10 +142,15 @@ func TestInt_UseSchema(t *testing.T) { err := client.Sessions.UseSchema(ctx, testSchema(t).ID()) require.NoError(t, err) }) - err := client.Sessions.UseSchema(ctx, testSchema(t).ID()) + // new database and schema created on purpose + database, databaseCleanup := createDatabase(t, client) + t.Cleanup(databaseCleanup) + schema, schemaCleanup := createSchema(t, client, database) + t.Cleanup(schemaCleanup) + err := client.Sessions.UseSchema(ctx, schema.ID()) require.NoError(t, err) actual, err := client.ContextFunctions.CurrentSchema(ctx) require.NoError(t, err) - expected := testSchema(t).Name + expected := schema.Name assert.Equal(t, expected, actual) } diff --git a/pkg/sdk/testint/setup_integration_test.go b/pkg/sdk/testint/setup_integration_test.go index 8a99eb484d..4508bfdbc9 100644 --- a/pkg/sdk/testint/setup_integration_test.go +++ b/pkg/sdk/testint/setup_integration_test.go @@ -51,10 +51,12 @@ type integrationTestContext struct { client *sdk.Client ctx context.Context - database *sdk.Database - databaseCleanup func() - schema *sdk.Schema - schemaCleanup func() + database *sdk.Database + databaseCleanup func() + schema *sdk.Schema + schemaCleanup func() + warehouse *sdk.Warehouse + warehouseCleanup func() } func (itc *integrationTestContext) initialize() error { @@ -81,6 +83,13 @@ func (itc *integrationTestContext) initialize() error { itc.schema = sc itc.schemaCleanup = scCleanup + wh, whCleanup, err := createWh(itc.client, itc.ctx) + if err != nil { + return err + } + itc.warehouse = wh + itc.warehouseCleanup = whCleanup + return nil } @@ -110,6 +119,19 @@ func createSc(client *sdk.Client, ctx context.Context, db *sdk.Database) (*sdk.S }, err } +func createWh(client *sdk.Client, ctx context.Context) (*sdk.Warehouse, func(), error) { + name := "int_test_wh_" + random.UUID() + id := sdk.NewAccountObjectIdentifier(name) + err := client.Warehouses.Create(ctx, id, nil) + if err != nil { + return nil, nil, err + } + warehouse, err := client.Warehouses.ShowByID(ctx, id) + return warehouse, func() { + _ = client.Warehouses.Drop(ctx, id, nil) + }, err +} + // timer measures time from invocation point to the end of method. // It's supposed to be used like: // @@ -140,3 +162,8 @@ func testSchema(t *testing.T) *sdk.Schema { t.Helper() return itc.schema } + +func testWarehouse(t *testing.T) *sdk.Warehouse { + t.Helper() + return itc.warehouse +} diff --git a/pkg/sdk/testint/tasks_gen_integration_test.go b/pkg/sdk/testint/tasks_gen_integration_test.go index d03415764f..f6dd6d9e5e 100644 --- a/pkg/sdk/testint/tasks_gen_integration_test.go +++ b/pkg/sdk/testint/tasks_gen_integration_test.go @@ -149,12 +149,9 @@ func TestInt_Tasks(t *testing.T) { }) t.Run("create task: almost complete case", func(t *testing.T) { - warehouse, warehouseCleanup := createWarehouse(t, client) - t.Cleanup(warehouseCleanup) - request := createTaskBasicRequest(t). WithOrReplace(sdk.Bool(true)). - WithWarehouse(sdk.NewCreateTaskWarehouseRequest().WithWarehouse(sdk.Pointer(warehouse.ID()))). + WithWarehouse(sdk.NewCreateTaskWarehouseRequest().WithWarehouse(sdk.Pointer(testWarehouse(t).ID()))). WithSchedule(sdk.String("10 MINUTE")). WithConfig(sdk.String(`$${"output_dir": "/temp/test_directory/", "learning_rate": 0.1}$$`)). WithAllowOverlappingExecution(sdk.Bool(true)). @@ -169,7 +166,7 @@ func TestInt_Tasks(t *testing.T) { task := createTaskWithRequest(t, request) - assertTaskWithOptions(t, task, id, "some comment", warehouse.Name, "10 MINUTE", `SYSTEM$STREAM_HAS_DATA('MYSTREAM')`, true, `{"output_dir": "/temp/test_directory/", "learning_rate": 0.1}`, nil) + assertTaskWithOptions(t, task, id, "some comment", testWarehouse(t).Name, "10 MINUTE", `SYSTEM$STREAM_HAS_DATA('MYSTREAM')`, true, `{"output_dir": "/temp/test_directory/", "learning_rate": 0.1}`, nil) }) t.Run("create task: with after", func(t *testing.T) { diff --git a/pkg/sdk/testint/warehouses_integration_test.go b/pkg/sdk/testint/warehouses_integration_test.go index e852bdd909..4b96d45283 100644 --- a/pkg/sdk/testint/warehouses_integration_test.go +++ b/pkg/sdk/testint/warehouses_integration_test.go @@ -13,6 +13,7 @@ func TestInt_WarehousesShow(t *testing.T) { client := testClient(t) ctx := testContext(t) + // new warehouses created on purpose testWarehouse, warehouseCleanup := createWarehouseWithOptions(t, client, &sdk.CreateWarehouseOptions{ WarehouseSize: &sdk.WarehouseSizeSmall, }) @@ -161,6 +162,7 @@ func TestInt_WarehouseDescribe(t *testing.T) { client := testClient(t) ctx := testContext(t) + // new warehouse created on purpose warehouse, warehouseCleanup := createWarehouse(t, client) t.Cleanup(warehouseCleanup) @@ -246,6 +248,7 @@ func TestInt_WarehouseAlter(t *testing.T) { }) t.Run("set", func(t *testing.T) { + // new warehouse created on purpose warehouse, warehouseCleanup := createWarehouse(t, client) t.Cleanup(warehouseCleanup) @@ -272,6 +275,7 @@ func TestInt_WarehouseAlter(t *testing.T) { }) t.Run("rename", func(t *testing.T) { + // new warehouse created on purpose warehouse, warehouseCleanup := createWarehouse(t, client) oldID := warehouse.ID() t.Cleanup(warehouseCleanup) @@ -299,6 +303,7 @@ func TestInt_WarehouseAlter(t *testing.T) { Comment: sdk.String("test comment"), MaxClusterCount: sdk.Int(10), } + // new warehouse created on purpose warehouse, warehouseCleanup := createWarehouseWithOptions(t, client, createOptions) t.Cleanup(warehouseCleanup) id := warehouse.ID() @@ -325,6 +330,7 @@ func TestInt_WarehouseAlter(t *testing.T) { }) t.Run("suspend & resume", func(t *testing.T) { + // new warehouse created on purpose warehouse, warehouseCleanup := createWarehouse(t, client) t.Cleanup(warehouseCleanup) @@ -360,6 +366,7 @@ func TestInt_WarehouseAlter(t *testing.T) { }) t.Run("resume without suspending", func(t *testing.T) { + // new warehouse created on purpose warehouse, warehouseCleanup := createWarehouse(t, client) t.Cleanup(warehouseCleanup) @@ -381,6 +388,7 @@ func TestInt_WarehouseAlter(t *testing.T) { }) t.Run("abort all queries", func(t *testing.T) { + // new warehouse created on purpose warehouse, warehouseCleanup := createWarehouse(t, client) t.Cleanup(warehouseCleanup) @@ -427,6 +435,7 @@ func TestInt_WarehouseAlter(t *testing.T) { }) t.Run("set tags", func(t *testing.T) { + // new warehouse created on purpose warehouse, warehouseCleanup := createWarehouse(t, client) t.Cleanup(warehouseCleanup) @@ -456,6 +465,7 @@ func TestInt_WarehouseAlter(t *testing.T) { }) t.Run("unset tags", func(t *testing.T) { + // new warehouse created on purpose warehouse, warehouseCleanup := createWarehouse(t, client) t.Cleanup(warehouseCleanup) @@ -507,6 +517,7 @@ func TestInt_WarehouseDrop(t *testing.T) { ctx := testContext(t) t.Run("when warehouse exists", func(t *testing.T) { + // new warehouse created on purpose warehouse, _ := createWarehouse(t, client) err := client.Warehouses.Drop(ctx, warehouse.ID(), nil) @@ -522,6 +533,7 @@ func TestInt_WarehouseDrop(t *testing.T) { }) t.Run("when warehouse exists and if exists is true", func(t *testing.T) { + // new warehouse created on purpose warehouse, _ := createWarehouse(t, client) dropOptions := &sdk.DropWarehouseOptions{IfExists: sdk.Bool(true)} From 5c633be461fd373d412b02b108e64b6cfc4eb856 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Cie=C5=9Blak?= Date: Tue, 24 Oct 2023 10:42:31 +0200 Subject: [PATCH 05/20] feat: Use streams from the new SDK in resource / datasource (#2129) --- README.md | 2 +- docs/resources/stream.md | 6 +- pkg/datasources/streams.go | 40 +-- pkg/helpers/helpers.go | 44 ++- pkg/helpers/helpers_test.go | 73 ++++ pkg/resources/external_table_internal_test.go | 74 ---- pkg/resources/stream.go | 326 +++++++----------- pkg/resources/stream_internal_test.go | 104 ------ pkg/resources/stream_test.go | 238 ------------- pkg/snowflake/stream.go | 198 ----------- pkg/snowflake/stream_test.go | 85 ----- 11 files changed, 261 insertions(+), 929 deletions(-) create mode 100644 pkg/helpers/helpers_test.go delete mode 100644 pkg/resources/external_table_internal_test.go delete mode 100644 pkg/resources/stream_internal_test.go delete mode 100644 pkg/resources/stream_test.go delete mode 100644 pkg/snowflake/stream.go delete mode 100644 pkg/snowflake/stream_test.go diff --git a/README.md b/README.md index 79cc03949e..fa7aadd62e 100644 --- a/README.md +++ b/README.md @@ -104,7 +104,7 @@ Integration status - indicates if given resource / datasource is using new SDK. | Function | ❌ | snowflake_function | snowflake_function | ❌ | | External Function | ❌ | snowflake_external_function | snowflake_external_function | ❌ | | Stored Procedure | ❌ | snowflake_stored_procedure | snowflake_stored_procedure | ❌ | -| Stream | ✅ | snowflake_stream | snowflake_stream | ❌ | +| Stream | ✅ | snowflake_stream | snowflake_stream | ✅ | | Task | ✅ | snowflake_task | snowflake_task | ❌ | | Masking Policy | ✅ | snowflake_masking_policy | snowflake_masking_policy | ✅ | | Row Access Policy | ❌ | snowflake_row_access_policy | snowflake_row_access_policy | ❌ | diff --git a/docs/resources/stream.md b/docs/resources/stream.md index 7c93417b9c..13cc6ff86d 100644 --- a/docs/resources/stream.md +++ b/docs/resources/stream.md @@ -42,9 +42,9 @@ resource "snowflake_stream" "stream" { - `append_only` (Boolean) Type of the stream that will be created. - `comment` (String) Specifies a comment for the stream. - `insert_only` (Boolean) Create an insert only stream type. -- `on_stage` (String) Name of the stage the stream will monitor. -- `on_table` (String) Name of the table the stream will monitor. -- `on_view` (String) Name of the view the stream will monitor. +- `on_stage` (String) Specifies an identifier for the stage the stream will monitor. +- `on_table` (String) Specifies an identifier for the table the stream will monitor. +- `on_view` (String) Specifies an identifier for the view the stream will monitor. - `show_initial_rows` (Boolean) Specifies whether to return all existing rows in the source table as row inserts the first time the stream is consumed. ### Read-Only diff --git a/pkg/datasources/streams.go b/pkg/datasources/streams.go index 3ec0579fba..cd355f8b72 100644 --- a/pkg/datasources/streams.go +++ b/pkg/datasources/streams.go @@ -1,12 +1,13 @@ package datasources import ( + "context" "database/sql" - "errors" "fmt" "log" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" ) @@ -63,33 +64,30 @@ func Streams() *schema.Resource { func ReadStreams(d *schema.ResourceData, meta interface{}) error { db := meta.(*sql.DB) + client := sdk.NewClientFromDB(db) + ctx := context.Background() databaseName := d.Get("database").(string) schemaName := d.Get("schema").(string) - currentStreams, err := snowflake.ListStreams(databaseName, schemaName, db) - if errors.Is(err, sql.ErrNoRows) { - // If not found, mark resource to be removed from state file during apply or refresh + currentStreams, err := client.Streams.Show(ctx, sdk.NewShowStreamRequest(). + WithIn(&sdk.In{ + Schema: sdk.NewDatabaseObjectIdentifier(databaseName, schemaName), + })) + if err != nil { log.Printf("[DEBUG] streams in schema (%s) not found", d.Id()) d.SetId("") return nil - } else if err != nil { - log.Printf("[DEBUG] unable to parse streams in schema (%s)", d.Id()) - d.SetId("") - return nil } - streams := []map[string]interface{}{} - - for _, stream := range currentStreams { - streamMap := map[string]interface{}{} - - streamMap["name"] = stream.StreamName.String - streamMap["database"] = stream.DatabaseName.String - streamMap["schema"] = stream.SchemaName.String - streamMap["comment"] = stream.Comment.String - streamMap["table"] = stream.TableName.String - - streams = append(streams, streamMap) + streams := make([]map[string]any, len(currentStreams)) + for i, stream := range currentStreams { + streams[i] = map[string]any{ + "name": stream.Name, + "database": stream.DatabaseName, + "schema": stream.SchemaName, + "comment": stream.Comment, + "table": stream.TableName, + } } d.SetId(fmt.Sprintf(`%v|%v`, databaseName, schemaName)) diff --git a/pkg/helpers/helpers.go b/pkg/helpers/helpers.go index 7bbebdac78..3ad4342fd1 100644 --- a/pkg/helpers/helpers.go +++ b/pkg/helpers/helpers.go @@ -1,6 +1,7 @@ package helpers import ( + "encoding/csv" "fmt" "log" "reflect" @@ -12,6 +13,11 @@ import ( "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" ) +const ( + IDDelimiter = "|" + ParameterIDDelimiter = '.' +) + // ToDo: We can merge these two functions together and also add more functions here with similar functionality // This function converts list of string into snowflake formated string like 'ele1', 'ele2'. @@ -100,6 +106,42 @@ func DecodeSnowflakeID(id string) sdk.ObjectIdentifier { } } +// DecodeSnowflakeParameterID decodes identifier (usually passed as one of the parameter in tf configuration) into sdk.ObjectIdentifier. +// identifier can be specified in two ways: quoted and unquoted, e.g. +// +// quoted { "some_identifier": "\"database.name\".\"schema.name\".\"test.name\" } +// (note that here dots as part of the name are allowed) +// +// unquoted { "some_identifier": "database_name.schema_name.test_name" } +// (note that here dots as part of the name are NOT allowed, because they're treated in this case as dividers) +// +// The following configuration { "some_identifier": "db.name" } will be parsed as an object called "name" that lives +// inside database called "db", not a database called "db.name". In this case quotes should be used. +func DecodeSnowflakeParameterID(identifier string) (sdk.ObjectIdentifier, error) { + reader := csv.NewReader(strings.NewReader(identifier)) + reader.Comma = ParameterIDDelimiter + lines, err := reader.ReadAll() + if err != nil { + return nil, fmt.Errorf("unable to read identifier: %s, err = %w", identifier, err) + } + if len(lines) != 1 { + return nil, fmt.Errorf("incompatible identifier: %s", identifier) + } + parts := lines[0] + switch len(parts) { + case 1: + return sdk.NewAccountObjectIdentifier(parts[0]), nil + case 2: + return sdk.NewDatabaseObjectIdentifier(parts[0], parts[1]), nil + case 3: + return sdk.NewSchemaObjectIdentifier(parts[0], parts[1], parts[2]), nil + case 4: + return sdk.NewTableColumnIdentifier(parts[0], parts[1], parts[2], parts[3]), nil + default: + return nil, fmt.Errorf("unable to classify identifier: %s", identifier) + } +} + func Retry(attempts int, sleepDuration time.Duration, f func() (error, bool)) error { for i := 0; i < attempts; i++ { err, done := f() @@ -115,5 +157,3 @@ func Retry(attempts int, sleepDuration time.Duration, f func() (error, bool)) er } return fmt.Errorf("giving up after %v attempts", attempts) } - -const IDDelimiter = "|" diff --git a/pkg/helpers/helpers_test.go b/pkg/helpers/helpers_test.go new file mode 100644 index 0000000000..1eb536525b --- /dev/null +++ b/pkg/helpers/helpers_test.go @@ -0,0 +1,73 @@ +package helpers + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDecodeSnowflakeParameterID(t *testing.T) { + testCases := map[string]struct { + id string + fullyQualifiedName string + }{ + "decodes quoted account object identifier": { + id: `"test.name"`, + fullyQualifiedName: `"test.name"`, + }, + "decodes quoted database object identifier": { + id: `"db"."test.name"`, + fullyQualifiedName: `"db"."test.name"`, + }, + "decodes quoted schema object identifier": { + id: `"db"."schema"."test.name"`, + fullyQualifiedName: `"db"."schema"."test.name"`, + }, + "decodes quoted table column identifier": { + id: `"db"."schema"."table.name"."test.name"`, + fullyQualifiedName: `"db"."schema"."table.name"."test.name"`, + }, + "decodes unquoted account object identifier": { + id: `name`, + fullyQualifiedName: `"name"`, + }, + "decodes unquoted database object identifier": { + id: `db.name`, + fullyQualifiedName: `"db"."name"`, + }, + "decodes unquoted schema object identifier": { + id: `db.schema.name`, + fullyQualifiedName: `"db"."schema"."name"`, + }, + "decodes unquoted table column identifier": { + id: `db.schema.table.name`, + fullyQualifiedName: `"db"."schema"."table"."name"`, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + id, err := DecodeSnowflakeParameterID(tc.id) + require.NoError(t, err) + require.Equal(t, tc.fullyQualifiedName, id.FullyQualifiedName()) + }) + } + + t.Run("identifier with too many parts", func(t *testing.T) { + id := `this.identifier.is.too.long.to.be.decoded` + _, err := DecodeSnowflakeParameterID(id) + require.Errorf(t, err, "unable to classify identifier: %s", id) + }) + + t.Run("incompatible empty identifier", func(t *testing.T) { + id := "" + _, err := DecodeSnowflakeParameterID(id) + require.Errorf(t, err, "incompatible identifier: %s", id) + }) + + t.Run("incompatible multiline identifier", func(t *testing.T) { + id := "db.\nname" + _, err := DecodeSnowflakeParameterID(id) + require.Errorf(t, err, "incompatible identifier: %s", id) + }) +} diff --git a/pkg/resources/external_table_internal_test.go b/pkg/resources/external_table_internal_test.go deleted file mode 100644 index 826e0f773c..0000000000 --- a/pkg/resources/external_table_internal_test.go +++ /dev/null @@ -1,74 +0,0 @@ -package resources - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/require" -) - -func ExternalTestTableIDFromString(t *testing.T) { - t.Helper() - r := require.New(t) - // Vanilla - id := "database_name|schema_name|table" - table, err := externalTableIDFromString(id) - r.NoError(err) - r.Equal("database_name", table.DatabaseName) - r.Equal("schema_name", table.SchemaName) - r.Equal("table", table.ExternalTableName) - - // Bad ID -- not enough fields - id = "database" - _, err = streamOnObjectIDFromString(id) - r.Equal(fmt.Errorf("3 fields allowed"), err) - - // Bad ID - id = "||" - _, err = streamOnObjectIDFromString(id) - r.NoError(err) - - // 0 lines - id = "" - _, err = streamOnObjectIDFromString(id) - r.Equal(fmt.Errorf("1 line at a time"), err) - - // 2 lines - id = `database_name|schema_name|table - database_name|schema_name|table` - _, err = streamOnObjectIDFromString(id) - r.Equal(fmt.Errorf("1 line at a time"), err) -} - -func ExternalTestTableStruct(t *testing.T) { - t.Helper() - r := require.New(t) - - // Vanilla - table := &externalTableID{ - DatabaseName: "database_name", - SchemaName: "schema_name", - ExternalTableName: "table", - } - sID, err := table.String() - r.NoError(err) - r.Equal("database_name|schema_name|table", sID) - - // Empty grant - table = &externalTableID{} - sID, err = table.String() - r.NoError(err) - r.Equal("||", sID) - - // Grant with extra delimiters - table = &externalTableID{ - DatabaseName: "database|name", - ExternalTableName: "table|name", - } - sID, err = table.String() - r.NoError(err) - newTable, err := streamOnObjectIDFromString(sID) - r.NoError(err) - r.Equal("database|name", newTable.DatabaseName) - r.Equal("table|name", newTable.Name) -} diff --git a/pkg/resources/stream.go b/pkg/resources/stream.go index 2c570c1c31..2633d96fd9 100644 --- a/pkg/resources/stream.go +++ b/pkg/resources/stream.go @@ -1,23 +1,19 @@ package resources import ( - "bytes" + "context" "database/sql" - "encoding/csv" - "errors" "fmt" "log" "strings" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" ) -const ( - streamIDDelimiter = '|' - streamOnObjectIDDelimiter = '.' -) - var streamSchema = map[string]*schema.Schema{ "name": { Type: schema.TypeString, @@ -46,21 +42,21 @@ var streamSchema = map[string]*schema.Schema{ Type: schema.TypeString, Optional: true, ForceNew: true, - Description: "Name of the table the stream will monitor.", + Description: "Specifies an identifier for the table the stream will monitor.", ExactlyOneOf: []string{"on_table", "on_view", "on_stage"}, }, "on_view": { Type: schema.TypeString, Optional: true, ForceNew: true, - Description: "Name of the view the stream will monitor.", + Description: "Specifies an identifier for the view the stream will monitor.", ExactlyOneOf: []string{"on_table", "on_view", "on_stage"}, }, "on_stage": { Type: schema.TypeString, Optional: true, ForceNew: true, - Description: "Name of the stage the stream will monitor.", + Description: "Specifies an identifier for the stage the stream will monitor.", ExactlyOneOf: []string{"on_table", "on_view", "on_stage"}, DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { // Suppress diff if the stage name is the same, even if database and schema are not specified @@ -109,94 +105,19 @@ func Stream() *schema.Resource { } } -type streamID struct { - DatabaseName string - SchemaName string - StreamName string -} - -type streamOnObjectID struct { - DatabaseName string - SchemaName string - Name string -} - -// String() takes in a streamID object and returns a pipe-delimited string: -// DatabaseName|SchemaName|StreamName. -func (si *streamID) String() (string, error) { - var buf bytes.Buffer - csvWriter := csv.NewWriter(&buf) - csvWriter.Comma = streamIDDelimiter - dataIdentifiers := [][]string{{si.DatabaseName, si.SchemaName, si.StreamName}} - if err := csvWriter.WriteAll(dataIdentifiers); err != nil { - return "", err - } - strStreamID := strings.TrimSpace(buf.String()) - return strStreamID, nil -} - -// streamIDFromString() takes in a pipe-delimited string: DatabaseName|SchemaName|StreamName -// and returns a streamID object. -func streamIDFromString(stringID string) (*streamID, error) { - reader := csv.NewReader(strings.NewReader(stringID)) - reader.Comma = streamIDDelimiter - lines, err := reader.ReadAll() - if err != nil { - return nil, fmt.Errorf("not CSV compatible") - } - - if len(lines) != 1 { - return nil, fmt.Errorf("1 line at a time") - } - if len(lines[0]) != 3 { - return nil, fmt.Errorf("3 fields allowed") - } - - streamResult := &streamID{ - DatabaseName: lines[0][0], - SchemaName: lines[0][1], - StreamName: lines[0][2], - } - return streamResult, nil -} - -// streamOnObjectIDFromString() takes in a dot-delimited string: DatabaseName.SchemaName.TableName -// and returns a streamOnObjectID object. -func streamOnObjectIDFromString(stringID string) (*streamOnObjectID, error) { - reader := csv.NewReader(strings.NewReader(stringID)) - reader.Comma = streamOnObjectIDDelimiter - lines, err := reader.ReadAll() - if err != nil { - return nil, fmt.Errorf("not CSV compatible") - } - - if len(lines) != 1 { - return nil, fmt.Errorf("1 line at a time") - } - if len(lines[0]) != 3 { - // return nil, fmt.Errorf("on table format: database_name.schema_name.target_table_name") - return nil, fmt.Errorf("invalid format for on_table: %v , expected: ", strings.Join(lines[0], ".")) - } - - streamOnTableResult := &streamOnObjectID{ - DatabaseName: lines[0][0], - SchemaName: lines[0][1], - Name: lines[0][2], - } - return streamOnTableResult, nil -} - // CreateStream implements schema.CreateFunc. func CreateStream(d *schema.ResourceData, meta interface{}) error { db := meta.(*sql.DB) - database := d.Get("database").(string) - schema := d.Get("schema").(string) + databaseName := d.Get("database").(string) + schemaName := d.Get("schema").(string) name := d.Get("name").(string) appendOnly := d.Get("append_only").(bool) insertOnly := d.Get("insert_only").(bool) showInitialRows := d.Get("show_initial_rows").(bool) + id := sdk.NewSchemaObjectIdentifier(databaseName, schemaName, name) - builder := snowflake.Stream(name, database, schema) + client := sdk.NewClientFromDB(db) + ctx := context.Background() onTable, onTableSet := d.GetOk("on_table") onView, onViewSet := d.GetOk("on_view") @@ -204,78 +125,101 @@ func CreateStream(d *schema.ResourceData, meta interface{}) error { switch { case onTableSet: - id, err := streamOnObjectIDFromString(onTable.(string)) + tableObjectIdentifier, err := helpers.DecodeSnowflakeParameterID(onTable.(string)) if err != nil { return err } + tableId := tableObjectIdentifier.(sdk.SchemaObjectIdentifier) - tq := snowflake.NewTableBuilder(id.Name, id.DatabaseName, id.SchemaName).Show() + tq := snowflake.NewTableBuilder(tableId.Name(), tableId.DatabaseName(), tableId.SchemaName()).Show() tableRow := snowflake.QueryRow(db, tq) - t, err := snowflake.ScanTable(tableRow) if err != nil { return err } - builder.WithExternalTable(t.IsExternal.String == "Y") - builder.WithOnTable(t.DatabaseName.String, t.SchemaName.String, t.TableName.String) + if t.IsExternal.String == "Y" { + req := sdk.NewCreateStreamOnExternalTableRequest(id, tableId) + if insertOnly { + req.WithInsertOnly(sdk.Bool(true)) + } + if v, ok := d.GetOk("comment"); ok { + req.WithComment(sdk.String(v.(string))) + } + err := client.Streams.CreateOnExternalTable(ctx, req) + if err != nil { + return fmt.Errorf("error creating stream %v err = %w", name, err) + } + } else { + req := sdk.NewCreateStreamOnTableRequest(id, tableId) + if appendOnly { + req.WithAppendOnly(sdk.Bool(true)) + } + if showInitialRows { + req.WithShowInitialRows(sdk.Bool(true)) + } + if v, ok := d.GetOk("comment"); ok { + req.WithComment(sdk.String(v.(string))) + } + err := client.Streams.CreateOnTable(ctx, req) + if err != nil { + return fmt.Errorf("error creating stream %v err = %w", name, err) + } + } case onViewSet: - id, err := streamOnObjectIDFromString(onView.(string)) + viewObjectIdentifier, err := helpers.DecodeSnowflakeParameterID(onView.(string)) + viewId := viewObjectIdentifier.(sdk.SchemaObjectIdentifier) if err != nil { return err } - tq := snowflake.NewViewBuilder(id.Name).WithDB(id.DatabaseName).WithSchema(id.SchemaName).Show() + tq := snowflake.NewViewBuilder(viewId.Name()).WithDB(viewId.DatabaseName()).WithSchema(viewId.SchemaName()).Show() viewRow := snowflake.QueryRow(db, tq) - - t, err := snowflake.ScanView(viewRow) + _, err = snowflake.ScanView(viewRow) if err != nil { return err } - builder.WithOnView(t.DatabaseName.String, t.SchemaName.String, t.Name.String) + req := sdk.NewCreateStreamOnViewRequest(id, viewId) + if appendOnly { + req.WithAppendOnly(sdk.Bool(true)) + } + if showInitialRows { + req.WithShowInitialRows(sdk.Bool(true)) + } + if v, ok := d.GetOk("comment"); ok { + req.WithComment(sdk.String(v.(string))) + } + err = client.Streams.CreateOnView(ctx, req) + if err != nil { + return fmt.Errorf("error creating stream %v err = %w", name, err) + } case onStageSet: - id, err := streamOnObjectIDFromString(onStage.(string)) + stageObjectIdentifier, err := helpers.DecodeSnowflakeParameterID(onStage.(string)) + stageId := stageObjectIdentifier.(sdk.SchemaObjectIdentifier) if err != nil { return err } - stageBuilder := snowflake.NewStageBuilder(id.Name, id.DatabaseName, id.SchemaName) + stageBuilder := snowflake.NewStageBuilder(stageId.Name(), stageId.DatabaseName(), stageId.SchemaName()) sq := stageBuilder.Describe() - d, err := snowflake.DescStage(db, sq) + stageDesc, err := snowflake.DescStage(db, sq) if err != nil { return err } - if !strings.Contains(d.Directory, "ENABLE = true") { + if !strings.Contains(stageDesc.Directory, "ENABLE = true") { return fmt.Errorf("directory must be enabled on stage") } - - builder.WithOnStage(id.DatabaseName, id.SchemaName, id.Name) - } - - builder.WithAppendOnly(appendOnly) - builder.WithInsertOnly(insertOnly) - builder.WithShowInitialRows(showInitialRows) - - // Set optionals - if v, ok := d.GetOk("comment"); ok { - builder.WithComment(v.(string)) - } - - stmt := builder.Create() - if err := snowflake.Exec(db, stmt); err != nil { - return fmt.Errorf("error creating stream %v", name) + req := sdk.NewCreateStreamOnDirectoryTableRequest(id, stageId) + if v, ok := d.GetOk("comment"); ok { + req.WithComment(sdk.String(v.(string))) + } + err = client.Streams.CreateOnDirectoryTable(ctx, req) + if err != nil { + return fmt.Errorf("error creating stream %v err = %w", name, err) + } } - streamID := &streamID{ - DatabaseName: database, - SchemaName: schema, - StreamName: name, - } - dataIDInput, err := streamID.String() - if err != nil { - return err - } - d.SetId(dataIDInput) + d.SetId(helpers.EncodeSnowflakeID(id)) return ReadStream(d, meta) } @@ -283,92 +227,93 @@ func CreateStream(d *schema.ResourceData, meta interface{}) error { // ReadStream implements schema.ReadFunc. func ReadStream(d *schema.ResourceData, meta interface{}) error { db := meta.(*sql.DB) - streamID, err := streamIDFromString(d.Id()) + client := sdk.NewClientFromDB(db) + ctx := context.Background() + id := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) + stream, err := client.Streams.ShowByID(ctx, sdk.NewShowByIdStreamRequest(id)) if err != nil { - return err - } - - dbName := streamID.DatabaseName - schema := streamID.SchemaName - name := streamID.StreamName - - stmt := snowflake.Stream(name, dbName, schema).Show() - row := snowflake.QueryRow(db, stmt) - stream, err := snowflake.ScanStream(row) - if errors.Is(err, sql.ErrNoRows) { - // If not found, mark resource to be removed from state file during apply or refresh log.Printf("[DEBUG] stream (%s) not found", d.Id()) d.SetId("") return nil } - if err != nil { - return err - } - - if err := d.Set("name", stream.StreamName.String); err != nil { + if err := d.Set("name", stream.Name); err != nil { return err } - - if err := d.Set("database", stream.DatabaseName.String); err != nil { + if err := d.Set("database", stream.DatabaseName); err != nil { return err } - - if err := d.Set("schema", stream.SchemaName.String); err != nil { + if err := d.Set("schema", stream.SchemaName); err != nil { return err } - - switch stream.SourceType.String { + switch *stream.SourceType { case "Stage": - if err := d.Set("on_stage", stream.TableName.String); err != nil { + if err := d.Set("on_stage", *stream.TableName); err != nil { return err } case "View": - if err := d.Set("on_view", stream.TableName.String); err != nil { + if err := d.Set("on_view", *stream.TableName); err != nil { return err } default: - if err := d.Set("on_table", stream.TableName.String); err != nil { + if err := d.Set("on_table", *stream.TableName); err != nil { return err } } - - if err := d.Set("append_only", stream.Mode.String == "APPEND_ONLY"); err != nil { + if err := d.Set("append_only", *stream.Mode == "APPEND_ONLY"); err != nil { return err } - - if err := d.Set("insert_only", stream.Mode.String == "INSERT_ONLY"); err != nil { + if err := d.Set("insert_only", *stream.Mode == "INSERT_ONLY"); err != nil { return err } - - if err := d.Set("show_initial_rows", stream.ShowInitialRows); err != nil { + // TODO: SHOW STREAMS doesn't return that value right now (I'm not sure if it ever did), but probably we can assume + // the customers got 'false' every time and hardcode it (it's only on create thing, so it's not necessary + // to track its value after creation). + if err := d.Set("show_initial_rows", false); err != nil { return err } - - if err := d.Set("comment", stream.Comment.String); err != nil { + if err := d.Set("comment", *stream.Comment); err != nil { return err } - - if err := d.Set("owner", stream.Owner.String); err != nil { + if err := d.Set("owner", *stream.Owner); err != nil { return err } return nil } -// DeleteStream implements schema.DeleteFunc. -func DeleteStream(d *schema.ResourceData, meta interface{}) error { +// UpdateStream implements schema.UpdateFunc. +func UpdateStream(d *schema.ResourceData, meta interface{}) error { db := meta.(*sql.DB) - streamID, err := streamIDFromString(d.Id()) - if err != nil { - return err + client := sdk.NewClientFromDB(db) + ctx := context.Background() + id := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) + + if d.HasChange("comment") { + comment := d.Get("comment").(string) + if comment == "" { + err := client.Streams.Alter(ctx, sdk.NewAlterStreamRequest(id).WithUnsetComment(sdk.Bool(true))) + if err != nil { + return fmt.Errorf("error unsetting stream comment on %v", d.Id()) + } + } else { + err := client.Streams.Alter(ctx, sdk.NewAlterStreamRequest(id).WithSetComment(sdk.String(comment))) + if err != nil { + return fmt.Errorf("error setting stream comment on %v", d.Id()) + } + } } - dbName := streamID.DatabaseName - schema := streamID.SchemaName - streamName := streamID.StreamName + return ReadStream(d, meta) +} - q := snowflake.Stream(streamName, dbName, schema).Drop() +// DeleteStream implements schema.DeleteFunc. +func DeleteStream(d *schema.ResourceData, meta interface{}) error { + db := meta.(*sql.DB) + client := sdk.NewClientFromDB(db) + ctx := context.Background() + streamId := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) - if err := snowflake.Exec(db, q); err != nil { + err := client.Streams.Drop(ctx, sdk.NewDropStreamRequest(streamId)) + if err != nil { return fmt.Errorf("error deleting stream %v err = %w", d.Id(), err) } @@ -376,28 +321,3 @@ func DeleteStream(d *schema.ResourceData, meta interface{}) error { return nil } - -// UpdateStream implements schema.UpdateFunc. -func UpdateStream(d *schema.ResourceData, meta interface{}) error { - streamID, err := streamIDFromString(d.Id()) - if err != nil { - return err - } - - dbName := streamID.DatabaseName - schema := streamID.SchemaName - streamName := streamID.StreamName - - builder := snowflake.Stream(streamName, dbName, schema) - - db := meta.(*sql.DB) - if d.HasChange("comment") { - comment := d.Get("comment") - q := builder.ChangeComment(comment.(string)) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error updating stream comment on %v", d.Id()) - } - } - - return ReadStream(d, meta) -} diff --git a/pkg/resources/stream_internal_test.go b/pkg/resources/stream_internal_test.go deleted file mode 100644 index 0b7925cf6d..0000000000 --- a/pkg/resources/stream_internal_test.go +++ /dev/null @@ -1,104 +0,0 @@ -package resources - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestStreamIDFromString(t *testing.T) { - r := require.New(t) - // Vanilla - id := "database_name|schema_name|stream" - stream, err := streamIDFromString(id) - r.NoError(err) - r.Equal("database_name", stream.DatabaseName) - r.Equal("schema_name", stream.SchemaName) - r.Equal("stream", stream.StreamName) - - // Bad ID -- not enough fields - id = "database" - _, err = streamIDFromString(id) - r.Equal(fmt.Errorf("3 fields allowed"), err) - - // Bad ID - id = "||" - _, err = streamIDFromString(id) - r.NoError(err) - - // 0 lines - id = "" - _, err = streamIDFromString(id) - r.Equal(fmt.Errorf("1 line at a time"), err) - - // 2 lines - id = `database_name|schema_name|stream - database_name|schema_name|stream` - _, err = streamIDFromString(id) - r.Equal(fmt.Errorf("1 line at a time"), err) -} - -func TestStreamStruct(t *testing.T) { - r := require.New(t) - - // Vanilla - stream := &streamID{ - DatabaseName: "database_name", - SchemaName: "schema_name", - StreamName: "stream_name", - } - sID, err := stream.String() - r.NoError(err) - r.Equal("database_name|schema_name|stream_name", sID) - - // Empty grant - stream = &streamID{} - sID, err = stream.String() - r.NoError(err) - r.Equal("||", sID) - - // Grant with extra delimiters - stream = &streamID{ - DatabaseName: "database|name", - StreamName: "stream|name", - } - sID, err = stream.String() - r.NoError(err) - newStream, err := streamIDFromString(sID) - r.NoError(err) - r.Equal("database|name", newStream.DatabaseName) - r.Equal("stream|name", newStream.StreamName) -} - -func TestStreamOnTableIDFromString(t *testing.T) { - r := require.New(t) - // Vanilla - id := "database_name.schema_name.target_table_name" - streamOnTable, err := streamOnObjectIDFromString(id) - r.NoError(err) - r.Equal("database_name", streamOnTable.DatabaseName) - r.Equal("schema_name", streamOnTable.SchemaName) - r.Equal("target_table_name", streamOnTable.Name) - - // Bad ID -- not enough fields - id = "database.schema" - _, err = streamOnObjectIDFromString(id) - r.Equal(fmt.Errorf("invalid format for on_table: database.schema , expected: "), err) - - // Bad ID - id = ".." - _, err = streamOnObjectIDFromString(id) - r.NoError(err) - - // 0 lines - id = "" - _, err = streamOnObjectIDFromString(id) - r.Equal(fmt.Errorf("1 line at a time"), err) - - // 2 lines - id = `database_name.schema_name.target_table_name - database_name.schema_name.target_table_name` - _, err = streamOnObjectIDFromString(id) - r.Equal(fmt.Errorf("1 line at a time"), err) -} diff --git a/pkg/resources/stream_test.go b/pkg/resources/stream_test.go deleted file mode 100644 index 9f3b0d8dca..0000000000 --- a/pkg/resources/stream_test.go +++ /dev/null @@ -1,238 +0,0 @@ -package resources_test - -import ( - "database/sql" - "testing" - "time" - - sqlmock "github.com/DATA-DOG/go-sqlmock" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/resources" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake" - . "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/testhelpers" - "github.com/stretchr/testify/require" -) - -func TestStream(t *testing.T) { - r := require.New(t) - err := resources.Stream().InternalValidate(provider.Provider().Schema, true) - r.NoError(err) -} - -func TestStreamCreate(t *testing.T) { - r := require.New(t) - - in := map[string]interface{}{ - "name": "stream_name", - "database": "database_name", - "schema": "schema_name", - "comment": "great comment", - "on_table": "target_db.target_schema.target_table", - "append_only": true, - "insert_only": false, - "show_initial_rows": true, - } - d := stream(t, "database_name|schema_name|stream_name", in) - - WithMockDb(t, func(db *sql.DB, mock sqlmock.Sqlmock) { - mock.ExpectExec(`CREATE STREAM "database_name"."schema_name"."stream_name" ON TABLE "target_db"."target_schema"."target_table" COMMENT = 'great comment' APPEND_ONLY = true INSERT_ONLY = false SHOW_INITIAL_ROWS = true`).WillReturnResult(sqlmock.NewResult(1, 1)) - expectStreamRead(mock) - expectOnTableRead(mock) - err := resources.CreateStream(d, db) - r.NoError(err) - r.Equal("stream_name", d.Get("name").(string)) - }) -} - -func TestStreamCreateOnExternalTable(t *testing.T) { - r := require.New(t) - - in := map[string]interface{}{ - "name": "stream_name", - "database": "database_name", - "schema": "schema_name", - "comment": "great comment", - "on_table": "target_db.target_schema.target_table", - "append_only": true, - "insert_only": false, - "show_initial_rows": true, - } - d := stream(t, "database_name|schema_name|stream_name", in) - - WithMockDb(t, func(db *sql.DB, mock sqlmock.Sqlmock) { - mock.ExpectExec(`CREATE STREAM "database_name"."schema_name"."stream_name" ON EXTERNAL TABLE "target_db"."target_schema"."target_table" COMMENT = 'great comment' APPEND_ONLY = true INSERT_ONLY = false SHOW_INITIAL_ROWS = true`).WillReturnResult(sqlmock.NewResult(1, 1)) - expectStreamRead(mock) - expectOnExternalTableRead(mock) - err := resources.CreateStream(d, db) - r.NoError(err) - r.Equal("stream_name", d.Get("name").(string)) - }) -} - -func TestStreamCreateOnView(t *testing.T) { - r := require.New(t) - - in := map[string]interface{}{ - "name": "stream_name", - "database": "database_name", - "schema": "schema_name", - "comment": "great comment", - "on_view": "target_db.target_schema.target_view", - "append_only": true, - "insert_only": false, - "show_initial_rows": true, - } - d := stream(t, "database_name|schema_name|stream_name", in) - - WithMockDb(t, func(db *sql.DB, mock sqlmock.Sqlmock) { - mock.ExpectExec(`CREATE STREAM "database_name"."schema_name"."stream_name" ON VIEW "target_db"."target_schema"."target_view" COMMENT = 'great comment' APPEND_ONLY = true INSERT_ONLY = false SHOW_INITIAL_ROWS = true`).WillReturnResult(sqlmock.NewResult(1, 1)) - expectStreamRead(mock) - expectOnViewRead(mock) - err := resources.CreateStream(d, db) - r.NoError(err) - r.Equal("stream_name", d.Get("name").(string)) - }) -} - -func TestStreamOnMultipleSource(t *testing.T) { - r := require.New(t) - - in := map[string]interface{}{ - "name": "stream_name", - "database": "database_name", - "schema": "schema_name", - "comment": "great comment", - "on_table": "target_db.target_schema.target_table", - "on_view": "target_db.target_schema.target_view", - "append_only": true, - "insert_only": false, - "show_initial_rows": true, - } - d := stream(t, "database_name|schema_name|stream_name", in) - - WithMockDb(t, func(db *sql.DB, mock sqlmock.Sqlmock) { - err := resources.CreateStream(d, db) - r.ErrorContains(err, "all expectations were already fulfilled,") - }) -} - -func expectStreamRead(mock sqlmock.Sqlmock) { - rows := sqlmock.NewRows([]string{"name", "database_name", "schema_name", "owner", "comment", "table_name", "type", "stale", "mode"}).AddRow("stream_name", "database_name", "schema_name", "owner_name", "grand comment", "target_table", "DELTA", false, "APPEND_ONLY") - mock.ExpectQuery(`SHOW STREAMS LIKE 'stream_name' IN SCHEMA "database_name"."schema_name"`).WillReturnRows(rows) -} - -func expectOnTableRead(mock sqlmock.Sqlmock) { - rows := sqlmock.NewRows([]string{"created_on", "name", "database_name", "schema_name", "kind", "comment", "cluster_by", "row", "bytes", "owner", "retention_time", "automatic_clustering", "change_tracking", "is_external"}).AddRow("", "target_table", "target_db", "target_schema", "TABLE", "mock comment", "", "", "", "", 1, "OFF", "OFF", "N") - mock.ExpectQuery(`SHOW TABLES LIKE 'target_table' IN SCHEMA "target_db"."target_schema"`).WillReturnRows(rows) -} - -func expectOnExternalTableRead(mock sqlmock.Sqlmock) { - rows := sqlmock.NewRows([]string{"created_on", "name", "database_name", "schema_name", "kind", "comment", "cluster_by", "row", "bytes", "owner", "retention_time", "automatic_clustering", "change_tracking", "is_external"}).AddRow("", "target_table", "target_db", "target_schema", "TABLE", "mock comment", "", "", "", "", 1, "OFF", "OFF", "Y") - mock.ExpectQuery(`SHOW TABLES LIKE 'target_table' IN SCHEMA "target_db"."target_schema"`).WillReturnRows(rows) -} - -func expectOnViewRead(mock sqlmock.Sqlmock) { - rows := sqlmock.NewRows([]string{"created_on", "name", "database_name", "schema_name", "kind", "comment", "cluster_by", "row", "bytes", "owner", "retention_time", "automatic_clustering", "change_tracking", "is_external"}).AddRow(time.Now(), "target_view", "target_db", "target_schema", "VIEW", "mock comment", "", "", "", "", 1, "OFF", "OFF", "Y") - mock.ExpectQuery(`SHOW VIEWS LIKE 'target_view' IN SCHEMA "target_db"."target_schema"`).WillReturnRows(rows) -} - -func TestStreamRead(t *testing.T) { - r := require.New(t) - - d := stream(t, "database_name|schema_name|stream_name", map[string]interface{}{"name": "stream_name", "comment": "grand comment"}) - - WithMockDb(t, func(db *sql.DB, mock sqlmock.Sqlmock) { - expectStreamRead(mock) - err := resources.ReadStream(d, db) - r.NoError(err) - r.Equal("stream_name", d.Get("name").(string)) - r.Equal("database_name", d.Get("database").(string)) - r.Equal("schema_name", d.Get("schema").(string)) - r.Equal("grand comment", d.Get("comment").(string)) - - // Test when resource is not found, checking if state will be empty - r.NotEmpty(d.State()) - q := snowflake.Stream("stream_name", "database_name", "schema_name").Show() - mock.ExpectQuery(q).WillReturnError(sql.ErrNoRows) - err2 := resources.ReadStream(d, db) - r.Empty(d.State()) - r.Nil(err2) - }) -} - -func TestStreamReadAppendOnlyMode(t *testing.T) { - r := require.New(t) - - d := stream(t, "database_name|schema_name|stream_name", map[string]interface{}{"name": "stream_name", "comment": "grand comment"}) - - WithMockDb(t, func(db *sql.DB, mock sqlmock.Sqlmock) { - rows := sqlmock.NewRows([]string{"name", "database_name", "schema_name", "owner", "comment", "table_name", "type", "stale", "mode"}).AddRow("stream_name", "database_name", "schema_name", "owner_name", "grand comment", "target_table", "DELTA", false, "APPEND_ONLY") - mock.ExpectQuery(`SHOW STREAMS LIKE 'stream_name' IN SCHEMA "database_name"."schema_name"`).WillReturnRows(rows) - err := resources.ReadStream(d, db) - r.NoError(err) - r.Equal(true, d.Get("append_only").(bool)) - }) -} - -func TestStreamReadInsertOnlyMode(t *testing.T) { - r := require.New(t) - - d := stream(t, "database_name|schema_name|stream_name", map[string]interface{}{"name": "stream_name", "comment": "grand comment"}) - - WithMockDb(t, func(db *sql.DB, mock sqlmock.Sqlmock) { - rows := sqlmock.NewRows([]string{"name", "database_name", "schema_name", "owner", "comment", "table_name", "type", "stale", "mode"}).AddRow("stream_name", "database_name", "schema_name", "owner_name", "grand comment", "target_table", "DELTA", false, "INSERT_ONLY") - mock.ExpectQuery(`SHOW STREAMS LIKE 'stream_name' IN SCHEMA "database_name"."schema_name"`).WillReturnRows(rows) - err := resources.ReadStream(d, db) - r.NoError(err) - r.Equal(true, d.Get("insert_only").(bool)) - }) -} - -func TestStreamReadDefaultMode(t *testing.T) { - r := require.New(t) - - d := stream(t, "database_name|schema_name|stream_name", map[string]interface{}{"name": "stream_name", "comment": "grand comment"}) - - WithMockDb(t, func(db *sql.DB, mock sqlmock.Sqlmock) { - rows := sqlmock.NewRows([]string{"name", "database_name", "schema_name", "owner", "comment", "table_name", "type", "stale", "mode"}).AddRow("stream_name", "database_name", "schema_name", "owner_name", "grand comment", "target_table", "DELTA", false, "DEFAULT") - mock.ExpectQuery(`SHOW STREAMS LIKE 'stream_name' IN SCHEMA "database_name"."schema_name"`).WillReturnRows(rows) - err := resources.ReadStream(d, db) - r.NoError(err) - r.Equal(false, d.Get("append_only").(bool)) - }) -} - -func TestStreamDelete(t *testing.T) { - r := require.New(t) - - d := stream(t, "database_name|schema_name|drop_it", map[string]interface{}{"name": "drop_it"}) - - WithMockDb(t, func(db *sql.DB, mock sqlmock.Sqlmock) { - mock.ExpectExec(`DROP STREAM "database_name"."schema_name"."drop_it"`).WillReturnResult(sqlmock.NewResult(1, 1)) - err := resources.DeleteStream(d, db) - r.NoError(err) - }) -} - -func TestStreamUpdate(t *testing.T) { - r := require.New(t) - - in := map[string]interface{}{ - "name": "stream_name", - "database": "database_name", - "schema": "schema_name", - "comment": "new stream comment", - "on_table": "target_table", - "append_only": true, - "insert_only": false, - } - - d := stream(t, "database_name|schema_name|stream_name", in) - - WithMockDb(t, func(db *sql.DB, mock sqlmock.Sqlmock) { - mock.ExpectExec(`ALTER STREAM "database_name"."schema_name"."stream_name" SET COMMENT = 'new stream comment'`).WillReturnResult(sqlmock.NewResult(1, 1)) - expectStreamRead(mock) - err := resources.UpdateStream(d, db) - r.NoError(err) - }) -} diff --git a/pkg/snowflake/stream.go b/pkg/snowflake/stream.go deleted file mode 100644 index d1baced51c..0000000000 --- a/pkg/snowflake/stream.go +++ /dev/null @@ -1,198 +0,0 @@ -package snowflake - -import ( - "database/sql" - "errors" - "fmt" - "log" - "strings" - - "github.com/jmoiron/sqlx" -) - -// StreamBuilder abstracts the creation of SQL queries for a Snowflake stream. -type StreamBuilder struct { - name string - db string - schema string - externalTable bool - onTable string - onView string - onStage string - appendOnly bool - insertOnly bool - showInitialRows bool - comment string -} - -// QualifiedName prepends the db and schema if set and escapes everything nicely. -func (sb *StreamBuilder) QualifiedName() string { - var n strings.Builder - - if sb.db != "" && sb.schema != "" { - n.WriteString(fmt.Sprintf(`"%v"."%v".`, sb.db, sb.schema)) - } - - if sb.db != "" && sb.schema == "" { - n.WriteString(fmt.Sprintf(`"%v"..`, sb.db)) - } - - if sb.db == "" && sb.schema != "" { - n.WriteString(fmt.Sprintf(`"%v".`, sb.schema)) - } - - n.WriteString(fmt.Sprintf(`"%v"`, sb.name)) - - return n.String() -} - -func (sb *StreamBuilder) WithComment(c string) *StreamBuilder { - sb.comment = c - return sb -} - -func (sb *StreamBuilder) WithOnTable(d string, s string, t string) *StreamBuilder { - sb.onTable = fmt.Sprintf(`"%v"."%v"."%v"`, d, s, t) - return sb -} - -func (sb *StreamBuilder) WithExternalTable(b bool) *StreamBuilder { - sb.externalTable = b - return sb -} - -func (sb *StreamBuilder) WithOnView(d string, s string, t string) *StreamBuilder { - sb.onView = fmt.Sprintf(`"%v"."%v"."%v"`, d, s, t) - return sb -} - -func (sb *StreamBuilder) WithOnStage(d string, s string, t string) *StreamBuilder { - sb.onStage = fmt.Sprintf(`"%v"."%v"."%v"`, d, s, t) - return sb -} - -func (sb *StreamBuilder) WithAppendOnly(b bool) *StreamBuilder { - sb.appendOnly = b - return sb -} - -func (sb *StreamBuilder) WithInsertOnly(b bool) *StreamBuilder { - sb.insertOnly = b - return sb -} - -func (sb *StreamBuilder) WithShowInitialRows(b bool) *StreamBuilder { - sb.showInitialRows = b - return sb -} - -// Stream returns a pointer to a Builder that abstracts the DDL operations for a stream. -// -// Supported DDL operations are: -// - CREATE Stream -// - ALTER Stream -// - DROP Stream -// - SHOW Stream -// -// [Snowflake Reference](https://docs.snowflake.com/en/sql-reference/sql/create-stream.html) -func Stream(name, db, schema string) *StreamBuilder { - return &StreamBuilder{ - name: name, - db: db, - schema: schema, - } -} - -// Create returns the SQL statement required to create a stream. -func (sb *StreamBuilder) Create() string { - q := strings.Builder{} - q.WriteString(fmt.Sprintf(`CREATE STREAM %v`, sb.QualifiedName())) - - q.WriteString(` ON`) - - switch { - case sb.onTable != "": - if sb.externalTable { - q.WriteString(` EXTERNAL`) - } - q.WriteString(fmt.Sprintf(` TABLE %v`, sb.onTable)) - case sb.onView != "": - q.WriteString(fmt.Sprintf(` VIEW %v`, sb.onView)) - case sb.onStage != "": - q.WriteString(fmt.Sprintf(` STAGE %v`, sb.onStage)) - } - - if sb.comment != "" { - q.WriteString(fmt.Sprintf(` COMMENT = '%v'`, EscapeString(sb.comment))) - } - - if sb.onStage == "" { - q.WriteString(fmt.Sprintf(` APPEND_ONLY = %v`, sb.appendOnly)) - - q.WriteString(fmt.Sprintf(` INSERT_ONLY = %v`, sb.insertOnly)) - - q.WriteString(fmt.Sprintf(` SHOW_INITIAL_ROWS = %v`, sb.showInitialRows)) - } - - return q.String() -} - -// ChangeComment returns the SQL query that will update the comment on the stream. -func (sb *StreamBuilder) ChangeComment(c string) string { - return fmt.Sprintf(`ALTER STREAM %v SET COMMENT = '%v'`, sb.QualifiedName(), EscapeString(c)) -} - -// RemoveComment returns the SQL query that will remove the comment on the stream. -func (sb *StreamBuilder) RemoveComment() string { - return fmt.Sprintf(`ALTER STREAM %v UNSET COMMENT`, sb.QualifiedName()) -} - -// Drop returns the SQL query that will drop a stream. -func (sb *StreamBuilder) Drop() string { - return fmt.Sprintf(`DROP STREAM %v`, sb.QualifiedName()) -} - -// Show returns the SQL query that will show a stream. -func (sb *StreamBuilder) Show() string { - return fmt.Sprintf(`SHOW STREAMS LIKE '%v' IN SCHEMA "%v"."%v"`, sb.name, sb.db, sb.schema) -} - -type DescStreamRow struct { - CreatedOn sql.NullString `db:"created_on"` - StreamName sql.NullString `db:"name"` - DatabaseName sql.NullString `db:"database_name"` - SchemaName sql.NullString `db:"schema_name"` - Owner sql.NullString `db:"owner"` - Comment sql.NullString `db:"comment"` - ShowInitialRows bool `db:"show_initial_rows"` - TableName sql.NullString `db:"table_name"` - Type sql.NullString `db:"type"` - Stale sql.NullString `db:"stale"` - Mode sql.NullString `db:"mode"` - SourceType sql.NullString `db:"source_type"` -} - -func ScanStream(row *sqlx.Row) (*DescStreamRow, error) { - t := &DescStreamRow{} - e := row.StructScan(t) - return t, e -} - -func ListStreams(databaseName string, schemaName string, db *sql.DB) ([]DescStreamRow, error) { - stmt := fmt.Sprintf(`SHOW STREAMS IN SCHEMA "%s"."%v"`, databaseName, schemaName) - rows, err := Query(db, stmt) - if err != nil { - return nil, err - } - defer rows.Close() - - dbs := []DescStreamRow{} - if err := sqlx.StructScan(rows, &dbs); err != nil { - if errors.Is(err, sql.ErrNoRows) { - log.Println("[DEBUG] no stages found") - return nil, nil - } - return nil, fmt.Errorf("unable to scan row for %s err = %w", stmt, err) - } - return dbs, nil -} diff --git a/pkg/snowflake/stream_test.go b/pkg/snowflake/stream_test.go deleted file mode 100644 index ad5ff9c629..0000000000 --- a/pkg/snowflake/stream_test.go +++ /dev/null @@ -1,85 +0,0 @@ -package snowflake - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestStreamCreate(t *testing.T) { - r := require.New(t) - s := Stream("test_stream", "test_db", "test_schema") - - s.WithOnTable("test_db", "test_schema", "test_target_table") - r.Equal(`CREATE STREAM "test_db"."test_schema"."test_stream" ON TABLE "test_db"."test_schema"."test_target_table" APPEND_ONLY = false INSERT_ONLY = false SHOW_INITIAL_ROWS = false`, s.Create()) - - s.WithComment("Test Comment") - r.Equal(`CREATE STREAM "test_db"."test_schema"."test_stream" ON TABLE "test_db"."test_schema"."test_target_table" COMMENT = 'Test Comment' APPEND_ONLY = false INSERT_ONLY = false SHOW_INITIAL_ROWS = false`, s.Create()) - - s.WithShowInitialRows(true) - r.Equal(`CREATE STREAM "test_db"."test_schema"."test_stream" ON TABLE "test_db"."test_schema"."test_target_table" COMMENT = 'Test Comment' APPEND_ONLY = false INSERT_ONLY = false SHOW_INITIAL_ROWS = true`, s.Create()) - - s.WithAppendOnly(true) - r.Equal(`CREATE STREAM "test_db"."test_schema"."test_stream" ON TABLE "test_db"."test_schema"."test_target_table" COMMENT = 'Test Comment' APPEND_ONLY = true INSERT_ONLY = false SHOW_INITIAL_ROWS = true`, s.Create()) - - s.WithInsertOnly(true) - r.Equal(`CREATE STREAM "test_db"."test_schema"."test_stream" ON TABLE "test_db"."test_schema"."test_target_table" COMMENT = 'Test Comment' APPEND_ONLY = true INSERT_ONLY = true SHOW_INITIAL_ROWS = true`, s.Create()) - - s.WithExternalTable(true) - r.Equal(`CREATE STREAM "test_db"."test_schema"."test_stream" ON EXTERNAL TABLE "test_db"."test_schema"."test_target_table" COMMENT = 'Test Comment' APPEND_ONLY = true INSERT_ONLY = true SHOW_INITIAL_ROWS = true`, s.Create()) -} - -func TestStreamOnStageCreate(t *testing.T) { - r := require.New(t) - s := Stream("test_stream", "test_db", "test_schema") - - s.WithOnStage("test_db", "test_schema", "test_target_stage") - r.Equal(`CREATE STREAM "test_db"."test_schema"."test_stream" ON STAGE "test_db"."test_schema"."test_target_stage"`, s.Create()) - - s.WithComment("Test Comment") - r.Equal(`CREATE STREAM "test_db"."test_schema"."test_stream" ON STAGE "test_db"."test_schema"."test_target_stage" COMMENT = 'Test Comment'`, s.Create()) -} - -func TestStreamOnViewCreate(t *testing.T) { - r := require.New(t) - s := Stream("test_stream", "test_db", "test_schema") - - s.WithOnView("test_db", "test_schema", "test_target_view") - r.Equal(`CREATE STREAM "test_db"."test_schema"."test_stream" ON VIEW "test_db"."test_schema"."test_target_view" APPEND_ONLY = false INSERT_ONLY = false SHOW_INITIAL_ROWS = false`, s.Create()) - - s.WithComment("Test Comment") - r.Equal(`CREATE STREAM "test_db"."test_schema"."test_stream" ON VIEW "test_db"."test_schema"."test_target_view" COMMENT = 'Test Comment' APPEND_ONLY = false INSERT_ONLY = false SHOW_INITIAL_ROWS = false`, s.Create()) - - s.WithShowInitialRows(true) - r.Equal(`CREATE STREAM "test_db"."test_schema"."test_stream" ON VIEW "test_db"."test_schema"."test_target_view" COMMENT = 'Test Comment' APPEND_ONLY = false INSERT_ONLY = false SHOW_INITIAL_ROWS = true`, s.Create()) - - s.WithAppendOnly(true) - r.Equal(`CREATE STREAM "test_db"."test_schema"."test_stream" ON VIEW "test_db"."test_schema"."test_target_view" COMMENT = 'Test Comment' APPEND_ONLY = true INSERT_ONLY = false SHOW_INITIAL_ROWS = true`, s.Create()) - - s.WithInsertOnly(true) - r.Equal(`CREATE STREAM "test_db"."test_schema"."test_stream" ON VIEW "test_db"."test_schema"."test_target_view" COMMENT = 'Test Comment' APPEND_ONLY = true INSERT_ONLY = true SHOW_INITIAL_ROWS = true`, s.Create()) -} - -func TestStreamChangeComment(t *testing.T) { - r := require.New(t) - s := Stream("test_stream", "test_db", "test_schema") - r.Equal(`ALTER STREAM "test_db"."test_schema"."test_stream" SET COMMENT = 'new stream comment'`, s.ChangeComment("new stream comment")) -} - -func TestStreamRemoveComment(t *testing.T) { - r := require.New(t) - s := Stream("test_stream", "test_db", "test_schema") - r.Equal(`ALTER STREAM "test_db"."test_schema"."test_stream" UNSET COMMENT`, s.RemoveComment()) -} - -func TestStreamDrop(t *testing.T) { - r := require.New(t) - s := Stream("test_stream", "test_db", "test_schema") - r.Equal(`DROP STREAM "test_db"."test_schema"."test_stream"`, s.Drop()) -} - -func TestStreamShow(t *testing.T) { - r := require.New(t) - s := Stream("test_stream", "test_db", "test_schema") - r.Equal(`SHOW STREAMS LIKE 'test_stream' IN SCHEMA "test_db"."test_schema"`, s.Show()) -} From 5db751d1aa71952b1528e81cf2fdcd05d9d5d0fb Mon Sep 17 00:00:00 2001 From: Scott Winkler Date: Tue, 24 Oct 2023 03:45:08 -0700 Subject: [PATCH 06/20] fix: cleanup acc tests (#2135) * update acc tests * fix acc test * fix fmt * fix comments * template database and schema names * update acc * fix error * fix error * fix error * fix acc tests --- ...sword_policy_attachment_acceptance_test.go | 21 +- pkg/resources/alert_acceptance_test.go | 47 +- .../database_grant_acceptance_test.go | 14 +- .../database_role_acceptance_test.go | 19 +- ...otification_integration_acceptance_test.go | 5 +- .../external_stage_acceptance_test.go | 25 +- .../external_table_acceptance_test.go | 53 +- .../external_table_grant_acceptance_test.go | 31 +- .../failover_group_acceptance_test.go | 50 +- pkg/resources/file_format_acceptance_test.go | 205 ++---- .../file_format_grant_acceptance_test.go | 49 +- pkg/resources/function_acceptance_test.go | 35 +- .../function_grant_acceptance_test.go | 39 +- ...rant_privileges_to_role_acceptance_test.go | 142 ++-- .../internal_stage_acceptance_test.go | 25 +- .../masking_policy_acceptance_test.go | 50 +- .../masking_policy_grant_acceptance_test.go | 35 +- .../materialized_view_acceptance_test.go | 46 +- ...materialized_view_grant_acceptance_test.go | 29 +- .../object_parameter_acceptance_test.go | 14 +- .../password_policy_acceptance_test.go | 26 +- pkg/resources/pipe_acceptance_test.go | 39 +- pkg/resources/pipe_grant_acceptance_test.go | 92 +-- pkg/resources/procedure.go | 6 +- pkg/resources/procedure_acceptance_test.go | 35 +- .../procedure_grant_acceptance_test.go | 33 +- pkg/resources/resource_monitor.go | 16 - .../resource_monitor_grant_acceptance_test.go | 2 +- pkg/resources/role_grants_acceptance_test.go | 2 +- .../row_access_policy_acceptance_test.go | 25 +- ...row_access_policy_grant_acceptance_test.go | 35 +- pkg/resources/schema_acceptance_test.go | 25 +- pkg/resources/schema_grant_acceptance_test.go | 24 +- pkg/resources/sequence_acceptance_test.go | 73 +- .../sequence_grant_acceptance_test.go | 29 +- pkg/resources/stage_acceptance_test.go | 25 +- pkg/resources/stage_grant_acceptance_test.go | 53 +- .../storage_integration_acceptance_test.go | 4 +- pkg/resources/stream_acceptance_test.go | 29 +- pkg/resources/stream_grant_acceptance_test.go | 53 +- pkg/resources/table_acceptance_test.go | 695 ++++++------------ ...king_policy_application_acceptance_test.go | 31 +- .../table_constraint_acceptance_test.go | 53 +- pkg/resources/table_grant_acceptance_test.go | 28 +- pkg/resources/tag_acceptance_test.go | 27 +- .../tag_association_acceptance_test.go | 121 ++- pkg/resources/tag_grant_acceptance_test.go | 37 +- .../tag_masking_policy_association.go | 15 - ...king_policy_association_acceptance_test.go | 32 +- pkg/resources/task_acceptance_test.go | 397 +++++----- pkg/resources/task_grant_acceptance_test.go | 124 ++-- pkg/resources/view_acceptance_test.go | 32 +- pkg/resources/view_grant_acceptance_test.go | 113 ++- pkg/sdk/sweepers.go | 2 +- 54 files changed, 1173 insertions(+), 2094 deletions(-) diff --git a/pkg/resources/account_password_policy_attachment_acceptance_test.go b/pkg/resources/account_password_policy_attachment_acceptance_test.go index 89cbeeb62c..6cc88eb0a9 100644 --- a/pkg/resources/account_password_policy_attachment_acceptance_test.go +++ b/pkg/resources/account_password_policy_attachment_acceptance_test.go @@ -19,7 +19,7 @@ func TestAcc_AccountPasswordPolicyAttachment(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: accountPasswordPolicyAttachmentConfig(prefix), + Config: accountPasswordPolicyAttachmentConfig(acc.TestDatabaseName, acc.TestSchemaName, prefix), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttrSet("snowflake_account_password_policy_attachment.att", "id"), ), @@ -41,22 +41,11 @@ func TestAcc_AccountPasswordPolicyAttachment(t *testing.T) { }) } -func accountPasswordPolicyAttachmentConfig(prefix string) string { +func accountPasswordPolicyAttachmentConfig(databaseName, schemaName, prefix string) string { s := ` -resource "snowflake_database" "test" { - name = "%v" - comment = "Terraform acceptance test" -} - -resource "snowflake_schema" "test" { - name = "%v" - database = snowflake_database.test.name - comment = "Terraform acceptance test" - } - resource "snowflake_password_policy" "pa" { - database = snowflake_database.test.name - schema = snowflake_schema.test.name + database = "%s" + schema = "%s" name = "%v" } @@ -64,5 +53,5 @@ resource "snowflake_account_password_policy_attachment" "att" { password_policy = snowflake_password_policy.pa.qualified_name } ` - return fmt.Sprintf(s, prefix, prefix, prefix) + return fmt.Sprintf(s, databaseName, schemaName, prefix) } diff --git a/pkg/resources/alert_acceptance_test.go b/pkg/resources/alert_acceptance_test.go index 5522997c76..f4e78eb593 100644 --- a/pkg/resources/alert_acceptance_test.go +++ b/pkg/resources/alert_acceptance_test.go @@ -33,19 +33,16 @@ type ( var ( warehouseName = "wh_" + strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) - databaseName = "db_" + strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) - schemaName = "PUBLIC" alertName = "a_" + strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) alertInitialState = &AccAlertTestSettings{ //nolint WarehouseName: warehouseName, - DatabaseName: databaseName, - + DatabaseName: acc.TestDatabaseName, Alert: &AlertSettings{ Name: alertName, - Schema: schemaName, Condition: "select 0 as c", Action: "select 0 as c", + Schema: acc.TestSchemaName, Enabled: true, Schedule: 5, Comment: "dummy", @@ -55,13 +52,12 @@ var ( // Changes: condition, action, comment, schedule. alertStepOne = &AccAlertTestSettings{ //nolint WarehouseName: warehouseName, - DatabaseName: databaseName, - + DatabaseName: acc.TestDatabaseName, Alert: &AlertSettings{ Name: alertName, - Schema: schemaName, Condition: "select 1 as c", Action: "select 1 as c", + Schema: acc.TestSchemaName, Enabled: true, Schedule: 15, Comment: "test", @@ -71,13 +67,12 @@ var ( // Changes: condition, action, comment, schedule. alertStepTwo = &AccAlertTestSettings{ //nolint WarehouseName: warehouseName, - DatabaseName: databaseName, - + DatabaseName: acc.TestDatabaseName, Alert: &AlertSettings{ Name: alertName, - Schema: schemaName, Condition: "select 2 as c", Action: "select 2 as c", + Schema: acc.TestSchemaName, Enabled: true, Schedule: 25, Comment: "text", @@ -87,13 +82,12 @@ var ( // Changes: condition, action, comment, schedule. alertStepThree = &AccAlertTestSettings{ //nolint WarehouseName: warehouseName, - DatabaseName: databaseName, - + DatabaseName: acc.TestDatabaseName, Alert: &AlertSettings{ Name: alertName, - Schema: schemaName, Condition: "select 2 as c", Action: "select 2 as c", + Schema: acc.TestSchemaName, Enabled: false, Schedule: 5, }, @@ -111,8 +105,8 @@ func TestAcc_Alert(t *testing.T) { Check: resource.ComposeTestCheckFunc( checkBool("snowflake_alert.test_alert", "enabled", alertInitialState.Alert.Enabled), resource.TestCheckResourceAttr("snowflake_alert.test_alert", "name", alertName), - resource.TestCheckResourceAttr("snowflake_alert.test_alert", "database", databaseName), - resource.TestCheckResourceAttr("snowflake_alert.test_alert", "schema", schemaName), + resource.TestCheckResourceAttr("snowflake_alert.test_alert", "database", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_alert.test_alert", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_alert.test_alert", "condition", alertInitialState.Alert.Condition), resource.TestCheckResourceAttr("snowflake_alert.test_alert", "action", alertInitialState.Alert.Action), resource.TestCheckResourceAttr("snowflake_alert.test_alert", "comment", alertInitialState.Alert.Comment), @@ -124,8 +118,8 @@ func TestAcc_Alert(t *testing.T) { Check: resource.ComposeTestCheckFunc( checkBool("snowflake_alert.test_alert", "enabled", alertStepOne.Alert.Enabled), resource.TestCheckResourceAttr("snowflake_alert.test_alert", "name", alertName), - resource.TestCheckResourceAttr("snowflake_alert.test_alert", "database", databaseName), - resource.TestCheckResourceAttr("snowflake_alert.test_alert", "schema", schemaName), + resource.TestCheckResourceAttr("snowflake_alert.test_alert", "database", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_alert.test_alert", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_alert.test_alert", "condition", alertStepOne.Alert.Condition), resource.TestCheckResourceAttr("snowflake_alert.test_alert", "action", alertStepOne.Alert.Action), resource.TestCheckResourceAttr("snowflake_alert.test_alert", "comment", alertStepOne.Alert.Comment), @@ -137,8 +131,8 @@ func TestAcc_Alert(t *testing.T) { Check: resource.ComposeTestCheckFunc( checkBool("snowflake_alert.test_alert", "enabled", alertStepTwo.Alert.Enabled), resource.TestCheckResourceAttr("snowflake_alert.test_alert", "name", alertName), - resource.TestCheckResourceAttr("snowflake_alert.test_alert", "database", databaseName), - resource.TestCheckResourceAttr("snowflake_alert.test_alert", "schema", schemaName), + resource.TestCheckResourceAttr("snowflake_alert.test_alert", "database", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_alert.test_alert", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_alert.test_alert", "condition", alertStepTwo.Alert.Condition), resource.TestCheckResourceAttr("snowflake_alert.test_alert", "action", alertStepTwo.Alert.Action), resource.TestCheckResourceAttr("snowflake_alert.test_alert", "comment", alertStepTwo.Alert.Comment), @@ -150,8 +144,8 @@ func TestAcc_Alert(t *testing.T) { Check: resource.ComposeTestCheckFunc( checkBool("snowflake_alert.test_alert", "enabled", alertStepThree.Alert.Enabled), resource.TestCheckResourceAttr("snowflake_alert.test_alert", "name", alertName), - resource.TestCheckResourceAttr("snowflake_alert.test_alert", "database", databaseName), - resource.TestCheckResourceAttr("snowflake_alert.test_alert", "schema", schemaName), + resource.TestCheckResourceAttr("snowflake_alert.test_alert", "database", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_alert.test_alert", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_alert.test_alert", "condition", alertStepThree.Alert.Condition), resource.TestCheckResourceAttr("snowflake_alert.test_alert", "action", alertStepThree.Alert.Action), resource.TestCheckResourceAttr("snowflake_alert.test_alert", "comment", alertStepThree.Alert.Comment), @@ -163,8 +157,8 @@ func TestAcc_Alert(t *testing.T) { Check: resource.ComposeTestCheckFunc( checkBool("snowflake_alert.test_alert", "enabled", alertInitialState.Alert.Enabled), resource.TestCheckResourceAttr("snowflake_alert.test_alert", "name", alertName), - resource.TestCheckResourceAttr("snowflake_alert.test_alert", "database", databaseName), - resource.TestCheckResourceAttr("snowflake_alert.test_alert", "schema", schemaName), + resource.TestCheckResourceAttr("snowflake_alert.test_alert", "database", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_alert.test_alert", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_alert.test_alert", "condition", alertInitialState.Alert.Condition), resource.TestCheckResourceAttr("snowflake_alert.test_alert", "action", alertInitialState.Alert.Action), resource.TestCheckResourceAttr("snowflake_alert.test_alert", "comment", alertInitialState.Alert.Comment), @@ -180,12 +174,9 @@ func alertConfig(settings *AccAlertTestSettings) string { //nolint resource "snowflake_warehouse" "test_wh" { name = "{{ .WarehouseName }}" } -resource "snowflake_database" "test_db" { - name = "{{ .DatabaseName }}" -} resource "snowflake_alert" "test_alert" { name = "{{ .Alert.Name }}" - database = snowflake_database.test_db.name + database = "{{ .DatabaseName }}" schema = "{{ .Alert.Schema }}" warehouse = snowflake_warehouse.test_wh.name alert_schedule { diff --git a/pkg/resources/database_grant_acceptance_test.go b/pkg/resources/database_grant_acceptance_test.go index c9e944f8aa..1bb7e9e040 100644 --- a/pkg/resources/database_grant_acceptance_test.go +++ b/pkg/resources/database_grant_acceptance_test.go @@ -34,7 +34,6 @@ func testRolesAndShares(t *testing.T, path string, roles []string) func(*terrafo } func TestAcc_DatabaseGrant(t *testing.T) { - dbName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) roleName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) shareName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) @@ -44,9 +43,9 @@ func TestAcc_DatabaseGrant(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: databaseGrantConfig(dbName, roleName, shareName), + Config: databaseGrantConfig(roleName, shareName, acc.TestDatabaseName), Check: resource.ComposeTestCheckFunc( - resource.TestCheckResourceAttr("snowflake_database_grant.test", "database_name", dbName), + resource.TestCheckResourceAttr("snowflake_database_grant.test", "database_name", acc.TestDatabaseName), resource.TestCheckResourceAttr("snowflake_database_grant.test", "privilege", "USAGE"), resource.TestCheckResourceAttr("snowflake_database_grant.test", "roles.#", "1"), resource.TestCheckResourceAttr("snowflake_database_grant.test", "shares.#", "1"), @@ -100,11 +99,8 @@ func TestAcc_DatabaseGrant(t *testing.T) { // }) // } -func databaseGrantConfig(db, role, share string) string { +func databaseGrantConfig(role, share, databaseName string) string { return fmt.Sprintf(` -resource "snowflake_database" "test" { - name = "%v" -} resource "snowflake_role" "test" { name = "%v" } @@ -114,9 +110,9 @@ resource "snowflake_share" "test" { } resource "snowflake_database_grant" "test" { - database_name = snowflake_database.test.name + database_name = "%s" roles = [snowflake_role.test.name] shares = [snowflake_share.test.name] } -`, db, role, share) +`, role, share, databaseName) } diff --git a/pkg/resources/database_role_acceptance_test.go b/pkg/resources/database_role_acceptance_test.go index c6f8ff4112..071ef0f9b1 100644 --- a/pkg/resources/database_role_acceptance_test.go +++ b/pkg/resources/database_role_acceptance_test.go @@ -12,7 +12,6 @@ import ( var ( resourceName = "snowflake_database_role.test_db_role" - dbName = "db_" + strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) dbRoleName = "db_role_" + strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) comment = "dummy" comment2 = "test comment" @@ -25,18 +24,18 @@ func TestAcc_DatabaseRole(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: databaseRoleConfig(dbName, dbRoleName, comment), + Config: databaseRoleConfig(dbRoleName, acc.TestDatabaseName, comment), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr(resourceName, "name", dbRoleName), - resource.TestCheckResourceAttr(resourceName, "database", dbName), + resource.TestCheckResourceAttr(resourceName, "database", acc.TestDatabaseName), resource.TestCheckResourceAttr(resourceName, "comment", comment), ), }, { - Config: databaseRoleConfig(dbName, dbRoleName, comment2), + Config: databaseRoleConfig(dbRoleName, acc.TestDatabaseName, comment2), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr(resourceName, "name", dbRoleName), - resource.TestCheckResourceAttr(resourceName, "database", dbName), + resource.TestCheckResourceAttr(resourceName, "database", acc.TestDatabaseName), resource.TestCheckResourceAttr(resourceName, "comment", comment2), ), }, @@ -44,17 +43,13 @@ func TestAcc_DatabaseRole(t *testing.T) { }) } -func databaseRoleConfig(dbName string, dbRoleName string, comment string) string { +func databaseRoleConfig(dbRoleName string, databaseName string, comment string) string { s := ` -resource "snowflake_database" "test_db" { - name = "%s" -} - resource "snowflake_database_role" "test_db_role" { name = "%s" - database = snowflake_database.test_db.name + database = "%s" comment = "%s" } ` - return fmt.Sprintf(s, dbName, dbRoleName, comment) + return fmt.Sprintf(s, dbRoleName, databaseName, comment) } diff --git a/pkg/resources/email_notification_integration_acceptance_test.go b/pkg/resources/email_notification_integration_acceptance_test.go index eaef2ade03..45086b5b75 100644 --- a/pkg/resources/email_notification_integration_acceptance_test.go +++ b/pkg/resources/email_notification_integration_acceptance_test.go @@ -12,11 +12,12 @@ import ( ) func TestAcc_EmailNotificationIntegration(t *testing.T) { - emailIntegrationName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) - if _, ok := os.LookupEnv("SKIP_EMAIL_INTEGRATION_TESTS"); ok { + env := os.Getenv("SKIP_EMAIL_INTEGRATION_TESTS") + if env != "" { t.Skip("Skipping TestAcc_EmailNotificationIntegration") } + emailIntegrationName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) resource.Test(t, resource.TestCase{ Providers: acc.TestAccProviders(), PreCheck: func() { acc.TestAccPreCheck(t) }, diff --git a/pkg/resources/external_stage_acceptance_test.go b/pkg/resources/external_stage_acceptance_test.go index d6af47d5ad..64f8fbb0c8 100644 --- a/pkg/resources/external_stage_acceptance_test.go +++ b/pkg/resources/external_stage_acceptance_test.go @@ -19,11 +19,11 @@ func TestAcc_ExternalStage(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: externalStageConfig(accName), + Config: externalStageConfig(accName, acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_stage.test", "name", accName), - resource.TestCheckResourceAttr("snowflake_stage.test", "database", accName), - resource.TestCheckResourceAttr("snowflake_stage.test", "schema", accName), + resource.TestCheckResourceAttr("snowflake_stage.test", "database", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_stage.test", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_stage.test", "comment", "Terraform acceptance test"), ), }, @@ -31,25 +31,14 @@ func TestAcc_ExternalStage(t *testing.T) { }) } -func externalStageConfig(n string) string { +func externalStageConfig(n, databaseName, schemaName string) string { return fmt.Sprintf(` -resource "snowflake_database" "test" { - name = "%v" - comment = "Terraform acceptance test" -} - -resource "snowflake_schema" "test" { - name = "%v" - database = snowflake_database.test.name - comment = "Terraform acceptance test" -} - resource "snowflake_stage" "test" { name = "%v" url = "s3://com.example.bucket/prefix" - database = snowflake_database.test.name - schema = snowflake_schema.test.name + database = "%s" + schema = "%s" comment = "Terraform acceptance test" } -`, n, n, n) +`, n, databaseName, schemaName) } diff --git a/pkg/resources/external_table_acceptance_test.go b/pkg/resources/external_table_acceptance_test.go index deed1c5847..d6ea4367db 100644 --- a/pkg/resources/external_table_acceptance_test.go +++ b/pkg/resources/external_table_acceptance_test.go @@ -12,18 +12,27 @@ import ( ) func TestAcc_ExternalTable(t *testing.T) { - if _, ok := os.LookupEnv("SKIP_EXTERNAL_TABLE_TESTS"); ok { - t.Skip("Skipping TestAccExternalTable") + env := os.Getenv("SKIP_EXTERNAL_TABLE_TEST") + if env != "" { + t.Skip("Skipping TestAcc_ExternalTable") } accName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) + bucketURL := os.Getenv("AWS_EXTERNAL_BUCKET_URL") + if bucketURL == "" { + t.Skip("Skipping TestAcc_ExternalTable") + } + roleName := os.Getenv("AWS_EXTERNAL_ROLE_NAME") + if roleName == "" { + t.Skip("Skipping TestAcc_ExternalTable") + } resource.Test(t, resource.TestCase{ Providers: acc.TestAccProviders(), PreCheck: func() { acc.TestAccPreCheck(t) }, CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: externalTableConfig(accName, []string{"s3://com.example.bucket/prefix"}), + Config: externalTableConfig(accName, bucketURL, roleName, acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_external_table.test_table", "name", accName), resource.TestCheckResourceAttr("snowflake_external_table.test_table", "database", accName), @@ -35,39 +44,27 @@ func TestAcc_ExternalTable(t *testing.T) { }) } -func externalTableConfig(name string, locations []string) string { +func externalTableConfig(name string, bucketURL string, roleName string, databaseName string, schemaName string) string { s := ` -resource "snowflake_database" "test" { - name = "%v" - comment = "Terraform acceptance test" -} - -resource "snowflake_schema" "test" { +resource "snowflake_storage_integration" "i" { name = "%v" - database = snowflake_database.test.name - comment = "Terraform acceptance test" + storage_allowed_locations = ["%s"] + storage_provider = "S3" + storage_aws_role_arn = "%s" } resource "snowflake_stage" "test" { name = "%v" - url = "s3://com.example.bucket/prefix" - database = snowflake_database.test.name - schema = snowflake_schema.test.name - comment = "Terraform acceptance test" + url = "%s" + database = "%s" + schema = "%s" storage_integration = snowflake_storage_integration.i.name } -resource "snowflake_storage_integration" "i" { - name = "%v" - storage_allowed_locations = %q - storage_provider = "S3" - storage_aws_role_arn = "arn:aws:iam::000000000001:/role/test" -} - resource "snowflake_external_table" "test_table" { - database = snowflake_database.test.name - schema = snowflake_schema.test.name - name = "%v" + name = "%s" + database = "%s" + schema = "%s" comment = "Terraform acceptance test" column { name = "column1" @@ -80,8 +77,8 @@ resource "snowflake_external_table" "test_table" { as = "($1:\"CreatedDate\"::timestamp)" } file_format = "TYPE = CSV" - location = "@${snowflake_database.test.name}.${snowflake_schema.test.name}.${snowflake_stage.test.name}" + location = "@\"%s\".\"%s\".\"${snowflake_stage.test.name}\"" } ` - return fmt.Sprintf(s, name, name, name, name, locations, name) + return fmt.Sprintf(s, name, bucketURL, roleName, name, bucketURL, databaseName, schemaName, name, databaseName, schemaName, databaseName, schemaName) } diff --git a/pkg/resources/external_table_grant_acceptance_test.go b/pkg/resources/external_table_grant_acceptance_test.go index d1f1e4f376..58a7db59a9 100644 --- a/pkg/resources/external_table_grant_acceptance_test.go +++ b/pkg/resources/external_table_grant_acceptance_test.go @@ -19,10 +19,10 @@ func TestAcc_ExternalTableGrant_onAll(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: externalTableGrantConfig(name, onAll, "SELECT"), + Config: externalTableGrantConfig(name, onAll, "SELECT", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( - resource.TestCheckResourceAttr("snowflake_external_table_grant.test", "database_name", name), - resource.TestCheckResourceAttr("snowflake_external_table_grant.test", "schema_name", name), + resource.TestCheckResourceAttr("snowflake_external_table_grant.test", "database_name", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_external_table_grant.test", "schema_name", acc.TestSchemaName), resource.TestCheckNoResourceAttr("snowflake_external_table_grant.test", "external_table_name"), resource.TestCheckResourceAttr("snowflake_external_table_grant.test", "with_grant_option", "false"), resource.TestCheckResourceAttr("snowflake_external_table_grant.test", "on_all", "true"), @@ -51,10 +51,10 @@ func TestAcc_ExternalTableGrant_onFuture(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: externalTableGrantConfig(name, onFuture, "SELECT"), + Config: externalTableGrantConfig(name, onFuture, "SELECT", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( - resource.TestCheckResourceAttr("snowflake_external_table_grant.test", "database_name", name), - resource.TestCheckResourceAttr("snowflake_external_table_grant.test", "schema_name", name), + resource.TestCheckResourceAttr("snowflake_external_table_grant.test", "database_name", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_external_table_grant.test", "schema_name", acc.TestSchemaName), resource.TestCheckNoResourceAttr("snowflake_external_table_grant.test", "external_table_name"), resource.TestCheckResourceAttr("snowflake_external_table_grant.test", "with_grant_option", "false"), resource.TestCheckResourceAttr("snowflake_external_table_grant.test", "on_future", "true"), @@ -74,7 +74,7 @@ func TestAcc_ExternalTableGrant_onFuture(t *testing.T) { }) } -func externalTableGrantConfig(name string, grantType grantType, privilege string) string { +func externalTableGrantConfig(name string, grantType grantType, privilege string, databaseName string, schemaName string) string { var externalTableNameConfig string switch grantType { case onFuture: @@ -84,25 +84,16 @@ func externalTableGrantConfig(name string, grantType grantType, privilege string } return fmt.Sprintf(` -resource "snowflake_database" "test" { - name = "%s" -} - -resource "snowflake_schema" "test" { - name = "%s" - database = snowflake_database.test.name -} - resource "snowflake_role" "test" { - name = "%s" + name = "%s" } resource "snowflake_external_table_grant" "test" { - database_name = snowflake_database.test.name + database_name = "%s" roles = [snowflake_role.test.name] - schema_name = snowflake_schema.test.name + schema_name = "%s" %s privilege = "%s" } -`, name, name, name, externalTableNameConfig, privilege) +`, name, databaseName, schemaName, externalTableNameConfig, privilege) } diff --git a/pkg/resources/failover_group_acceptance_test.go b/pkg/resources/failover_group_acceptance_test.go index c841ccfa70..ee77f72b90 100644 --- a/pkg/resources/failover_group_acceptance_test.go +++ b/pkg/resources/failover_group_acceptance_test.go @@ -24,7 +24,7 @@ func TestAcc_FailoverGroupBasic(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: failoverGroupBasic(randomCharacters, accountName), + Config: failoverGroupBasic(randomCharacters, accountName, acc.TestDatabaseName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_failover_group.fg", "name", randomCharacters), resource.TestCheckResourceAttr("snowflake_failover_group.fg", "object_types.#", "4"), @@ -59,7 +59,7 @@ func TestAcc_FailoverGroupRemoveObjectTypes(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: failoverGroupWithInterval(randomCharacters, accountName, 20), + Config: failoverGroupWithInterval(randomCharacters, accountName, 20, acc.TestDatabaseName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_failover_group.fg", "name", randomCharacters), resource.TestCheckResourceAttr("snowflake_failover_group.fg", "object_types.#", "4"), @@ -97,7 +97,7 @@ func TestAcc_FailoverGroupInterval(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: failoverGroupWithInterval(randomCharacters, accountName, 10), + Config: failoverGroupWithInterval(randomCharacters, accountName, 10, acc.TestDatabaseName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_failover_group.fg", "name", randomCharacters), resource.TestCheckResourceAttr("snowflake_failover_group.fg", "object_types.#", "4"), @@ -111,7 +111,7 @@ func TestAcc_FailoverGroupInterval(t *testing.T) { }, // Update Interval { - Config: failoverGroupWithInterval(randomCharacters, accountName, 20), + Config: failoverGroupWithInterval(randomCharacters, accountName, 20, acc.TestDatabaseName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_failover_group.fg", "name", randomCharacters), resource.TestCheckResourceAttr("snowflake_failover_group.fg", "object_types.#", "4"), @@ -125,7 +125,7 @@ func TestAcc_FailoverGroupInterval(t *testing.T) { }, // Change to Cron Expression { - Config: failoverGroupWithCronExpression(randomCharacters, accountName, "0 0 10-20 * TUE,THU"), + Config: failoverGroupWithCronExpression(randomCharacters, accountName, "0 0 10-20 * TUE,THU", acc.TestDatabaseName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_failover_group.fg", "name", randomCharacters), resource.TestCheckResourceAttr("snowflake_failover_group.fg", "object_types.#", "4"), @@ -141,7 +141,7 @@ func TestAcc_FailoverGroupInterval(t *testing.T) { }, // Update Cron Expression { - Config: failoverGroupWithCronExpression(randomCharacters, accountName, "0 0 5-20 * TUE,THU"), + Config: failoverGroupWithCronExpression(randomCharacters, accountName, "0 0 5-20 * TUE,THU", acc.TestDatabaseName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_failover_group.fg", "name", randomCharacters), resource.TestCheckResourceAttr("snowflake_failover_group.fg", "object_types.#", "4"), @@ -157,7 +157,7 @@ func TestAcc_FailoverGroupInterval(t *testing.T) { }, // Change to Interval { - Config: failoverGroupWithInterval(randomCharacters, accountName, 10), + Config: failoverGroupWithInterval(randomCharacters, accountName, 10, acc.TestDatabaseName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_failover_group.fg", "name", randomCharacters), resource.TestCheckResourceAttr("snowflake_failover_group.fg", "object_types.#", "4"), @@ -180,17 +180,13 @@ func TestAcc_FailoverGroupInterval(t *testing.T) { }) } -func failoverGroupBasic(randomCharacters, accountName string) string { +func failoverGroupBasic(randomCharacters, accountName, databaseName string) string { return fmt.Sprintf(` -resource "snowflake_database" "db" { - name = "tst-terraform-%s" -} - resource "snowflake_failover_group" "fg" { name = "%s" object_types = ["WAREHOUSES","DATABASES", "INTEGRATIONS", "ROLES"] allowed_accounts= ["%s"] - allowed_databases = [snowflake_database.db.name] + allowed_databases = ["%s"] allowed_integration_types = ["SECURITY INTEGRATIONS"] replication_schedule { cron { @@ -199,34 +195,26 @@ resource "snowflake_failover_group" "fg" { } } } -`, randomCharacters, randomCharacters, accountName) +`, randomCharacters, accountName, databaseName) } -func failoverGroupWithInterval(randomCharacters, accountName string, interval int) string { +func failoverGroupWithInterval(randomCharacters, accountName string, interval int, databaseName string) string { return fmt.Sprintf(` -resource "snowflake_database" "db" { - name = "tst-terraform-%s" -} - resource "snowflake_failover_group" "fg" { name = "%s" object_types = ["WAREHOUSES","DATABASES", "INTEGRATIONS", "ROLES"] allowed_accounts= ["%s"] - allowed_databases = [snowflake_database.db.name] + allowed_databases = ["%s"] allowed_integration_types = ["SECURITY INTEGRATIONS"] replication_schedule { interval = %d } } -`, randomCharacters, randomCharacters, accountName, interval) +`, randomCharacters, accountName, databaseName, interval) } func failoverGroupWithNoWarehouse(randomCharacters, accountName string, interval int) string { return fmt.Sprintf(` -resource "snowflake_database" "db" { - name = "tst-terraform-%s" -} - resource "snowflake_failover_group" "fg" { name = "%s" object_types = ["DATABASES", "INTEGRATIONS", "ROLES"] @@ -236,20 +224,16 @@ resource "snowflake_failover_group" "fg" { interval = %d } } -`, randomCharacters, randomCharacters, accountName, interval) +`, randomCharacters, accountName, interval) } -func failoverGroupWithCronExpression(randomCharacters, accountName, expression string) string { +func failoverGroupWithCronExpression(randomCharacters, accountName, expression, databaseName string) string { return fmt.Sprintf(` -resource "snowflake_database" "db" { - name = "tst-terraform-%s" -} - resource "snowflake_failover_group" "fg" { name = "%s" object_types = ["WAREHOUSES","DATABASES", "INTEGRATIONS", "ROLES"] allowed_accounts= ["%s"] - allowed_databases = [snowflake_database.db.name] + allowed_databases = ["%s"] allowed_integration_types = ["SECURITY INTEGRATIONS"] replication_schedule { cron { @@ -258,5 +242,5 @@ resource "snowflake_failover_group" "fg" { } } } -`, randomCharacters, randomCharacters, accountName, expression) +`, randomCharacters, accountName, databaseName, expression) } diff --git a/pkg/resources/file_format_acceptance_test.go b/pkg/resources/file_format_acceptance_test.go index f076973810..0df26e29d2 100644 --- a/pkg/resources/file_format_acceptance_test.go +++ b/pkg/resources/file_format_acceptance_test.go @@ -18,11 +18,11 @@ func TestAcc_FileFormatCSV(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: fileFormatConfigCSV(accName), + Config: fileFormatConfigCSV(accName, acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_file_format.test", "name", accName), - resource.TestCheckResourceAttr("snowflake_file_format.test", "database", accName), - resource.TestCheckResourceAttr("snowflake_file_format.test", "schema", accName), + resource.TestCheckResourceAttr("snowflake_file_format.test", "database", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_file_format.test", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_file_format.test", "format_type", "CSV"), resource.TestCheckResourceAttr("snowflake_file_format.test", "compression", "GZIP"), resource.TestCheckResourceAttr("snowflake_file_format.test", "record_delimiter", "\r"), @@ -71,11 +71,11 @@ func TestAcc_FileFormatJSON(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: fileFormatConfigJSON(accName), + Config: fileFormatConfigJSON(accName, acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_file_format.test", "name", accName), - resource.TestCheckResourceAttr("snowflake_file_format.test", "database", accName), - resource.TestCheckResourceAttr("snowflake_file_format.test", "schema", accName), + resource.TestCheckResourceAttr("snowflake_file_format.test", "database", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_file_format.test", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_file_format.test", "format_type", "JSON"), resource.TestCheckResourceAttr("snowflake_file_format.test", "compression", "GZIP"), resource.TestCheckResourceAttr("snowflake_file_format.test", "date_format", "YYY-MM-DD"), @@ -108,11 +108,11 @@ func TestAcc_FileFormatAvro(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: fileFormatConfigAvro(accName), + Config: fileFormatConfigAvro(accName, acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_file_format.test", "name", accName), - resource.TestCheckResourceAttr("snowflake_file_format.test", "database", accName), - resource.TestCheckResourceAttr("snowflake_file_format.test", "schema", accName), + resource.TestCheckResourceAttr("snowflake_file_format.test", "database", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_file_format.test", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_file_format.test", "format_type", "AVRO"), resource.TestCheckResourceAttr("snowflake_file_format.test", "compression", "GZIP"), resource.TestCheckResourceAttr("snowflake_file_format.test", "trim_space", "true"), @@ -134,11 +134,11 @@ func TestAcc_FileFormatORC(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: fileFormatConfigORC(accName), + Config: fileFormatConfigORC(accName, acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_file_format.test", "name", accName), - resource.TestCheckResourceAttr("snowflake_file_format.test", "database", accName), - resource.TestCheckResourceAttr("snowflake_file_format.test", "schema", accName), + resource.TestCheckResourceAttr("snowflake_file_format.test", "database", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_file_format.test", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_file_format.test", "format_type", "ORC"), resource.TestCheckResourceAttr("snowflake_file_format.test", "trim_space", "true"), resource.TestCheckResourceAttr("snowflake_file_format.test", "null_if.#", "1"), @@ -159,11 +159,11 @@ func TestAcc_FileFormatParquet(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: fileFormatConfigParquet(accName), + Config: fileFormatConfigParquet(accName, acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_file_format.test", "name", accName), - resource.TestCheckResourceAttr("snowflake_file_format.test", "database", accName), - resource.TestCheckResourceAttr("snowflake_file_format.test", "schema", accName), + resource.TestCheckResourceAttr("snowflake_file_format.test", "database", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_file_format.test", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_file_format.test", "format_type", "PARQUET"), resource.TestCheckResourceAttr("snowflake_file_format.test", "compression", "SNAPPY"), resource.TestCheckResourceAttr("snowflake_file_format.test", "binary_as_text", "true"), @@ -186,11 +186,11 @@ func TestAcc_FileFormatXML(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: fileFormatConfigXML(accName), + Config: fileFormatConfigXML(accName, acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_file_format.test", "name", accName), - resource.TestCheckResourceAttr("snowflake_file_format.test", "database", accName), - resource.TestCheckResourceAttr("snowflake_file_format.test", "schema", accName), + resource.TestCheckResourceAttr("snowflake_file_format.test", "database", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_file_format.test", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_file_format.test", "format_type", "XML"), resource.TestCheckResourceAttr("snowflake_file_format.test", "compression", "GZIP"), resource.TestCheckResourceAttr("snowflake_file_format.test", "preserve_space", "true"), @@ -216,11 +216,11 @@ func TestAcc_FileFormatCSVDefaults(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: fileFormatConfigFullDefaults(accName, "CSV"), + Config: fileFormatConfigFullDefaults(accName, "CSV", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_file_format.test", "name", accName), - resource.TestCheckResourceAttr("snowflake_file_format.test", "database", accName), - resource.TestCheckResourceAttr("snowflake_file_format.test", "schema", accName), + resource.TestCheckResourceAttr("snowflake_file_format.test", "database", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_file_format.test", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_file_format.test", "format_type", "CSV"), ), }, @@ -242,11 +242,11 @@ func TestAcc_FileFormatJSONDefaults(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: fileFormatConfigFullDefaults(accName, "JSON"), + Config: fileFormatConfigFullDefaults(accName, "JSON", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_file_format.test", "name", accName), - resource.TestCheckResourceAttr("snowflake_file_format.test", "database", accName), - resource.TestCheckResourceAttr("snowflake_file_format.test", "schema", accName), + resource.TestCheckResourceAttr("snowflake_file_format.test", "database", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_file_format.test", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_file_format.test", "format_type", "JSON"), ), }, @@ -268,11 +268,11 @@ func TestAcc_FileFormatAVRODefaults(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: fileFormatConfigFullDefaults(accName, "AVRO"), + Config: fileFormatConfigFullDefaults(accName, "AVRO", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_file_format.test", "name", accName), - resource.TestCheckResourceAttr("snowflake_file_format.test", "database", accName), - resource.TestCheckResourceAttr("snowflake_file_format.test", "schema", accName), + resource.TestCheckResourceAttr("snowflake_file_format.test", "database", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_file_format.test", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_file_format.test", "format_type", "AVRO"), ), }, @@ -294,11 +294,11 @@ func TestAcc_FileFormatORCDefaults(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: fileFormatConfigFullDefaults(accName, "ORC"), + Config: fileFormatConfigFullDefaults(accName, "ORC", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_file_format.test", "name", accName), - resource.TestCheckResourceAttr("snowflake_file_format.test", "database", accName), - resource.TestCheckResourceAttr("snowflake_file_format.test", "schema", accName), + resource.TestCheckResourceAttr("snowflake_file_format.test", "database", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_file_format.test", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_file_format.test", "format_type", "ORC"), ), }, @@ -320,11 +320,11 @@ func TestAcc_FileFormatPARQUETDefaults(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: fileFormatConfigFullDefaults(accName, "PARQUET"), + Config: fileFormatConfigFullDefaults(accName, "PARQUET", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_file_format.test", "name", accName), - resource.TestCheckResourceAttr("snowflake_file_format.test", "database", accName), - resource.TestCheckResourceAttr("snowflake_file_format.test", "schema", accName), + resource.TestCheckResourceAttr("snowflake_file_format.test", "database", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_file_format.test", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_file_format.test", "format_type", "PARQUET"), ), }, @@ -346,11 +346,11 @@ func TestAcc_FileFormatXMLDefaults(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: fileFormatConfigFullDefaults(accName, "XML"), + Config: fileFormatConfigFullDefaults(accName, "XML", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_file_format.test", "name", accName), - resource.TestCheckResourceAttr("snowflake_file_format.test", "database", accName), - resource.TestCheckResourceAttr("snowflake_file_format.test", "schema", accName), + resource.TestCheckResourceAttr("snowflake_file_format.test", "database", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_file_format.test", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_file_format.test", "format_type", "XML"), ), }, @@ -363,23 +363,12 @@ func TestAcc_FileFormatXMLDefaults(t *testing.T) { }) } -func fileFormatConfigCSV(n string) string { +func fileFormatConfigCSV(n string, databaseName string, schemaName string) string { return fmt.Sprintf(` -resource "snowflake_database" "test" { - name = "%v" - comment = "Terraform acceptance test" -} - -resource "snowflake_schema" "test" { - name = "%v" - database = snowflake_database.test.name - comment = "Terraform acceptance test" -} - resource "snowflake_file_format" "test" { name = "%v" - database = snowflake_database.test.name - schema = snowflake_schema.test.name + database = "%s" + schema = "%s" format_type = "CSV" compression = "GZIP" record_delimiter = "\r" @@ -403,26 +392,15 @@ resource "snowflake_file_format" "test" { encoding = "UTF-16" comment = "Terraform acceptance test" } -`, n, n, n) +`, n, databaseName, schemaName) } -func fileFormatConfigJSON(n string) string { +func fileFormatConfigJSON(n string, databaseName string, schemaName string) string { return fmt.Sprintf(` -resource "snowflake_database" "test" { - name = "%v" - comment = "Terraform acceptance test" -} - -resource "snowflake_schema" "test" { - name = "%v" - database = snowflake_database.test.name - comment = "Terraform acceptance test" -} - resource "snowflake_file_format" "test" { name = "%v" - database = snowflake_database.test.name - schema = snowflake_schema.test.name + database = "%s" + schema = "%s" format_type = "JSON" compression = "GZIP" date_format = "YYY-MM-DD" @@ -440,77 +418,44 @@ resource "snowflake_file_format" "test" { skip_byte_order_mark = false comment = "Terraform acceptance test" } -`, n, n, n) +`, n, databaseName, schemaName) } -func fileFormatConfigAvro(n string) string { +func fileFormatConfigAvro(n string, databaseName string, schemaName string) string { return fmt.Sprintf(` -resource "snowflake_database" "test" { - name = "%v" - comment = "Terraform acceptance test" -} - -resource "snowflake_schema" "test" { - name = "%v" - database = snowflake_database.test.name - comment = "Terraform acceptance test" -} - resource "snowflake_file_format" "test" { name = "%v" - database = snowflake_database.test.name - schema = snowflake_schema.test.name + database = "%s" + schema = "%s" format_type = "AVRO" compression = "GZIP" trim_space = true null_if = ["NULL"] comment = "Terraform acceptance test" } -`, n, n, n) +`, n, databaseName, schemaName) } -func fileFormatConfigORC(n string) string { +func fileFormatConfigORC(n string, databaseName string, schemaName string) string { return fmt.Sprintf(` -resource "snowflake_database" "test" { - name = "%v" - comment = "Terraform acceptance test" -} - -resource "snowflake_schema" "test" { - name = "%v" - database = snowflake_database.test.name - comment = "Terraform acceptance test" -} - resource "snowflake_file_format" "test" { name = "%v" - database = snowflake_database.test.name - schema = snowflake_schema.test.name + database = "%s" + schema = "%s" format_type = "ORC" trim_space = true null_if = ["NULL"] comment = "Terraform acceptance test" } -`, n, n, n) +`, n, databaseName, schemaName) } -func fileFormatConfigParquet(n string) string { +func fileFormatConfigParquet(n string, databaseName string, schemaName string) string { return fmt.Sprintf(` -resource "snowflake_database" "test" { - name = "%v" - comment = "Terraform acceptance test" -} - -resource "snowflake_schema" "test" { - name = "%v" - database = snowflake_database.test.name - comment = "Terraform acceptance test" -} - resource "snowflake_file_format" "test" { name = "%v" - database = snowflake_database.test.name - schema = snowflake_schema.test.name + database = "%s" + schema = "%s" format_type = "PARQUET" compression = "SNAPPY" binary_as_text = true @@ -518,26 +463,15 @@ resource "snowflake_file_format" "test" { null_if = ["NULL"] comment = "Terraform acceptance test" } -`, n, n, n) +`, n, databaseName, schemaName) } -func fileFormatConfigXML(n string) string { +func fileFormatConfigXML(n string, databaseName string, schemaName string) string { return fmt.Sprintf(` -resource "snowflake_database" "test" { - name = "%v" - comment = "Terraform acceptance test" -} - -resource "snowflake_schema" "test" { - name = "%v" - database = snowflake_database.test.name - comment = "Terraform acceptance test" -} - resource "snowflake_file_format" "test" { name = "%v" - database = snowflake_database.test.name - schema = snowflake_schema.test.name + database = "%s" + schema = "%s" format_type = "XML" compression = "GZIP" ignore_utf8_errors = true @@ -548,27 +482,16 @@ resource "snowflake_file_format" "test" { skip_byte_order_mark = false comment = "Terraform acceptance test" } -`, n, n, n) +`, n, databaseName, schemaName) } -func fileFormatConfigFullDefaults(n, formatType string) string { +func fileFormatConfigFullDefaults(n string, formatType string, databaseName string, schemaName string) string { return fmt.Sprintf(` -resource "snowflake_database" "test" { - name = "%v" - comment = "Terraform acceptance test" -} - -resource "snowflake_schema" "test" { - name = "%v" - database = snowflake_database.test.name - comment = "Terraform acceptance test" -} - resource "snowflake_file_format" "test" { name = "%v" - database = snowflake_database.test.name - schema = snowflake_schema.test.name + database = "%s" + schema = "%s" format_type = "%s" } -`, n, n, n, formatType) +`, n, databaseName, schemaName, formatType) } diff --git a/pkg/resources/file_format_grant_acceptance_test.go b/pkg/resources/file_format_grant_acceptance_test.go index d536ac2233..ff5c37c287 100644 --- a/pkg/resources/file_format_grant_acceptance_test.go +++ b/pkg/resources/file_format_grant_acceptance_test.go @@ -13,16 +13,16 @@ import ( func TestAcc_FileFormatGrant_defaults(t *testing.T) { name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) - resource.ParallelTest(t, resource.TestCase{ + resource.Test(t, resource.TestCase{ Providers: acc.TestAccProviders(), PreCheck: func() { acc.TestAccPreCheck(t) }, CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: fileFormatGrantConfig(name, normal, "USAGE"), + Config: fileFormatGrantConfig(name, normal, "USAGE", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( - resource.TestCheckResourceAttr("snowflake_file_format_grant.test", "database_name", name), - resource.TestCheckResourceAttr("snowflake_file_format_grant.test", "schema_name", name), + resource.TestCheckResourceAttr("snowflake_file_format_grant.test", "database_name", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_file_format_grant.test", "schema_name", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_file_format_grant.test", "file_format_name", name), resource.TestCheckResourceAttr("snowflake_file_format_grant.test", "with_grant_option", "false"), resource.TestCheckResourceAttr("snowflake_file_format_grant.test", "privilege", "USAGE"), @@ -30,10 +30,10 @@ func TestAcc_FileFormatGrant_defaults(t *testing.T) { }, // UPDATE ALL PRIVILEGES { - Config: fileFormatGrantConfig(name, normal, "ALL PRIVILEGES"), + Config: fileFormatGrantConfig(name, normal, "ALL PRIVILEGES", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( - resource.TestCheckResourceAttr("snowflake_file_format_grant.test", "database_name", name), - resource.TestCheckResourceAttr("snowflake_file_format_grant.test", "schema_name", name), + resource.TestCheckResourceAttr("snowflake_file_format_grant.test", "database_name", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_file_format_grant.test", "schema_name", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_file_format_grant.test", "file_format_name", name), resource.TestCheckResourceAttr("snowflake_file_format_grant.test", "with_grant_option", "false"), resource.TestCheckResourceAttr("snowflake_file_format_grant.test", "privilege", "ALL PRIVILEGES"), @@ -61,10 +61,10 @@ func TestAcc_FileFormatGrant_onAll(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: fileFormatGrantConfig(name, onAll, "USAGE"), + Config: fileFormatGrantConfig(name, onAll, "USAGE", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( - resource.TestCheckResourceAttr("snowflake_file_format_grant.test", "database_name", name), - resource.TestCheckResourceAttr("snowflake_file_format_grant.test", "schema_name", name), + resource.TestCheckResourceAttr("snowflake_file_format_grant.test", "database_name", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_file_format_grant.test", "schema_name", acc.TestSchemaName), resource.TestCheckNoResourceAttr("snowflake_file_format_grant.test", "file_format_name"), resource.TestCheckResourceAttr("snowflake_file_format_grant.test", "with_grant_option", "false"), resource.TestCheckResourceAttr("snowflake_file_format_grant.test", "on_all", "true"), @@ -93,10 +93,10 @@ func TestAcc_FileFormatGrant_onFuture(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: fileFormatGrantConfig(name, onFuture, "USAGE"), + Config: fileFormatGrantConfig(name, onFuture, "USAGE", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( - resource.TestCheckResourceAttr("snowflake_file_format_grant.test", "database_name", name), - resource.TestCheckResourceAttr("snowflake_file_format_grant.test", "schema_name", name), + resource.TestCheckResourceAttr("snowflake_file_format_grant.test", "database_name", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_file_format_grant.test", "schema_name", acc.TestSchemaName), resource.TestCheckNoResourceAttr("snowflake_file_format_grant.test", "file_format_name"), resource.TestCheckResourceAttr("snowflake_file_format_grant.test", "with_grant_option", "false"), resource.TestCheckResourceAttr("snowflake_file_format_grant.test", "on_future", "true"), @@ -116,7 +116,7 @@ func TestAcc_FileFormatGrant_onFuture(t *testing.T) { }) } -func fileFormatGrantConfig(name string, grantType grantType, privilege string) string { +func fileFormatGrantConfig(name string, grantType grantType, privilege string, databaseName string, schemaName string) string { var fileFormatNameConfig string switch grantType { case normal: @@ -128,24 +128,14 @@ func fileFormatGrantConfig(name string, grantType grantType, privilege string) s } return fmt.Sprintf(` - -resource snowflake_database test { - name = "%s" -} - -resource snowflake_schema test { - name = "%s" - database = snowflake_database.test.name -} - resource snowflake_role test { name = "%s" } resource snowflake_file_format test { name = "%s" - database = snowflake_database.test.name - schema = snowflake_schema.test.name + database = "%s" + schema = "%s" format_type = "PARQUET" compression = "AUTO" @@ -153,13 +143,12 @@ resource snowflake_file_format test { resource snowflake_file_format_grant test { %s - database_name = snowflake_database.test.name - schema_name = snowflake_schema.test.name + database_name = "%s" + schema_name = "%s" privilege = "%s" roles = [ snowflake_role.test.name ] } - -`, name, name, name, name, fileFormatNameConfig, privilege) +`, name, name, databaseName, schemaName, fileFormatNameConfig, databaseName, schemaName, privilege) } diff --git a/pkg/resources/function_acceptance_test.go b/pkg/resources/function_acceptance_test.go index f2c91b3504..cb09a086c1 100644 --- a/pkg/resources/function_acceptance_test.go +++ b/pkg/resources/function_acceptance_test.go @@ -16,8 +16,6 @@ func TestAcc_Function(t *testing.T) { t.Skip("Skipping TestAcc_Function") } - dbName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) - schemaName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) functName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) expBody1 := "3.141592654::FLOAT" @@ -31,7 +29,7 @@ func TestAcc_Function(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: functionConfig(dbName, schemaName, functName), + Config: functionConfig(functName, acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_function.test_funct", "name", functName), resource.TestCheckResourceAttr("snowflake_function.test_funct", "comment", "Terraform acceptance test"), @@ -72,31 +70,20 @@ func TestAcc_Function(t *testing.T) { }) } -func functionConfig(db, schema, name string) string { +func functionConfig(name string, databaseName string, schemaName string) string { return fmt.Sprintf(` - resource "snowflake_database" "test_database" { - name = "%s" - comment = "Terraform acceptance test" - } - - resource "snowflake_schema" "test_schema" { - name = "%s" - database = snowflake_database.test_database.name - comment = "Terraform acceptance test" - } - resource "snowflake_function" "test_funct_simple" { name = "%s" - database = snowflake_database.test_database.name - schema = snowflake_schema.test_schema.name + database = "%s" + schema = "%s" return_type = "float" statement = "3.141592654::FLOAT" } resource "snowflake_function" "test_funct" { name = "%s" - database = snowflake_database.test_database.name - schema = snowflake_schema.test_schema.name + database = "%s" + schema = "%s" arguments { name = "arg1" type = "varchar" @@ -109,8 +96,8 @@ func functionConfig(db, schema, name string) string { resource "snowflake_function" "test_funct_java" { name = "%s" - database = snowflake_database.test_database.name - schema = snowflake_schema.test_schema.name + database = "%s" + schema = "%s" arguments { name = "arg1" type = "number" @@ -124,8 +111,8 @@ func functionConfig(db, schema, name string) string { resource "snowflake_function" "test_funct_complex" { name = "%s" - database = snowflake_database.test_database.name - schema = snowflake_schema.test_schema.name + database = "%s" + schema = "%s" arguments { name = "arg1" type = "varchar" @@ -142,5 +129,5 @@ union all select 3, 4 EOT } - `, db, schema, name, name, name, name) + `, name, databaseName, schemaName, name, databaseName, schemaName, name, databaseName, schemaName, name, databaseName, schemaName) } diff --git a/pkg/resources/function_grant_acceptance_test.go b/pkg/resources/function_grant_acceptance_test.go index 56229753f2..0102b99cca 100644 --- a/pkg/resources/function_grant_acceptance_test.go +++ b/pkg/resources/function_grant_acceptance_test.go @@ -10,7 +10,7 @@ import ( "github.com/hashicorp/terraform-plugin-testing/helper/resource" ) -func TestAccFunctionGrant_onFuture(t *testing.T) { +func TestAcc_FunctionGrant_onFuture(t *testing.T) { name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) resource.ParallelTest(t, resource.TestCase{ @@ -19,10 +19,10 @@ func TestAccFunctionGrant_onFuture(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: functionGrantConfig(name, onFuture, "USAGE"), + Config: functionGrantConfig(name, onFuture, "USAGE", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( - resource.TestCheckResourceAttr("snowflake_function_grant.test", "database_name", name), - resource.TestCheckResourceAttr("snowflake_function_grant.test", "schema_name", name), + resource.TestCheckResourceAttr("snowflake_function_grant.test", "database_name", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_function_grant.test", "schema_name", acc.TestSchemaName), resource.TestCheckNoResourceAttr("snowflake_function_grant.test", "function_name"), resource.TestCheckResourceAttr("snowflake_function_grant.test", "with_grant_option", "false"), resource.TestCheckResourceAttr("snowflake_function_grant.test", "on_future", "true"), @@ -42,7 +42,7 @@ func TestAccFunctionGrant_onFuture(t *testing.T) { }) } -func TestAccFunctionGrant_onAll(t *testing.T) { +func TestAcc_FunctionGrant_onAll(t *testing.T) { name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) resource.ParallelTest(t, resource.TestCase{ @@ -51,10 +51,10 @@ func TestAccFunctionGrant_onAll(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: functionGrantConfig(name, onAll, "USAGE"), + Config: functionGrantConfig(name, onAll, "USAGE", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( - resource.TestCheckResourceAttr("snowflake_function_grant.test", "database_name", name), - resource.TestCheckResourceAttr("snowflake_function_grant.test", "schema_name", name), + resource.TestCheckResourceAttr("snowflake_function_grant.test", "database_name", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_function_grant.test", "schema_name", acc.TestSchemaName), resource.TestCheckNoResourceAttr("snowflake_function_grant.test", "function_name"), resource.TestCheckResourceAttr("snowflake_function_grant.test", "with_grant_option", "false"), resource.TestCheckResourceAttr("snowflake_function_grant.test", "on_all", "true"), @@ -63,10 +63,10 @@ func TestAccFunctionGrant_onAll(t *testing.T) { }, // UPDATE ALL PRIVILEGES { - Config: functionGrantConfig(name, onAll, "ALL PRIVILEGES"), + Config: functionGrantConfig(name, onAll, "ALL PRIVILEGES", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( - resource.TestCheckResourceAttr("snowflake_function_grant.test", "database_name", name), - resource.TestCheckResourceAttr("snowflake_function_grant.test", "schema_name", name), + resource.TestCheckResourceAttr("snowflake_function_grant.test", "database_name", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_function_grant.test", "schema_name", acc.TestSchemaName), resource.TestCheckNoResourceAttr("snowflake_function_grant.test", "function_name"), resource.TestCheckResourceAttr("snowflake_function_grant.test", "with_grant_option", "false"), resource.TestCheckResourceAttr("snowflake_function_grant.test", "on_all", "true"), @@ -86,7 +86,7 @@ func TestAccFunctionGrant_onAll(t *testing.T) { }) } -func functionGrantConfig(name string, grantType grantType, privilege string) string { +func functionGrantConfig(name string, grantType grantType, privilege, databaseName, schemaName string) string { var functionNameConfig string switch grantType { case onFuture: @@ -96,25 +96,16 @@ func functionGrantConfig(name string, grantType grantType, privilege string) str } return fmt.Sprintf(` -resource snowflake_database test { - name = "%s" -} - -resource snowflake_schema test { - name = "%s" - database = snowflake_database.test.name -} - resource snowflake_role test { name = "%s" } resource "snowflake_function_grant" "test" { - database_name = snowflake_database.test.name + database_name = "%s" roles = [snowflake_role.test.name] - schema_name = snowflake_schema.test.name + schema_name = "%s" %s privilege = "%s" } -`, name, name, name, functionNameConfig, privilege) +`, name, databaseName, schemaName, functionNameConfig, privilege) } diff --git a/pkg/resources/grant_privileges_to_role_acceptance_test.go b/pkg/resources/grant_privileges_to_role_acceptance_test.go index 535195cd23..52e5cd0512 100644 --- a/pkg/resources/grant_privileges_to_role_acceptance_test.go +++ b/pkg/resources/grant_privileges_to_role_acceptance_test.go @@ -10,7 +10,7 @@ import ( "github.com/hashicorp/terraform-plugin-testing/helper/resource" ) -func TestAccGrantPrivilegesToRole_onAccount(t *testing.T) { +func TestAcc_GrantPrivilegesToRole_onAccount(t *testing.T) { name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) resource.ParallelTest(t, resource.TestCase{ @@ -49,7 +49,7 @@ func TestAccGrantPrivilegesToRole_onAccount(t *testing.T) { } /* - func TestAccGrantPrivilegesToRole_onAccountAllPrivileges(t *testing.T) { + func TestAcc_GrantPrivilegesToRole_onAccountAllPrivileges(t *testing.T) { name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) resource.ParallelTest(t, resource.TestCase{ @@ -108,7 +108,7 @@ func grantPrivilegesToRole_onAccountConfigAllPrivileges(name string) string { `, name) } -func TestAccGrantPrivilegesToRole_onAccountObject(t *testing.T) { +func TestAcc_GrantPrivilegesToRole_onAccountObject(t *testing.T) { name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) resource.ParallelTest(t, resource.TestCase{ @@ -117,7 +117,7 @@ func TestAccGrantPrivilegesToRole_onAccountObject(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: grantPrivilegesToRole_onAccountObjectConfig(name, []string{"CREATE DATABASE ROLE"}), + Config: grantPrivilegesToRole_onAccountObjectConfig(name, []string{"CREATE DATABASE ROLE"}, acc.TestDatabaseName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "role_name", name), resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "on_account_object.#", "1"), @@ -129,7 +129,7 @@ func TestAccGrantPrivilegesToRole_onAccountObject(t *testing.T) { }, // ADD PRIVILEGE { - Config: grantPrivilegesToRole_onAccountObjectConfig(name, []string{"MONITOR", "CREATE SCHEMA"}), + Config: grantPrivilegesToRole_onAccountObjectConfig(name, []string{"MONITOR", "CREATE SCHEMA"}, acc.TestDatabaseName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "role_name", name), resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "privileges.#", "2"), @@ -147,7 +147,7 @@ func TestAccGrantPrivilegesToRole_onAccountObject(t *testing.T) { }) } -func TestAccGrantPrivilegesToRole_onAccountObjectAllPrivileges(t *testing.T) { +func TestAcc_GrantPrivilegesToRole_onAccountObjectAllPrivileges(t *testing.T) { name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) resource.ParallelTest(t, resource.TestCase{ @@ -156,7 +156,7 @@ func TestAccGrantPrivilegesToRole_onAccountObjectAllPrivileges(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: grantPrivilegesToRole_onAccountObjectConfigAllPrivileges(name), + Config: grantPrivilegesToRole_onAccountObjectConfigAllPrivileges(name, acc.TestDatabaseName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "role_name", name), resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "on_account_object.#", "1"), @@ -175,7 +175,7 @@ func TestAccGrantPrivilegesToRole_onAccountObjectAllPrivileges(t *testing.T) { }) } -func grantPrivilegesToRole_onAccountObjectConfig(name string, privileges []string) string { +func grantPrivilegesToRole_onAccountObjectConfig(name string, privileges []string, databaseName string) string { doubleQuotePrivileges := make([]string, len(privileges)) for i, p := range privileges { doubleQuotePrivileges[i] = fmt.Sprintf(`"%v"`, p) @@ -191,13 +191,13 @@ func grantPrivilegesToRole_onAccountObjectConfig(name string, privileges []strin role_name = snowflake_role.r.name on_account_object { object_type = "DATABASE" - object_name = "terraform_test_database" + object_name = "%s" } } - `, name, privilegesString) + `, name, privilegesString, databaseName) } -func grantPrivilegesToRole_onAccountObjectConfigAllPrivileges(name string) string { +func grantPrivilegesToRole_onAccountObjectConfigAllPrivileges(name string, databaseName string) string { return fmt.Sprintf(` resource "snowflake_role" "r" { name = "%v" @@ -208,13 +208,13 @@ func grantPrivilegesToRole_onAccountObjectConfigAllPrivileges(name string) strin role_name = snowflake_role.r.name on_account_object { object_type = "DATABASE" - object_name = "terraform_test_database" + object_name = "%s" } } - `, name) + `, name, databaseName) } -func TestAccGrantPrivilegesToRole_onSchema(t *testing.T) { +func TestAcc_GrantPrivilegesToRole_onSchema(t *testing.T) { name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) resource.ParallelTest(t, resource.TestCase{ @@ -223,7 +223,7 @@ func TestAccGrantPrivilegesToRole_onSchema(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: grantPrivilegesToRole_onSchemaConfig(name, []string{"MONITOR", "USAGE"}), + Config: grantPrivilegesToRole_onSchemaConfig(name, []string{"MONITOR", "USAGE"}, acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "role_name", name), resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "on_schema.#", "1"), @@ -235,7 +235,7 @@ func TestAccGrantPrivilegesToRole_onSchema(t *testing.T) { }, // ADD PRIVILEGE { - Config: grantPrivilegesToRole_onSchemaConfig(name, []string{"MONITOR"}), + Config: grantPrivilegesToRole_onSchemaConfig(name, []string{"MONITOR"}, acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "role_name", name), resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "privileges.#", "1"), @@ -252,7 +252,7 @@ func TestAccGrantPrivilegesToRole_onSchema(t *testing.T) { }) } -func grantPrivilegesToRole_onSchemaConfig(name string, privileges []string) string { +func grantPrivilegesToRole_onSchemaConfig(name string, privileges []string, databaseName string, schemaName string) string { doubleQuotePrivileges := make([]string, len(privileges)) for i, p := range privileges { doubleQuotePrivileges[i] = fmt.Sprintf(`"%v"`, p) @@ -267,13 +267,13 @@ func grantPrivilegesToRole_onSchemaConfig(name string, privileges []string) stri role_name = snowflake_role.r.name privileges = [%s] on_schema { - schema_name = "\"terraform_test_database\".\"terraform_test_schema\"" + schema_name = "\"%s\".\"%s\"" } } - `, name, privilegesString) + `, name, privilegesString, databaseName, schemaName) } -func grantPrivilegesToRole_onSchemaConfigAllPrivileges(name string) string { +func grantPrivilegesToRole_onSchemaConfigAllPrivileges(name string, databaseName string, schemaName string) string { return fmt.Sprintf(` resource "snowflake_role" "r" { name = "%v" @@ -283,13 +283,13 @@ func grantPrivilegesToRole_onSchemaConfigAllPrivileges(name string) string { role_name = snowflake_role.r.name all_privileges = true on_schema { - schema_name = "\"terraform_test_database\".\"terraform_test_schema\"" + schema_name = "\"%s\".\"%s\"" } } - `, name) + `, name, databaseName, schemaName) } -func TestAccGrantPrivilegesToRole_onSchemaConfigAllPrivileges(t *testing.T) { +func TestAcc_GrantPrivilegesToRole_onSchemaConfigAllPrivileges(t *testing.T) { name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) resource.ParallelTest(t, resource.TestCase{ @@ -298,7 +298,7 @@ func TestAccGrantPrivilegesToRole_onSchemaConfigAllPrivileges(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: grantPrivilegesToRole_onSchemaConfigAllPrivileges(name), + Config: grantPrivilegesToRole_onSchemaConfigAllPrivileges(name, acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "role_name", name), resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "on_schema.#", "1"), @@ -316,7 +316,7 @@ func TestAccGrantPrivilegesToRole_onSchemaConfigAllPrivileges(t *testing.T) { }) } -func TestAccGrantPrivilegesToRole_onSchema_allSchemasInDatabase(t *testing.T) { +func TestAcc_GrantPrivilegesToRole_onSchema_allSchemasInDatabase(t *testing.T) { name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) resource.ParallelTest(t, resource.TestCase{ @@ -325,7 +325,7 @@ func TestAccGrantPrivilegesToRole_onSchema_allSchemasInDatabase(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: grantPrivilegesToRole_onSchema_allSchemasInDatabaseConfig(name, []string{"MONITOR", "USAGE"}), + Config: grantPrivilegesToRole_onSchema_allSchemasInDatabaseConfig(name, []string{"MONITOR", "USAGE"}, acc.TestDatabaseName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "role_name", name), resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "on_schema.#", "1"), @@ -337,7 +337,7 @@ func TestAccGrantPrivilegesToRole_onSchema_allSchemasInDatabase(t *testing.T) { }, // REMOVE PRIVILEGE { - Config: grantPrivilegesToRole_onSchema_allSchemasInDatabaseConfig(name, []string{"MONITOR"}), + Config: grantPrivilegesToRole_onSchema_allSchemasInDatabaseConfig(name, []string{"MONITOR"}, acc.TestDatabaseName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "role_name", name), resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "privileges.#", "1"), @@ -354,7 +354,7 @@ func TestAccGrantPrivilegesToRole_onSchema_allSchemasInDatabase(t *testing.T) { }) } -func TestAccGrantPrivilegesToRole_onSchema_futureSchemasInDatabase(t *testing.T) { +func TestAcc_GrantPrivilegesToRole_onSchema_futureSchemasInDatabase(t *testing.T) { name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) resource.ParallelTest(t, resource.TestCase{ @@ -363,7 +363,7 @@ func TestAccGrantPrivilegesToRole_onSchema_futureSchemasInDatabase(t *testing.T) CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: grantPrivilegesToRole_onSchema_futureSchemasInDatabaseConfig(name, []string{"MONITOR", "USAGE"}), + Config: grantPrivilegesToRole_onSchema_futureSchemasInDatabaseConfig(name, []string{"MONITOR", "USAGE"}, acc.TestDatabaseName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "role_name", name), resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "on_schema.#", "1"), @@ -383,7 +383,7 @@ func TestAccGrantPrivilegesToRole_onSchema_futureSchemasInDatabase(t *testing.T) }) } -func grantPrivilegesToRole_onSchema_allSchemasInDatabaseConfig(name string, privileges []string) string { +func grantPrivilegesToRole_onSchema_allSchemasInDatabaseConfig(name string, privileges []string, databaseName string) string { doubleQuotePrivileges := make([]string, len(privileges)) for i, p := range privileges { doubleQuotePrivileges[i] = fmt.Sprintf(`"%v"`, p) @@ -398,14 +398,14 @@ func grantPrivilegesToRole_onSchema_allSchemasInDatabaseConfig(name string, priv role_name = snowflake_role.r.name privileges = [%s] on_schema { - all_schemas_in_database = "terraform_test_database" + all_schemas_in_database = "%s" } } - `, name, privilegesString) + `, name, privilegesString, databaseName) } -func grantPrivilegesToRole_onSchema_futureSchemasInDatabaseConfig(name string, privileges []string) string { +func grantPrivilegesToRole_onSchema_futureSchemasInDatabaseConfig(name string, privileges []string, databaseName string) string { doubleQuotePrivileges := make([]string, len(privileges)) for i, p := range privileges { doubleQuotePrivileges[i] = fmt.Sprintf(`"%v"`, p) @@ -420,14 +420,14 @@ func grantPrivilegesToRole_onSchema_futureSchemasInDatabaseConfig(name string, p role_name = snowflake_role.r.name privileges = [%s] on_schema { - future_schemas_in_database = "terraform_test_database" + future_schemas_in_database = "%s" } } - `, name, privilegesString) + `, name, privilegesString, databaseName) } -func TestAccGrantPrivilegesToRole_onSchemaObject_objectType(t *testing.T) { +func TestAcc_GrantPrivilegesToRole_onSchemaObject_objectType(t *testing.T) { name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) resource.ParallelTest(t, resource.TestCase{ @@ -436,7 +436,7 @@ func TestAccGrantPrivilegesToRole_onSchemaObject_objectType(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: grantPrivilegesToRole_onSchemaObject_objectType(name, []string{"SELECT", "REFERENCES"}), + Config: grantPrivilegesToRole_onSchemaObject_objectType(name, []string{"SELECT", "REFERENCES"}, acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "role_name", name), resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "on_schema_object.#", "1"), @@ -449,7 +449,7 @@ func TestAccGrantPrivilegesToRole_onSchemaObject_objectType(t *testing.T) { }, // REMOVE PRIVILEGE { - Config: grantPrivilegesToRole_onSchemaObject_objectType(name, []string{"SELECT"}), + Config: grantPrivilegesToRole_onSchemaObject_objectType(name, []string{"SELECT"}, acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "role_name", name), resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "privileges.#", "1"), @@ -466,7 +466,7 @@ func TestAccGrantPrivilegesToRole_onSchemaObject_objectType(t *testing.T) { }) } -func grantPrivilegesToRole_onSchemaObject_objectType(name string, privileges []string) string { +func grantPrivilegesToRole_onSchemaObject_objectType(name string, privileges []string, databaseName string, schemaName string) string { doubleQuotePrivileges := make([]string, len(privileges)) for i, p := range privileges { doubleQuotePrivileges[i] = fmt.Sprintf(`"%v"`, p) @@ -479,8 +479,8 @@ func grantPrivilegesToRole_onSchemaObject_objectType(name string, privileges []s resource "snowflake_view" "v" { name = "%v" - database = "terraform_test_database" - schema = "terraform_test_schema" + database = "%s" + schema = "%s" is_secure = true statement = "SELECT ROLE_NAME, ROLE_OWNER FROM INFORMATION_SCHEMA.APPLICABLE_ROLES" } @@ -491,13 +491,13 @@ func grantPrivilegesToRole_onSchemaObject_objectType(name string, privileges []s privileges = [%s] on_schema_object { object_type = "VIEW" - object_name = "\"terraform_test_database\".\"terraform_test_schema\".\"%s\"" + object_name = "\"%s\".\"%s\".\"%s\"" } } - `, name, name, privilegesString, name) + `, name, name, databaseName, schemaName, privilegesString, databaseName, schemaName, name) } -func TestAccGrantPrivilegesToRole_onSchemaObject_allInSchema(t *testing.T) { +func TestAcc_GrantPrivilegesToRole_onSchemaObject_allInSchema(t *testing.T) { name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) resource.ParallelTest(t, resource.TestCase{ @@ -506,7 +506,7 @@ func TestAccGrantPrivilegesToRole_onSchemaObject_allInSchema(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: grantPrivilegesToRole_onSchemaObject_allInSchema(name, []string{"SELECT", "REFERENCES"}), + Config: grantPrivilegesToRole_onSchemaObject_allInSchema(name, []string{"SELECT", "REFERENCES"}, acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "role_name", name), resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "on_schema_object.#", "1"), @@ -520,7 +520,7 @@ func TestAccGrantPrivilegesToRole_onSchemaObject_allInSchema(t *testing.T) { }, // REMOVE PRIVILEGE { - Config: grantPrivilegesToRole_onSchemaObject_allInSchema(name, []string{"SELECT"}), + Config: grantPrivilegesToRole_onSchemaObject_allInSchema(name, []string{"SELECT"}, acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "role_name", name), resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "privileges.#", "1"), @@ -537,7 +537,7 @@ func TestAccGrantPrivilegesToRole_onSchemaObject_allInSchema(t *testing.T) { }) } -func grantPrivilegesToRole_onSchemaObject_allInSchema(name string, privileges []string) string { +func grantPrivilegesToRole_onSchemaObject_allInSchema(name string, privileges []string, databaseName string, schemaName string) string { doubleQuotePrivileges := make([]string, len(privileges)) for i, p := range privileges { doubleQuotePrivileges[i] = fmt.Sprintf(`"%v"`, p) @@ -554,14 +554,14 @@ func grantPrivilegesToRole_onSchemaObject_allInSchema(name string, privileges [] on_schema_object { all { object_type_plural = "TABLES" - in_schema = "\"terraform_test_database\".\"terraform_test_schema\"" + in_schema = "\"%s\".\"%s\"" } } } - `, name, privilegesString) + `, name, privilegesString, databaseName, schemaName) } -func TestAccGrantPrivilegesToRole_onSchemaObject_allInDatabase(t *testing.T) { +func TestAcc_GrantPrivilegesToRole_onSchemaObject_allInDatabase(t *testing.T) { name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) resource.ParallelTest(t, resource.TestCase{ @@ -570,7 +570,7 @@ func TestAccGrantPrivilegesToRole_onSchemaObject_allInDatabase(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: grantPrivilegesToRole_onSchemaObject_allInDatabase(name, []string{"SELECT", "REFERENCES"}), + Config: grantPrivilegesToRole_onSchemaObject_allInDatabase(name, []string{"SELECT", "REFERENCES"}, acc.TestDatabaseName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "role_name", name), resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "on_schema_object.#", "1"), @@ -584,7 +584,7 @@ func TestAccGrantPrivilegesToRole_onSchemaObject_allInDatabase(t *testing.T) { }, // REMOVE PRIVILEGE { - Config: grantPrivilegesToRole_onSchemaObject_allInDatabase(name, []string{"SELECT"}), + Config: grantPrivilegesToRole_onSchemaObject_allInDatabase(name, []string{"SELECT"}, acc.TestDatabaseName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "role_name", name), resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "privileges.#", "1"), @@ -601,7 +601,7 @@ func TestAccGrantPrivilegesToRole_onSchemaObject_allInDatabase(t *testing.T) { }) } -func grantPrivilegesToRole_onSchemaObject_allInDatabase(name string, privileges []string) string { +func grantPrivilegesToRole_onSchemaObject_allInDatabase(name string, privileges []string, databaseName string) string { doubleQuotePrivileges := make([]string, len(privileges)) for i, p := range privileges { doubleQuotePrivileges[i] = fmt.Sprintf(`"%v"`, p) @@ -618,14 +618,14 @@ func grantPrivilegesToRole_onSchemaObject_allInDatabase(name string, privileges on_schema_object { all { object_type_plural = "TABLES" - in_database = "terraform_test_database" + in_database = "%s" } } } - `, name, privilegesString) + `, name, privilegesString, databaseName) } -func TestAccGrantPrivilegesToRole_onSchemaObject_futureInSchema(t *testing.T) { +func TestAcc_GrantPrivilegesToRole_onSchemaObject_futureInSchema(t *testing.T) { name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) resource.ParallelTest(t, resource.TestCase{ @@ -634,7 +634,7 @@ func TestAccGrantPrivilegesToRole_onSchemaObject_futureInSchema(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: grantPrivilegesToRole_onSchemaObject_futureInSchema(name, []string{"SELECT", "REFERENCES"}), + Config: grantPrivilegesToRole_onSchemaObject_futureInSchema(name, []string{"SELECT", "REFERENCES"}, acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "role_name", name), resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "on_schema_object.#", "1"), @@ -648,7 +648,7 @@ func TestAccGrantPrivilegesToRole_onSchemaObject_futureInSchema(t *testing.T) { }, // REMOVE PRIVILEGE { - Config: grantPrivilegesToRole_onSchemaObject_futureInSchema(name, []string{"SELECT"}), + Config: grantPrivilegesToRole_onSchemaObject_futureInSchema(name, []string{"SELECT"}, acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "role_name", name), resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "privileges.#", "1"), @@ -665,7 +665,7 @@ func TestAccGrantPrivilegesToRole_onSchemaObject_futureInSchema(t *testing.T) { }) } -func grantPrivilegesToRole_onSchemaObject_futureInSchema(name string, privileges []string) string { +func grantPrivilegesToRole_onSchemaObject_futureInSchema(name string, privileges []string, databaseName string, schemaName string) string { doubleQuotePrivileges := make([]string, len(privileges)) for i, p := range privileges { doubleQuotePrivileges[i] = fmt.Sprintf(`"%v"`, p) @@ -682,14 +682,14 @@ func grantPrivilegesToRole_onSchemaObject_futureInSchema(name string, privileges on_schema_object { future { object_type_plural = "TABLES" - in_schema = "\"terraform_test_database\".\"terraform_test_schema\"" + in_schema = "\"%s\".\"%s\"" } } } - `, name, privilegesString) + `, name, privilegesString, databaseName, schemaName) } -func TestAccGrantPrivilegesToRole_onSchemaObject_futureInDatabase(t *testing.T) { +func TestAcc_GrantPrivilegesToRole_onSchemaObject_futureInDatabase(t *testing.T) { name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) objectType := "TABLES" resource.ParallelTest(t, resource.TestCase{ @@ -698,7 +698,7 @@ func TestAccGrantPrivilegesToRole_onSchemaObject_futureInDatabase(t *testing.T) CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: grantPrivilegesToRole_onSchemaObject_futureInDatabase(name, objectType, []string{"SELECT", "REFERENCES"}), + Config: grantPrivilegesToRole_onSchemaObject_futureInDatabase(name, objectType, []string{"SELECT", "REFERENCES"}, acc.TestDatabaseName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "role_name", name), resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "on_schema_object.#", "1"), @@ -712,7 +712,7 @@ func TestAccGrantPrivilegesToRole_onSchemaObject_futureInDatabase(t *testing.T) }, // REMOVE PRIVILEGE { - Config: grantPrivilegesToRole_onSchemaObject_futureInDatabase(name, objectType, []string{"SELECT"}), + Config: grantPrivilegesToRole_onSchemaObject_futureInDatabase(name, objectType, []string{"SELECT"}, acc.TestDatabaseName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "role_name", name), resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "privileges.#", "1"), @@ -729,7 +729,7 @@ func TestAccGrantPrivilegesToRole_onSchemaObject_futureInDatabase(t *testing.T) }) } -func grantPrivilegesToRole_onSchemaObject_futureInDatabase(name string, objectType string, privileges []string) string { +func grantPrivilegesToRole_onSchemaObject_futureInDatabase(name string, objectType string, privileges []string, databaseName string) string { doubleQuotePrivileges := make([]string, len(privileges)) for i, p := range privileges { doubleQuotePrivileges[i] = fmt.Sprintf(`"%v"`, p) @@ -746,14 +746,14 @@ func grantPrivilegesToRole_onSchemaObject_futureInDatabase(name string, objectTy on_schema_object { future { object_type_plural = "%s" - in_database = "terraform_test_database" + in_database = "%s" } } } - `, name, privilegesString, objectType) + `, name, privilegesString, objectType, databaseName) } -func TestAccGrantPrivilegesToRole_multipleResources(t *testing.T) { +func TestAcc_GrantPrivilegesToRole_multipleResources(t *testing.T) { name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) resource.ParallelTest(t, resource.TestCase{ @@ -822,7 +822,7 @@ func grantPrivilegesToRole_multipleResources(name string, privileges1, privilege `, name, privilegesString1, privilegesString2) } -func TestAccGrantPrivilegesToRole_onSchemaObject_futureInDatabase_externalTable(t *testing.T) { +func TestAcc_GrantPrivilegesToRole_onSchemaObject_futureInDatabase_externalTable(t *testing.T) { name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) objectType := "EXTERNAL TABLES" resource.ParallelTest(t, resource.TestCase{ @@ -831,7 +831,7 @@ func TestAccGrantPrivilegesToRole_onSchemaObject_futureInDatabase_externalTable( CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: grantPrivilegesToRole_onSchemaObject_futureInDatabase(name, objectType, []string{"SELECT", "REFERENCES"}), + Config: grantPrivilegesToRole_onSchemaObject_futureInDatabase(name, objectType, []string{"SELECT", "REFERENCES"}, acc.TestDatabaseName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "role_name", name), resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "on_schema_object.#", "1"), @@ -845,7 +845,7 @@ func TestAccGrantPrivilegesToRole_onSchemaObject_futureInDatabase_externalTable( }, // REMOVE PRIVILEGE { - Config: grantPrivilegesToRole_onSchemaObject_futureInDatabase(name, objectType, []string{"SELECT"}), + Config: grantPrivilegesToRole_onSchemaObject_futureInDatabase(name, objectType, []string{"SELECT"}, acc.TestDatabaseName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "role_name", name), resource.TestCheckResourceAttr("snowflake_grant_privileges_to_role.g", "privileges.#", "1"), diff --git a/pkg/resources/internal_stage_acceptance_test.go b/pkg/resources/internal_stage_acceptance_test.go index 3cd8e77fe0..fccea3832a 100644 --- a/pkg/resources/internal_stage_acceptance_test.go +++ b/pkg/resources/internal_stage_acceptance_test.go @@ -19,11 +19,11 @@ func TestAcc_InternalStage(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: internalStageConfig(accName), + Config: internalStageConfig(accName, acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_stage.test", "name", accName), - resource.TestCheckResourceAttr("snowflake_stage.test", "database", accName), - resource.TestCheckResourceAttr("snowflake_stage.test", "schema", accName), + resource.TestCheckResourceAttr("snowflake_stage.test", "database", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_stage.test", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_stage.test", "comment", "Terraform acceptance test"), ), }, @@ -31,24 +31,13 @@ func TestAcc_InternalStage(t *testing.T) { }) } -func internalStageConfig(n string) string { +func internalStageConfig(n, databaseName, schemaName string) string { return fmt.Sprintf(` -resource "snowflake_database" "test" { - name = "%v" - comment = "Terraform acceptance test" -} - -resource "snowflake_schema" "test" { - name = "%v" - database = snowflake_database.test.name - comment = "Terraform acceptance test" -} - resource "snowflake_stage" "test" { name = "%v" - database = snowflake_database.test.name - schema = snowflake_schema.test.name + database = "%s" + schema = "%s" comment = "Terraform acceptance test" } -`, n, n, n) +`, n, databaseName, schemaName) } diff --git a/pkg/resources/masking_policy_acceptance_test.go b/pkg/resources/masking_policy_acceptance_test.go index 606e101c07..dab09f717d 100644 --- a/pkg/resources/masking_policy_acceptance_test.go +++ b/pkg/resources/masking_policy_acceptance_test.go @@ -21,11 +21,11 @@ func TestAcc_MaskingPolicy(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: maskingPolicyConfig(accName, accName, comment), + Config: maskingPolicyConfig(accName, accName, comment, acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_masking_policy.test", "name", accName), - resource.TestCheckResourceAttr("snowflake_masking_policy.test", "database", accName), - resource.TestCheckResourceAttr("snowflake_masking_policy.test", "schema", accName), + resource.TestCheckResourceAttr("snowflake_masking_policy.test", "database", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_masking_policy.test", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_masking_policy.test", "comment", comment), resource.TestCheckResourceAttr("snowflake_masking_policy.test", "masking_expression", "case when current_role() in ('ANALYST') then val else sha2(val, 512) end"), resource.TestCheckResourceAttr("snowflake_masking_policy.test", "return_data_type", "VARCHAR"), @@ -37,7 +37,7 @@ func TestAcc_MaskingPolicy(t *testing.T) { }, // change comment { - Config: maskingPolicyConfig(accName, accName, comment2), + Config: maskingPolicyConfig(accName, accName, comment2, acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_masking_policy.test", "name", accName), resource.TestCheckResourceAttr("snowflake_masking_policy.test", "comment", comment2), @@ -45,14 +45,14 @@ func TestAcc_MaskingPolicy(t *testing.T) { }, // rename { - Config: maskingPolicyConfig(accName, accName2, comment2), + Config: maskingPolicyConfig(accName, accName2, comment2, acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_masking_policy.test", "name", accName2), ), }, // change body and unset comment { - Config: maskingPolicyConfigMultiline(accName, accName2), + Config: maskingPolicyConfigMultiline(accName, accName2, acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_masking_policy.test", "masking_expression", "case\n\twhen current_role() in ('ROLE_A') then\n\t\tval\n\twhen is_role_in_session( 'ROLE_B' ) then\n\t\t'ABC123'\n\telse\n\t\t'******'\nend"), resource.TestCheckResourceAttr("snowflake_masking_policy.test", "comment", ""), @@ -68,23 +68,12 @@ func TestAcc_MaskingPolicy(t *testing.T) { }) } -func maskingPolicyConfig(n string, name string, comment string) string { +func maskingPolicyConfig(n string, name string, comment string, databaseName string, schemaName string) string { return fmt.Sprintf(` -resource "snowflake_database" "test" { - name = "%v" - comment = "Terraform acceptance test" -} - -resource "snowflake_schema" "test" { - name = "%v" - database = snowflake_database.test.name - comment = "Terraform acceptance test" -} - resource "snowflake_masking_policy" "test" { name = "%s" - database = snowflake_database.test.name - schema = snowflake_schema.test.name + database = "%s" + schema = "%s" signature { column { name = "val" @@ -95,26 +84,15 @@ resource "snowflake_masking_policy" "test" { return_data_type = "VARCHAR" comment = "%s" } -`, n, n, name, comment) +`, name, databaseName, schemaName, comment) } -func maskingPolicyConfigMultiline(n string, name string) string { +func maskingPolicyConfigMultiline(n string, name string, databaseName string, schemaName string) string { return fmt.Sprintf(` - resource "snowflake_database" "test" { - name = "%v" - comment = "Terraform acceptance test" - } - - resource "snowflake_schema" "test" { - name = "%v" - database = snowflake_database.test.name - comment = "Terraform acceptance test" - } - resource "snowflake_masking_policy" "test" { name = "%s" - database = snowflake_database.test.name - schema = snowflake_schema.test.name + database = "%s" + schema = "%s" signature { column { name = "val" @@ -133,5 +111,5 @@ func maskingPolicyConfigMultiline(n string, name string) string { EOF return_data_type = "VARCHAR" } - `, n, n, name) + `, name, databaseName, schemaName) } diff --git a/pkg/resources/masking_policy_grant_acceptance_test.go b/pkg/resources/masking_policy_grant_acceptance_test.go index 7a55ee2b9d..fc435e2555 100644 --- a/pkg/resources/masking_policy_grant_acceptance_test.go +++ b/pkg/resources/masking_policy_grant_acceptance_test.go @@ -19,10 +19,10 @@ func TestAcc_MaskingPolicyGrant(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: maskingPolicyGrantConfig(accName, "APPLY"), + Config: maskingPolicyGrantConfig(accName, "APPLY", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( - resource.TestCheckResourceAttr("snowflake_masking_policy_grant.test", "database_name", accName), - resource.TestCheckResourceAttr("snowflake_masking_policy_grant.test", "schema_name", accName), + resource.TestCheckResourceAttr("snowflake_masking_policy_grant.test", "database_name", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_masking_policy_grant.test", "schema_name", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_masking_policy_grant.test", "masking_policy_name", accName), resource.TestCheckResourceAttr("snowflake_masking_policy_grant.test", "with_grant_option", "false"), resource.TestCheckResourceAttr("snowflake_masking_policy_grant.test", "privilege", "APPLY"), @@ -30,10 +30,10 @@ func TestAcc_MaskingPolicyGrant(t *testing.T) { }, // UPDATE ALL PRIVILEGES { - Config: maskingPolicyGrantConfig(accName, "ALL PRIVILEGES"), + Config: maskingPolicyGrantConfig(accName, "ALL PRIVILEGES", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( - resource.TestCheckResourceAttr("snowflake_masking_policy_grant.test", "database_name", accName), - resource.TestCheckResourceAttr("snowflake_masking_policy_grant.test", "schema_name", accName), + resource.TestCheckResourceAttr("snowflake_masking_policy_grant.test", "database_name", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_masking_policy_grant.test", "schema_name", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_masking_policy_grant.test", "masking_policy_name", accName), resource.TestCheckResourceAttr("snowflake_masking_policy_grant.test", "with_grant_option", "false"), resource.TestCheckResourceAttr("snowflake_masking_policy_grant.test", "privilege", "ALL PRIVILEGES"), @@ -52,27 +52,16 @@ func TestAcc_MaskingPolicyGrant(t *testing.T) { }) } -func maskingPolicyGrantConfig(name string, privilege string) string { +func maskingPolicyGrantConfig(name string, privilege string, databaseName string, schemaName string) string { return fmt.Sprintf(` - resource "snowflake_database" "test" { - name = "%v" - comment = "Terraform acceptance test" - } - - resource "snowflake_schema" "test" { - name = "%v" - database = snowflake_database.test.name - comment = "Terraform acceptance test" - } - resource "snowflake_role" "test" { name = "%v" } resource "snowflake_masking_policy" "test" { name = "%v" - database = snowflake_database.test.name - schema = snowflake_schema.test.name + database = "%s" + schema = "%s" signature { column { name = "val" @@ -86,10 +75,10 @@ func maskingPolicyGrantConfig(name string, privilege string) string { resource "snowflake_masking_policy_grant" "test" { masking_policy_name = snowflake_masking_policy.test.name - database_name = snowflake_database.test.name + database_name = "%s" roles = [snowflake_role.test.name] - schema_name = snowflake_schema.test.name + schema_name = "%s" privilege = "%s" } - `, name, name, name, name, privilege) + `, name, name, databaseName, schemaName, databaseName, schemaName, privilege) } diff --git a/pkg/resources/materialized_view_acceptance_test.go b/pkg/resources/materialized_view_acceptance_test.go index 7159df4be0..e6a0176075 100644 --- a/pkg/resources/materialized_view_acceptance_test.go +++ b/pkg/resources/materialized_view_acceptance_test.go @@ -15,11 +15,9 @@ func TestAcc_MaterializedView(t *testing.T) { if _, ok := os.LookupEnv("SKIP_MATERIALIZED_VIEW_TESTS"); ok { t.Skip("Skipping TestAcc_MaterializedView") } - dbName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) - schemaName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) tableName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) - warehouseName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) viewName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) + warehouseName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) resource.ParallelTest(t, resource.TestCase{ Providers: acc.TestAccProviders(), @@ -27,11 +25,11 @@ func TestAcc_MaterializedView(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: materializedViewConfig(warehouseName, dbName, schemaName, tableName, viewName, fmt.Sprintf("SELECT ID, DATA FROM \\\"%s\\\";", tableName)), + Config: materializedViewConfig(warehouseName, tableName, viewName, fmt.Sprintf("SELECT ID, DATA FROM \\\"%s\\\";", tableName), acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_materialized_view.test", "name", viewName), - resource.TestCheckResourceAttr("snowflake_materialized_view.test", "database", dbName), - resource.TestCheckResourceAttr("snowflake_materialized_view.test", "schema", schemaName), + resource.TestCheckResourceAttr("snowflake_materialized_view.test", "database", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_materialized_view.test", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_materialized_view.test", "warehouse", warehouseName), resource.TestCheckResourceAttr("snowflake_materialized_view.test", "comment", "Terraform test resource"), checkBool("snowflake_materialized_view.test", "is_secure", true), @@ -45,8 +43,6 @@ func TestAcc_MaterializedView2(t *testing.T) { if _, ok := os.LookupEnv("SKIP_MATERIALIZED_VIEW_TESTS"); ok { t.Skip("Skipping TestAcc_MaterializedView2") } - dbName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) - schemaName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) tableName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) warehouseName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) viewName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) @@ -57,11 +53,11 @@ func TestAcc_MaterializedView2(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: materializedViewConfig(warehouseName, dbName, schemaName, tableName, viewName, fmt.Sprintf("SELECT ID, DATA FROM \\\"%s\\\" WHERE ID LIKE 'foo%%';", tableName)), + Config: materializedViewConfig(warehouseName, tableName, viewName, fmt.Sprintf("SELECT ID, DATA FROM \\\"%s\\\" WHERE ID LIKE 'foo%%';", tableName), acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_materialized_view.test", "name", viewName), - resource.TestCheckResourceAttr("snowflake_materialized_view.test", "database", dbName), - resource.TestCheckResourceAttr("snowflake_materialized_view.test", "schema", schemaName), + resource.TestCheckResourceAttr("snowflake_materialized_view.test", "database", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_materialized_view.test", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_materialized_view.test", "warehouse", warehouseName), resource.TestCheckResourceAttr("snowflake_materialized_view.test", "comment", "Terraform test resource"), checkBool("snowflake_materialized_view.test", "is_secure", true), @@ -71,27 +67,16 @@ func TestAcc_MaterializedView2(t *testing.T) { }) } -func materializedViewConfig(warehouseName string, dbName string, schemaName string, tableName string, viewName string, q string) string { +func materializedViewConfig(warehouseName string, tableName string, viewName string, q string, databaseName string, schemaName string) string { // convert the cluster from string slice to string return fmt.Sprintf(` -resource "snowflake_warehouse" "test" { - name = "%s" - initially_suspended = false -} - -resource "snowflake_database" "test" { +resource "snowflake_warehouse" "wh" { name = "%s" } - -resource "snowflake_schema" "test" { - database = snowflake_database.test.name - name = "%s" -} - resource "snowflake_table" "test" { - database = snowflake_database.test.name - schema = snowflake_schema.test.name name = "%s" + database = "%s" + schema = "%s" column { name = "ID" @@ -107,17 +92,16 @@ resource "snowflake_table" "test" { resource "snowflake_materialized_view" "test" { name = "%s" comment = "Terraform test resource" - database = snowflake_database.test.name - schema = snowflake_schema.test.name - warehouse = snowflake_warehouse.test.name + database = "%s" + schema = "%s" + warehouse = snowflake_warehouse.wh.name is_secure = true or_replace = false statement = "%s" depends_on = [ - snowflake_warehouse.test, snowflake_table.test ] } -`, warehouseName, dbName, schemaName, tableName, viewName, q) +`, warehouseName, tableName, databaseName, schemaName, viewName, databaseName, schemaName, q) } diff --git a/pkg/resources/materialized_view_grant_acceptance_test.go b/pkg/resources/materialized_view_grant_acceptance_test.go index 651d918876..f3047710a5 100644 --- a/pkg/resources/materialized_view_grant_acceptance_test.go +++ b/pkg/resources/materialized_view_grant_acceptance_test.go @@ -19,10 +19,10 @@ func TestAcc_MaterializedViewFutureGrant(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: materializedViewGrantConfigFuture(name, onFuture, "SELECT"), + Config: materializedViewGrantConfigFuture(name, onFuture, "SELECT", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( - resource.TestCheckResourceAttr("snowflake_materialized_view_grant.test", "database_name", name), - resource.TestCheckResourceAttr("snowflake_materialized_view_grant.test", "schema_name", name), + resource.TestCheckResourceAttr("snowflake_materialized_view_grant.test", "database_name", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_materialized_view_grant.test", "schema_name", acc.TestSchemaName), resource.TestCheckNoResourceAttr("snowflake_materialized_view_grant.test", "materialized_view_name"), resource.TestCheckResourceAttr("snowflake_materialized_view_grant.test", "with_grant_option", "false"), resource.TestCheckResourceAttr("snowflake_materialized_view_grant.test", "on_future", "true"), @@ -51,10 +51,10 @@ func TestAcc_MaterializedViewAllGrant(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: materializedViewGrantConfigFuture(name, onAll, "SELECT"), + Config: materializedViewGrantConfigFuture(name, onAll, "SELECT", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( - resource.TestCheckResourceAttr("snowflake_materialized_view_grant.test", "database_name", name), - resource.TestCheckResourceAttr("snowflake_materialized_view_grant.test", "schema_name", name), + resource.TestCheckResourceAttr("snowflake_materialized_view_grant.test", "database_name", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_materialized_view_grant.test", "schema_name", acc.TestSchemaName), resource.TestCheckNoResourceAttr("snowflake_materialized_view_grant.test", "materialized_view_name"), resource.TestCheckResourceAttr("snowflake_materialized_view_grant.test", "with_grant_option", "false"), resource.TestCheckResourceAttr("snowflake_materialized_view_grant.test", "on_all", "true"), @@ -73,7 +73,7 @@ func TestAcc_MaterializedViewAllGrant(t *testing.T) { }) } -func materializedViewGrantConfigFuture(name string, grantType grantType, privilege string) string { +func materializedViewGrantConfigFuture(name string, grantType grantType, privilege string, databaseName string, schemaName string) string { var materializedViewNameConfig string switch grantType { case onFuture: @@ -83,25 +83,16 @@ func materializedViewGrantConfigFuture(name string, grantType grantType, privile } return fmt.Sprintf(` -resource "snowflake_database" "test" { - name = "%s" -} - -resource "snowflake_schema" "test" { - name = "%s" - database = snowflake_database.test.name -} - resource "snowflake_role" "test" { name = "%s" } resource "snowflake_materialized_view_grant" "test" { - database_name = snowflake_database.test.name + database_name = "%s" roles = [snowflake_role.test.name] - schema_name = snowflake_schema.test.name + schema_name = "%s" %s privilege = "%s" } -`, name, name, name, materializedViewNameConfig, privilege) +`, name, databaseName, schemaName, materializedViewNameConfig, privilege) } diff --git a/pkg/resources/object_parameter_acceptance_test.go b/pkg/resources/object_parameter_acceptance_test.go index 490c26f220..75060270d4 100644 --- a/pkg/resources/object_parameter_acceptance_test.go +++ b/pkg/resources/object_parameter_acceptance_test.go @@ -2,23 +2,20 @@ package resources_test import ( "fmt" - "strings" "testing" acc "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance" - "github.com/hashicorp/terraform-plugin-testing/helper/acctest" "github.com/hashicorp/terraform-plugin-testing/helper/resource" ) func TestAcc_ObjectParameter(t *testing.T) { - prefix := "tst-terraform" + strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) resource.ParallelTest(t, resource.TestCase{ Providers: acc.TestAccProviders(), PreCheck: func() { acc.TestAccPreCheck(t) }, CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: objectParameterConfigBasic(prefix, "USER_TASK_TIMEOUT_MS", "1000"), + Config: objectParameterConfigBasic("USER_TASK_TIMEOUT_MS", "1000", acc.TestDatabaseName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_object_parameter.p", "key", "USER_TASK_TIMEOUT_MS"), resource.TestCheckResourceAttr("snowflake_object_parameter.p", "value", "1000"), @@ -58,19 +55,16 @@ resource "snowflake_object_parameter" "p" { return fmt.Sprintf(s, key, value) } -func objectParameterConfigBasic(prefix, key, value string) string { +func objectParameterConfigBasic(key, value, databaseName string) string { s := ` -resource "snowflake_database" "d" { - name = "%s" -} resource "snowflake_object_parameter" "p" { key = "%s" value = "%s" object_type = "DATABASE" object_identifier { - name = snowflake_database.d.name + name = "%s" } } ` - return fmt.Sprintf(s, prefix, key, value) + return fmt.Sprintf(s, key, value, databaseName) } diff --git a/pkg/resources/password_policy_acceptance_test.go b/pkg/resources/password_policy_acceptance_test.go index 89c38c56d6..4b98264c9a 100644 --- a/pkg/resources/password_policy_acceptance_test.go +++ b/pkg/resources/password_policy_acceptance_test.go @@ -19,7 +19,7 @@ func TestAcc_PasswordPolicy(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: passwordPolicyConfig(accName, 10, 30, "this is a test resource"), + Config: passwordPolicyConfig(accName, 10, 30, "this is a test resource", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_password_policy.pa", "name", accName), resource.TestCheckResourceAttr("snowflake_password_policy.pa", "min_length", "10"), @@ -27,7 +27,7 @@ func TestAcc_PasswordPolicy(t *testing.T) { ), }, { - Config: passwordPolicyConfig(accName, 20, 50, "this is a test resource"), + Config: passwordPolicyConfig(accName, 20, 50, "this is a test resource", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_password_policy.pa", "min_length", "20"), resource.TestCheckResourceAttr("snowflake_password_policy.pa", "max_length", "50"), @@ -52,17 +52,17 @@ func TestAcc_PasswordPolicy(t *testing.T) { }) } -func passwordPolicyConfig(s string, minLength int, maxLength int, comment string) string { +func passwordPolicyConfig(s string, minLength int, maxLength int, comment string, databaseName string, schemaName string) string { return fmt.Sprintf(` resource "snowflake_password_policy" "pa" { - database = "terraform_test_database" - schema = "terraform_test_schema" name = "%v" + database = "%s" + schema = "%s" min_length = %d max_length = %d or_replace = true } - `, s, minLength, maxLength) + `, s, databaseName, schemaName, minLength, maxLength) } func TestAcc_PasswordPolicyMaxAgeDays(t *testing.T) { @@ -75,20 +75,20 @@ func TestAcc_PasswordPolicyMaxAgeDays(t *testing.T) { Steps: []resource.TestStep{ // Creation sets zero properly { - Config: passwordPolicyDefaultMaxageDaysConfig(accName, 0), + Config: passwordPolicyDefaultMaxageDaysConfig(accName, acc.TestDatabaseName, acc.TestSchemaName, 0), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_password_policy.pa", "max_age_days", "0"), ), }, { - Config: passwordPolicyDefaultMaxageDaysConfig(accName, 10), + Config: passwordPolicyDefaultMaxageDaysConfig(accName, acc.TestDatabaseName, acc.TestSchemaName, 10), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_password_policy.pa", "max_age_days", "10"), ), }, // Update sets zero properly { - Config: passwordPolicyDefaultMaxageDaysConfig(accName, 0), + Config: passwordPolicyDefaultMaxageDaysConfig(accName, acc.TestDatabaseName, acc.TestSchemaName, 0), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_password_policy.pa", "max_age_days", "0"), ), @@ -102,13 +102,13 @@ func TestAcc_PasswordPolicyMaxAgeDays(t *testing.T) { }) } -func passwordPolicyDefaultMaxageDaysConfig(s string, maxAgeDays int) string { +func passwordPolicyDefaultMaxageDaysConfig(s string, databaseName string, schemaName string, maxAgeDays int) string { return fmt.Sprintf(` resource "snowflake_password_policy" "pa" { - database = "terraform_test_database" - schema = "terraform_test_schema" name = "%v" + database = "%s" + schema = "%s" max_age_days = %d } - `, s, maxAgeDays) + `, s, databaseName, schemaName, maxAgeDays) } diff --git a/pkg/resources/pipe_acceptance_test.go b/pkg/resources/pipe_acceptance_test.go index 8fd8e0b884..2de7527fc6 100644 --- a/pkg/resources/pipe_acceptance_test.go +++ b/pkg/resources/pipe_acceptance_test.go @@ -23,11 +23,11 @@ func TestAcc_Pipe(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: pipeConfig(accName), + Config: pipeConfig(accName, acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_pipe.test", "name", accName), - resource.TestCheckResourceAttr("snowflake_pipe.test", "database", accName), - resource.TestCheckResourceAttr("snowflake_pipe.test", "schema", accName), + resource.TestCheckResourceAttr("snowflake_pipe.test", "database", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_pipe.test", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_pipe.test", "comment", "Terraform acceptance test"), resource.TestCheckResourceAttr("snowflake_pipe.test", "auto_ingest", "false"), resource.TestCheckResourceAttr("snowflake_pipe.test", "notification_channel", ""), @@ -37,23 +37,12 @@ func TestAcc_Pipe(t *testing.T) { }) } -func pipeConfig(name string) string { +func pipeConfig(name string, databaseName string, schemaName string) string { s := ` -resource "snowflake_database" "test" { - name = "%v" - comment = "Terraform acceptance test" -} - -resource "snowflake_schema" "test" { - name = snowflake_database.test.name - database = snowflake_database.test.name - comment = "Terraform acceptance test" -} - resource "snowflake_table" "test" { - database = snowflake_database.test.name - schema = snowflake_schema.test.name - name = snowflake_schema.test.name + database = "%s" + schema = "%s" + name = "%s" column { name = "id" @@ -67,17 +56,17 @@ resource "snowflake_table" "test" { } resource "snowflake_stage" "test" { - name = snowflake_schema.test.name - database = snowflake_database.test.name - schema = snowflake_schema.test.name + name = "%s" + database = "%s" + schema = "%s" comment = "Terraform acceptance test" } resource "snowflake_pipe" "test" { - database = snowflake_database.test.name - schema = snowflake_schema.test.name - name = snowflake_schema.test.name + database = "%s" + schema = "%s" + name = "%s" comment = "Terraform acceptance test" copy_statement = < { - Config: taskOwnershipGrantConfig(name, onFuture, "OWNERSHIP", name), + Config: taskOwnershipGrantConfig(name, onFuture, "OWNERSHIP", name, acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( - resource.TestCheckResourceAttr("snowflake_task_grant.test", "database_name", name), - resource.TestCheckResourceAttr("snowflake_task_grant.test", "schema_name", name), + resource.TestCheckResourceAttr("snowflake_task_grant.test", "database_name", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_task_grant.test", "schema_name", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_task_grant.test", "on_future", "true"), resource.TestCheckResourceAttr("snowflake_task_grant.test", "with_grant_option", "false"), resource.TestCheckResourceAttr("snowflake_task_grant.test", "privilege", "OWNERSHIP"), @@ -215,10 +204,10 @@ func TestAcc_TaskOwnershipGrant_onFuture(t *testing.T) { }, // UPDATE SCHEMA level FUTURE OWNERSHIP grant to role { - Config: taskOwnershipGrantConfig(name, onFuture, "OWNERSHIP", new_name), + Config: taskOwnershipGrantConfig(name, onFuture, "OWNERSHIP", new_name, acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( - resource.TestCheckResourceAttr("snowflake_task_grant.test", "database_name", name), - resource.TestCheckResourceAttr("snowflake_task_grant.test", "schema_name", name), + resource.TestCheckResourceAttr("snowflake_task_grant.test", "database_name", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_task_grant.test", "schema_name", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_task_grant.test", "on_future", "true"), resource.TestCheckResourceAttr("snowflake_task_grant.test", "with_grant_option", "false"), resource.TestCheckResourceAttr("snowflake_task_grant.test", "privilege", "OWNERSHIP"), @@ -238,7 +227,7 @@ func TestAcc_TaskOwnershipGrant_onFuture(t *testing.T) { }) } -func taskOwnershipGrantConfig(name string, grantType grantType, privilege string, rolename string) string { +func taskOwnershipGrantConfig(name string, grantType grantType, privilege string, rolename string, databaseName string, schemaName string) string { var taskNameConfig string switch grantType { case normal: @@ -250,17 +239,6 @@ func taskOwnershipGrantConfig(name string, grantType grantType, privilege string } s := ` -resource "snowflake_database" "test" { - name = "%v" - comment = "Terraform acceptance test" -} - -resource "snowflake_schema" "test" { - name = snowflake_database.test.name - database = snowflake_database.test.name - comment = "Terraform acceptance test" -} - resource "snowflake_role" "test" { name = "%v" } @@ -271,12 +249,12 @@ resource "snowflake_role" "test_new" { resource "snowflake_task_grant" "test" { %s - database_name = snowflake_database.test.name roles = [ "%s" ] - schema_name = snowflake_schema.test.name + database_name = "%s" + schema_name = "%s" privilege = "%s" with_grant_option = false } ` - return fmt.Sprintf(s, name, name, name, taskNameConfig, rolename, privilege) + return fmt.Sprintf(s, name, name, taskNameConfig, rolename, databaseName, schemaName, privilege) } diff --git a/pkg/resources/view_acceptance_test.go b/pkg/resources/view_acceptance_test.go index ba8b7f7d4c..4383d3bade 100644 --- a/pkg/resources/view_acceptance_test.go +++ b/pkg/resources/view_acceptance_test.go @@ -19,10 +19,10 @@ func TestAcc_View(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: viewConfig(accName, false, "SELECT ROLE_NAME, ROLE_OWNER FROM INFORMATION_SCHEMA.APPLICABLE_ROLES"), + Config: viewConfig(accName, false, "SELECT ROLE_NAME, ROLE_OWNER FROM INFORMATION_SCHEMA.APPLICABLE_ROLES", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_view.test", "name", accName), - resource.TestCheckResourceAttr("snowflake_view.test", "database", accName), + resource.TestCheckResourceAttr("snowflake_view.test", "database", acc.TestDatabaseName), resource.TestCheckResourceAttr("snowflake_view.test", "comment", "Terraform test resource"), resource.TestCheckResourceAttr("snowflake_view.test", "copy_grants", "false"), checkBool("snowflake_view.test", "is_secure", true), // this is from user_acceptance_test.go @@ -41,10 +41,10 @@ func TestAcc_View2(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: viewConfig(accName, false, "SELECT ROLE_NAME, ROLE_OWNER FROM INFORMATION_SCHEMA.APPLICABLE_ROLES where ROLE_OWNER like 'foo%%';"), + Config: viewConfig(accName, false, "SELECT ROLE_NAME, ROLE_OWNER FROM INFORMATION_SCHEMA.APPLICABLE_ROLES where ROLE_OWNER like 'foo%%';", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_view.test", "name", accName), - resource.TestCheckResourceAttr("snowflake_view.test", "database", accName), + resource.TestCheckResourceAttr("snowflake_view.test", "database", acc.TestDatabaseName), resource.TestCheckResourceAttr("snowflake_view.test", "comment", "Terraform test resource"), resource.TestCheckResourceAttr("snowflake_view.test", "copy_grants", "false"), checkBool("snowflake_view.test", "is_secure", true), // this is from user_acceptance_test.go @@ -63,10 +63,10 @@ func TestAcc_ViewWithCopyGrants(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: viewConfig(accName, true, "SELECT ROLE_NAME, ROLE_OWNER FROM INFORMATION_SCHEMA.APPLICABLE_ROLES"), + Config: viewConfig(accName, true, "SELECT ROLE_NAME, ROLE_OWNER FROM INFORMATION_SCHEMA.APPLICABLE_ROLES", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_view.test", "name", accName), - resource.TestCheckResourceAttr("snowflake_view.test", "database", accName), + resource.TestCheckResourceAttr("snowflake_view.test", "database", acc.TestDatabaseName), resource.TestCheckResourceAttr("snowflake_view.test", "comment", "Terraform test resource"), resource.TestCheckResourceAttr("snowflake_view.test", "copy_grants", "true"), checkBool("snowflake_view.test", "is_secure", true), // this is from user_acceptance_test.go @@ -88,7 +88,7 @@ func TestAcc_ViewChangeCopyGrants(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: viewConfig(accName, false, "SELECT ROLE_NAME, ROLE_OWNER FROM INFORMATION_SCHEMA.APPLICABLE_ROLES"), + Config: viewConfig(accName, false, "SELECT ROLE_NAME, ROLE_OWNER FROM INFORMATION_SCHEMA.APPLICABLE_ROLES", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_view.test", "copy_grants", "false"), resource.TestCheckResourceAttrWith("snowflake_view.test", "created_on", func(value string) error { @@ -99,7 +99,7 @@ func TestAcc_ViewChangeCopyGrants(t *testing.T) { ), }, { - Config: viewConfig(accName, true, "SELECT ROLE_NAME, ROLE_OWNER FROM INFORMATION_SCHEMA.APPLICABLE_ROLES"), + Config: viewConfig(accName, true, "SELECT ROLE_NAME, ROLE_OWNER FROM INFORMATION_SCHEMA.APPLICABLE_ROLES", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttrWith("snowflake_view.test", "created_on", func(value string) error { if value != createdOn { @@ -125,7 +125,7 @@ func TestAcc_ViewChangeCopyGrantsReversed(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: viewConfig(accName, true, "SELECT ROLE_NAME, ROLE_OWNER FROM INFORMATION_SCHEMA.APPLICABLE_ROLES"), + Config: viewConfig(accName, true, "SELECT ROLE_NAME, ROLE_OWNER FROM INFORMATION_SCHEMA.APPLICABLE_ROLES", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_view.test", "copy_grants", "true"), resource.TestCheckResourceAttrWith("snowflake_view.test", "created_on", func(value string) error { @@ -136,7 +136,7 @@ func TestAcc_ViewChangeCopyGrantsReversed(t *testing.T) { ), }, { - Config: viewConfig(accName, false, "SELECT ROLE_NAME, ROLE_OWNER FROM INFORMATION_SCHEMA.APPLICABLE_ROLES"), + Config: viewConfig(accName, false, "SELECT ROLE_NAME, ROLE_OWNER FROM INFORMATION_SCHEMA.APPLICABLE_ROLES", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttrWith("snowflake_view.test", "created_on", func(value string) error { if value != createdOn { @@ -151,21 +151,17 @@ func TestAcc_ViewChangeCopyGrantsReversed(t *testing.T) { }) } -func viewConfig(n string, copyGrants bool, q string) string { +func viewConfig(n string, copyGrants bool, q string, databaseName string, schemaName string) string { return fmt.Sprintf(` -resource "snowflake_database" "test" { - name = "%v" -} - resource "snowflake_view" "test" { name = "%v" comment = "Terraform test resource" - database = snowflake_database.test.name - schema = "PUBLIC" + database = "%s" + schema = "%s" is_secure = true or_replace = %t copy_grants = %t statement = "%s" } -`, n, n, copyGrants, copyGrants, q) +`, n, databaseName, schemaName, copyGrants, copyGrants, q) } diff --git a/pkg/resources/view_grant_acceptance_test.go b/pkg/resources/view_grant_acceptance_test.go index d9616b2328..8c27671bbe 100644 --- a/pkg/resources/view_grant_acceptance_test.go +++ b/pkg/resources/view_grant_acceptance_test.go @@ -16,13 +16,13 @@ import ( func TestAcc_ViewGrantBasic(t *testing.T) { name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) - resource.ParallelTest(t, resource.TestCase{ + resource.Test(t, resource.TestCase{ Providers: acc.TestAccProviders(), PreCheck: func() { acc.TestAccPreCheck(t) }, CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: viewGrantConfig(name, normal, "SELECT"), + Config: viewGrantConfig(name, normal, "SELECT", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_view_grant.test", "view_name", name), resource.TestCheckResourceAttr("snowflake_view_grant.test", "privilege", "SELECT"), @@ -30,7 +30,7 @@ func TestAcc_ViewGrantBasic(t *testing.T) { }, // UPDATE ALL PRIVILEGES { - Config: viewGrantConfig(name, normal, "ALL PRIVILEGES"), + Config: viewGrantConfig(name, normal, "ALL PRIVILEGES", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_view_grant.test", "view_name", name), resource.TestCheckResourceAttr("snowflake_view_grant.test", "privilege", "ALL PRIVILEGES"), @@ -49,18 +49,17 @@ func TestAcc_ViewGrantBasic(t *testing.T) { } func TestAcc_ViewGrantShares(t *testing.T) { - databaseName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) viewName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) roleName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) shareName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) - resource.ParallelTest(t, resource.TestCase{ + resource.Test(t, resource.TestCase{ Providers: acc.TestAccProviders(), PreCheck: func() { acc.TestAccPreCheck(t) }, CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: viewGrantConfigShares(t, databaseName, viewName, roleName, shareName), + Config: viewGrantConfigShares(t, viewName, roleName, shareName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_view_grant.test", "view_name", viewName), resource.TestCheckResourceAttr("snowflake_view_grant.test", "privilege", "SELECT"), @@ -78,16 +77,16 @@ func TestAcc_ViewGrantShares(t *testing.T) { }) } -func TestAcc_FutureViewGrantChange(t *testing.T) { +func TestAcc_ViewGrantChange(t *testing.T) { name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) - resource.ParallelTest(t, resource.TestCase{ + resource.Test(t, resource.TestCase{ Providers: acc.TestAccProviders(), PreCheck: func() { acc.TestAccPreCheck(t) }, CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: viewGrantConfig(name, normal, "SELECT"), + Config: viewGrantConfig(name, normal, "SELECT", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_view_grant.test", "view_name", name), resource.TestCheckResourceAttr("snowflake_view_grant.test", "on_future", "false"), @@ -96,7 +95,7 @@ func TestAcc_FutureViewGrantChange(t *testing.T) { }, // CHANGE FROM CURRENT TO FUTURE VIEWS { - Config: viewGrantConfig(name, onFuture, "SELECT"), + Config: viewGrantConfig(name, onFuture, "SELECT", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( resource.TestCheckNoResourceAttr("snowflake_view_grant.test", "view_name"), resource.TestCheckResourceAttr("snowflake_view_grant.test", "on_future", "true"), @@ -116,64 +115,53 @@ func TestAcc_FutureViewGrantChange(t *testing.T) { }) } -func viewGrantConfigShares(t *testing.T, databaseName, viewName, role, shareName string) string { +func viewGrantConfigShares(t *testing.T, viewName, role, shareName string) string { t.Helper() r := require.New(t) tmpl := template.Must(template.New("shares").Parse(` -resource "snowflake_database" "test" { - name = "{{.database_name}}" -} - -resource "snowflake_schema" "test" { - name = "{{ .schema_name }}" - database = snowflake_database.test.name -} - resource "snowflake_view" "test" { - name = "{{.view_name}}" - database = "{{.database_name}}" - schema = "{{ .schema_name }}" - statement = "SELECT ROLE_NAME, ROLE_OWNER FROM INFORMATION_SCHEMA.APPLICABLE_ROLES" - is_secure = true - - depends_on = [snowflake_database.test, snowflake_schema.test] + name = "{{.view_name}}" + database = "{{.database_name}}" + schema = "{{ .schema_name }}" + statement = "SELECT ROLE_NAME, ROLE_OWNER FROM INFORMATION_SCHEMA.APPLICABLE_ROLES" + is_secure = true } resource "snowflake_role" "test" { - name = "{{.role_name}}" + name = "{{.role_name}}" } resource "snowflake_share" "test" { - name = "{{.share_name}}" + name = "{{.share_name}}" } resource "snowflake_database_grant" "test" { - database_name = "{{ .database_name }}" - shares = ["{{ .share_name }}"] + database_name = "{{ .database_name }}" + shares = ["{{ .share_name }}"] - depends_on = [snowflake_database.test, snowflake_share.test] + depends_on = [snowflake_share.test] } resource "snowflake_view_grant" "test" { - view_name = "{{ .view_name }}" - database_name = "{{ .database_name }}" - roles = ["{{ .role_name }}"] + view_name = "{{ .view_name }}" + database_name = "{{ .database_name }}" + roles = ["{{ .role_name }}"] shares = ["{{ .share_name }}"] schema_name = "{{ .schema_name }}" - // HACK(el): There is a problem with the provider where - // in older versions of terraform referencing role.name will - // trick the provider into thinking there are no roles inputted - // so I hard-code the references. - depends_on = [snowflake_database_grant.test, snowflake_role.test, snowflake_share.test, snowflake_view.test, snowflake_schema.test] + // HACK(el): There is a problem with the provider where + // in older versions of terraform referencing role.name will + // trick the provider into thinking there are no roles inputted + // so I hard-code the references. + depends_on = [snowflake_database_grant.test, snowflake_role.test, snowflake_share.test, snowflake_view.test] }`)) out := bytes.NewBuffer(nil) err := tmpl.Execute(out, map[string]string{ "share_name": shareName, - "database_name": databaseName, - "schema_name": databaseName, + "database_name": acc.TestDatabaseName, + "schema_name": acc.TestSchemaName, "role_name": role, "view_name": viewName, }) @@ -182,7 +170,7 @@ resource "snowflake_view_grant" "test" { return out.String() } -func viewGrantConfig(name string, grantType grantType, privilege string) string { +func viewGrantConfig(name string, grantType grantType, privilege string, databaseName string, schemaName string) string { var viewNameConfig string switch grantType { case normal: @@ -194,35 +182,26 @@ func viewGrantConfig(name string, grantType grantType, privilege string) string } return fmt.Sprintf(` -resource "snowflake_database" "test" { - name = "%s" -} - -resource "snowflake_schema" "test" { - name = "%s" - database = snowflake_database.test.name -} - resource "snowflake_view" "test" { - name = "%s" - database = snowflake_database.test.name - schema = snowflake_schema.test.name - statement = "SELECT ROLE_NAME, ROLE_OWNER FROM INFORMATION_SCHEMA.APPLICABLE_ROLES" - is_secure = true + name = "%s" + database = "%s" + schema = "%s" + statement = "SELECT ROLE_NAME, ROLE_OWNER FROM INFORMATION_SCHEMA.APPLICABLE_ROLES" + is_secure = true } resource "snowflake_role" "test" { - name = "%s" + name = "%s" } resource "snowflake_view_grant" "test" { - %s - database_name = snowflake_view.test.database - roles = [snowflake_role.test.name] - schema_name = snowflake_schema.test.name - privilege = "%s" + %s + database_name = "%s" + roles = [snowflake_role.test.name] + schema_name = "%s" + privilege = "%s" } -`, name, name, name, name, viewNameConfig, privilege) +`, name, databaseName, schemaName, name, viewNameConfig, databaseName, schemaName, privilege) } func TestAcc_ViewGrantOnAll(t *testing.T) { @@ -234,10 +213,10 @@ func TestAcc_ViewGrantOnAll(t *testing.T) { CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: viewGrantConfig(name, onAll, "SELECT"), + Config: viewGrantConfig(name, onAll, "SELECT", acc.TestDatabaseName, acc.TestSchemaName), Check: resource.ComposeTestCheckFunc( - resource.TestCheckResourceAttr("snowflake_view_grant.test", "database_name", name), - resource.TestCheckResourceAttr("snowflake_view_grant.test", "schema_name", name), + resource.TestCheckResourceAttr("snowflake_view_grant.test", "database_name", acc.TestDatabaseName), + resource.TestCheckResourceAttr("snowflake_view_grant.test", "schema_name", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_view_grant.test", "on_all", "true"), resource.TestCheckResourceAttr("snowflake_view_grant.test", "privilege", "SELECT"), resource.TestCheckResourceAttr("snowflake_view_grant.test", "with_grant_option", "false"), diff --git a/pkg/sdk/sweepers.go b/pkg/sdk/sweepers.go index da8bfffa44..659d763e30 100644 --- a/pkg/sdk/sweepers.go +++ b/pkg/sdk/sweepers.go @@ -196,7 +196,7 @@ func getWarehouseSweeper(client *Client, prefix string) func() error { return err } for _, wh := range whs { - if (prefix == "" || strings.HasPrefix(wh.Name, prefix)) && wh.Name != "SNOWFLAKE" && wh.Name != "test_terraform_warehouse" { + if (prefix == "" || strings.HasPrefix(wh.Name, prefix)) && wh.Name != "SNOWFLAKE" && wh.Name != "terraform_test_warehouse" { log.Printf("[DEBUG] Dropping warehouse %s", wh.Name) if err := client.Warehouses.Drop(ctx, wh.ID(), nil); err != nil { return err From 7926d8de71a3122dae84957675d1bfd9745b9021 Mon Sep 17 00:00:00 2001 From: Scott Winkler Date: Thu, 26 Oct 2023 02:05:56 -0700 Subject: [PATCH 07/20] change comment to NullString (#2148) --- pkg/sdk/accounts.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/sdk/accounts.go b/pkg/sdk/accounts.go index 08facba0ac..831c1bd164 100644 --- a/pkg/sdk/accounts.go +++ b/pkg/sdk/accounts.go @@ -318,7 +318,7 @@ type accountDBRow struct { Edition string `db:"edition"` AccountURL string `db:"account_url"` CreatedOn time.Time `db:"created_on"` - Comment string `db:"comment"` + Comment sql.NullString `db:"comment"` AccountLocator string `db:"account_locator"` AccountLocatorURL string `db:"account_locator_url"` AccountOldURLSavedOn sql.NullString `db:"account_old_url_saved_on"` @@ -339,7 +339,7 @@ func (row accountDBRow) convert() *Account { Edition: AccountEdition(row.Edition), AccountURL: row.AccountURL, CreatedOn: row.CreatedOn, - Comment: row.Comment, + Comment: row.Comment.String, AccountLocator: row.AccountLocator, AccountLocatorURL: row.AccountLocatorURL, ManagedAccounts: row.ManagedAccounts, From 6de32ae6ec16ad76fb40afddfcaa7f650322cb67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Cie=C5=9Blak?= Date: Thu, 26 Oct 2023 11:13:42 +0200 Subject: [PATCH 08/20] fix: view statement update (#2152) --- pkg/resources/view.go | 47 +++++++++++----- pkg/resources/view_acceptance_test.go | 77 +++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 12 deletions(-) diff --git a/pkg/resources/view.go b/pkg/resources/view.go index 672e87a3be..3007c48612 100644 --- a/pkg/resources/view.go +++ b/pkg/resources/view.go @@ -64,7 +64,6 @@ var viewSchema = map[string]*schema.Schema{ Type: schema.TypeString, Required: true, Description: "Specifies the query used to create the view.", - ForceNew: true, DiffSuppressFunc: DiffSuppressStatement, }, "created_on": { @@ -81,7 +80,7 @@ func normalizeQuery(str string) string { // DiffSuppressStatement will suppress diffs between statements if they differ in only case or in // runs of whitespace (\s+ = \s). This is needed because the snowflake api does not faithfully -// round-trip queries so we cannot do a simple character-wise comparison to detect changes. +// round-trip queries, so we cannot do a simple character-wise comparison to detect changes. // // Warnings: We will have false positives in cases where a change in case or run of whitespace is // semantically significant. @@ -279,11 +278,38 @@ func UpdateView(d *schema.ResourceData, meta interface{}) error { schema := viewID.SchemaName view := viewID.ViewName builder := snowflake.NewViewBuilder(view).WithDB(dbName).WithSchema(schema) - db := meta.(*sql.DB) + + // The only way to update the statement field in a view is to perform create or replace with the new statement. + // In case of any statement change, create or replace will be performed with all the old parameters, except statement + // and copy grants (which is always set to true to keep the permissions from the previous state). + if d.HasChange("statement") { + isSecureOld, _ := d.GetChange("is_secure") + commentOld, _ := d.GetChange("comment") + tagsOld, _ := d.GetChange("tag") + + if isSecureOld.(bool) { + builder.WithSecure() + } + + query, err := builder. + WithReplace(). + WithStatement(d.Get("statement").(string)). + WithCopyGrants(). + WithComment(commentOld.(string)). + WithTags(getTags(tagsOld).toSnowflakeTagValues()). + Create() + if err != nil { + return fmt.Errorf("error when building sql query on %v, err = %w", d.Id(), err) + } + + if err := snowflake.Exec(db, query); err != nil { + return fmt.Errorf("error when changing property on %v and performing create or replace to update view statements, err = %w", d.Id(), err) + } + } + if d.HasChange("name") { name := d.Get("name") - q, err := builder.Rename(name.(string)) if err != nil { return err @@ -291,7 +317,6 @@ func UpdateView(d *schema.ResourceData, meta interface{}) error { if err = snowflake.Exec(db, q); err != nil { return fmt.Errorf("error renaming view %v", d.Id()) } - viewID := &ViewID{ DatabaseName: dbName, SchemaName: schema, @@ -305,9 +330,7 @@ func UpdateView(d *schema.ResourceData, meta interface{}) error { } if d.HasChange("comment") { - comment := d.Get("comment") - - if c := comment.(string); c == "" { + if comment := d.Get("comment").(string); comment == "" { q, err := builder.RemoveComment() if err != nil { return err @@ -316,7 +339,7 @@ func UpdateView(d *schema.ResourceData, meta interface{}) error { return fmt.Errorf("error unsetting comment for view %v", d.Id()) } } else { - q, err := builder.ChangeComment(c) + q, err := builder.ChangeComment(comment) if err != nil { return err } @@ -325,10 +348,9 @@ func UpdateView(d *schema.ResourceData, meta interface{}) error { } } } - if d.HasChange("is_secure") { - secure := d.Get("is_secure") - if secure.(bool) { + if d.HasChange("is_secure") { + if d.Get("is_secure").(bool) { q, err := builder.Secure() if err != nil { return err @@ -346,6 +368,7 @@ func UpdateView(d *schema.ResourceData, meta interface{}) error { } } } + tagChangeErr := handleTagChanges(db, d, builder) if tagChangeErr != nil { return tagChangeErr diff --git a/pkg/resources/view_acceptance_test.go b/pkg/resources/view_acceptance_test.go index 4383d3bade..8b3e123ce1 100644 --- a/pkg/resources/view_acceptance_test.go +++ b/pkg/resources/view_acceptance_test.go @@ -151,6 +151,31 @@ func TestAcc_ViewChangeCopyGrantsReversed(t *testing.T) { }) } +func TestAcc_ViewStatementUpdate(t *testing.T) { + resource.ParallelTest(t, resource.TestCase{ + Providers: acc.TestAccProviders(), + PreCheck: func() { acc.TestAccPreCheck(t) }, + CheckDestroy: nil, + Steps: []resource.TestStep{ + { + Config: viewConfigWithGrants(acc.TestDatabaseName, acc.TestSchemaName, `\"name\"`), + Check: resource.ComposeTestCheckFunc( + // there should be more than one privilege, because we applied grant all privileges and initially there's always one which is ownership + resource.TestCheckResourceAttr("data.snowflake_grants.grants", "grants.#", "2"), + resource.TestCheckResourceAttr("data.snowflake_grants.grants", "grants.1.privilege", "SELECT"), + ), + }, + { + Config: viewConfigWithGrants(acc.TestDatabaseName, acc.TestSchemaName, "*"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr("data.snowflake_grants.grants", "grants.#", "2"), + resource.TestCheckResourceAttr("data.snowflake_grants.grants", "grants.1.privilege", "SELECT"), + ), + }, + }, + }) +} + func viewConfig(n string, copyGrants bool, q string, databaseName string, schemaName string) string { return fmt.Sprintf(` resource "snowflake_view" "test" { @@ -165,3 +190,55 @@ resource "snowflake_view" "test" { } `, n, databaseName, schemaName, copyGrants, copyGrants, q) } + +func viewConfigWithGrants(databaseName string, schemaName string, selectStatement string) string { + return fmt.Sprintf(` +resource "snowflake_table" "table" { + database = "%s" + schema = "%s" + name = "view_test_table" + + column { + name = "name" + type = "text" + } +} + +resource "snowflake_view" "test" { + depends_on = [snowflake_table.table] + name = "test" + comment = "created by terraform" + database = "%s" + schema = "%s" + statement = "select %s from \"%s\".\"%s\".\"${snowflake_table.table.name}\"" + or_replace = true + copy_grants = true + is_secure = true +} + +resource "snowflake_role" "test" { + name = "test" +} + +resource "snowflake_view_grant" "grant" { + database_name = "%s" + schema_name = "%s" + view_name = snowflake_view.test.name + privilege = "SELECT" + roles = [snowflake_role.test.name] +} + +data "snowflake_grants" "grants" { + depends_on = [snowflake_view_grant.grant, snowflake_view.test] + grants_on { + object_name = "\"%s\".\"%s\".\"${snowflake_view.test.name}\"" + object_type = "VIEW" + } +} + `, databaseName, schemaName, + databaseName, schemaName, + selectStatement, + databaseName, schemaName, + databaseName, schemaName, + databaseName, schemaName) +} From 4d4bcdbe841807da2fa08d534eaf846234934f7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Cie=C5=9Blak?= Date: Thu, 26 Oct 2023 14:09:14 +0200 Subject: [PATCH 09/20] chore: Return multiple errors in existing validations (#2122) --- pkg/sdk/accounts.go | 143 +++++----- pkg/sdk/alerts.go | 35 ++- pkg/sdk/alerts_test.go | 8 +- pkg/sdk/comments.go | 7 + pkg/sdk/database_role_test.go | 10 +- pkg/sdk/database_role_validations.go | 12 +- pkg/sdk/databases.go | 106 +++++--- pkg/sdk/databases_test.go | 6 +- pkg/sdk/dynamic_table_test.go | 6 +- pkg/sdk/dynamic_table_validations.go | 16 +- pkg/sdk/errors.go | 30 ++- pkg/sdk/external_tables_test.go | 2 +- pkg/sdk/external_tables_validations.go | 54 +++- pkg/sdk/failover_groups.go | 61 +++-- pkg/sdk/grants_test.go | 62 ++--- pkg/sdk/grants_validations.go | 249 +++++++++++------- pkg/sdk/integration_test_imports.go | 8 + pkg/sdk/masking_policy.go | 61 +++-- pkg/sdk/masking_policy_test.go | 15 +- pkg/sdk/network_policies_gen_test.go | 4 +- pkg/sdk/network_policies_validations_gen.go | 4 +- pkg/sdk/parameters.go | 81 +++--- pkg/sdk/password_policy.go | 43 +-- pkg/sdk/password_policy_test.go | 10 +- pkg/sdk/pipes_test.go | 38 +-- pkg/sdk/pipes_validations.go | 64 ++--- pkg/sdk/poc/generator/field.go | 9 + pkg/sdk/poc/generator/validation.go | 12 +- pkg/sdk/replication_functions.go | 4 + pkg/sdk/resource_monitors.go | 32 ++- pkg/sdk/resource_monitors_test.go | 8 +- pkg/sdk/roles_test.go | 17 +- pkg/sdk/roles_validations.go | 28 +- pkg/sdk/schemas.go | 26 +- pkg/sdk/session_policies_gen_test.go | 8 +- pkg/sdk/session_policies_validations_gen.go | 6 +- pkg/sdk/sessions.go | 13 +- pkg/sdk/shares.go | 46 +++- pkg/sdk/shares_test.go | 2 +- pkg/sdk/streams_gen_test.go | 14 +- pkg/sdk/streams_validations_gen.go | 14 +- pkg/sdk/tags_test.go | 14 +- pkg/sdk/tags_validations.go | 36 +-- pkg/sdk/tasks_gen_test.go | 19 +- pkg/sdk/tasks_validations_gen.go | 10 +- .../testint/dynamic_table_integration_test.go | 3 +- pkg/sdk/testint/tags_integration_test.go | 3 +- pkg/sdk/users.go | 76 +++--- pkg/sdk/users_test.go | 10 +- pkg/sdk/validations.go | 4 + pkg/sdk/warehouses.go | 55 ++-- pkg/sdk/warehouses_test.go | 2 +- 52 files changed, 931 insertions(+), 675 deletions(-) diff --git a/pkg/sdk/accounts.go b/pkg/sdk/accounts.go index 831c1bd164..e3f2388058 100644 --- a/pkg/sdk/accounts.go +++ b/pkg/sdk/accounts.go @@ -3,6 +3,7 @@ package sdk import ( "context" "database/sql" + "errors" "fmt" "time" ) @@ -57,19 +58,23 @@ type CreateAccountOptions struct { } func (opts *CreateAccountOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } + var errs []error if opts.AdminName == "" { - return fmt.Errorf("AdminName is required") + errs = append(errs, errNotSet("CreateAccountOptions", "AdminName")) } if !anyValueSet(opts.AdminPassword, opts.AdminRSAPublicKey) { - return fmt.Errorf("at least one of AdminPassword or AdminRSAPublicKey must be set") + errs = append(errs, errAtLeastOneOf("CreateAccountOptions", "AdminPassword", "AdminRSAPublicKey")) } if opts.Email == "" { - return fmt.Errorf("email is required") + errs = append(errs, errNotSet("CreateAccountOptions", "Email")) } if opts.Edition == "" { - return fmt.Errorf("edition is required") + errs = append(errs, errNotSet("CreateAccountOptions", "Edition")) } - return nil + return errors.Join(errs...) } func (c *accounts) Create(ctx context.Context, id AccountObjectIdentifier, opts *CreateAccountOptions) error { @@ -92,26 +97,34 @@ type AlterAccountOptions struct { } func (opts *AlterAccountOptions) validate() error { - if ok := exactlyOneValueSet( - opts.Set, - opts.Unset, - opts.Drop, - opts.Rename); !ok { - return fmt.Errorf("exactly one of Set, Unset, Drop, Rename must be set") + if opts == nil { + return errors.Join(ErrNilOptions) + } + var errs []error + if !exactlyOneValueSet(opts.Set, opts.Unset, opts.Drop, opts.Rename) { + errs = append(errs, errExactlyOneOf("CreateAccountOptions", "Set", "Unset", "Drop", "Rename")) } if valueSet(opts.Set) { - return opts.Set.validate() + if err := opts.Set.validate(); err != nil { + errs = append(errs, err) + } } if valueSet(opts.Unset) { - return opts.Unset.validate() + if err := opts.Unset.validate(); err != nil { + errs = append(errs, err) + } } if valueSet(opts.Drop) { - return opts.Drop.validate() + if err := opts.Drop.validate(); err != nil { + errs = append(errs, err) + } } if valueSet(opts.Rename) { - return opts.Rename.validate() + if err := opts.Rename.validate(); err != nil { + errs = append(errs, err) + } } - return nil + return errors.Join(errs...) } type AccountLevelParameters struct { @@ -122,27 +135,28 @@ type AccountLevelParameters struct { } func (opts *AccountLevelParameters) validate() error { + var errs []error if valueSet(opts.AccountParameters) { if err := opts.AccountParameters.validate(); err != nil { - return err + errs = append(errs, err) } } if valueSet(opts.SessionParameters) { if err := opts.SessionParameters.validate(); err != nil { - return err + errs = append(errs, err) } } if valueSet(opts.ObjectParameters) { if err := opts.ObjectParameters.validate(); err != nil { - return err + errs = append(errs, err) } } if valueSet(opts.UserParameters) { if err := opts.UserParameters.validate(); err != nil { - return err + errs = append(errs, err) } } - return nil + return errors.Join(errs...) } type AccountSet struct { @@ -154,34 +168,16 @@ type AccountSet struct { } func (opts *AccountSet) validate() error { - if !anyValueSet(opts.Parameters, opts.ResourceMonitor, opts.PasswordPolicy, opts.SessionPolicy, opts.Tag) { - return fmt.Errorf("at least one of parameters, resource monitor, password policy, session policy, or tag must be set") + var errs []error + if !exactlyOneValueSet(opts.Parameters, opts.ResourceMonitor, opts.PasswordPolicy, opts.SessionPolicy, opts.Tag) { + errs = append(errs, errExactlyOneOf("AccountSet", "Parameters", "ResourceMonitor", "PasswordPolicy", "SessionPolicy", "Tag")) } if valueSet(opts.Parameters) { - if !everyValueNil(opts.ResourceMonitor, opts.PasswordPolicy, opts.SessionPolicy, opts.Tag) { - return fmt.Errorf("cannot set both parameters and resource monitor, password policy, session policy, or tag") - } - return opts.Parameters.validate() - } - if valueSet(opts.ResourceMonitor) { - if !everyValueNil(opts.PasswordPolicy, opts.SessionPolicy, opts.Tag) { - return fmt.Errorf("cannot set both resource monitor and password policy, session policy, or tag") - } - return nil - } - if valueSet(opts.PasswordPolicy) { - if !everyValueNil(opts.SessionPolicy, opts.Tag) { - return fmt.Errorf("cannot set both password policy and session policy or tag") + if err := opts.Parameters.validate(); err != nil { + errs = append(errs, err) } - return nil } - if valueSet(opts.SessionPolicy) { - if !everyValueNil(opts.Tag) { - return fmt.Errorf("cannot set both session policy and tag") - } - return nil - } - return nil + return errors.Join(errs...) } type AccountLevelParametersUnset struct { @@ -193,7 +189,7 @@ type AccountLevelParametersUnset struct { func (opts *AccountLevelParametersUnset) validate() error { if !anyValueSet(opts.AccountParameters, opts.SessionParameters, opts.ObjectParameters, opts.UserParameters) { - return fmt.Errorf("at least one of account parameters, session parameters, object parameters, or user parameters must be set") + return errAtLeastOneOf("AccountLevelParametersUnset", "AccountParameters", "SessionParameters", "ObjectParameters", "UserParameters") } return nil } @@ -206,28 +202,16 @@ type AccountUnset struct { } func (opts *AccountUnset) validate() error { - if !anyValueSet(opts.Parameters, opts.PasswordPolicy, opts.SessionPolicy, opts.Tag) { - return fmt.Errorf("at least one of parameters, password policy, session policy, or tag must be set") + var errs []error + if !exactlyOneValueSet(opts.Parameters, opts.PasswordPolicy, opts.SessionPolicy, opts.Tag) { + errs = append(errs, errExactlyOneOf("AccountUnset", "Parameters", "PasswordPolicy", "SessionPolicy", "Tag")) } if valueSet(opts.Parameters) { - if !everyValueNil(opts.PasswordPolicy, opts.SessionPolicy, opts.Tag) { - return fmt.Errorf("cannot unset both parameters and password policy, session policy, or tag") - } - return opts.Parameters.validate() - } - if valueSet(opts.PasswordPolicy) { - if !everyValueNil(opts.SessionPolicy, opts.Tag) { - return fmt.Errorf("cannot unset both password policy and session policy or tag") + if err := opts.Parameters.validate(); err != nil { + errs = append(errs, err) } - return nil } - if valueSet(opts.SessionPolicy) { - if !everyValueNil(opts.Tag) { - return fmt.Errorf("cannot unset both session policy and tag") - } - return nil - } - return nil + return errors.Join(errs...) } type AccountRename struct { @@ -237,13 +221,14 @@ type AccountRename struct { } func (opts *AccountRename) validate() error { + var errs []error if !ValidObjectIdentifier(opts.Name) { - return fmt.Errorf("Name must be set") + errs = append(errs, ErrInvalidObjectIdentifier) } if !ValidObjectIdentifier(opts.NewName) { - return fmt.Errorf("NewName must be set") + errs = append(errs, errInvalidIdentifier("AccountRename", "NewName")) } - return nil + return errors.Join(errs...) } type AccountDrop struct { @@ -252,17 +237,19 @@ type AccountDrop struct { } func (opts *AccountDrop) validate() error { + var errs []error if !ValidObjectIdentifier(opts.Name) { - return fmt.Errorf("Name must be set") + errs = append(errs, ErrInvalidObjectIdentifier) } if valueSet(opts.OldURL) { + // TODO: Should this really be validated to be true ? if !*opts.OldURL { - return fmt.Errorf("OldURL must be true") + errs = append(errs, fmt.Errorf("OldURL must be true")) } } else { - return fmt.Errorf("OldURL must be set") + errs = append(errs, errNotSet("AccountDrop", "OldURL")) } - return nil + return errors.Join(errs...) } func (c *accounts) Alter(ctx context.Context, opts *AlterAccountOptions) error { @@ -280,6 +267,9 @@ type ShowAccountOptions struct { } func (opts *ShowAccountOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } return nil } @@ -380,7 +370,6 @@ func (c *accounts) ShowByID(ctx context.Context, id AccountObjectIdentifier) (*A if err != nil { return nil, err } - for _, account := range accounts { if account.AccountName == id.Name() || account.AccountLocator == id.Name() { return &account, nil @@ -399,13 +388,17 @@ type DropAccountOptions struct { } func (opts *DropAccountOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } + var errs []error if !ValidObjectIdentifier(opts.name) { - return fmt.Errorf("Name must be set") + errs = append(errs, ErrInvalidObjectIdentifier) } if !validateIntGreaterThanOrEqual(opts.gracePeriodInDays, 3) { - return fmt.Errorf("gracePeriodInDays must be greater than or equal to 3") + errs = append(errs, errIntValue("DropAccountOptions", "gracePeriodInDays", IntErrGreaterOrEqual, 3)) } - return nil + return errors.Join(errs...) } func (c *accounts) Drop(ctx context.Context, id AccountObjectIdentifier, gracePeriodInDays int, opts *DropAccountOptions) error { diff --git a/pkg/sdk/alerts.go b/pkg/sdk/alerts.go index 31cc85d42d..a057df1bf3 100644 --- a/pkg/sdk/alerts.go +++ b/pkg/sdk/alerts.go @@ -55,10 +55,12 @@ type AlertCondition struct { } func (opts *CreateAlertOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } if !ValidObjectIdentifier(opts.name) { - return errors.New("invalid object identifier") + return errors.Join(ErrInvalidObjectIdentifier) } - return nil } @@ -115,15 +117,17 @@ type AlterAlertOptions struct { } func (opts *AlterAlertOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } + var errs []error if !ValidObjectIdentifier(opts.name) { - return errors.New("invalid object identifier") + errs = append(errs, ErrInvalidObjectIdentifier) } - if !exactlyOneValueSet(opts.Action, opts.Set, opts.Unset, opts.ModifyCondition, opts.ModifyAction) { - return errExactlyOneOf("Action", "Set", "Unset", "ModifyCondition", "ModifyAction") + errs = append(errs, errExactlyOneOf("AlterAlertOptions", "Action", "Set", "Unset", "ModifyCondition", "ModifyAction")) } - - return nil + return errors.Join(errs...) } type AlertSet struct { @@ -163,8 +167,11 @@ type dropAlertOptions struct { } func (opts *dropAlertOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + return errors.Join(ErrInvalidObjectIdentifier) } return nil } @@ -254,6 +261,9 @@ func (row alertDBRow) convert() *Alert { } func (opts *ShowAlertOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } return nil } @@ -295,9 +305,12 @@ type describeAlertOptions struct { name SchemaObjectIdentifier `ddl:"identifier"` } -func (v *describeAlertOptions) validate() error { - if !ValidObjectIdentifier(v.name) { - return ErrInvalidObjectIdentifier +func (opts *describeAlertOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } + if !ValidObjectIdentifier(opts.name) { + return errors.Join(ErrInvalidObjectIdentifier) } return nil } diff --git a/pkg/sdk/alerts_test.go b/pkg/sdk/alerts_test.go index b1305bc100..8466e5bc2d 100644 --- a/pkg/sdk/alerts_test.go +++ b/pkg/sdk/alerts_test.go @@ -37,7 +37,7 @@ func TestAlertAlter(t *testing.T) { opts := &AlterAlertOptions{ name: id, } - assertOptsInvalid(t, opts, errExactlyOneOf("Action", "Set", "Unset", "ModifyCondition", "ModifyAction")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterAlertOptions", "Action", "Set", "Unset", "ModifyCondition", "ModifyAction")) }) t.Run("fail when 2 alter actions specified", func(t *testing.T) { @@ -49,7 +49,7 @@ func TestAlertAlter(t *testing.T) { Comment: String(newComment), }, } - assertOptsInvalid(t, opts, errExactlyOneOf("Action", "Set", "Unset", "ModifyCondition", "ModifyAction")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterAlertOptions", "Action", "Set", "Unset", "ModifyCondition", "ModifyAction")) }) t.Run("with resume", func(t *testing.T) { @@ -119,7 +119,7 @@ func TestAlertDrop(t *testing.T) { t.Run("empty options", func(t *testing.T) { opts := &dropAlertOptions{} - assertOptsInvalid(t, opts, ErrInvalidObjectIdentifier) + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) t.Run("only name", func(t *testing.T) { @@ -210,7 +210,7 @@ func TestAlertDescribe(t *testing.T) { t.Run("empty options", func(t *testing.T) { opts := &describeAlertOptions{} - assertOptsInvalid(t, opts, ErrInvalidObjectIdentifier) + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) t.Run("only name", func(t *testing.T) { diff --git a/pkg/sdk/comments.go b/pkg/sdk/comments.go index df0bfcf06b..a52fd6e4ca 100644 --- a/pkg/sdk/comments.go +++ b/pkg/sdk/comments.go @@ -2,6 +2,7 @@ package sdk import ( "context" + "errors" ) var ( @@ -31,6 +32,9 @@ type SetCommentOptions struct { } func (opts *SetCommentOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } return nil } @@ -60,6 +64,9 @@ type SetColumnCommentOptions struct { } func (opts *SetColumnCommentOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } return nil } diff --git a/pkg/sdk/database_role_test.go b/pkg/sdk/database_role_test.go index f40c2514be..645db94160 100644 --- a/pkg/sdk/database_role_test.go +++ b/pkg/sdk/database_role_test.go @@ -30,7 +30,7 @@ func TestDatabaseRoleCreate(t *testing.T) { opts := defaultOpts() opts.IfNotExists = Bool(true) opts.OrReplace = Bool(true) - assertOptsInvalidJoinedErrors(t, opts, errOneOf("OrReplace", "IfNotExists")) + assertOptsInvalidJoinedErrors(t, opts, errOneOf("createDatabaseRoleOptions", "OrReplace", "IfNotExists")) }) t.Run("validation: multiple errors", func(t *testing.T) { @@ -38,7 +38,7 @@ func TestDatabaseRoleCreate(t *testing.T) { opts.name = NewDatabaseObjectIdentifier("", "") opts.IfNotExists = Bool(true) opts.OrReplace = Bool(true) - assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier, errOneOf("OrReplace", "IfNotExists")) + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier, errOneOf("createDatabaseRoleOptions", "OrReplace", "IfNotExists")) }) t.Run("basic", func(t *testing.T) { @@ -77,7 +77,7 @@ func TestDatabaseRoleAlter(t *testing.T) { t.Run("validation: no alter action", func(t *testing.T) { opts := defaultOpts() - assertOptsInvalidJoinedErrors(t, opts, errAlterNeedsExactlyOneAction) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("alterDatabaseRoleOptions", "Rename", "Set", "Unset")) }) t.Run("validation: multiple alter actions", func(t *testing.T) { @@ -88,7 +88,7 @@ func TestDatabaseRoleAlter(t *testing.T) { opts.Unset = &DatabaseRoleUnset{ Comment: true, } - assertOptsInvalidJoinedErrors(t, opts, errAlterNeedsExactlyOneAction) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("alterDatabaseRoleOptions", "Rename", "Set", "Unset")) }) t.Run("validation: invalid new name", func(t *testing.T) { @@ -114,7 +114,7 @@ func TestDatabaseRoleAlter(t *testing.T) { opts.Unset = &DatabaseRoleUnset{ Comment: false, } - assertOptsInvalidJoinedErrors(t, opts, errAlterNeedsAtLeastOneProperty) + assertOptsInvalidJoinedErrors(t, opts, errAtLeastOneOf("alterDatabaseRoleOptions.Unset", "Comment")) }) t.Run("rename", func(t *testing.T) { diff --git a/pkg/sdk/database_role_validations.go b/pkg/sdk/database_role_validations.go index c62162e130..e09bc81c98 100644 --- a/pkg/sdk/database_role_validations.go +++ b/pkg/sdk/database_role_validations.go @@ -22,7 +22,7 @@ func (opts *createDatabaseRoleOptions) validate() error { errs = append(errs, ErrInvalidObjectIdentifier) } if everyValueSet(opts.OrReplace, opts.IfNotExists) && *opts.OrReplace && *opts.IfNotExists { - errs = append(errs, errOneOf("OrReplace", "IfNotExists")) + errs = append(errs, errOneOf("createDatabaseRoleOptions", "OrReplace", "IfNotExists")) } return errors.Join(errs...) } @@ -35,12 +35,8 @@ func (opts *alterDatabaseRoleOptions) validate() error { if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } - if ok := exactlyOneValueSet( - opts.Rename, - opts.Set, - opts.Unset, - ); !ok { - errs = append(errs, errAlterNeedsExactlyOneAction) + if !exactlyOneValueSet(opts.Rename, opts.Set, opts.Unset) { + errs = append(errs, errExactlyOneOf("alterDatabaseRoleOptions", "Rename", "Set", "Unset")) } if rename := opts.Rename; valueSet(rename) { if !ValidObjectIdentifier(rename.Name) { @@ -52,7 +48,7 @@ func (opts *alterDatabaseRoleOptions) validate() error { } if unset := opts.Unset; valueSet(unset) { if !unset.Comment { - errs = append(errs, errAlterNeedsAtLeastOneProperty) + errs = append(errs, errAtLeastOneOf("alterDatabaseRoleOptions.Unset", "Comment")) } } return errors.Join(errs...) diff --git a/pkg/sdk/databases.go b/pkg/sdk/databases.go index ee27b23cb8..3c3e0745c0 100644 --- a/pkg/sdk/databases.go +++ b/pkg/sdk/databases.go @@ -148,15 +148,22 @@ type CreateDatabaseOptions struct { } func (opts *CreateDatabaseOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } + var errs []error + if !ValidObjectIdentifier(opts.name) { + errs = append(errs, ErrInvalidObjectIdentifier) + } if valueSet(opts.Clone) { if err := opts.Clone.validate(); err != nil { - return err + errs = append(errs, err) } } if everyValueSet(opts.OrReplace, opts.IfNotExists) { - return errors.New("IF NOT EXISTS and OR REPLACE are incompatible.") + errs = append(errs, errOneOf("CreateDatabaseOptions", "OrReplace", "IfNotExists")) } - return nil + return errors.Join(errs...) } func (v *databases) Create(ctx context.Context, id AccountObjectIdentifier, opts *CreateDatabaseOptions) error { @@ -185,13 +192,17 @@ type CreateSharedDatabaseOptions struct { } func (opts *CreateSharedDatabaseOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } + var errs []error if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + errs = append(errs, ErrInvalidObjectIdentifier) } if !ValidObjectIdentifier(opts.fromShare) { - return ErrInvalidObjectIdentifier + errs = append(errs, errInvalidIdentifier("CreateSharedDatabaseOptions", "fromShare")) } - return nil + return errors.Join(errs...) } func (v *databases) CreateShared(ctx context.Context, id AccountObjectIdentifier, shareID ExternalObjectIdentifier, opts *CreateSharedDatabaseOptions) error { @@ -223,13 +234,17 @@ type CreateSecondaryDatabaseOptions struct { } func (opts *CreateSecondaryDatabaseOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } + var errs []error if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + errs = append(errs, ErrInvalidObjectIdentifier) } if !ValidObjectIdentifier(opts.primaryDatabase) { - return ErrInvalidObjectIdentifier + errs = append(errs, errInvalidIdentifier("CreateSecondaryDatabaseOptions", "primaryDatabase")) } - return nil + return errors.Join(errs...) } func (v *databases) CreateSecondary(ctx context.Context, id AccountObjectIdentifier, primaryID ExternalObjectIdentifier, opts *CreateSecondaryDatabaseOptions) error { @@ -262,31 +277,27 @@ type AlterDatabaseOptions struct { } func (opts *AlterDatabaseOptions) validate() error { - if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier - } - if ValidObjectIdentifier(opts.NewName) && anyValueSet(opts.Set, opts.Unset, opts.SwapWith) { - return errors.New("RENAME TO cannot be set with other options") + if opts == nil { + return errors.Join(ErrNilOptions) } - - if ValidObjectIdentifier(opts.SwapWith) && anyValueSet(opts.Set, opts.Unset, opts.NewName) { - return errors.New("SWAP WITH cannot be set with other options") + var errs []error + if !ValidObjectIdentifier(opts.name) { + errs = append(errs, ErrInvalidObjectIdentifier) } - - if valueSet(opts.Set) && valueSet(opts.Unset) { - return errors.New("only one of SET or UNSET can be set") + if !exactlyOneValueSet(opts.NewName, opts.Set, opts.Unset, opts.SwapWith) { + errs = append(errs, errExactlyOneOf("AlterDatabaseOptions", "NewName", "Set", "Unset", "SwapWith")) } if valueSet(opts.Set) { if err := opts.Set.validate(); err != nil { - return err + errs = append(errs, err) } } if valueSet(opts.Unset) { if err := opts.Unset.validate(); err != nil { - return err + errs = append(errs, err) } } - return nil + return errors.Join(errs...) } type DatabaseSet struct { @@ -311,7 +322,7 @@ type DatabaseUnset struct { func (v *DatabaseUnset) validate() error { if valueSet(v.Tag) { if anyValueSet(v.DataRetentionTimeInDays, v.MaxDataExtensionTimeInDays, v.DefaultDDLCollation, v.Comment) { - return errors.New("Tag cannot be set with other options") + return errors.New("tag cannot be set with other options") } } return nil @@ -344,23 +355,27 @@ type AlterDatabaseReplicationOptions struct { } func (opts *AlterDatabaseReplicationOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } + var errs []error if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + errs = append(errs, ErrInvalidObjectIdentifier) } if !exactlyOneValueSet(opts.EnableReplication, opts.DisableReplication, opts.Refresh) { - return errExactlyOneOf("EnableReplication", "DisableReplication", "Refresh") + errs = append(errs, errExactlyOneOf("AlterDatabaseReplicationOptions", "EnableReplication", "DisableReplication", "Refresh")) } if valueSet(opts.EnableReplication) { if err := opts.EnableReplication.validate(); err != nil { - return err + errs = append(errs, err) } } if valueSet(opts.DisableReplication) { if err := opts.DisableReplication.validate(); err != nil { - return err + errs = append(errs, err) } } - return nil + return errors.Join(errs...) } type EnableReplication struct { @@ -407,26 +422,27 @@ type AlterDatabaseFailoverOptions struct { } func (opts *AlterDatabaseFailoverOptions) validate() error { - if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + if opts == nil { + return errors.Join(ErrNilOptions) } - if everyValueNil(opts.EnableFailover, opts.DisableFailover, opts.Primary) { - return errors.New("one of ENABLE FAILOVER, DISABLE FAILOVER or PRIMARY must be set") + var errs []error + if !ValidObjectIdentifier(opts.name) { + errs = append(errs, ErrInvalidObjectIdentifier) } if !exactlyOneValueSet(opts.EnableFailover, opts.DisableFailover, opts.Primary) { - return errors.New("only one of ENABLE FAILOVER, DISABLE FAILOVER or PRIMARY can be set") + errs = append(errs, errExactlyOneOf("AlterDatabaseFailoverOptions", "EnableFailover", "DisableFailover", "Primary")) } if valueSet(opts.EnableFailover) { if err := opts.EnableFailover.validate(); err != nil { - return err + errs = append(errs, err) } } if valueSet(opts.DisableFailover) { if err := opts.DisableFailover.validate(); err != nil { - return err + errs = append(errs, err) } } - return nil + return errors.Join(errs...) } type EnableFailover struct { @@ -470,6 +486,9 @@ type DropDatabaseOptions struct { } func (opts *DropDatabaseOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } if !ValidObjectIdentifier(opts.name) { return ErrInvalidObjectIdentifier } @@ -500,8 +519,11 @@ type undropDatabaseOptions struct { } func (opts *undropDatabaseOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + return errors.Join(ErrInvalidObjectIdentifier) } return nil } @@ -533,6 +555,9 @@ type ShowDatabasesOptions struct { } func (opts *ShowDatabasesOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } return nil } @@ -581,8 +606,11 @@ type describeDatabaseOptions struct { } func (opts *describeDatabaseOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + return errors.Join(ErrInvalidObjectIdentifier) } return nil } diff --git a/pkg/sdk/databases_test.go b/pkg/sdk/databases_test.go index 873530efb9..d0906dc0c0 100644 --- a/pkg/sdk/databases_test.go +++ b/pkg/sdk/databases_test.go @@ -8,6 +8,7 @@ import ( func TestDatabasesCreate(t *testing.T) { t.Run("clone", func(t *testing.T) { opts := &CreateDatabaseOptions{ + name: NewAccountObjectIdentifier("db"), Clone: &Clone{ SourceObject: NewAccountObjectIdentifier("db1"), At: &TimeTravel{ @@ -15,11 +16,12 @@ func TestDatabasesCreate(t *testing.T) { }, }, } - assertOptsValidAndSQLEquals(t, opts, `CREATE DATABASE CLONE "db1" AT (TIMESTAMP => '2021-01-01 00:00:00 +0000 UTC')`) + assertOptsValidAndSQLEquals(t, opts, `CREATE DATABASE "db" CLONE "db1" AT (TIMESTAMP => '2021-01-01 00:00:00 +0000 UTC')`) }) t.Run("complete", func(t *testing.T) { opts := &CreateDatabaseOptions{ + name: NewAccountObjectIdentifier("db"), OrReplace: Bool(true), Transient: Bool(true), Comment: String("comment"), @@ -32,7 +34,7 @@ func TestDatabasesCreate(t *testing.T) { }, }, } - assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE TRANSIENT DATABASE DATA_RETENTION_TIME_IN_DAYS = 1 MAX_DATA_EXTENSION_TIME_IN_DAYS = 1 COMMENT = 'comment' TAG ("db1"."schema1"."tag1" = 'v1')`) + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE TRANSIENT DATABASE "db" DATA_RETENTION_TIME_IN_DAYS = 1 MAX_DATA_EXTENSION_TIME_IN_DAYS = 1 COMMENT = 'comment' TAG ("db1"."schema1"."tag1" = 'v1')`) }) } diff --git a/pkg/sdk/dynamic_table_test.go b/pkg/sdk/dynamic_table_test.go index f5c9c637eb..e770b3031b 100644 --- a/pkg/sdk/dynamic_table_test.go +++ b/pkg/sdk/dynamic_table_test.go @@ -64,19 +64,19 @@ func TestDynamicTableAlter(t *testing.T) { t.Run("validation: no alter action", func(t *testing.T) { opts := defaultOpts() - assertOptsInvalidJoinedErrors(t, opts, errAlterNeedsExactlyOneAction) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("alterDynamicTableOptions", "Suspend", "Resume", "Refresh", "Set")) }) t.Run("validation: multiple alter actions", func(t *testing.T) { opts := defaultOpts() opts.Resume = Bool(true) opts.Suspend = Bool(true) - assertOptsInvalidJoinedErrors(t, opts, errAlterNeedsExactlyOneAction) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("alterDynamicTableOptions", "Suspend", "Resume", "Refresh", "Set")) }) t.Run("validation: no property to unset", func(t *testing.T) { opts := defaultOpts() - assertOptsInvalidJoinedErrors(t, opts, errAlterNeedsAtLeastOneProperty) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("alterDynamicTableOptions", "Suspend", "Resume", "Refresh", "Set")) }) t.Run("suspend", func(t *testing.T) { diff --git a/pkg/sdk/dynamic_table_validations.go b/pkg/sdk/dynamic_table_validations.go index 4448223985..d5323f1d8d 100644 --- a/pkg/sdk/dynamic_table_validations.go +++ b/pkg/sdk/dynamic_table_validations.go @@ -19,7 +19,7 @@ func (tl *TargetLag) validate() error { } var errs []error if everyValueSet(tl.MaximumDuration, tl.Downstream) { - errs = append(errs, errOneOf("MaximumDuration", "Downstream")) + errs = append(errs, errOneOf("TargetLag", "MaximumDuration", "Downstream")) } return errors.Join(errs...) } @@ -60,16 +60,8 @@ func (opts *alterDynamicTableOptions) validate() error { if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } - if ok := exactlyOneValueSet( - opts.Suspend, - opts.Resume, - opts.Refresh, - opts.Set, - ); !ok { - errs = append(errs, errAlterNeedsExactlyOneAction) - } - if !anyValueSet(opts.Suspend, opts.Resume, opts.Refresh, opts.Set) { - errs = append(errs, errAlterNeedsAtLeastOneProperty) + if ok := exactlyOneValueSet(opts.Suspend, opts.Resume, opts.Refresh, opts.Set); !ok { + errs = append(errs, errExactlyOneOf("alterDynamicTableOptions", "Suspend", "Resume", "Refresh", "Set")) } if valueSet(opts.Set) && valueSet(opts.Set.TargetLag) { errs = append(errs, opts.Set.TargetLag.validate()) @@ -86,7 +78,7 @@ func (opts *showDynamicTableOptions) validate() error { errs = append(errs, ErrPatternRequiredForLikeKeyword) } if valueSet(opts.In) && !exactlyOneValueSet(opts.In.Account, opts.In.Database, opts.In.Schema) { - errs = append(errs, errScopeRequiredForInKeyword) + errs = append(errs, errExactlyOneOf("showDynamicTableOptions.In", "Account", "Database", "Schema")) } return errors.Join(errs...) } diff --git a/pkg/sdk/errors.go b/pkg/sdk/errors.go index e90d643889..16f8d50aa5 100644 --- a/pkg/sdk/errors.go +++ b/pkg/sdk/errors.go @@ -20,6 +20,28 @@ var ( ErrDifferentDatabase = errors.New("database must be the same") ) +type IntErrType string + +const ( + IntErrEqual IntErrType = "equal to" + IntErrGreaterOrEqual IntErrType = "greater than or equal to" + IntErrGreater IntErrType = "greater than" + IntErrLessOrEqual IntErrType = "less than or equal to" + IntErrLess IntErrType = "less than" +) + +func errIntValue(structName string, fieldName string, intErrType IntErrType, limit int) error { + return fmt.Errorf("%s field: %s must be %s %d", structName, fieldName, string(intErrType), limit) +} + +func errIntBetween(structName string, fieldName string, from int, to int) error { + return fmt.Errorf("%s field: %s must be between %d and %d", structName, fieldName, from, to) +} + +func errInvalidIdentifier(structName string, identifierField string) error { + return fmt.Errorf("invalid object identifier of %s field: %s", structName, identifierField) +} + func errOneOf(structName string, fieldNames ...string) error { return fmt.Errorf("%v fields: %v are incompatible and cannot be set at the same time", structName, fieldNames) } @@ -28,12 +50,12 @@ func errNotSet(structName string, fieldNames ...string) error { return fmt.Errorf("%v fields: %v should be set", structName, fieldNames) } -func errExactlyOneOf(fieldNames ...string) error { - return fmt.Errorf("exactly one of %v must be set", fieldNames) +func errExactlyOneOf(structName string, fieldNames ...string) error { + return fmt.Errorf("exactly one of %s fileds %v must be set", structName, fieldNames) } -func errAtLeastOneOf(fieldNames ...string) error { - return fmt.Errorf("at least one of %v must be set", fieldNames) +func errAtLeastOneOf(structName string, fieldNames ...string) error { + return fmt.Errorf("at least one of %s fields %v must be set", structName, fieldNames) } func decodeDriverError(err error) error { diff --git a/pkg/sdk/external_tables_test.go b/pkg/sdk/external_tables_test.go index 71ea7b83e7..e6ed3fb29a 100644 --- a/pkg/sdk/external_tables_test.go +++ b/pkg/sdk/external_tables_test.go @@ -356,7 +356,7 @@ func TestExternalTablesAlter(t *testing.T) { assertOptsInvalidJoinedErrors( t, opts, ErrInvalidObjectIdentifier, - errOneOf("AlterExternalTableOptions", "Refresh", "AddFiles", "RemoveFiles", "AutoRefresh", "SetTag", "UnsetTag"), + errExactlyOneOf("AlterExternalTableOptions", "Refresh", "AddFiles", "RemoveFiles", "AutoRefresh", "SetTag", "UnsetTag"), ) }) } diff --git a/pkg/sdk/external_tables_validations.go b/pkg/sdk/external_tables_validations.go index 7d56812a59..a34264ce0c 100644 --- a/pkg/sdk/external_tables_validations.go +++ b/pkg/sdk/external_tables_validations.go @@ -19,6 +19,9 @@ var ( ) func (opts *CreateExternalTableOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } var errs []error if everyValueSet(opts.OrReplace, opts.IfNotExists) { errs = append(errs, errOneOf("CreateExternalTableOptions", "OrReplace", "IfNotExists")) @@ -45,6 +48,9 @@ func (opts *CreateExternalTableOptions) validate() error { } func (opts *CreateWithManualPartitioningExternalTableOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } var errs []error if everyValueSet(opts.OrReplace, opts.IfNotExists) { errs = append(errs, errOneOf("CreateWithManualPartitioningExternalTableOptions", "OrReplace", "IfNotExists")) @@ -71,6 +77,9 @@ func (opts *CreateWithManualPartitioningExternalTableOptions) validate() error { } func (opts *CreateDeltaLakeExternalTableOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } var errs []error if everyValueSet(opts.OrReplace, opts.IfNotExists) { errs = append(errs, errOneOf("CreateDeltaLakeExternalTableOptions", "OrReplace", "IfNotExists")) @@ -97,6 +106,9 @@ func (opts *CreateDeltaLakeExternalTableOptions) validate() error { } func (opts *CreateExternalTableUsingTemplateOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } var errs []error if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) @@ -123,18 +135,23 @@ func (opts *CreateExternalTableUsingTemplateOptions) validate() error { } func (opts *AlterExternalTableOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } var errs []error if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } - if anyValueSet(opts.Refresh, opts.AddFiles, opts.RemoveFiles, opts.AutoRefresh, opts.SetTag, opts.UnsetTag) && - !exactlyOneValueSet(opts.Refresh, opts.AddFiles, opts.RemoveFiles, opts.AutoRefresh, opts.SetTag, opts.UnsetTag) { - errs = append(errs, errOneOf("AlterExternalTableOptions", "Refresh", "AddFiles", "RemoveFiles", "AutoRefresh", "SetTag", "UnsetTag")) + if !exactlyOneValueSet(opts.Refresh, opts.AddFiles, opts.RemoveFiles, opts.AutoRefresh, opts.SetTag, opts.UnsetTag) { + errs = append(errs, errExactlyOneOf("AlterExternalTableOptions", "Refresh", "AddFiles", "RemoveFiles", "AutoRefresh", "SetTag", "UnsetTag")) } return errors.Join(errs...) } func (opts *AlterExternalTablePartitionOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } var errs []error if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) @@ -146,6 +163,9 @@ func (opts *AlterExternalTablePartitionOptions) validate() error { } func (opts *DropExternalTableOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } var errs []error if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) @@ -159,25 +179,34 @@ func (opts *DropExternalTableOptions) validate() error { } func (opts *ShowExternalTableOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } return nil } -func (v *describeExternalTableColumnsOptions) validate() error { - if !ValidObjectIdentifier(v.name) { - return ErrInvalidObjectIdentifier +func (opts *describeExternalTableColumnsOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } + if !ValidObjectIdentifier(opts.name) { + return errors.Join(ErrInvalidObjectIdentifier) } return nil } -func (v *describeExternalTableStageOptions) validate() error { - if !ValidObjectIdentifier(v.name) { - return ErrInvalidObjectIdentifier +func (opts *describeExternalTableStageOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } + if !ValidObjectIdentifier(opts.name) { + return errors.Join(ErrInvalidObjectIdentifier) } return nil } func (cpp *CloudProviderParams) validate() error { - if anyValueSet(cpp.GoogleCloudStorageIntegration, cpp.MicrosoftAzureIntegration) && exactlyOneValueSet(cpp.GoogleCloudStorageIntegration, cpp.MicrosoftAzureIntegration) { + if anyValueSet(cpp.GoogleCloudStorageIntegration, cpp.MicrosoftAzureIntegration) && !exactlyOneValueSet(cpp.GoogleCloudStorageIntegration, cpp.MicrosoftAzureIntegration) { return errOneOf("CloudProviderParams", "GoogleCloudStorageIntegration", "MicrosoftAzureIntegration") } return nil @@ -201,8 +230,11 @@ func (opts *ExternalTableFileFormat) validate() error { } func (opts *ExternalTableDropOption) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } if anyValueSet(opts.Restrict, opts.Cascade) && !exactlyOneValueSet(opts.Restrict, opts.Cascade) { - return errOneOf("ExternalTableDropOption", "Restrict", "Cascade") + return errors.Join(errOneOf("ExternalTableDropOption", "Restrict", "Cascade")) } return nil } diff --git a/pkg/sdk/failover_groups.go b/pkg/sdk/failover_groups.go index f97163693a..7baac32971 100644 --- a/pkg/sdk/failover_groups.go +++ b/pkg/sdk/failover_groups.go @@ -68,8 +68,11 @@ type CreateFailoverGroupOptions struct { } func (opts *CreateFailoverGroupOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + return errors.Join(ErrInvalidObjectIdentifier) } return nil } @@ -102,13 +105,17 @@ type CreateSecondaryReplicationGroupOptions struct { } func (opts *CreateSecondaryReplicationGroupOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } + var errs []error if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + errs = append(errs, ErrInvalidObjectIdentifier) } if !ValidObjectIdentifier(opts.primaryFailoverGroup) { - return ErrInvalidObjectIdentifier + errs = append(errs, errInvalidIdentifier("CreateSecondaryReplicationGroupOptions", "primaryFailoverGroup")) } - return nil + return errors.Join(errs...) } func (v *failoverGroups) CreateSecondaryReplicationGroup(ctx context.Context, id AccountObjectIdentifier, primaryFailoverGroupID ExternalObjectIdentifier, opts *CreateSecondaryReplicationGroupOptions) error { @@ -142,33 +149,37 @@ type AlterSourceFailoverGroupOptions struct { } func (opts *AlterSourceFailoverGroupOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } + var errs []error if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + errs = append(errs, ErrInvalidObjectIdentifier) } if !exactlyOneValueSet(opts.Set, opts.Add, opts.Move, opts.Remove, opts.NewName) { - return errors.New("exactly one of SET, ADD, MOVE, REMOVE, or NewName must be specified") + errs = append(errs, errExactlyOneOf("AlterSourceFailoverGroupOptions", "Set", "Add", "Move", "Remove", "NewName")) } if valueSet(opts.Set) { if err := opts.Set.validate(); err != nil { - return err + errs = append(errs, err) } } if valueSet(opts.Add) { if err := opts.Add.validate(); err != nil { - return err + errs = append(errs, err) } } if valueSet(opts.Move) { if err := opts.Move.validate(); err != nil { - return err + errs = append(errs, err) } } if valueSet(opts.Remove) { if err := opts.Remove.validate(); err != nil { - return err + errs = append(errs, err) } } - return nil + return errors.Join(errs...) } type FailoverGroupSet struct { @@ -248,13 +259,17 @@ type AlterTargetFailoverGroupOptions struct { } func (opts *AlterTargetFailoverGroupOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } + var errs []error if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + errs = append(errs, ErrInvalidObjectIdentifier) } if !exactlyOneValueSet(opts.Refresh, opts.Primary, opts.Suspend, opts.Resume) { - return errors.New("must set one of [Refresh, Primary, Suspend, Resume]") + errs = append(errs, errExactlyOneOf("AlterTargetFailoverGroupOptions", "Refresh", "Primary", "Suspend", "Resume")) } - return nil + return errors.Join(errs...) } func (v *failoverGroups) AlterTarget(ctx context.Context, id AccountObjectIdentifier, opts *AlterTargetFailoverGroupOptions) error { @@ -282,8 +297,11 @@ type DropFailoverGroupOptions struct { } func (opts *DropFailoverGroupOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + return errors.Join(ErrInvalidObjectIdentifier) } return nil } @@ -315,6 +333,9 @@ type ShowFailoverGroupOptions struct { } func (opts *ShowFailoverGroupOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } return nil } @@ -486,8 +507,11 @@ type showFailoverGroupDatabasesOptions struct { } func (opts *showFailoverGroupDatabasesOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } if !ValidObjectIdentifier(opts.in) { - return ErrInvalidObjectIdentifier + return errors.Join(ErrInvalidObjectIdentifier) } return nil } @@ -525,8 +549,11 @@ type showFailoverGroupSharesOptions struct { } func (opts *showFailoverGroupSharesOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } if !ValidObjectIdentifier(opts.in) { - return ErrInvalidObjectIdentifier + return errors.Join(ErrInvalidObjectIdentifier) } return nil } diff --git a/pkg/sdk/grants_test.go b/pkg/sdk/grants_test.go index f478978562..3c8b0825a9 100644 --- a/pkg/sdk/grants_test.go +++ b/pkg/sdk/grants_test.go @@ -326,16 +326,16 @@ func TestGrants_GrantPrivilegesToDatabaseRole(t *testing.T) { } } - t.Run("validation: no privileges set", func(t *testing.T) { + t.Run("validation: nil privileges set", func(t *testing.T) { opts := defaultGrantsForDb() opts.privileges = nil - assertOptsInvalid(t, opts, fmt.Errorf("privileges must be set")) + assertOptsInvalidJoinedErrors(t, opts, errNotSet("GrantPrivilegesToDatabaseRoleOptions", "privileges")) }) t.Run("validation: no privileges set", func(t *testing.T) { opts := defaultGrantsForDb() opts.privileges = &DatabaseRoleGrantPrivileges{} - assertOptsInvalid(t, opts, fmt.Errorf("exactly one of DatabasePrivileges, SchemaPrivileges, or SchemaObjectPrivileges must be set")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("DatabaseRoleGrantPrivileges", "DatabasePrivileges", "SchemaPrivileges", "SchemaObjectPrivileges")) }) t.Run("validation: too many privileges set", func(t *testing.T) { @@ -344,19 +344,19 @@ func TestGrants_GrantPrivilegesToDatabaseRole(t *testing.T) { DatabasePrivileges: []AccountObjectPrivilege{AccountObjectPrivilegeCreateSchema}, SchemaPrivileges: []SchemaPrivilege{SchemaPrivilegeCreateAlert}, } - assertOptsInvalid(t, opts, fmt.Errorf("exactly one of DatabasePrivileges, SchemaPrivileges, or SchemaObjectPrivileges must be set")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("DatabaseRoleGrantPrivileges", "DatabasePrivileges", "SchemaPrivileges", "SchemaObjectPrivileges")) }) t.Run("validation: no on set", func(t *testing.T) { opts := defaultGrantsForDb() opts.on = nil - assertOptsInvalid(t, opts, fmt.Errorf("on must be set")) + assertOptsInvalidJoinedErrors(t, opts, errNotSet("GrantPrivilegesToDatabaseRoleOptions", "on")) }) t.Run("validation: no on set", func(t *testing.T) { opts := defaultGrantsForDb() opts.on = &DatabaseRoleGrantOn{} - assertOptsInvalid(t, opts, fmt.Errorf("exactly one of Database, Schema, or SchemaObject must be set")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("DatabaseRoleGrantOn", "Database", "Schema", "SchemaObject")) }) t.Run("validation: too many ons set", func(t *testing.T) { @@ -367,19 +367,19 @@ func TestGrants_GrantPrivilegesToDatabaseRole(t *testing.T) { Schema: Pointer(NewDatabaseObjectIdentifier("db1", "schema1")), }, } - assertOptsInvalid(t, opts, fmt.Errorf("exactly one of Database, Schema, or SchemaObject must be set")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("DatabaseRoleGrantOn", "Database", "Schema", "SchemaObject")) }) t.Run("validation: grant on schema", func(t *testing.T) { opts := defaultGrantsForSchema() opts.on.Schema = &GrantOnSchema{} - assertOptsInvalid(t, opts, fmt.Errorf("exactly one of Schema, AllSchemasInDatabase, or FutureSchemasInDatabase must be set")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("GrantOnSchema", "Schema", "AllSchemasInDatabase", "FutureSchemasInDatabase")) }) t.Run("validation: grant on schema object", func(t *testing.T) { opts := defaultGrantsForSchemaObject() opts.on.SchemaObject = &GrantOnSchemaObject{} - assertOptsInvalid(t, opts, fmt.Errorf("exactly one of Object, AllIn or Future must be set")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("GrantOnSchemaObject", "SchemaObject", "All", "Future")) }) t.Run("validation: grant on schema object - all", func(t *testing.T) { @@ -391,7 +391,7 @@ func TestGrants_GrantPrivilegesToDatabaseRole(t *testing.T) { }, }, } - assertOptsInvalid(t, opts, fmt.Errorf("exactly one of InDatabase, or InSchema must be set")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("GrantOnSchemaObjectIn", "InDatabase", "InSchema")) }) t.Run("validation: grant on schema object - future", func(t *testing.T) { @@ -403,13 +403,13 @@ func TestGrants_GrantPrivilegesToDatabaseRole(t *testing.T) { }, }, } - assertOptsInvalid(t, opts, fmt.Errorf("exactly one of InDatabase, or InSchema must be set")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("GrantOnSchemaObjectIn", "InDatabase", "InSchema")) }) t.Run("validation: unsupported database privilege", func(t *testing.T) { opts := defaultGrantsForDb() opts.privileges.DatabasePrivileges = []AccountObjectPrivilege{AccountObjectPrivilegeCreateDatabaseRole} - assertOptsInvalid(t, opts, fmt.Errorf("privilege CREATE DATABASE ROLE is not allowed")) + assertOptsInvalidJoinedErrors(t, opts, fmt.Errorf("privilege CREATE DATABASE ROLE is not allowed")) }) t.Run("on database", func(t *testing.T) { @@ -512,16 +512,16 @@ func TestGrants_RevokePrivilegesFromDatabaseRoleRole(t *testing.T) { } } - t.Run("validation: no privileges set", func(t *testing.T) { + t.Run("validation: nil privileges set", func(t *testing.T) { opts := defaultGrantsForDb() opts.privileges = nil - assertOptsInvalid(t, opts, fmt.Errorf("privileges must be set")) + assertOptsInvalidJoinedErrors(t, opts, errNotSet("RevokePrivilegesFromDatabaseRoleOptions", "privileges")) }) t.Run("validation: no privileges set", func(t *testing.T) { opts := defaultGrantsForDb() opts.privileges = &DatabaseRoleGrantPrivileges{} - assertOptsInvalid(t, opts, fmt.Errorf("exactly one of DatabasePrivileges, SchemaPrivileges, or SchemaObjectPrivileges must be set")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("DatabaseRoleGrantPrivileges", "DatabasePrivileges", "SchemaPrivileges", "SchemaObjectPrivileges")) }) t.Run("validation: too many privileges set", func(t *testing.T) { @@ -530,19 +530,19 @@ func TestGrants_RevokePrivilegesFromDatabaseRoleRole(t *testing.T) { DatabasePrivileges: []AccountObjectPrivilege{AccountObjectPrivilegeCreateSchema}, SchemaPrivileges: []SchemaPrivilege{SchemaPrivilegeCreateAlert}, } - assertOptsInvalid(t, opts, fmt.Errorf("exactly one of DatabasePrivileges, SchemaPrivileges, or SchemaObjectPrivileges must be set")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("DatabaseRoleGrantPrivileges", "DatabasePrivileges", "SchemaPrivileges", "SchemaObjectPrivileges")) }) - t.Run("validation: no on set", func(t *testing.T) { + t.Run("validation: nil on set", func(t *testing.T) { opts := defaultGrantsForDb() opts.on = nil - assertOptsInvalid(t, opts, fmt.Errorf("on must be set")) + assertOptsInvalidJoinedErrors(t, opts, errNotSet("RevokePrivilegesFromDatabaseRoleOptions", "on")) }) t.Run("validation: no on set", func(t *testing.T) { opts := defaultGrantsForDb() opts.on = &DatabaseRoleGrantOn{} - assertOptsInvalid(t, opts, fmt.Errorf("exactly one of Database, Schema, or SchemaObject must be set")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("DatabaseRoleGrantOn", "Database", "Schema", "SchemaObject")) }) t.Run("validation: too many ons set", func(t *testing.T) { @@ -553,19 +553,19 @@ func TestGrants_RevokePrivilegesFromDatabaseRoleRole(t *testing.T) { Schema: Pointer(NewDatabaseObjectIdentifier("db1", "schema1")), }, } - assertOptsInvalid(t, opts, fmt.Errorf("exactly one of Database, Schema, or SchemaObject must be set")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("DatabaseRoleGrantOn", "Database", "Schema", "SchemaObject")) }) t.Run("validation: grant on schema", func(t *testing.T) { opts := defaultGrantsForSchema() opts.on.Schema = &GrantOnSchema{} - assertOptsInvalid(t, opts, fmt.Errorf("exactly one of Schema, AllSchemasInDatabase, or FutureSchemasInDatabase must be set")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("GrantOnSchema", "Schema", "AllSchemasInDatabase", "FutureSchemasInDatabase")) }) t.Run("validation: grant on schema object", func(t *testing.T) { opts := defaultGrantsForSchemaObject() opts.on.SchemaObject = &GrantOnSchemaObject{} - assertOptsInvalid(t, opts, fmt.Errorf("exactly one of Object, AllIn or Future must be set")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("GrantOnSchemaObject", "SchemaObject", "All", "Future")) }) t.Run("validation: grant on schema object - all", func(t *testing.T) { @@ -577,7 +577,7 @@ func TestGrants_RevokePrivilegesFromDatabaseRoleRole(t *testing.T) { }, }, } - assertOptsInvalid(t, opts, fmt.Errorf("exactly one of InDatabase, or InSchema must be set")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("GrantOnSchemaObjectIn", "InDatabase", "InSchema")) }) t.Run("validation: grant on schema object - future", func(t *testing.T) { @@ -589,13 +589,13 @@ func TestGrants_RevokePrivilegesFromDatabaseRoleRole(t *testing.T) { }, }, } - assertOptsInvalid(t, opts, fmt.Errorf("exactly one of InDatabase, or InSchema must be set")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("GrantOnSchemaObjectIn", "InDatabase", "InSchema")) }) t.Run("validation: unsupported database privilege", func(t *testing.T) { opts := defaultGrantsForDb() opts.privileges.DatabasePrivileges = []AccountObjectPrivilege{AccountObjectPrivilegeCreateDatabaseRole} - assertOptsInvalid(t, opts, errors.New("privilege CREATE DATABASE ROLE is not allowed")) + assertOptsInvalidJoinedErrors(t, opts, errors.New("privilege CREATE DATABASE ROLE is not allowed")) }) t.Run("on database", func(t *testing.T) { @@ -828,7 +828,7 @@ func TestGrants_GrantOwnership(t *testing.T) { t.Run("validation: grant on empty", func(t *testing.T) { opts := defaultOpts() opts.On = OwnershipGrantOn{} - assertOptsInvalid(t, opts, fmt.Errorf("exactly one of [Object AllIn Future] must be set")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("OwnershipGrantOn", "Object", "AllIn", "Future")) }) t.Run("validation: grant on too many", func(t *testing.T) { @@ -843,7 +843,7 @@ func TestGrants_GrantOwnership(t *testing.T) { InDatabase: Pointer(dbId), }, } - assertOptsInvalid(t, opts, fmt.Errorf("exactly one of [Object AllIn Future] must be set")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("OwnershipGrantOn", "Object", "AllIn", "Future")) }) t.Run("validation: grant on schema object - all", func(t *testing.T) { @@ -853,7 +853,7 @@ func TestGrants_GrantOwnership(t *testing.T) { PluralObjectType: PluralObjectTypeTables, }, } - assertOptsInvalid(t, opts, fmt.Errorf("exactly one of InDatabase, or InSchema must be set")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("GrantOnSchemaObjectIn", "InDatabase", "InSchema")) }) t.Run("validation: grant on schema object - future", func(t *testing.T) { @@ -863,13 +863,13 @@ func TestGrants_GrantOwnership(t *testing.T) { PluralObjectType: PluralObjectTypeTables, }, } - assertOptsInvalid(t, opts, fmt.Errorf("exactly one of InDatabase, or InSchema must be set")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("GrantOnSchemaObjectIn", "InDatabase", "InSchema")) }) t.Run("validation: grant to empty", func(t *testing.T) { opts := defaultOpts() opts.To = OwnershipGrantTo{} - assertOptsInvalid(t, opts, fmt.Errorf("exactly one of [databaseRoleName accountRoleName] must be set")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("OwnershipGrantTo", "databaseRoleName", "accountRoleName")) }) t.Run("validation: grant to role and database role", func(t *testing.T) { @@ -878,7 +878,7 @@ func TestGrants_GrantOwnership(t *testing.T) { DatabaseRoleName: Pointer(databaseRoleId), AccountRoleName: Pointer(roleId), } - assertOptsInvalid(t, opts, fmt.Errorf("exactly one of [databaseRoleName accountRoleName] must be set")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("OwnershipGrantTo", "databaseRoleName", "accountRoleName")) }) t.Run("on schema object to role", func(t *testing.T) { diff --git a/pkg/sdk/grants_validations.go b/pkg/sdk/grants_validations.go index 05f6a598bb..a66029ac63 100644 --- a/pkg/sdk/grants_validations.go +++ b/pkg/sdk/grants_validations.go @@ -1,6 +1,7 @@ package sdk import ( + "errors" "fmt" // TODO: change to slices with go 1.21 @@ -19,129 +20,150 @@ var ( ) func (opts *GrantPrivilegesToAccountRoleOptions) validate() error { - if !valueSet(opts.privileges) { - return fmt.Errorf("privileges must be set") + if opts == nil { + return errors.Join(ErrNilOptions) } - if err := opts.privileges.validate(); err != nil { - return err + var errs []error + if !valueSet(opts.privileges) { + errs = append(errs, errNotSet("GrantPrivilegesToAccountRoleOptions", "privileges")) + } else { + if err := opts.privileges.validate(); err != nil { + errs = append(errs, err) + } } if !valueSet(opts.on) { - return fmt.Errorf("on must be set") - } - if err := opts.on.validate(); err != nil { - return err + errs = append(errs, errNotSet("GrantPrivilegesToAccountRoleOptions", "on")) + } else { + if err := opts.on.validate(); err != nil { + errs = append(errs, err) + } } - return nil + return errors.Join(errs...) } func (v *AccountRoleGrantPrivileges) validate() error { if !exactlyOneValueSet(v.AllPrivileges, v.GlobalPrivileges, v.AccountObjectPrivileges, v.SchemaPrivileges, v.SchemaObjectPrivileges) { - return fmt.Errorf("exactly one of AllPrivileges, GlobalPrivileges, AccountObjectPrivileges, SchemaPrivileges, or SchemaObjectPrivileges must be set") + return errExactlyOneOf("AccountRoleGrantPrivileges", "AllPrivileges", "GlobalPrivileges", "AccountObjectPrivileges", "SchemaPrivileges", "SchemaObjectPrivileges") } return nil } func (v *AccountRoleGrantOn) validate() error { + var errs []error if !exactlyOneValueSet(v.Account, v.AccountObject, v.Schema, v.SchemaObject) { - return fmt.Errorf("exactly one of Account, AccountObject, Schema, or SchemaObject must be set") + errs = append(errs, errExactlyOneOf("AccountRoleGrantOn", "Account", "AccountObject", "Schema", "SchemaObject")) } if valueSet(v.AccountObject) { if err := v.AccountObject.validate(); err != nil { - return err + errs = append(errs, err) } } if valueSet(v.Schema) { if err := v.Schema.validate(); err != nil { - return err + errs = append(errs, err) } } if valueSet(v.SchemaObject) { if err := v.SchemaObject.validate(); err != nil { - return err + errs = append(errs, err) } } - return nil + return errors.Join(errs...) } func (v *GrantOnAccountObject) validate() error { if !exactlyOneValueSet(v.User, v.ResourceMonitor, v.Warehouse, v.Database, v.Integration, v.FailoverGroup, v.ReplicationGroup) { - return fmt.Errorf("exactly one of User, ResourceMonitor, Warehouse, Database, Integration, FailoverGroup, or ReplicationGroup must be set") + return errExactlyOneOf("GrantOnAccountObject", "User", "ResourceMonitor", "Warehouse", "Database", "Integration", "FailoverGroup", "ReplicationGroup") } return nil } func (v *GrantOnSchema) validate() error { if !exactlyOneValueSet(v.Schema, v.AllSchemasInDatabase, v.FutureSchemasInDatabase) { - return fmt.Errorf("exactly one of Schema, AllSchemasInDatabase, or FutureSchemasInDatabase must be set") + return errExactlyOneOf("GrantOnSchema", "Schema", "AllSchemasInDatabase", "FutureSchemasInDatabase") } return nil } func (v *GrantOnSchemaObject) validate() error { + var errs []error if !exactlyOneValueSet(v.SchemaObject, v.All, v.Future) { - return fmt.Errorf("exactly one of Object, AllIn or Future must be set") + errs = append(errs, errExactlyOneOf("GrantOnSchemaObject", "SchemaObject", "All", "Future")) } if valueSet(v.All) { if err := v.All.validate(); err != nil { - return err + errs = append(errs, err) } } if valueSet(v.Future) { if err := v.Future.validate(); err != nil { - return err + errs = append(errs, err) } } - return nil + return errors.Join(errs...) } func (v *GrantOnSchemaObjectIn) validate() error { if !exactlyOneValueSet(v.InDatabase, v.InSchema) { - return fmt.Errorf("exactly one of InDatabase, or InSchema must be set") + return errExactlyOneOf("GrantOnSchemaObjectIn", "InDatabase", "InSchema") } return nil } func (opts *RevokePrivilegesFromAccountRoleOptions) validate() error { - if !valueSet(opts.privileges) { - return fmt.Errorf("privileges must be set") + if opts == nil { + return errors.Join(ErrNilOptions) } - if err := opts.privileges.validate(); err != nil { - return err + var errs []error + if !valueSet(opts.privileges) { + errs = append(errs, errNotSet("RevokePrivilegesFromAccountRoleOptions", "privileges")) + } else { + if err := opts.privileges.validate(); err != nil { + errs = append(errs, err) + } } if !valueSet(opts.on) { - return fmt.Errorf("on must be set") - } - if err := opts.on.validate(); err != nil { - return err + errs = append(errs, errNotSet("RevokePrivilegesFromAccountRoleOptions", "on")) + } else { + if err := opts.on.validate(); err != nil { + errs = append(errs, err) + } } if !ValidObjectIdentifier(opts.accountRole) { - return ErrInvalidObjectIdentifier + errs = append(errs, ErrInvalidObjectIdentifier) } if everyValueSet(opts.Restrict, opts.Cascade) { - return fmt.Errorf("either Restrict or Cascade can be set, or neither but not both") + errs = append(errs, errOneOf("RevokePrivilegesFromAccountRoleOptions", "Restrict", "Cascade")) } - return nil + return errors.Join(errs...) } func (opts *GrantPrivilegesToDatabaseRoleOptions) validate() error { - if !valueSet(opts.privileges) { - return fmt.Errorf("privileges must be set") + if opts == nil { + return errors.Join(ErrNilOptions) } - if err := opts.privileges.validate(); err != nil { - return err + var errs []error + if !valueSet(opts.privileges) { + errs = append(errs, errNotSet("GrantPrivilegesToDatabaseRoleOptions", "privileges")) + } else { + if err := opts.privileges.validate(); err != nil { + errs = append(errs, err) + } } if !valueSet(opts.on) { - return fmt.Errorf("on must be set") - } - if err := opts.on.validate(); err != nil { - return err + errs = append(errs, errNotSet("GrantPrivilegesToDatabaseRoleOptions", "on")) + } else { + if err := opts.on.validate(); err != nil { + errs = append(errs, err) + } } - return nil + return errors.Join(errs...) } func (v *DatabaseRoleGrantPrivileges) validate() error { + var errs []error if !exactlyOneValueSet(v.DatabasePrivileges, v.SchemaPrivileges, v.SchemaObjectPrivileges) { - return fmt.Errorf("exactly one of DatabasePrivileges, SchemaPrivileges, or SchemaObjectPrivileges must be set") + errs = append(errs, errExactlyOneOf("DatabaseRoleGrantPrivileges", "DatabasePrivileges", "SchemaPrivileges", "SchemaObjectPrivileges")) } if valueSet(v.DatabasePrivileges) { allowedPrivileges := []AccountObjectPrivilege{ @@ -152,163 +174,192 @@ func (v *DatabaseRoleGrantPrivileges) validate() error { } for _, p := range v.DatabasePrivileges { if !slices.Contains(allowedPrivileges, p) { - return fmt.Errorf("privilege %s is not allowed", p.String()) + errs = append(errs, fmt.Errorf("privilege %s is not allowed", p.String())) } } } - return nil + return errors.Join(errs...) } func (v *DatabaseRoleGrantOn) validate() error { + var errs []error if !exactlyOneValueSet(v.Database, v.Schema, v.SchemaObject) { - return fmt.Errorf("exactly one of Database, Schema, or SchemaObject must be set") + errs = append(errs, errExactlyOneOf("DatabaseRoleGrantOn", "Database", "Schema", "SchemaObject")) } if valueSet(v.Schema) { if err := v.Schema.validate(); err != nil { - return err + errs = append(errs, err) } } if valueSet(v.SchemaObject) { if err := v.SchemaObject.validate(); err != nil { - return err + errs = append(errs, err) } } - return nil + return errors.Join(errs...) } func (opts *RevokePrivilegesFromDatabaseRoleOptions) validate() error { - if !valueSet(opts.privileges) { - return fmt.Errorf("privileges must be set") + if opts == nil { + return errors.Join(ErrNilOptions) } - if err := opts.privileges.validate(); err != nil { - return err + var errs []error + if !valueSet(opts.privileges) { + errs = append(errs, errNotSet("RevokePrivilegesFromDatabaseRoleOptions", "privileges")) + } else { + if err := opts.privileges.validate(); err != nil { + errs = append(errs, err) + } } if !valueSet(opts.on) { - return fmt.Errorf("on must be set") - } - if err := opts.on.validate(); err != nil { - return err + errs = append(errs, errNotSet("RevokePrivilegesFromDatabaseRoleOptions", "on")) + } else { + if err := opts.on.validate(); err != nil { + errs = append(errs, err) + } } if !ValidObjectIdentifier(opts.databaseRole) { - return ErrInvalidObjectIdentifier + errs = append(errs, ErrInvalidObjectIdentifier) } if everyValueSet(opts.Restrict, opts.Cascade) { - return fmt.Errorf("either Restrict or Cascade can be set, or neither but not both") + errs = append(errs, errOneOf("RevokePrivilegesFromDatabaseRoleOptions", "Restrict", "Cascade")) } - return nil + return errors.Join(errs...) } func (opts *grantPrivilegeToShareOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } + var errs []error if !ValidObjectIdentifier(opts.to) { - return ErrInvalidObjectIdentifier + errs = append(errs, ErrInvalidObjectIdentifier) } if !valueSet(opts.On) || opts.privilege == "" { - return fmt.Errorf("on and privilege are required") + errs = append(errs, fmt.Errorf("on and privilege are required")) } - if err := opts.On.validate(); err != nil { - return err + if valueSet(opts.On) { + if err := opts.On.validate(); err != nil { + errs = append(errs, err) + } } - return nil + return errors.Join(errs...) } func (v *GrantPrivilegeToShareOn) validate() error { + var errs []error if !exactlyOneValueSet(v.Database, v.Schema, v.Function, v.Table, v.View) { - return fmt.Errorf("only one of database, schema, function, table, or view can be set") + errs = append(errs, errExactlyOneOf("GrantPrivilegeToShareOn", "Database", "Schema", "Function", "Table", "View")) } if valueSet(v.Table) { if err := v.Table.validate(); err != nil { - return err + errs = append(errs, err) } } - return nil + return errors.Join(errs...) } func (v *OnTable) validate() error { if !exactlyOneValueSet(v.Name, v.AllInSchema) { - return fmt.Errorf("only one of name or allInSchema can be set") + return errExactlyOneOf("OnTable", "Name", "AllInSchema") } return nil } func (opts *revokePrivilegeFromShareOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } + var errs []error if !ValidObjectIdentifier(opts.from) { - return ErrInvalidObjectIdentifier + errs = append(errs, ErrInvalidObjectIdentifier) } if !valueSet(opts.On) || opts.privilege == "" { - return fmt.Errorf("on and privilege are required") - } - if !exactlyOneValueSet(opts.On.Database, opts.On.Schema, opts.On.Table, opts.On.View) { - return fmt.Errorf("only one of database, schema, function, table, or view can be set") + errs = append(errs, errNotSet("revokePrivilegeFromShareOptions", "On", "privilege")) } - - if err := opts.On.validate(); err != nil { - return err + if valueSet(opts.On) { + if !exactlyOneValueSet(opts.On.Database, opts.On.Schema, opts.On.Table, opts.On.View) { + errs = append(errs, errExactlyOneOf("revokePrivilegeFromShareOptions", "On.Database", "On.Schema", "On.Table", "On.View")) + } + if err := opts.On.validate(); err != nil { + errs = append(errs, err) + } } - - return nil + return errors.Join(errs...) } func (v *RevokePrivilegeFromShareOn) validate() error { + var errs []error if !exactlyOneValueSet(v.Database, v.Schema, v.Table, v.View) { - return fmt.Errorf("only one of database, schema, table, or view can be set") + errs = append(errs, errExactlyOneOf("RevokePrivilegeFromShareOn", "Database", "Schema", "Table", "View")) } if valueSet(v.Table) { - return v.Table.validate() + if err := v.Table.validate(); err != nil { + errs = append(errs, err) + } } if valueSet(v.View) { - return v.View.validate() + if err := v.View.validate(); err != nil { + errs = append(errs, err) + } } - return nil + return errors.Join(errs...) } func (v *OnView) validate() error { if !exactlyOneValueSet(v.Name, v.AllInSchema) { - return fmt.Errorf("only one of name or allInSchema can be set") + return errExactlyOneOf("OnView", "Name", "AllInSchema") } return nil } func (opts *GrantOwnershipOptions) validate() error { - if err := opts.On.validate(); err != nil { - return err + if opts == nil { + return errors.Join(ErrNilOptions) } - if err := opts.To.validate(); err != nil { - return err + var errs []error + if valueSet(opts.On) { + if err := opts.On.validate(); err != nil { + errs = append(errs, err) + } } - return nil + if valueSet(opts.To) { + if err := opts.To.validate(); err != nil { + errs = append(errs, err) + } + } + return errors.Join(errs...) } func (v *OwnershipGrantOn) validate() error { + var errs []error if !exactlyOneValueSet(v.Object, v.All, v.Future) { - return errExactlyOneOf("Object", "AllIn", "Future") + errs = append(errs, errExactlyOneOf("OwnershipGrantOn", "Object", "AllIn", "Future")) } if valueSet(v.All) { if err := v.All.validate(); err != nil { - return err + errs = append(errs, err) } } if valueSet(v.Future) { if err := v.Future.validate(); err != nil { - return err + errs = append(errs, err) } } - return nil + return errors.Join(errs...) } func (v *OwnershipGrantTo) validate() error { if !exactlyOneValueSet(v.DatabaseRoleName, v.AccountRoleName) { - return errExactlyOneOf("databaseRoleName", "accountRoleName") + return errExactlyOneOf("OwnershipGrantTo", "databaseRoleName", "accountRoleName") } return nil } // TODO: add validations for ShowGrantsOn, ShowGrantsTo, ShowGrantsOf and ShowGrantsIn func (opts *ShowGrantOptions) validate() error { - if everyValueNil(opts.On, opts.To, opts.Of, opts.In) { - return nil - } - if !exactlyOneValueSet(opts.On, opts.To, opts.Of, opts.In) { - return fmt.Errorf("only one of [on, to, of, in] can be set") + if moreThanOneValueSet(opts.On, opts.To, opts.Of, opts.In) { + return errOneOf("ShowGrantOptions", "On", "To", "Of", "In") } return nil } diff --git a/pkg/sdk/integration_test_imports.go b/pkg/sdk/integration_test_imports.go index ea3398351c..d038d50765 100644 --- a/pkg/sdk/integration_test_imports.go +++ b/pkg/sdk/integration_test_imports.go @@ -15,3 +15,11 @@ func (c *Client) ExecForTests(ctx context.Context, sql string) (sql.Result, erro result, err := c.db.ExecContext(ctx, sql) return result, decodeDriverError(err) } + +func ErrExactlyOneOf(structName string, fieldNames ...string) error { + return errExactlyOneOf(structName, fieldNames...) +} + +func ErrOneOf(structName string, fieldNames ...string) error { + return errOneOf(structName, fieldNames...) +} diff --git a/pkg/sdk/masking_policy.go b/pkg/sdk/masking_policy.go index 26120d6b57..1bb20269cd 100644 --- a/pkg/sdk/masking_policy.go +++ b/pkg/sdk/masking_policy.go @@ -52,20 +52,23 @@ type CreateMaskingPolicyOptions struct { } func (opts *CreateMaskingPolicyOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } + var errs []error if !ValidObjectIdentifier(opts.name) { - return errors.New("invalid object identifier") + errs = append(errs, errors.Join(ErrInvalidObjectIdentifier)) } if !valueSet(opts.signature) { - return errNotSet("CreateMaskingPolicyOptions", "signature") + errs = append(errs, errNotSet("CreateMaskingPolicyOptions", "signature")) } if !valueSet(opts.returns) { - return errNotSet("CreateMaskingPolicyOptions", "returns") + errs = append(errs, errNotSet("CreateMaskingPolicyOptions", "returns")) } if !valueSet(opts.body) { - return errNotSet("CreateMaskingPolicyOptions", "body") + errs = append(errs, errNotSet("CreateMaskingPolicyOptions", "body")) } - - return nil + return errors.Join(errs...) } func (v *maskingPolicies) Create(ctx context.Context, id SchemaObjectIdentifier, signature []TableColumnSignature, returns DataType, body string, opts *CreateMaskingPolicyOptions) error { @@ -99,27 +102,30 @@ type AlterMaskingPolicyOptions struct { } func (opts *AlterMaskingPolicyOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } + var errs []error if !ValidObjectIdentifier(opts.name) { - return errors.New("invalid object identifier") + errs = append(errs, ErrInvalidObjectIdentifier) } - - if !exactlyOneValueSet(opts.Set, opts.Unset, opts.NewName) { - return errExactlyOneOf("Set", "Unset", "NewName") + if opts.NewName != nil && !ValidObjectIdentifier(opts.NewName) { + errs = append(errs, errInvalidIdentifier("AlterMaskingPolicyOptions", "NewName")) + } + if !exactlyOneValueSet(opts.NewName, opts.Set, opts.Unset) { + errs = append(errs, errExactlyOneOf("AlterMaskingPolicyOptions", "NewName", "Set", "Unset")) } - if valueSet(opts.Set) { if err := opts.Set.validate(); err != nil { - return err + errs = append(errs, err) } } - if valueSet(opts.Unset) { if err := opts.Unset.validate(); err != nil { - return err + errs = append(errs, err) } } - - return nil + return errors.Join(errs...) } type MaskingPolicySet struct { @@ -130,7 +136,7 @@ type MaskingPolicySet struct { func (v *MaskingPolicySet) validate() error { if !exactlyOneValueSet(v.Body, v.Tag, v.Comment) { - return errors.New("only one parameter can be set at a time") + return errExactlyOneOf("MaskingPolicySet", "Body", "Tag", "Comment") } return nil } @@ -142,7 +148,7 @@ type MaskingPolicyUnset struct { func (v *MaskingPolicyUnset) validate() error { if !exactlyOneValueSet(v.Tag, v.Comment) { - return errors.New("only one parameter can be unset at a time") + return errExactlyOneOf("MaskingPolicyUnset", "Tag", "Comment") } return nil } @@ -171,8 +177,11 @@ type DropMaskingPolicyOptions struct { } func (opts *DropMaskingPolicyOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + return errors.Join(ErrInvalidObjectIdentifier) } return nil } @@ -205,7 +214,10 @@ type ShowMaskingPolicyOptions struct { Limit *int `ddl:"parameter,no_equals" sql:"LIMIT"` } -func (input *ShowMaskingPolicyOptions) validate() error { +func (opts *ShowMaskingPolicyOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } return nil } @@ -298,9 +310,12 @@ type describeMaskingPolicyOptions struct { name SchemaObjectIdentifier `ddl:"identifier"` } -func (v *describeMaskingPolicyOptions) validate() error { - if !ValidObjectIdentifier(v.name) { - return ErrInvalidObjectIdentifier +func (opts *describeMaskingPolicyOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } + if !ValidObjectIdentifier(opts.name) { + return errors.Join(ErrInvalidObjectIdentifier) } return nil } diff --git a/pkg/sdk/masking_policy_test.go b/pkg/sdk/masking_policy_test.go index 6bd8129d05..23d9362387 100644 --- a/pkg/sdk/masking_policy_test.go +++ b/pkg/sdk/masking_policy_test.go @@ -26,7 +26,7 @@ func TestMaskingPolicyCreate(t *testing.T) { signature: signature, returns: DataTypeVARCHAR, } - assertOptsInvalid(t, opts, errNotSet("CreateMaskingPolicyOptions", "body")) + assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateMaskingPolicyOptions", "body")) }) t.Run("validation: no signature", func(t *testing.T) { @@ -35,7 +35,7 @@ func TestMaskingPolicyCreate(t *testing.T) { body: expression, returns: DataTypeVARCHAR, } - assertOptsInvalid(t, opts, errNotSet("CreateMaskingPolicyOptions", "signature")) + assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateMaskingPolicyOptions", "signature")) }) t.Run("validation: no returns", func(t *testing.T) { @@ -44,7 +44,7 @@ func TestMaskingPolicyCreate(t *testing.T) { signature: signature, body: expression, } - assertOptsInvalid(t, opts, errNotSet("CreateMaskingPolicyOptions", "returns")) + assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateMaskingPolicyOptions", "returns")) }) t.Run("only required options", func(t *testing.T) { @@ -81,15 +81,14 @@ func TestMaskingPolicyAlter(t *testing.T) { t.Run("validation: empty options", func(t *testing.T) { opts := &AlterMaskingPolicyOptions{} - assertOptsInvalid(t, opts, ErrInvalidObjectIdentifier) + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) t.Run("validation: no option", func(t *testing.T) { opts := &AlterMaskingPolicyOptions{ name: id, } - - assertOptsInvalid(t, opts, errExactlyOneOf("Set", "Unset", "NewName")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterMaskingPolicyOptions", "NewName", "Set", "Unset")) }) t.Run("with set", func(t *testing.T) { @@ -128,7 +127,7 @@ func TestMaskingPolicyDrop(t *testing.T) { t.Run("validation: empty options", func(t *testing.T) { opts := &DropMaskingPolicyOptions{} - assertOptsInvalid(t, opts, ErrInvalidObjectIdentifier) + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) t.Run("only name", func(t *testing.T) { @@ -207,7 +206,7 @@ func TestMaskingPolicyDescribe(t *testing.T) { t.Run("validation: empty options", func(t *testing.T) { opts := &describeMaskingPolicyOptions{} - assertOptsInvalid(t, opts, ErrInvalidObjectIdentifier) + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) t.Run("only name", func(t *testing.T) { diff --git a/pkg/sdk/network_policies_gen_test.go b/pkg/sdk/network_policies_gen_test.go index f03fd49bb3..697661c4ff 100644 --- a/pkg/sdk/network_policies_gen_test.go +++ b/pkg/sdk/network_policies_gen_test.go @@ -60,13 +60,13 @@ func TestNetworkPolicies_Alter(t *testing.T) { t.Run("validation: exactly one field from [opts.Set opts.UnsetComment opts.RenameTo] should be present", func(t *testing.T) { opts := defaultOpts() - assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("Set", "UnsetComment", "RenameTo")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterNetworkPolicyOptions", "Set", "UnsetComment", "RenameTo")) }) t.Run("validation: at least one of the fields [opts.Set.AllowedIpList opts.Set.BlockedIpList opts.Set.Comment] should be set", func(t *testing.T) { opts := defaultOpts() opts.Set = &NetworkPolicySet{} - assertOptsInvalidJoinedErrors(t, opts, errAtLeastOneOf("AllowedIpList", "BlockedIpList", "Comment")) + assertOptsInvalidJoinedErrors(t, opts, errAtLeastOneOf("AlterNetworkPolicyOptions.Set", "AllowedIpList", "BlockedIpList", "Comment")) }) t.Run("set allowed ip list", func(t *testing.T) { diff --git a/pkg/sdk/network_policies_validations_gen.go b/pkg/sdk/network_policies_validations_gen.go index 6a7e5aff61..2b06d456e4 100644 --- a/pkg/sdk/network_policies_validations_gen.go +++ b/pkg/sdk/network_policies_validations_gen.go @@ -30,14 +30,14 @@ func (opts *AlterNetworkPolicyOptions) validate() error { errs = append(errs, ErrInvalidObjectIdentifier) } if ok := exactlyOneValueSet(opts.Set, opts.UnsetComment, opts.RenameTo); !ok { - errs = append(errs, errExactlyOneOf("Set", "UnsetComment", "RenameTo")) + errs = append(errs, errExactlyOneOf("AlterNetworkPolicyOptions", "Set", "UnsetComment", "RenameTo")) } if valueSet(opts.RenameTo) && !ValidObjectIdentifier(opts.RenameTo) { errs = append(errs, ErrInvalidObjectIdentifier) } if valueSet(opts.Set) { if ok := anyValueSet(opts.Set.AllowedIpList, opts.Set.BlockedIpList, opts.Set.Comment); !ok { - errs = append(errs, errAtLeastOneOf("AllowedIpList", "BlockedIpList", "Comment")) + errs = append(errs, errAtLeastOneOf("AlterNetworkPolicyOptions.Set", "AllowedIpList", "BlockedIpList", "Comment")) } } return errors.Join(errs...) diff --git a/pkg/sdk/parameters.go b/pkg/sdk/parameters.go index 7cff37644a..6a7318bd1d 100644 --- a/pkg/sdk/parameters.go +++ b/pkg/sdk/parameters.go @@ -3,6 +3,7 @@ package sdk import ( "context" "database/sql" + "errors" "fmt" "strconv" ) @@ -815,24 +816,24 @@ type AccountParameters struct { } func (v *AccountParameters) validate() error { + var errs []error if valueSet(v.ClientEncryptionKeySize) { if !(*v.ClientEncryptionKeySize == 128 || *v.ClientEncryptionKeySize == 256) { - return fmt.Errorf("CLIENT_ENCRYPTION_KEY_SIZE must be either 128 or 256") + errs = append(errs, fmt.Errorf("CLIENT_ENCRYPTION_KEY_SIZE must be either 128 or 256")) } } if valueSet(v.InitialReplicationSizeLimitInTB) { l := *v.InitialReplicationSizeLimitInTB if l < 0.0 || (l < 0.0 && l < 1.0) { - return fmt.Errorf("%v must be 0.0 and above with a scale of at least 1 (e.g. 20.5, 32.25, 33.333, etc.)", l) + errs = append(errs, fmt.Errorf("%v must be 0.0 and above with a scale of at least 1 (e.g. 20.5, 32.25, 33.333, etc.)", l)) } - return nil } if valueSet(v.MinDataRetentionTimeInDays) { - if ok := validateIntInRange(*v.MinDataRetentionTimeInDays, 0, 90); !ok { - return fmt.Errorf("MIN_DATA_RETENTION_TIME_IN_DAYS must be between 0 and 90") + if !validateIntInRange(*v.MinDataRetentionTimeInDays, 0, 90) { + errs = append(errs, errIntBetween("AccountParameters", "MinDataRetentionTimeInDays", 0, 90)) } } - return nil + return errors.Join(errs...) } type AccountParametersUnset struct { @@ -932,42 +933,43 @@ type SessionParameters struct { } func (v *SessionParameters) validate() error { + var errs []error if valueSet(v.JSONIndent) { - if ok := validateIntInRange(*v.JSONIndent, 0, 16); !ok { - return fmt.Errorf("JSON_INDENT must be between 0 and 16") + if !validateIntInRange(*v.JSONIndent, 0, 16) { + errs = append(errs, errIntBetween("SessionParameters", "JSONIndent", 0, 16)) } } if valueSet(v.LockTimeout) { - if ok := validateIntGreaterThanOrEqual(*v.LockTimeout, 0); !ok { - return fmt.Errorf("LOCK_TIMEOUT must be greater than or equal to 0") + if !validateIntGreaterThanOrEqual(*v.LockTimeout, 0) { + errs = append(errs, errIntValue("SessionParameters", "LockTimeout", IntErrGreaterOrEqual, 0)) } } if valueSet(v.QueryTag) { if len(*v.QueryTag) > 2000 { - return fmt.Errorf("QUERY_TAG must be less than 2000 characters") + errs = append(errs, errIntValue("SessionParameters", "QueryTag", IntErrLess, 2000)) } } if valueSet(v.RowsPerResultset) { - if ok := validateIntGreaterThanOrEqual(*v.RowsPerResultset, 0); !ok { - return fmt.Errorf("ROWS_PER_RESULTSET must be greater than or equal to 0") + if !validateIntGreaterThanOrEqual(*v.RowsPerResultset, 0) { + errs = append(errs, errIntValue("SessionParameters", "RowsPerResultset", IntErrGreaterOrEqual, 0)) } } if valueSet(v.TwoDigitCenturyStart) { - if ok := validateIntInRange(*v.TwoDigitCenturyStart, 1900, 2100); !ok { - return fmt.Errorf("TWO_DIGIT_CENTURY_START must be between 1900 and 2100") + if !validateIntInRange(*v.TwoDigitCenturyStart, 1900, 2100) { + errs = append(errs, errIntBetween("SessionParameters", "TwoDigitCenturyStart", 1900, 2100)) } } if valueSet(v.WeekOfYearPolicy) { - if ok := validateIntInRange(*v.WeekOfYearPolicy, 0, 1); !ok { - return fmt.Errorf("WEEK_OF_YEAR_POLICY must be either 0 or 1") + if !validateIntInRange(*v.WeekOfYearPolicy, 0, 1) { + errs = append(errs, fmt.Errorf("WEEK_OF_YEAR_POLICY must be either 0 or 1")) } } if valueSet(v.WeekStart) { - if ok := validateIntInRange(*v.WeekStart, 0, 1); !ok { - return fmt.Errorf("WEEK_START must be either 0 or 1") + if !validateIntInRange(*v.WeekStart, 0, 1) { + errs = append(errs, fmt.Errorf("WEEK_START must be either 0 or 1")) } } - return nil + return errors.Join(errs...) } type SessionParametersUnset struct { @@ -1011,8 +1013,8 @@ type SessionParametersUnset struct { } func (v *SessionParametersUnset) validate() error { - if ok := anyValueSet(v.AbortDetachedQuery, v.Autocommit, v.BinaryInputFormat, v.BinaryOutputFormat, v.DateInputFormat, v.DateOutputFormat, v.ErrorOnNondeterministicMerge, v.ErrorOnNondeterministicUpdate, v.GeographyOutputFormat, v.JSONIndent, v.LockTimeout, v.QueryTag, v.RowsPerResultset, v.SimulatedDataSharingConsumer, v.StatementTimeoutInSeconds, v.StrictJSONOutput, v.TimestampDayIsAlways24h, v.TimestampInputFormat, v.TimestampLTZOutputFormat, v.TimestampNTZOutputFormat, v.TimestampOutputFormat, v.TimestampTypeMapping, v.TimestampTZOutputFormat, v.Timezone, v.TimeInputFormat, v.TimeOutputFormat, v.TransactionDefaultIsolationLevel, v.TwoDigitCenturyStart, v.UnsupportedDDLAction, v.UseCachedResult, v.WeekOfYearPolicy, v.WeekStart); !ok { - return fmt.Errorf("at least one session parameter must be set") + if !anyValueSet(v.AbortDetachedQuery, v.Autocommit, v.BinaryInputFormat, v.BinaryOutputFormat, v.DateInputFormat, v.DateOutputFormat, v.ErrorOnNondeterministicMerge, v.ErrorOnNondeterministicUpdate, v.GeographyOutputFormat, v.JSONIndent, v.LockTimeout, v.QueryTag, v.RowsPerResultset, v.SimulatedDataSharingConsumer, v.StatementTimeoutInSeconds, v.StrictJSONOutput, v.TimestampDayIsAlways24h, v.TimestampInputFormat, v.TimestampLTZOutputFormat, v.TimestampNTZOutputFormat, v.TimestampOutputFormat, v.TimestampTypeMapping, v.TimestampTZOutputFormat, v.Timezone, v.TimeInputFormat, v.TimeOutputFormat, v.TransactionDefaultIsolationLevel, v.TwoDigitCenturyStart, v.UnsupportedDDLAction, v.UseCachedResult, v.WeekOfYearPolicy, v.WeekStart) { + return errors.Join(errAtLeastOneOf("SessionParametersUnset", "AbortDetachedQuery", "Autocommit", "BinaryInputFormat", "BinaryOutputFormat", "DateInputFormat", "DateOutputFormat", "ErrorOnNondeterministicMerge", "ErrorOnNondeterministicUpdate", "GeographyOutputFormat", "JSONIndent", "LockTimeout", "QueryTag", "RowsPerResultset", "SimulatedDataSharingConsumer", "StatementTimeoutInSeconds", "StrictJSONOutput", "TimestampDayIsAlways24h", "TimestampInputFormat", "TimestampLTZOutputFormat", "TimestampNTZOutputFormat", "TimestampOutputFormat", "TimestampTypeMapping", "TimestampTZOutputFormat", "Timezone", "TimeInputFormat", "TimeOutputFormat", "TransactionDefaultIsolationLevel", "TwoDigitCenturyStart", "UnsupportedDDLAction", "UseCachedResult", "WeekOfYearPolicy", "WeekStart")) } return nil } @@ -1057,41 +1059,38 @@ type ObjectParameters struct { } func (v *ObjectParameters) validate() error { + var errs []error if valueSet(v.DataRetentionTimeInDays) { - if ok := validateIntInRange(*v.DataRetentionTimeInDays, 0, 90); !ok { - return fmt.Errorf("DATA_RETENTION_TIME_IN_DAYS must be between 0 and 90") + if !validateIntInRange(*v.DataRetentionTimeInDays, 0, 90) { + errs = append(errs, errIntBetween("ObjectParameters", "DataRetentionTimeInDays", 0, 90)) } } if valueSet(v.MaxConcurrencyLevel) { - if ok := validateIntGreaterThanOrEqual(*v.MaxConcurrencyLevel, 1); !ok { - return fmt.Errorf("MAX_CONCURRENCY_LEVEL must be greater than or equal to 1") + if !validateIntGreaterThanOrEqual(*v.MaxConcurrencyLevel, 1) { + errs = append(errs, errIntValue("ObjectParameters", "MaxConcurrencyLevel", IntErrGreaterOrEqual, 1)) } } - if valueSet(v.MaxDataExtensionTimeInDays) { - if ok := validateIntInRange(*v.MaxDataExtensionTimeInDays, 0, 90); !ok { - return fmt.Errorf("MAX_DATA_EXTENSION_TIME_IN_DAYS must be between 0 and 90") + if !validateIntInRange(*v.MaxDataExtensionTimeInDays, 0, 90) { + errs = append(errs, errIntBetween("ObjectParameters", "MaxDataExtensionTimeInDays", 0, 90)) } } - if valueSet(v.StatementQueuedTimeoutInSeconds) { - if ok := validateIntGreaterThanOrEqual(*v.StatementQueuedTimeoutInSeconds, 0); !ok { - return fmt.Errorf("STATEMENT_QUEUED_TIMEOUT_IN_SECONDS must be greater than or equal to 0") + if !validateIntGreaterThanOrEqual(*v.StatementQueuedTimeoutInSeconds, 0) { + errs = append(errs, errIntValue("ObjectParameters", "StatementQueuedTimeoutInSeconds", IntErrGreaterOrEqual, 0)) } } - if valueSet(v.SuspendTaskAfterNumFailures) { - if ok := validateIntGreaterThanOrEqual(*v.SuspendTaskAfterNumFailures, 0); !ok { - return fmt.Errorf("SUSPEND_TASK_AFTER_NUM_FAILURES must be greater than or equal to 0") + if !validateIntGreaterThanOrEqual(*v.SuspendTaskAfterNumFailures, 0) { + errs = append(errs, errIntValue("ObjectParameters", "SuspendTaskAfterNumFailures", IntErrGreaterOrEqual, 0)) } } - if valueSet(v.UserTaskTimeoutMs) { - if ok := validateIntInRange(*v.UserTaskTimeoutMs, 0, 86400000); !ok { - return fmt.Errorf("USER_TASK_TIMEOUT_MS must be between 0 and 86400000") + if !validateIntInRange(*v.UserTaskTimeoutMs, 0, 86400000) { + errs = append(errs, errIntBetween("ObjectParameters", "UserTaskTimeoutMs", 0, 86400000)) } } - return nil + return errors.Join(errs...) } type ObjectParametersUnset struct { @@ -1152,8 +1151,8 @@ type ParametersIn struct { } func (v *ParametersIn) validate() error { - if ok := anyValueSet(v.Session, v.Account, v.User, v.Warehouse, v.Database, v.Schema, v.Task, v.Table); !ok { - return fmt.Errorf("at least one IN parameter must be set") + if !anyValueSet(v.Session, v.Account, v.User, v.Warehouse, v.Database, v.Schema, v.Task, v.Table) { + return errors.Join(errAtLeastOneOf("Session", "Account", "User", "Warehouse", "Database", "Schema", "Task", "Table")) } return nil } diff --git a/pkg/sdk/password_policy.go b/pkg/sdk/password_policy.go index 578e8b264e..9fea0bdd8c 100644 --- a/pkg/sdk/password_policy.go +++ b/pkg/sdk/password_policy.go @@ -53,10 +53,12 @@ type CreatePasswordPolicyOptions struct { } func (opts *CreatePasswordPolicyOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + return errors.Join(ErrInvalidObjectIdentifier) } - return nil } @@ -88,27 +90,27 @@ type AlterPasswordPolicyOptions struct { } func (opts *AlterPasswordPolicyOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } + var errs []error if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + errs = append(errs, ErrInvalidObjectIdentifier) } - if !exactlyOneValueSet(opts.Set, opts.Unset, opts.NewName) { - return errExactlyOneOf("Set", "Unset", "NewName") + errs = append(errs, errExactlyOneOf("Set", "Unset", "NewName")) } - if valueSet(opts.Set) { if err := opts.Set.validate(); err != nil { - return err + errs = append(errs, err) } } - if valueSet(opts.Unset) { if err := opts.Unset.validate(); err != nil { - return err + errs = append(errs, err) } } - - return nil + return errors.Join(errs...) } type PasswordPolicySet struct { @@ -209,8 +211,11 @@ type DropPasswordPolicyOptions struct { } func (opts *DropPasswordPolicyOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + return errors.Join(ErrInvalidObjectIdentifier) } return nil } @@ -240,7 +245,10 @@ type ShowPasswordPolicyOptions struct { Limit *int `ddl:"parameter,no_equals" sql:"LIMIT"` } -func (input *ShowPasswordPolicyOptions) validate() error { +func (opts *ShowPasswordPolicyOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } return nil } @@ -339,9 +347,12 @@ type describePasswordPolicyOptions struct { name SchemaObjectIdentifier `ddl:"identifier"` } -func (v *describePasswordPolicyOptions) validate() error { - if !ValidObjectIdentifier(v.name) { - return ErrInvalidObjectIdentifier +func (opts *describePasswordPolicyOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } + if !ValidObjectIdentifier(opts.name) { + return errors.Join(ErrInvalidObjectIdentifier) } return nil } diff --git a/pkg/sdk/password_policy_test.go b/pkg/sdk/password_policy_test.go index a3b62c2fd5..05c15edcd3 100644 --- a/pkg/sdk/password_policy_test.go +++ b/pkg/sdk/password_policy_test.go @@ -11,7 +11,7 @@ func TestPasswordPolicyCreate(t *testing.T) { t.Run("validation: empty options", func(t *testing.T) { opts := &CreatePasswordPolicyOptions{} - assertOptsInvalid(t, opts, ErrInvalidObjectIdentifier) + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) t.Run("only name", func(t *testing.T) { @@ -46,14 +46,14 @@ func TestPasswordPolicyAlter(t *testing.T) { t.Run("empty options", func(t *testing.T) { opts := &AlterPasswordPolicyOptions{} - assertOptsInvalid(t, opts, ErrInvalidObjectIdentifier) + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) t.Run("only name", func(t *testing.T) { opts := &AlterPasswordPolicyOptions{ name: id, } - assertOptsInvalid(t, opts, errExactlyOneOf("Set", "Unset", "NewName")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("Set", "Unset", "NewName")) }) t.Run("with set", func(t *testing.T) { @@ -93,7 +93,7 @@ func TestPasswordPolicyDrop(t *testing.T) { t.Run("validation: empty options", func(t *testing.T) { opts := &DropPasswordPolicyOptions{} - assertOptsInvalid(t, opts, ErrInvalidObjectIdentifier) + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) t.Run("only name", func(t *testing.T) { @@ -180,7 +180,7 @@ func TestPasswordPolicyDescribe(t *testing.T) { t.Run("validation: empty options", func(t *testing.T) { opts := &describePasswordPolicyOptions{} - assertOptsInvalid(t, opts, ErrInvalidObjectIdentifier) + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) t.Run("only name", func(t *testing.T) { diff --git a/pkg/sdk/pipes_test.go b/pkg/sdk/pipes_test.go index 9c2047f1ff..773603072f 100644 --- a/pkg/sdk/pipes_test.go +++ b/pkg/sdk/pipes_test.go @@ -16,19 +16,19 @@ func TestPipesCreate(t *testing.T) { t.Run("validation: nil options", func(t *testing.T) { var opts *CreatePipeOptions = nil - assertOptsInvalid(t, opts, ErrNilOptions) + assertOptsInvalidJoinedErrors(t, opts, ErrNilOptions) }) t.Run("validation: incorrect identifier", func(t *testing.T) { opts := defaultOpts() opts.name = NewSchemaObjectIdentifier("", "", "") - assertOptsInvalid(t, opts, ErrInvalidObjectIdentifier) + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) t.Run("validation: copy statement required", func(t *testing.T) { opts := defaultOpts() opts.copyStatement = "" - assertOptsInvalid(t, opts, errCopyStatementRequired) + assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreatePipeOptions", "copyStatement")) }) t.Run("basic", func(t *testing.T) { @@ -59,18 +59,18 @@ func TestPipesAlter(t *testing.T) { t.Run("validation: nil options", func(t *testing.T) { var opts *AlterPipeOptions = nil - assertOptsInvalid(t, opts, ErrNilOptions) + assertOptsInvalidJoinedErrors(t, opts, ErrNilOptions) }) t.Run("validation: incorrect identifier", func(t *testing.T) { opts := defaultOpts() opts.name = NewSchemaObjectIdentifier("", "", "") - assertOptsInvalid(t, opts, ErrInvalidObjectIdentifier) + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) t.Run("validation: no alter action", func(t *testing.T) { opts := defaultOpts() - assertOptsInvalid(t, opts, errAlterNeedsExactlyOneAction) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterPipeOptions", "Set", "Unset", "SetTags", "UnsetTags", "Refresh")) }) t.Run("validation: multiple alter actions", func(t *testing.T) { @@ -81,13 +81,13 @@ func TestPipesAlter(t *testing.T) { opts.Unset = &PipeUnset{ Comment: Bool(true), } - assertOptsInvalid(t, opts, errAlterNeedsExactlyOneAction) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterPipeOptions", "Set", "Unset", "SetTags", "UnsetTags", "Refresh")) }) t.Run("validation: no property to set", func(t *testing.T) { opts := defaultOpts() opts.Set = &PipeSet{} - assertOptsInvalid(t, opts, errAlterNeedsAtLeastOneProperty) + assertOptsInvalidJoinedErrors(t, opts, errAtLeastOneOf("AlterPipeOptions.Set", "ErrorIntegration", "PipeExecutionPaused", "Comment")) }) t.Run("validation: empty tags slice for set", func(t *testing.T) { @@ -95,13 +95,13 @@ func TestPipesAlter(t *testing.T) { opts.SetTags = &PipeSetTags{ Tag: []TagAssociation{}, } - assertOptsInvalid(t, opts, errAlterNeedsAtLeastOneProperty) + assertOptsInvalidJoinedErrors(t, opts, errNotSet("AlterPipeOptions.SetTags", "Tag")) }) t.Run("validation: no property to unset", func(t *testing.T) { opts := defaultOpts() opts.Unset = &PipeUnset{} - assertOptsInvalid(t, opts, errAlterNeedsAtLeastOneProperty) + assertOptsInvalidJoinedErrors(t, opts, errAtLeastOneOf("AlterPipeOptions.Unset", "PipeExecutionPaused", "Comment")) }) t.Run("validation: empty tags slice for unset", func(t *testing.T) { @@ -109,7 +109,7 @@ func TestPipesAlter(t *testing.T) { opts.UnsetTags = &PipeUnsetTags{ Tag: []ObjectIdentifier{}, } - assertOptsInvalid(t, opts, errAlterNeedsAtLeastOneProperty) + assertOptsInvalidJoinedErrors(t, opts, errNotSet("AlterPipeOptions.UnsetTags", "Tag")) }) t.Run("set tag: single", func(t *testing.T) { @@ -212,13 +212,13 @@ func TestPipesDrop(t *testing.T) { t.Run("validation: nil options", func(t *testing.T) { var opts *DropPipeOptions = nil - assertOptsInvalid(t, opts, ErrNilOptions) + assertOptsInvalidJoinedErrors(t, opts, ErrNilOptions) }) t.Run("validation: incorrect identifier", func(t *testing.T) { opts := defaultOpts() opts.name = NewSchemaObjectIdentifier("", "", "") - assertOptsInvalid(t, opts, ErrInvalidObjectIdentifier) + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) t.Run("empty options", func(t *testing.T) { @@ -244,19 +244,19 @@ func TestPipesShow(t *testing.T) { t.Run("validation: nil options", func(t *testing.T) { var opts *ShowPipeOptions = nil - assertOptsInvalid(t, opts, ErrNilOptions) + assertOptsInvalidJoinedErrors(t, opts, ErrNilOptions) }) t.Run("validation: empty like", func(t *testing.T) { opts := defaultOpts() opts.Like = &Like{} - assertOptsInvalid(t, opts, ErrPatternRequiredForLikeKeyword) + assertOptsInvalidJoinedErrors(t, opts, ErrPatternRequiredForLikeKeyword) }) t.Run("validation: empty in", func(t *testing.T) { opts := defaultOpts() opts.In = &In{} - assertOptsInvalid(t, opts, errScopeRequiredForInKeyword) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("ShowPipeOptions.In", "Account", "Database", "Schema")) }) t.Run("validation: exactly one scope for in", func(t *testing.T) { @@ -265,7 +265,7 @@ func TestPipesShow(t *testing.T) { Account: Bool(true), Database: databaseIdentifier, } - assertOptsInvalid(t, opts, errScopeRequiredForInKeyword) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("ShowPipeOptions.In", "Account", "Database", "Schema")) }) t.Run("empty options", func(t *testing.T) { @@ -350,13 +350,13 @@ func TestPipesDescribe(t *testing.T) { t.Run("validation: nil options", func(t *testing.T) { var opts *describePipeOptions = nil - assertOptsInvalid(t, opts, ErrNilOptions) + assertOptsInvalidJoinedErrors(t, opts, ErrNilOptions) }) t.Run("validation: incorrect identifier", func(t *testing.T) { opts := defaultOpts() opts.name = NewSchemaObjectIdentifier("", "", "") - assertOptsInvalid(t, opts, ErrInvalidObjectIdentifier) + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) t.Run("with name", func(t *testing.T) { diff --git a/pkg/sdk/pipes_validations.go b/pkg/sdk/pipes_validations.go index 593dd5f5e7..22b50dceac 100644 --- a/pkg/sdk/pipes_validations.go +++ b/pkg/sdk/pipes_validations.go @@ -14,92 +14,84 @@ var ( func (opts *CreatePipeOptions) validate() error { if opts == nil { - return ErrNilOptions + return errors.Join(ErrNilOptions) } + var errs []error if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + errs = append(errs, ErrInvalidObjectIdentifier) } if opts.copyStatement == "" { - return errCopyStatementRequired + errs = append(errs, errNotSet("CreatePipeOptions", "copyStatement")) } - return nil + return errors.Join(errs...) } func (opts *AlterPipeOptions) validate() error { if opts == nil { - return ErrNilOptions + return errors.Join(ErrNilOptions) } + var errs []error if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + errs = append(errs, ErrInvalidObjectIdentifier) } - if ok := exactlyOneValueSet( - opts.Set, - opts.Unset, - opts.SetTags, - opts.UnsetTags, - opts.Refresh, - ); !ok { - return errAlterNeedsExactlyOneAction + if ok := exactlyOneValueSet(opts.Set, opts.Unset, opts.SetTags, opts.UnsetTags, opts.Refresh); !ok { + errs = append(errs, errExactlyOneOf("AlterPipeOptions", "Set", "Unset", "SetTags", "UnsetTags", "Refresh")) } if set := opts.Set; valueSet(set) { if !anyValueSet(set.ErrorIntegration, set.PipeExecutionPaused, set.Comment) { - return errAlterNeedsAtLeastOneProperty + errs = append(errs, errAtLeastOneOf("AlterPipeOptions.Set", "ErrorIntegration", "PipeExecutionPaused", "Comment")) } } if unset := opts.Unset; valueSet(unset) { if !anyValueSet(unset.PipeExecutionPaused, unset.Comment) { - return errAlterNeedsAtLeastOneProperty + errs = append(errs, errAtLeastOneOf("AlterPipeOptions.Unset", "PipeExecutionPaused", "Comment")) } } if setTags := opts.SetTags; valueSet(setTags) { if !valueSet(setTags.Tag) { - return errAlterNeedsAtLeastOneProperty + errs = append(errs, errNotSet("AlterPipeOptions.SetTags", "Tag")) } } if unsetTags := opts.UnsetTags; valueSet(unsetTags) { if !valueSet(unsetTags.Tag) { - return errAlterNeedsAtLeastOneProperty + errs = append(errs, errNotSet("AlterPipeOptions.UnsetTags", "Tag")) } } - return nil + return errors.Join(errs...) } func (opts *DropPipeOptions) validate() error { if opts == nil { - return ErrNilOptions + return errors.Join(ErrNilOptions) } + var errs []error if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + errs = append(errs, ErrInvalidObjectIdentifier) } - return nil + return errors.Join(errs...) } func (opts *ShowPipeOptions) validate() error { if opts == nil { - return ErrNilOptions + return errors.Join(ErrNilOptions) } + var errs []error if valueSet(opts.Like) && !valueSet(opts.Like.Pattern) { - return ErrPatternRequiredForLikeKeyword + errs = append(errs, ErrPatternRequiredForLikeKeyword) } if valueSet(opts.In) && !exactlyOneValueSet(opts.In.Account, opts.In.Database, opts.In.Schema) { - return errScopeRequiredForInKeyword + errs = append(errs, errExactlyOneOf("ShowPipeOptions.In", "Account", "Database", "Schema")) } - return nil + return errors.Join(errs...) } func (opts *describePipeOptions) validate() error { if opts == nil { - return ErrNilOptions + return errors.Join(ErrNilOptions) } + var errs []error if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + errs = append(errs, ErrInvalidObjectIdentifier) } - return nil + return errors.Join(errs...) } - -var ( - errCopyStatementRequired = errors.New("copy statement required") - errScopeRequiredForInKeyword = errors.New("exactly one scope must be specified for in keyword") - errAlterNeedsExactlyOneAction = errors.New("alter statement needs exactly one action from: set, unset, refresh") - errAlterNeedsAtLeastOneProperty = errors.New("alter statement needs at least one property") -) diff --git a/pkg/sdk/poc/generator/field.go b/pkg/sdk/poc/generator/field.go index edfc2b7106..b3eaf639f5 100644 --- a/pkg/sdk/poc/generator/field.go +++ b/pkg/sdk/poc/generator/field.go @@ -133,6 +133,15 @@ func (f *Field) Path() string { } } +// PathWithRoot returns the way through the tree to the top, with dot separator and root included (e.g. Struct.SomeField.SomeChild) +func (f *Field) PathWithRoot() string { + if f.IsRoot() { + return f.Name + } else { + return fmt.Sprintf("%s.%s", f.Parent.Path(), f.Name) + } +} + // DtoKind returns what should be fields kind in generated DTO, because it may differ from Kind func (f *Field) DtoKind() string { switch { diff --git a/pkg/sdk/poc/generator/validation.go b/pkg/sdk/poc/generator/validation.go index 39729f3a47..7c4d29b881 100644 --- a/pkg/sdk/poc/generator/validation.go +++ b/pkg/sdk/poc/generator/validation.go @@ -57,13 +57,13 @@ func (v *Validation) Condition(field *Field) string { case ValidIdentifier: return fmt.Sprintf("!ValidObjectIdentifier(%s)", strings.Join(v.fieldsWithPath(field), ",")) case ValidIdentifierIfSet: - return fmt.Sprintf("valueSet(%s) && !ValidObjectIdentifier(%s)", strings.Join(v.fieldsWithPath(field), ","), strings.Join(v.fieldsWithPath(field), ",")) + return fmt.Sprintf("%s != nil && !ValidObjectIdentifier(%s)", strings.Join(v.fieldsWithPath(field), ","), strings.Join(v.fieldsWithPath(field), ",")) case ConflictingFields: return fmt.Sprintf("everyValueSet(%s)", strings.Join(v.fieldsWithPath(field), ",")) case ExactlyOneValueSet: - return fmt.Sprintf("ok := exactlyOneValueSet(%s); !ok", strings.Join(v.fieldsWithPath(field), ",")) + return fmt.Sprintf("!exactlyOneValueSet(%s)", strings.Join(v.fieldsWithPath(field), ",")) case AtLeastOneValueSet: - return fmt.Sprintf("ok := anyValueSet(%s); !ok", strings.Join(v.fieldsWithPath(field), ",")) + return fmt.Sprintf("!anyValueSet(%s)", strings.Join(v.fieldsWithPath(field), ",")) case ValidateValue: return fmt.Sprintf("err := %s.validate(); err != nil", strings.Join(v.fieldsWithPath(field.Parent), ",")) } @@ -77,11 +77,11 @@ func (v *Validation) ReturnedError(field *Field) string { case ValidIdentifierIfSet: return "ErrInvalidObjectIdentifier" case ConflictingFields: - return fmt.Sprintf(`errOneOf("%s", %s)`, field.Name, strings.Join(v.paramsQuoted(), ",")) + return fmt.Sprintf(`errOneOf("%s", %s)`, field.PathWithRoot(), strings.Join(v.paramsQuoted(), ",")) case ExactlyOneValueSet: - return fmt.Sprintf("errExactlyOneOf(%s)", strings.Join(v.paramsQuoted(), ",")) + return fmt.Sprintf(`errExactlyOneOf("%s", %s)`, field.PathWithRoot(), strings.Join(v.paramsQuoted(), ",")) case AtLeastOneValueSet: - return fmt.Sprintf("errAtLeastOneOf(%s)", strings.Join(v.paramsQuoted(), ",")) + return fmt.Sprintf(`errAtLeastOneOf("%s", %s)`, field.PathWithRoot(), strings.Join(v.paramsQuoted(), ",")) case ValidateValue: return "err" } diff --git a/pkg/sdk/replication_functions.go b/pkg/sdk/replication_functions.go index 2b42c42dbb..99efec33cc 100644 --- a/pkg/sdk/replication_functions.go +++ b/pkg/sdk/replication_functions.go @@ -2,6 +2,7 @@ package sdk import ( "context" + "errors" "time" ) @@ -94,6 +95,9 @@ type ShowRegionsOptions struct { } func (opts *ShowRegionsOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } return nil } diff --git a/pkg/sdk/resource_monitors.go b/pkg/sdk/resource_monitors.go index ce9cea34ca..6a865f9d07 100644 --- a/pkg/sdk/resource_monitors.go +++ b/pkg/sdk/resource_monitors.go @@ -196,8 +196,11 @@ type ResourceMonitorWith struct { } func (opts *CreateResourceMonitorOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + return errors.Join(ErrInvalidObjectIdentifier) } return nil } @@ -281,17 +284,22 @@ type AlterResourceMonitorOptions struct { } func (opts *AlterResourceMonitorOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } + var errs []error if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + errs = append(errs, ErrInvalidObjectIdentifier) } - if opts.Set == nil { - return nil + if valueSet(opts.Set) { + if (opts.Set.Frequency != nil && opts.Set.StartTimestamp == nil) || (opts.Set.Frequency == nil && opts.Set.StartTimestamp != nil) { + errs = append(errs, errors.New("must specify frequency and start time together")) + } } - if (opts.Set.Frequency != nil && opts.Set.StartTimestamp == nil) || (opts.Set.Frequency == nil && opts.Set.StartTimestamp != nil) { - return errors.New("must specify frequency and start time together") + if !exactlyOneValueSet(opts.Set, opts.NotifyUsers) && opts.Triggers == nil { + errs = append(errs, errExactlyOneOf("AlterResourceMonitorOptions", "Set", "NotifyUsers", "Triggers")) } - - return nil + return errors.Join(errs...) } func (v *resourceMonitors) Alter(ctx context.Context, id AccountObjectIdentifier, opts *AlterResourceMonitorOptions) error { @@ -327,8 +335,11 @@ type dropResourceMonitorOptions struct { } func (opts *dropResourceMonitorOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + return errors.Join(ErrInvalidObjectIdentifier) } return nil } @@ -356,6 +367,9 @@ type ShowResourceMonitorOptions struct { } func (opts *ShowResourceMonitorOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } return nil } diff --git a/pkg/sdk/resource_monitors_test.go b/pkg/sdk/resource_monitors_test.go index b42dbf8795..2357d5305d 100644 --- a/pkg/sdk/resource_monitors_test.go +++ b/pkg/sdk/resource_monitors_test.go @@ -10,7 +10,7 @@ func TestResourceMonitorCreate(t *testing.T) { t.Run("validation: empty options", func(t *testing.T) { opts := &CreateResourceMonitorOptions{} - assertOptsInvalid(t, opts, ErrInvalidObjectIdentifier) + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) t.Run("with complete options", func(t *testing.T) { @@ -55,14 +55,14 @@ func TestResourceMonitorAlter(t *testing.T) { t.Run("validation: empty options", func(t *testing.T) { opts := &AlterResourceMonitorOptions{} - assertOptsInvalid(t, opts, ErrInvalidObjectIdentifier) + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) t.Run("only name", func(t *testing.T) { opts := &AlterResourceMonitorOptions{ name: id, } - assertOptsValidAndSQLEquals(t, opts, "ALTER RESOURCE MONITOR %s", id.FullyQualifiedName()) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterResourceMonitorOptions", "Set", "NotifyUsers", "Triggers")) }) t.Run("with a single set", func(t *testing.T) { @@ -97,7 +97,7 @@ func TestResourceMonitorDrop(t *testing.T) { t.Run("empty options", func(t *testing.T) { opts := &dropResourceMonitorOptions{} - assertOptsInvalid(t, opts, ErrInvalidObjectIdentifier) + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) t.Run("only name", func(t *testing.T) { diff --git a/pkg/sdk/roles_test.go b/pkg/sdk/roles_test.go index 96b219092d..061d6f7aee 100644 --- a/pkg/sdk/roles_test.go +++ b/pkg/sdk/roles_test.go @@ -1,7 +1,6 @@ package sdk import ( - "errors" "testing" ) @@ -42,7 +41,7 @@ func TestRolesCreate(t *testing.T) { IfNotExists: Bool(true), OrReplace: Bool(true), } - assertOptsInvalidJoinedErrors(t, opts, errOneOf("OrReplace", "IfNotExists")) + assertOptsInvalidJoinedErrors(t, opts, errOneOf("CreateRoleOptions", "OrReplace", "IfNotExists")) }) } @@ -66,7 +65,7 @@ func TestRolesDrop(t *testing.T) { opts := &DropRoleOptions{ name: NewAccountObjectIdentifier(""), } - assertOptsInvalid(t, opts, ErrInvalidObjectIdentifier) + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) } @@ -136,7 +135,7 @@ func TestRolesAlter(t *testing.T) { opts := &AlterRoleOptions{ name: RandomAccountObjectIdentifier(), } - assertOptsInvalidJoinedErrors(t, opts, errors.New("no alter action specified")) + assertOptsInvalidJoinedErrors(t, opts, errAtLeastOneOf("AlterRoleOptions", "RenameTo", "SetComment", "UnsetComment", "SetTags", "UnsetTags")) }) t.Run("validation: more than one alter action specified", func(t *testing.T) { @@ -145,7 +144,7 @@ func TestRolesAlter(t *testing.T) { SetComment: String("comment"), UnsetComment: Bool(true), } - assertOptsInvalidJoinedErrors(t, opts, errOneOf("RenameTo", "SetComment", "UnsetComment", "SetTags", "UnsetTags")) + assertOptsInvalidJoinedErrors(t, opts, errOneOf("AlterRoleOptions", "RenameTo", "SetComment", "UnsetComment", "SetTags", "UnsetTags")) }) } @@ -216,7 +215,7 @@ func TestRolesGrant(t *testing.T) { opts := &GrantRoleOptions{ name: NewAccountObjectIdentifier(""), } - assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier, errors.New("only one grant option can be set [TO ROLE or TO USER]")) + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier, errOneOf("GrantRoleOptions.Grant", "Role", "User")) }) t.Run("validation: invalid object identifier for granted role", func(t *testing.T) { @@ -227,7 +226,7 @@ func TestRolesGrant(t *testing.T) { Role: &id, }, } - assertOptsInvalidJoinedErrors(t, opts, errors.New("invalid object identifier for granted role")) + assertOptsInvalidJoinedErrors(t, opts, errInvalidIdentifier("GrantRoleOptions.Grant", "Role")) }) t.Run("validation: invalid object identifier for granted user", func(t *testing.T) { @@ -238,7 +237,7 @@ func TestRolesGrant(t *testing.T) { User: &id, }, } - assertOptsInvalidJoinedErrors(t, opts, errors.New("invalid object identifier for granted user")) + assertOptsInvalidJoinedErrors(t, opts, errInvalidIdentifier("GrantRoleOptions.Grant", "User")) }) } @@ -267,6 +266,6 @@ func TestRolesRevoke(t *testing.T) { opts := &RevokeRoleOptions{ name: NewAccountObjectIdentifier(""), } - assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier, errors.New("only one revoke option can be set [FROM ROLE or FROM USER]")) + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier, errOneOf("RevokeRoleOptions.Revoke", "Role", "User")) }) } diff --git a/pkg/sdk/roles_validations.go b/pkg/sdk/roles_validations.go index a231928b59..44c30754af 100644 --- a/pkg/sdk/roles_validations.go +++ b/pkg/sdk/roles_validations.go @@ -13,49 +13,49 @@ var ( func (opts *CreateRoleOptions) validate() error { if opts == nil { - return ErrNilOptions + return errors.Join(ErrNilOptions) } var errs []error if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } if everyValueSet(opts.OrReplace, opts.IfNotExists) { - errs = append(errs, errOneOf("OrReplace", "IfNotExists")) + errs = append(errs, errOneOf("CreateRoleOptions", "OrReplace", "IfNotExists")) } return errors.Join(errs...) } func (opts *AlterRoleOptions) validate() error { if opts == nil { - return ErrNilOptions + return errors.Join(ErrNilOptions) } var errs []error if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } if everyValueNil(opts.RenameTo, opts.SetComment, opts.UnsetComment, opts.SetTags, opts.UnsetTags) { - errs = append(errs, errors.New("no alter action specified")) + errs = append(errs, errAtLeastOneOf("AlterRoleOptions", "RenameTo", "SetComment", "UnsetComment", "SetTags", "UnsetTags")) } if anyValueSet(opts.RenameTo, opts.SetComment, opts.UnsetComment, opts.SetTags, opts.UnsetTags) && !exactlyOneValueSet(opts.RenameTo, opts.SetComment, opts.UnsetComment, opts.SetTags, opts.UnsetTags) { - errs = append(errs, errOneOf("RenameTo", "SetComment", "UnsetComment", "SetTags", "UnsetTags")) + errs = append(errs, errOneOf("AlterRoleOptions", "RenameTo", "SetComment", "UnsetComment", "SetTags", "UnsetTags")) } return errors.Join(errs...) } func (opts *DropRoleOptions) validate() error { if opts == nil { - return ErrNilOptions + return errors.Join(ErrNilOptions) } if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + return errors.Join(ErrInvalidObjectIdentifier) } return nil } func (opts *ShowRoleOptions) validate() error { if opts == nil { - return ErrNilOptions + return errors.Join(ErrNilOptions) } var errs []error if valueSet(opts.Like) && !valueSet(opts.Like.Pattern) { @@ -69,34 +69,34 @@ func (opts *ShowRoleOptions) validate() error { func (opts *GrantRoleOptions) validate() error { if opts == nil { - return ErrNilOptions + return errors.Join(ErrNilOptions) } var errs []error if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } if (opts.Grant.Role != nil && opts.Grant.User != nil) || (opts.Grant.Role == nil && opts.Grant.User == nil) { - errs = append(errs, errors.New("only one grant option can be set [TO ROLE or TO USER]")) + errs = append(errs, errOneOf("GrantRoleOptions.Grant", "Role", "User")) } if opts.Grant.Role != nil && !ValidObjectIdentifier(opts.Grant.Role) { - errs = append(errs, errors.New("invalid object identifier for granted role")) + errs = append(errs, errInvalidIdentifier("GrantRoleOptions.Grant", "Role")) } if opts.Grant.User != nil && !ValidObjectIdentifier(opts.Grant.User) { - errs = append(errs, errors.New("invalid object identifier for granted user")) + errs = append(errs, errInvalidIdentifier("GrantRoleOptions.Grant", "User")) } return errors.Join(errs...) } func (opts *RevokeRoleOptions) validate() error { if opts == nil { - return ErrNilOptions + return errors.Join(ErrNilOptions) } var errs []error if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } if (opts.Revoke.Role != nil && opts.Revoke.User != nil) || (opts.Revoke.Role == nil && opts.Revoke.User == nil) { - errs = append(errs, errors.New("only one revoke option can be set [FROM ROLE or FROM USER]")) + errs = append(errs, errOneOf("RevokeRoleOptions.Revoke", "Role", "User")) } return errors.Join(errs...) } diff --git a/pkg/sdk/schemas.go b/pkg/sdk/schemas.go index 90b0e32d7f..5fbe0b6fd3 100644 --- a/pkg/sdk/schemas.go +++ b/pkg/sdk/schemas.go @@ -108,6 +108,9 @@ type CreateSchemaOptions struct { } func (opts *CreateSchemaOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } var errs []error if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) @@ -118,7 +121,7 @@ func (opts *CreateSchemaOptions) validate() error { } } if everyValueSet(opts.OrReplace, opts.IfNotExists) { - errs = append(errs, errOneOf("IfNotExists", "OrReplace")) + errs = append(errs, errOneOf("CreateSchemaOptions", "IfNotExists", "OrReplace")) } return errors.Join(errs...) } @@ -155,6 +158,9 @@ type AlterSchemaOptions struct { } func (opts *AlterSchemaOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } var errs []error if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) @@ -233,12 +239,15 @@ type DropSchemaOptions struct { } func (opts *DropSchemaOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } var errs []error if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } if everyValueSet(opts.Cascade, opts.Restrict) { - errs = append(errs, errors.New("only one of the fields [ Cascade | Restrict ] can be set at once")) + errs = append(errs, errOneOf("DropSchemaOptions", "Cascade", "Restrict")) } return errors.Join(errs...) } @@ -267,8 +276,11 @@ type undropSchemaOptions struct { } func (opts *undropSchemaOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + return errors.Join(ErrInvalidObjectIdentifier) } return nil } @@ -296,8 +308,11 @@ type describeSchemaOptions struct { } func (opts *describeSchemaOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + return errors.Join(ErrInvalidObjectIdentifier) } return nil } @@ -346,6 +361,9 @@ type ShowSchemaOptions struct { } func (opts *ShowSchemaOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } return nil } diff --git a/pkg/sdk/session_policies_gen_test.go b/pkg/sdk/session_policies_gen_test.go index 0c68390c53..f45948ff72 100644 --- a/pkg/sdk/session_policies_gen_test.go +++ b/pkg/sdk/session_policies_gen_test.go @@ -68,7 +68,7 @@ func TestSessionPolicies_Alter(t *testing.T) { t.Run("validation: exactly one field from [opts.RenameTo opts.Set opts.SetTags opts.UnsetTags opts.Unset] should be present - none present", func(t *testing.T) { opts := defaultOpts() - assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("RenameTo", "Set", "SetTags", "UnsetTags", "Unset")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterSessionPolicyOptions", "RenameTo", "Set", "SetTags", "UnsetTags", "Unset")) }) t.Run("validation: exactly one field from [opts.RenameTo opts.Set opts.SetTags opts.UnsetTags opts.Unset] should be present - more present", func(t *testing.T) { @@ -79,19 +79,19 @@ func TestSessionPolicies_Alter(t *testing.T) { opts.Unset = &SessionPolicyUnset{ Comment: Bool(true), } - assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("RenameTo", "Set", "SetTags", "UnsetTags", "Unset")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterSessionPolicyOptions", "RenameTo", "Set", "SetTags", "UnsetTags", "Unset")) }) t.Run("validation: at least one of the fields [opts.Set.SessionIdleTimeoutMins opts.Set.SessionUiIdleTimeoutMins opts.Set.Comment] should be set", func(t *testing.T) { opts := defaultOpts() opts.Set = &SessionPolicySet{} - assertOptsInvalidJoinedErrors(t, opts, errAtLeastOneOf("SessionIdleTimeoutMins", "SessionUiIdleTimeoutMins", "Comment")) + assertOptsInvalidJoinedErrors(t, opts, errAtLeastOneOf("AlterSessionPolicyOptions.Set", "SessionIdleTimeoutMins", "SessionUiIdleTimeoutMins", "Comment")) }) t.Run("validation: at least one of the fields [opts.Unset.SessionIdleTimeoutMins opts.Unset.SessionUiIdleTimeoutMins opts.Unset.Comment] should be set", func(t *testing.T) { opts := defaultOpts() opts.Unset = &SessionPolicyUnset{} - assertOptsInvalidJoinedErrors(t, opts, errAtLeastOneOf("SessionIdleTimeoutMins", "SessionUiIdleTimeoutMins", "Comment")) + assertOptsInvalidJoinedErrors(t, opts, errAtLeastOneOf("AlterSessionPolicyOptions.Unset", "SessionIdleTimeoutMins", "SessionUiIdleTimeoutMins", "Comment")) }) t.Run("alter set", func(t *testing.T) { diff --git a/pkg/sdk/session_policies_validations_gen.go b/pkg/sdk/session_policies_validations_gen.go index f14894942c..b30a05f1d8 100644 --- a/pkg/sdk/session_policies_validations_gen.go +++ b/pkg/sdk/session_policies_validations_gen.go @@ -33,16 +33,16 @@ func (opts *AlterSessionPolicyOptions) validate() error { errs = append(errs, ErrInvalidObjectIdentifier) } if ok := exactlyOneValueSet(opts.RenameTo, opts.Set, opts.SetTags, opts.UnsetTags, opts.Unset); !ok { - errs = append(errs, errExactlyOneOf("RenameTo", "Set", "SetTags", "UnsetTags", "Unset")) + errs = append(errs, errExactlyOneOf("AlterSessionPolicyOptions", "RenameTo", "Set", "SetTags", "UnsetTags", "Unset")) } if valueSet(opts.Set) { if ok := anyValueSet(opts.Set.SessionIdleTimeoutMins, opts.Set.SessionUiIdleTimeoutMins, opts.Set.Comment); !ok { - errs = append(errs, errAtLeastOneOf("SessionIdleTimeoutMins", "SessionUiIdleTimeoutMins", "Comment")) + errs = append(errs, errAtLeastOneOf("AlterSessionPolicyOptions.Set", "SessionIdleTimeoutMins", "SessionUiIdleTimeoutMins", "Comment")) } } if valueSet(opts.Unset) { if ok := anyValueSet(opts.Unset.SessionIdleTimeoutMins, opts.Unset.SessionUiIdleTimeoutMins, opts.Unset.Comment); !ok { - errs = append(errs, errAtLeastOneOf("SessionIdleTimeoutMins", "SessionUiIdleTimeoutMins", "Comment")) + errs = append(errs, errAtLeastOneOf("AlterSessionPolicyOptions.Unset", "SessionIdleTimeoutMins", "SessionUiIdleTimeoutMins", "Comment")) } } return errors.Join(errs...) diff --git a/pkg/sdk/sessions.go b/pkg/sdk/sessions.go index 74e95439ad..fcccc978ec 100644 --- a/pkg/sdk/sessions.go +++ b/pkg/sdk/sessions.go @@ -2,6 +2,7 @@ package sdk import ( "context" + "errors" "fmt" ) @@ -37,20 +38,24 @@ type AlterSessionOptions struct { } func (opts *AlterSessionOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } + var errs []error if everyValueNil(opts.Set, opts.Unset) { - return fmt.Errorf("either SET or UNSET must be set") + errs = append(errs, errOneOf("AlterSessionOptions", "Set", "Unset")) } if valueSet(opts.Set) { if err := opts.Set.validate(); err != nil { - return err + errs = append(errs, err) } } if valueSet(opts.Unset) { if err := opts.Unset.validate(); err != nil { - return err + errs = append(errs, err) } } - return nil + return errors.Join(errs...) } type SessionSet struct { diff --git a/pkg/sdk/shares.go b/pkg/sdk/shares.go index 7a55136e9e..a4946c875f 100644 --- a/pkg/sdk/shares.go +++ b/pkg/sdk/shares.go @@ -2,6 +2,7 @@ package sdk import ( "context" + "errors" "fmt" "strings" "time" @@ -111,8 +112,11 @@ type CreateShareOptions struct { } func (opts *CreateShareOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } if !ValidObjectIdentifier(opts.name) { - return fmt.Errorf("not a valid object identifier: %s", opts.name) + return errors.Join(ErrInvalidObjectIdentifier) } return nil } @@ -141,6 +145,12 @@ type dropShareOptions struct { } func (opts *dropShareOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } + if !ValidObjectIdentifier(opts.name) { + return errors.Join(ErrInvalidObjectIdentifier) + } return nil } @@ -172,33 +182,37 @@ type AlterShareOptions struct { } func (opts *AlterShareOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } + var errs []error if !ValidObjectIdentifier(opts.name) { - return fmt.Errorf("not a valid object identifier: %s", opts.name) + errs = append(errs, ErrInvalidObjectIdentifier) } - if ok := exactlyOneValueSet(opts.Add, opts.Remove, opts.Set, opts.Unset); !ok { - return errExactlyOneOf("Add", "Remove", "Set", "Unset") + if !exactlyOneValueSet(opts.Add, opts.Remove, opts.Set, opts.Unset) { + errs = append(errs, errExactlyOneOf("AlterShareOptions", "Add", "Remove", "Set", "Unset")) } if valueSet(opts.Add) { if err := opts.Add.validate(); err != nil { - return err + errs = append(errs, err) } } if valueSet(opts.Remove) { if err := opts.Remove.validate(); err != nil { - return err + errs = append(errs, err) } } if valueSet(opts.Set) { if err := opts.Set.validate(); err != nil { - return err + errs = append(errs, err) } } if valueSet(opts.Unset) { if err := opts.Unset.validate(); err != nil { - return err + errs = append(errs, err) } } - return nil + return errors.Join(errs...) } type ShareAdd struct { @@ -243,8 +257,8 @@ type ShareUnset struct { } func (v *ShareUnset) validate() error { - if ok := exactlyOneValueSet(v.Comment, v.Tag); !ok { - return fmt.Errorf("exactly one of comment, tag must be set") + if !exactlyOneValueSet(v.Comment, v.Tag) { + return errExactlyOneOf("ShareUnset", "Comment", "Tag") } return nil } @@ -275,6 +289,9 @@ type ShowShareOptions struct { } func (opts *ShowShareOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } return nil } @@ -348,8 +365,11 @@ type describeShareOptions struct { } func (opts *describeShareOptions) validate() error { - if ok := ValidObjectIdentifier(opts.name); !ok { - return ErrInvalidObjectIdentifier + if opts == nil { + return errors.Join(ErrNilOptions) + } + if !ValidObjectIdentifier(opts.name) { + return errors.Join(ErrInvalidObjectIdentifier) } return nil } diff --git a/pkg/sdk/shares_test.go b/pkg/sdk/shares_test.go index 3f6e038442..94b52e95bc 100644 --- a/pkg/sdk/shares_test.go +++ b/pkg/sdk/shares_test.go @@ -30,7 +30,7 @@ func TestShareAlter(t *testing.T) { opts := &AlterShareOptions{ name: NewAccountObjectIdentifier("myshare"), } - assertOptsInvalid(t, opts, errExactlyOneOf("Add", "Remove", "Set", "Unset")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterShareOptions", "Add", "Remove", "Set", "Unset")) }) t.Run("with add", func(t *testing.T) { diff --git a/pkg/sdk/streams_gen_test.go b/pkg/sdk/streams_gen_test.go index 9ddf6362c3..9070314bc0 100644 --- a/pkg/sdk/streams_gen_test.go +++ b/pkg/sdk/streams_gen_test.go @@ -48,14 +48,14 @@ func TestStreams_CreateOnTable(t *testing.T) { opts := defaultOpts() opts.On.At = Bool(true) opts.On.Before = Bool(true) - assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("At", "Before")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateOnTableStreamOptions.On", "At", "Before")) }) t.Run("validation: exactly one field from [opts.On.Statement.Timestamp opts.On.Statement.Offset opts.On.Statement.Statement opts.On.Statement.Stream] should be present", func(t *testing.T) { opts := defaultOpts() opts.On.At = Bool(true) opts.On.Statement = OnStreamStatement{} - assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("Timestamp", "Offset", "Statement", "Stream")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateOnTableStreamOptions.On.Statement", "Timestamp", "Offset", "Statement", "Stream")) }) t.Run("basic", func(t *testing.T) { @@ -126,13 +126,13 @@ func TestStreams_CreateOnExternalTable(t *testing.T) { opts := defaultOpts() opts.On.At = Bool(true) opts.On.Before = Bool(true) - assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("At", "Before")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateOnExternalTableStreamOptions.On", "At", "Before")) }) t.Run("validation: exactly one field from [opts.On.Statement.Timestamp opts.On.Statement.Offset opts.On.Statement.Statement opts.On.Statement.Stream] should be present", func(t *testing.T) { opts := defaultOpts() opts.On.Statement = OnStreamStatement{} - assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("Timestamp", "Offset", "Statement", "Stream")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateOnExternalTableStreamOptions.On.Statement", "Timestamp", "Offset", "Statement", "Stream")) }) t.Run("basic", func(t *testing.T) { @@ -253,13 +253,13 @@ func TestStreams_CreateOnView(t *testing.T) { opts := defaultOpts() opts.On.At = Bool(true) opts.On.Before = Bool(true) - assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("At", "Before")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateOnViewStreamOptions.On", "At", "Before")) }) t.Run("validation: exactly one field from [opts.On.Statement.Timestamp opts.On.Statement.Offset opts.On.Statement.Statement opts.On.Statement.Stream] should be present", func(t *testing.T) { opts := defaultOpts() opts.On.Statement = OnStreamStatement{} - assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("Timestamp", "Offset", "Statement", "Stream")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateOnViewStreamOptions.On.Statement", "Timestamp", "Offset", "Statement", "Stream")) }) t.Run("basic", func(t *testing.T) { @@ -351,7 +351,7 @@ func TestStreams_Alter(t *testing.T) { t.Run("validation: exactly one field from [opts.SetComment opts.UnsetComment opts.SetTags opts.UnsetTags] should be present", func(t *testing.T) { opts := defaultOpts() - assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("SetComment", "UnsetComment", "SetTags", "UnsetTags")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterStreamOptions", "SetComment", "UnsetComment", "SetTags", "UnsetTags")) }) t.Run("set comment", func(t *testing.T) { diff --git a/pkg/sdk/streams_validations_gen.go b/pkg/sdk/streams_validations_gen.go index b350ec10bb..fb2e53fba6 100644 --- a/pkg/sdk/streams_validations_gen.go +++ b/pkg/sdk/streams_validations_gen.go @@ -30,11 +30,11 @@ func (opts *CreateOnTableStreamOptions) validate() error { } if valueSet(opts.On) { if ok := exactlyOneValueSet(opts.On.At, opts.On.Before); !ok { - errs = append(errs, errExactlyOneOf("At", "Before")) + errs = append(errs, errExactlyOneOf("CreateOnTableStreamOptions.On", "At", "Before")) } if valueSet(opts.On.Statement) { if ok := exactlyOneValueSet(opts.On.Statement.Timestamp, opts.On.Statement.Offset, opts.On.Statement.Statement, opts.On.Statement.Stream); !ok { - errs = append(errs, errExactlyOneOf("Timestamp", "Offset", "Statement", "Stream")) + errs = append(errs, errExactlyOneOf("CreateOnTableStreamOptions.On.Statement", "Timestamp", "Offset", "Statement", "Stream")) } } } @@ -57,11 +57,11 @@ func (opts *CreateOnExternalTableStreamOptions) validate() error { } if valueSet(opts.On) { if ok := exactlyOneValueSet(opts.On.At, opts.On.Before); !ok { - errs = append(errs, errExactlyOneOf("At", "Before")) + errs = append(errs, errExactlyOneOf("CreateOnExternalTableStreamOptions.On", "At", "Before")) } if valueSet(opts.On.Statement) { if ok := exactlyOneValueSet(opts.On.Statement.Timestamp, opts.On.Statement.Offset, opts.On.Statement.Statement, opts.On.Statement.Stream); !ok { - errs = append(errs, errExactlyOneOf("Timestamp", "Offset", "Statement", "Stream")) + errs = append(errs, errExactlyOneOf("CreateOnExternalTableStreamOptions.On.Statement", "Timestamp", "Offset", "Statement", "Stream")) } } } @@ -101,11 +101,11 @@ func (opts *CreateOnViewStreamOptions) validate() error { } if valueSet(opts.On) { if ok := exactlyOneValueSet(opts.On.At, opts.On.Before); !ok { - errs = append(errs, errExactlyOneOf("At", "Before")) + errs = append(errs, errExactlyOneOf("CreateOnViewStreamOptions.On", "At", "Before")) } if valueSet(opts.On.Statement) { if ok := exactlyOneValueSet(opts.On.Statement.Timestamp, opts.On.Statement.Offset, opts.On.Statement.Statement, opts.On.Statement.Stream); !ok { - errs = append(errs, errExactlyOneOf("Timestamp", "Offset", "Statement", "Stream")) + errs = append(errs, errExactlyOneOf("CreateOnViewStreamOptions.On.Statement", "Timestamp", "Offset", "Statement", "Stream")) } } } @@ -135,7 +135,7 @@ func (opts *AlterStreamOptions) validate() error { errs = append(errs, errOneOf("AlterStreamOptions", "IfExists", "UnsetTags")) } if ok := exactlyOneValueSet(opts.SetComment, opts.UnsetComment, opts.SetTags, opts.UnsetTags); !ok { - errs = append(errs, errExactlyOneOf("SetComment", "UnsetComment", "SetTags", "UnsetTags")) + errs = append(errs, errExactlyOneOf("AlterStreamOptions", "SetComment", "UnsetComment", "SetTags", "UnsetTags")) } return errors.Join(errs...) } diff --git a/pkg/sdk/tags_test.go b/pkg/sdk/tags_test.go index a630b66b28..ba5503c402 100644 --- a/pkg/sdk/tags_test.go +++ b/pkg/sdk/tags_test.go @@ -66,14 +66,14 @@ func TestTagCreate(t *testing.T) { }, } opts.Comment = String("comment") - assertOptsInvalidJoinedErrors(t, opts, errOneOf("Comment", "AllowedValues")) + assertOptsInvalidJoinedErrors(t, opts, errOneOf("createTagOptions", "Comment", "AllowedValues")) }) t.Run("validation: both ifNotExists and orReplace present", func(t *testing.T) { opts := defaultOpts() opts.IfNotExists = Bool(true) opts.OrReplace = Bool(true) - assertOptsInvalidJoinedErrors(t, opts, errOneOf("OrReplace", "IfNotExists")) + assertOptsInvalidJoinedErrors(t, opts, errOneOf("createTagOptions", "OrReplace", "IfNotExists")) }) t.Run("validation: multiple errors", func(t *testing.T) { @@ -81,7 +81,7 @@ func TestTagCreate(t *testing.T) { opts.name = NewSchemaObjectIdentifier("", "", "") opts.IfNotExists = Bool(true) opts.OrReplace = Bool(true) - assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier, errOneOf("OrReplace", "IfNotExists")) + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier, errOneOf("createTagOptions", "OrReplace", "IfNotExists")) }) } @@ -159,7 +159,7 @@ func TestTagShow(t *testing.T) { t.Run("validation: empty in", func(t *testing.T) { opts := defaultOpts() opts.In = &In{} - assertOptsInvalidJoinedErrors(t, opts, errScopeRequiredForInKeyword) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("showTagOptions.In", "Account", "Database", "Schema")) }) t.Run("show with empty options", func(t *testing.T) { @@ -284,7 +284,7 @@ func TestTagAlter(t *testing.T) { t.Run("validation: no alter action", func(t *testing.T) { opts := defaultOpts() - assertOptsInvalidJoinedErrors(t, opts, errAlterNeedsExactlyOneAction) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("alterTagOptions", "Add", "Drop", "Set", "Unset", "Rename")) }) t.Run("validation: multiple alter actions", func(t *testing.T) { @@ -295,7 +295,7 @@ func TestTagAlter(t *testing.T) { opts.Unset = &TagUnset{ AllowedValues: Bool(true), } - assertOptsInvalidJoinedErrors(t, opts, errAlterNeedsExactlyOneAction) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("alterTagOptions", "Add", "Drop", "Set", "Unset", "Rename")) }) t.Run("validation: invalid new name", func(t *testing.T) { @@ -319,6 +319,6 @@ func TestTagAlter(t *testing.T) { t.Run("validation: no property to unset", func(t *testing.T) { opts := defaultOpts() opts.Unset = &TagUnset{} - assertOptsInvalidJoinedErrors(t, opts, errAlterNeedsAtLeastOneProperty) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("TagUnset", "MaskingPolicies", "AllowedValues", "Comment")) }) } diff --git a/pkg/sdk/tags_validations.go b/pkg/sdk/tags_validations.go index dc3d19c0d3..d13c4b4a50 100644 --- a/pkg/sdk/tags_validations.go +++ b/pkg/sdk/tags_validations.go @@ -2,7 +2,6 @@ package sdk import ( "errors" - "fmt" ) var ( @@ -25,10 +24,10 @@ func (opts *createTagOptions) validate() error { errs = append(errs, ErrInvalidObjectIdentifier) } if everyValueSet(opts.OrReplace, opts.IfNotExists) && *opts.OrReplace && *opts.IfNotExists { - errs = append(errs, errOneOf("OrReplace", "IfNotExists")) + errs = append(errs, errOneOf("createTagOptions", "OrReplace", "IfNotExists")) } if valueSet(opts.Comment) && valueSet(opts.AllowedValues) { - errs = append(errs, errOneOf("Comment", "AllowedValues")) + errs = append(errs, errOneOf("createTagOptions", "Comment", "AllowedValues")) } if valueSet(opts.AllowedValues) { if err := opts.AllowedValues.validate(); err != nil { @@ -39,8 +38,8 @@ func (opts *createTagOptions) validate() error { } func (v *AllowedValues) validate() error { - if ok := validateIntInRange(len(v.Values), 1, 50); !ok { - return fmt.Errorf("number of the AllowedValues must be between 1 and 50") + if !validateIntInRange(len(v.Values), 1, 50) { + return errIntBetween("AllowedValues", "Values", 1, 50) } return nil } @@ -48,11 +47,11 @@ func (v *AllowedValues) validate() error { func (v *TagSet) validate() error { var errs []error if !exactlyOneValueSet(v.MaskingPolicies, v.Comment) { - errs = append(errs, errOneOf("MaskingPolicies", "Comment")) + errs = append(errs, errOneOf("TagSet", "MaskingPolicies", "Comment")) } if valueSet(v.MaskingPolicies) { - if ok := validateIntGreaterThanOrEqual(len(v.MaskingPolicies.MaskingPolicies), 1); !ok { - errs = append(errs, fmt.Errorf("number of the MaskingPolicies must be greater than zero")) + if !validateIntGreaterThan(len(v.MaskingPolicies.MaskingPolicies), 0) { + errs = append(errs, errIntValue("TagSet.MaskingPolicies", "MaskingPolicies", IntErrGreater, 0)) } } return errors.Join(errs...) @@ -61,11 +60,11 @@ func (v *TagSet) validate() error { func (v *TagUnset) validate() error { var errs []error if !exactlyOneValueSet(v.MaskingPolicies, v.AllowedValues, v.Comment) { - errs = append(errs, errOneOf("MaskingPolicies", "AllowedValues", "Comment")) + errs = append(errs, errExactlyOneOf("TagUnset", "MaskingPolicies", "AllowedValues", "Comment")) } if valueSet(v.MaskingPolicies) { - if ok := validateIntGreaterThanOrEqual(len(v.MaskingPolicies.MaskingPolicies), 1); !ok { - errs = append(errs, fmt.Errorf("number of the MaskingPolicies must be greater than zero")) + if !validateIntGreaterThan(len(v.MaskingPolicies.MaskingPolicies), 0) { + errs = append(errs, errIntValue("TagUnset.MaskingPolicies", "MaskingPolicies", IntErrGreater, 0)) } } return errors.Join(errs...) @@ -79,14 +78,8 @@ func (opts *alterTagOptions) validate() error { if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } - if ok := exactlyOneValueSet( - opts.Add, - opts.Drop, - opts.Set, - opts.Unset, - opts.Rename, - ); !ok { - errs = append(errs, errAlterNeedsExactlyOneAction) + if !exactlyOneValueSet(opts.Add, opts.Drop, opts.Set, opts.Unset, opts.Rename) { + errs = append(errs, errExactlyOneOf("alterTagOptions", "Add", "Drop", "Set", "Unset", "Rename")) } if valueSet(opts.Add) && valueSet(opts.Add.AllowedValues) { if err := opts.Add.AllowedValues.validate(); err != nil { @@ -107,9 +100,6 @@ func (opts *alterTagOptions) validate() error { if err := opts.Unset.validate(); err != nil { errs = append(errs, err) } - if !anyValueSet(opts.Unset.MaskingPolicies, opts.Unset.AllowedValues, opts.Unset.Comment) { - errs = append(errs, errAlterNeedsAtLeastOneProperty) - } } if valueSet(opts.Rename) { if !ValidObjectIdentifier(opts.Rename.Name) { @@ -128,7 +118,7 @@ func (opts *showTagOptions) validate() error { errs = append(errs, ErrPatternRequiredForLikeKeyword) } if valueSet(opts.In) && !exactlyOneValueSet(opts.In.Account, opts.In.Database, opts.In.Schema) { - errs = append(errs, errScopeRequiredForInKeyword) + errs = append(errs, errExactlyOneOf("showTagOptions.In", "Account", "Database", "Schema")) } return errors.Join(errs...) } diff --git a/pkg/sdk/tasks_gen_test.go b/pkg/sdk/tasks_gen_test.go index c167a0998d..56bb1f65e9 100644 --- a/pkg/sdk/tasks_gen_test.go +++ b/pkg/sdk/tasks_gen_test.go @@ -1,7 +1,6 @@ package sdk import ( - "fmt" "testing" ) @@ -38,7 +37,7 @@ func TestTasks_Create(t *testing.T) { t.Run("validation: exactly one field from [opts.Warehouse.Warehouse opts.Warehouse.UserTaskManagedInitialWarehouseSize] should be present", func(t *testing.T) { opts := defaultOpts() opts.Warehouse = &CreateTaskWarehouse{} - assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("Warehouse", "UserTaskManagedInitialWarehouseSize")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateTaskOptions.Warehouse", "Warehouse", "UserTaskManagedInitialWarehouseSize")) }) t.Run("validation: opts.SessionParameters.SessionParameters should be valid", func(t *testing.T) { @@ -46,7 +45,7 @@ func TestTasks_Create(t *testing.T) { opts.SessionParameters = &SessionParameters{ JSONIndent: Int(25), } - assertOptsInvalidJoinedErrors(t, opts, fmt.Errorf("JSON_INDENT must be between 0 and 16")) + assertOptsInvalidJoinedErrors(t, opts, errIntBetween("SessionParameters", "JSONIndent", 0, 16)) }) t.Run("basic", func(t *testing.T) { @@ -156,20 +155,20 @@ func TestTasks_Alter(t *testing.T) { t.Run("validation: exactly one field from [opts.Resume opts.Suspend opts.RemoveAfter opts.AddAfter opts.Set opts.Unset opts.SetTags opts.UnsetTags opts.ModifyAs opts.ModifyWhen] should be present", func(t *testing.T) { opts := defaultOpts() - assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("Resume", "Suspend", "RemoveAfter", "AddAfter", "Set", "Unset", "SetTags", "UnsetTags", "ModifyAs", "ModifyWhen")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterTaskOptions", "Resume", "Suspend", "RemoveAfter", "AddAfter", "Set", "Unset", "SetTags", "UnsetTags", "ModifyAs", "ModifyWhen")) }) t.Run("validation: exactly one field from [opts.Resume opts.Suspend opts.RemoveAfter opts.AddAfter opts.Set opts.Unset opts.SetTags opts.UnsetTags opts.ModifyAs opts.ModifyWhen] should be present - more present", func(t *testing.T) { opts := defaultOpts() opts.Resume = Bool(true) opts.Suspend = Bool(true) - assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("Resume", "Suspend", "RemoveAfter", "AddAfter", "Set", "Unset", "SetTags", "UnsetTags", "ModifyAs", "ModifyWhen")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterTaskOptions", "Resume", "Suspend", "RemoveAfter", "AddAfter", "Set", "Unset", "SetTags", "UnsetTags", "ModifyAs", "ModifyWhen")) }) t.Run("validation: at least one of the fields [opts.Set.Warehouse opts.Set.UserTaskManagedInitialWarehouseSize opts.Set.Schedule opts.Set.Config opts.Set.AllowOverlappingExecution opts.Set.UserTaskTimeoutMs opts.Set.SuspendTaskAfterNumFailures opts.Set.ErrorIntegration opts.Set.Comment opts.Set.SessionParameters] should be set", func(t *testing.T) { opts := defaultOpts() opts.Set = &TaskSet{} - assertOptsInvalidJoinedErrors(t, opts, errAtLeastOneOf("Warehouse", "UserTaskManagedInitialWarehouseSize", "Schedule", "Config", "AllowOverlappingExecution", "UserTaskTimeoutMs", "SuspendTaskAfterNumFailures", "ErrorIntegration", "Comment", "SessionParameters")) + assertOptsInvalidJoinedErrors(t, opts, errAtLeastOneOf("AlterTaskOptions.Set", "Warehouse", "UserTaskManagedInitialWarehouseSize", "Schedule", "Config", "AllowOverlappingExecution", "UserTaskTimeoutMs", "SuspendTaskAfterNumFailures", "ErrorIntegration", "Comment", "SessionParameters")) }) t.Run("validation: conflicting fields for [opts.Set.Warehouse opts.Set.UserTaskManagedInitialWarehouseSize]", func(t *testing.T) { @@ -178,7 +177,7 @@ func TestTasks_Alter(t *testing.T) { opts.Set = &TaskSet{} opts.Set.Warehouse = &warehouseId opts.Set.UserTaskManagedInitialWarehouseSize = &WarehouseSizeXSmall - assertOptsInvalidJoinedErrors(t, opts, errOneOf("Set", "Warehouse", "UserTaskManagedInitialWarehouseSize")) + assertOptsInvalidJoinedErrors(t, opts, errOneOf("AlterTaskOptions.Set", "Warehouse", "UserTaskManagedInitialWarehouseSize")) }) t.Run("validation: opts.Set.SessionParameters.SessionParameters should be valid", func(t *testing.T) { @@ -187,20 +186,20 @@ func TestTasks_Alter(t *testing.T) { opts.Set.SessionParameters = &SessionParameters{ JSONIndent: Int(25), } - assertOptsInvalidJoinedErrors(t, opts, fmt.Errorf("JSON_INDENT must be between 0 and 16")) + assertOptsInvalidJoinedErrors(t, opts, errIntBetween("SessionParameters", "JSONIndent", 0, 16)) }) t.Run("validation: at least one of the fields [opts.Unset.Warehouse opts.Unset.Schedule opts.Unset.Config opts.Unset.AllowOverlappingExecution opts.Unset.UserTaskTimeoutMs opts.Unset.SuspendTaskAfterNumFailures opts.Unset.ErrorIntegration opts.Unset.Comment opts.Unset.SessionParametersUnset] should be set", func(t *testing.T) { opts := defaultOpts() opts.Unset = &TaskUnset{} - assertOptsInvalidJoinedErrors(t, opts, errAtLeastOneOf("Warehouse", "Schedule", "Config", "AllowOverlappingExecution", "UserTaskTimeoutMs", "SuspendTaskAfterNumFailures", "ErrorIntegration", "Comment", "SessionParametersUnset")) + assertOptsInvalidJoinedErrors(t, opts, errAtLeastOneOf("AlterTaskOptions.Unset", "Warehouse", "Schedule", "Config", "AllowOverlappingExecution", "UserTaskTimeoutMs", "SuspendTaskAfterNumFailures", "ErrorIntegration", "Comment", "SessionParametersUnset")) }) t.Run("validation: opts.Unset.SessionParametersUnset.SessionParametersUnset should be valid", func(t *testing.T) { opts := defaultOpts() opts.Unset = &TaskUnset{} opts.Unset.SessionParametersUnset = &SessionParametersUnset{} - assertOptsInvalidJoinedErrors(t, opts, fmt.Errorf("at least one session parameter must be set")) + assertOptsInvalidJoinedErrors(t, opts, errAtLeastOneOf("SessionParametersUnset", "AbortDetachedQuery", "Autocommit", "BinaryInputFormat", "BinaryOutputFormat", "DateInputFormat", "DateOutputFormat", "ErrorOnNondeterministicMerge", "ErrorOnNondeterministicUpdate", "GeographyOutputFormat", "JSONIndent", "LockTimeout", "QueryTag", "RowsPerResultset", "SimulatedDataSharingConsumer", "StatementTimeoutInSeconds", "StrictJSONOutput", "TimestampDayIsAlways24h", "TimestampInputFormat", "TimestampLTZOutputFormat", "TimestampNTZOutputFormat", "TimestampOutputFormat", "TimestampTypeMapping", "TimestampTZOutputFormat", "Timezone", "TimeInputFormat", "TimeOutputFormat", "TransactionDefaultIsolationLevel", "TwoDigitCenturyStart", "UnsupportedDDLAction", "UseCachedResult", "WeekOfYearPolicy", "WeekStart")) }) t.Run("alter resume", func(t *testing.T) { diff --git a/pkg/sdk/tasks_validations_gen.go b/pkg/sdk/tasks_validations_gen.go index b3b5317b53..48197e59d3 100644 --- a/pkg/sdk/tasks_validations_gen.go +++ b/pkg/sdk/tasks_validations_gen.go @@ -25,7 +25,7 @@ func (opts *CreateTaskOptions) validate() error { } if valueSet(opts.Warehouse) { if ok := exactlyOneValueSet(opts.Warehouse.Warehouse, opts.Warehouse.UserTaskManagedInitialWarehouseSize); !ok { - errs = append(errs, errExactlyOneOf("Warehouse", "UserTaskManagedInitialWarehouseSize")) + errs = append(errs, errExactlyOneOf("CreateTaskOptions.Warehouse", "Warehouse", "UserTaskManagedInitialWarehouseSize")) } } if valueSet(opts.SessionParameters) { @@ -59,14 +59,14 @@ func (opts *AlterTaskOptions) validate() error { errs = append(errs, ErrInvalidObjectIdentifier) } if ok := exactlyOneValueSet(opts.Resume, opts.Suspend, opts.RemoveAfter, opts.AddAfter, opts.Set, opts.Unset, opts.SetTags, opts.UnsetTags, opts.ModifyAs, opts.ModifyWhen); !ok { - errs = append(errs, errExactlyOneOf("Resume", "Suspend", "RemoveAfter", "AddAfter", "Set", "Unset", "SetTags", "UnsetTags", "ModifyAs", "ModifyWhen")) + errs = append(errs, errExactlyOneOf("AlterTaskOptions", "Resume", "Suspend", "RemoveAfter", "AddAfter", "Set", "Unset", "SetTags", "UnsetTags", "ModifyAs", "ModifyWhen")) } if valueSet(opts.Set) { if ok := anyValueSet(opts.Set.Warehouse, opts.Set.UserTaskManagedInitialWarehouseSize, opts.Set.Schedule, opts.Set.Config, opts.Set.AllowOverlappingExecution, opts.Set.UserTaskTimeoutMs, opts.Set.SuspendTaskAfterNumFailures, opts.Set.ErrorIntegration, opts.Set.Comment, opts.Set.SessionParameters); !ok { - errs = append(errs, errAtLeastOneOf("Warehouse", "UserTaskManagedInitialWarehouseSize", "Schedule", "Config", "AllowOverlappingExecution", "UserTaskTimeoutMs", "SuspendTaskAfterNumFailures", "ErrorIntegration", "Comment", "SessionParameters")) + errs = append(errs, errAtLeastOneOf("AlterTaskOptions.Set", "Warehouse", "UserTaskManagedInitialWarehouseSize", "Schedule", "Config", "AllowOverlappingExecution", "UserTaskTimeoutMs", "SuspendTaskAfterNumFailures", "ErrorIntegration", "Comment", "SessionParameters")) } if everyValueSet(opts.Set.Warehouse, opts.Set.UserTaskManagedInitialWarehouseSize) { - errs = append(errs, errOneOf("Set", "Warehouse", "UserTaskManagedInitialWarehouseSize")) + errs = append(errs, errOneOf("AlterTaskOptions.Set", "Warehouse", "UserTaskManagedInitialWarehouseSize")) } if valueSet(opts.Set.SessionParameters) { if err := opts.Set.SessionParameters.validate(); err != nil { @@ -76,7 +76,7 @@ func (opts *AlterTaskOptions) validate() error { } if valueSet(opts.Unset) { if ok := anyValueSet(opts.Unset.Warehouse, opts.Unset.Schedule, opts.Unset.Config, opts.Unset.AllowOverlappingExecution, opts.Unset.UserTaskTimeoutMs, opts.Unset.SuspendTaskAfterNumFailures, opts.Unset.ErrorIntegration, opts.Unset.Comment, opts.Unset.SessionParametersUnset); !ok { - errs = append(errs, errAtLeastOneOf("Warehouse", "Schedule", "Config", "AllowOverlappingExecution", "UserTaskTimeoutMs", "SuspendTaskAfterNumFailures", "ErrorIntegration", "Comment", "SessionParametersUnset")) + errs = append(errs, errAtLeastOneOf("AlterTaskOptions.Unset", "Warehouse", "Schedule", "Config", "AllowOverlappingExecution", "UserTaskTimeoutMs", "SuspendTaskAfterNumFailures", "ErrorIntegration", "Comment", "SessionParametersUnset")) } if valueSet(opts.Unset.SessionParametersUnset) { if err := opts.Unset.SessionParametersUnset.validate(); err != nil { diff --git a/pkg/sdk/testint/dynamic_table_integration_test.go b/pkg/sdk/testint/dynamic_table_integration_test.go index 0c6edea858..67359cf35d 100644 --- a/pkg/sdk/testint/dynamic_table_integration_test.go +++ b/pkg/sdk/testint/dynamic_table_integration_test.go @@ -131,8 +131,7 @@ func TestInt_DynamicTableAlter(t *testing.T) { err := client.DynamicTables.Alter(ctx, sdk.NewAlterDynamicTableRequest(dynamicTable.ID()).WithSuspend(sdk.Bool(true)).WithResume(sdk.Bool(true))) require.Error(t, err) - expected := "alter statement needs exactly one action from: set, unset, refresh" - require.Equal(t, expected, err.Error()) + require.Equal(t, sdk.ErrExactlyOneOf("alterDynamicTableOptions", "Suspend", "Resume", "Refresh", "Set").Error(), err.Error()) }) t.Run("alter with set", func(t *testing.T) { diff --git a/pkg/sdk/testint/tags_integration_test.go b/pkg/sdk/testint/tags_integration_test.go index b473932c5f..1fc675e432 100644 --- a/pkg/sdk/testint/tags_integration_test.go +++ b/pkg/sdk/testint/tags_integration_test.go @@ -88,8 +88,7 @@ func TestInt_Tags(t *testing.T) { comment := random.Comment() values := []string{"value1", "value2"} err := client.Tags.Create(ctx, sdk.NewCreateTagRequest(id).WithOrReplace(true).WithComment(&comment).WithAllowedValues(values)) - expected := "Comment fields: [AllowedValues] are incompatible and cannot be set at the same time" - require.Equal(t, expected, err.Error()) + require.Equal(t, sdk.ErrOneOf("createTagOptions", "Comment", "AllowedValues").Error(), err.Error()) }) t.Run("create tag: no optionals", func(t *testing.T) { diff --git a/pkg/sdk/users.go b/pkg/sdk/users.go index 8d45365db6..2edcf06ea3 100644 --- a/pkg/sdk/users.go +++ b/pkg/sdk/users.go @@ -170,8 +170,11 @@ type UserTag struct { } func (opts *CreateUserOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } if !ValidObjectIdentifier(opts.name) { - return errors.New("invalid object identifier") + return errors.Join(ErrInvalidObjectIdentifier) } return nil } @@ -273,37 +276,32 @@ type AlterUserOptions struct { } func (opts *AlterUserOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } + var errs []error if !ValidObjectIdentifier(opts.name) { - return errors.New("invalid object identifier") - } - if ok := exactlyOneValueSet( - opts.NewName, - opts.ResetPassword, - opts.AbortAllQueries, - opts.AddDelegatedAuthorization, - opts.RemoveDelegatedAuthorization, - opts.Set, - opts.Unset, - ); !ok { - return errExactlyOneOf("NewName", "ResetPassword", "AbortAllQueries", "AddDelegatedAuthorization", "RemoveDelegatedAuthorization", "Set", "Unset") + errs = append(errs, ErrInvalidObjectIdentifier) + } + if !exactlyOneValueSet(opts.NewName, opts.ResetPassword, opts.AbortAllQueries, opts.AddDelegatedAuthorization, opts.RemoveDelegatedAuthorization, opts.Set, opts.Unset) { + errs = append(errs, errExactlyOneOf("AlterUserOptions", "NewName", "ResetPassword", "AbortAllQueries", "AddDelegatedAuthorization", "RemoveDelegatedAuthorization", "Set", "Unset")) } if valueSet(opts.RemoveDelegatedAuthorization) { if err := opts.RemoveDelegatedAuthorization.validate(); err != nil { - return err + errs = append(errs, err) } } if valueSet(opts.Set) { if err := opts.Set.validate(); err != nil { - return err + errs = append(errs, err) } } if valueSet(opts.Unset) { if err := opts.Unset.validate(); err != nil { - return err + errs = append(errs, err) } } - - return nil + return errors.Join(errs...) } func (v *users) Alter(ctx context.Context, id AccountObjectIdentifier, opts *AlterUserOptions) error { @@ -326,6 +324,7 @@ type AddDelegatedAuthorization struct { Role string `ddl:"parameter,no_equals" sql:"ADD DELEGATED AUTHORIZATION OF ROLE"` Integration string `ddl:"parameter,no_equals" sql:"TO SECURITY INTEGRATION"` } + type RemoveDelegatedAuthorization struct { // one of Role *string `ddl:"parameter,no_equals" sql:"REMOVE DELEGATED AUTHORIZATION OF ROLE"` @@ -335,13 +334,14 @@ type RemoveDelegatedAuthorization struct { } func (opts *RemoveDelegatedAuthorization) validate() error { + var errs []error if !exactlyOneValueSet(opts.Role, opts.Authorizations) { - return fmt.Errorf("exactly one of role or authorizations must be set") + errs = append(errs, errExactlyOneOf("RemoveDelegatedAuthorization", "Role", "Authorization")) } if !valueSet(opts.Integration) { - return fmt.Errorf("integration name must be set") + errs = append(errs, errNotSet("RemoveDelegatedAuthorization", "Integration")) } - return nil + return errors.Join(errs...) } type UserSet struct { @@ -354,16 +354,8 @@ type UserSet struct { } func (opts *UserSet) validate() error { - if !anyValueSet(opts.PasswordPolicy, opts.SessionPolicy, opts.Tags, opts.ObjectProperties, opts.ObjectParameters, opts.SessionParameters) { - return fmt.Errorf("at least one of password policy, tag, object properties, object parameters, or session parameters must be set") - } - if moreThanOneValueSet(opts.SessionPolicy, opts.PasswordPolicy, opts.Tags) { - return fmt.Errorf("setting session policy, password policy and tags must be done separately") - } - if anyValueSet(opts.ObjectParameters, opts.SessionParameters, opts.ObjectProperties) { - if anyValueSet(opts.PasswordPolicy, opts.SessionPolicy, opts.Tags) { - return fmt.Errorf("cannot set both {object parameters, session parameters,object properties} and password policy, session policy, or tag") - } + if !exactlyOneValueSet(opts.PasswordPolicy, opts.SessionPolicy, opts.Tags, opts.ObjectProperties, opts.ObjectParameters, opts.SessionParameters) { + return errExactlyOneOf("UserSet", "PasswordPolicy", "SessionPolicy", "Tags", "ObjectProperties", "ObjectParameters", "SessionParameters") } return nil } @@ -379,7 +371,7 @@ type UserUnset struct { func (opts *UserUnset) validate() error { if !exactlyOneValueSet(opts.Tags, opts.PasswordPolicy, opts.SessionPolicy, opts.ObjectProperties, opts.ObjectParameters, opts.SessionParameters) { - return fmt.Errorf("exactly one of password policy, tag, object properties, object parameters, or session parameters must be set") + return errExactlyOneOf("UserUnset", "Tags", "PasswordPolicy", "SessionPolicy", "ObjectProperties", "ObjectParameters", "SessionParameters") } return nil } @@ -392,8 +384,11 @@ type DropUserOptions struct { } func (opts *DropUserOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + return errors.Join(ErrInvalidObjectIdentifier) } return nil } @@ -401,7 +396,6 @@ func (opts *DropUserOptions) validate() error { func (v *users) Drop(ctx context.Context, id AccountObjectIdentifier) error { opts := &DropUserOptions{} opts.name = id - if err := opts.validate(); err != nil { return fmt.Errorf("validate drop options: %w", err) } @@ -526,9 +520,12 @@ type describeUserOptions struct { name AccountObjectIdentifier `ddl:"identifier"` } -func (v *describeUserOptions) validate() error { - if !ValidObjectIdentifier(v.name) { - return ErrInvalidObjectIdentifier +func (opts *describeUserOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } + if !ValidObjectIdentifier(opts.name) { + return errors.Join(ErrInvalidObjectIdentifier) } return nil } @@ -565,7 +562,10 @@ type ShowUserOptions struct { From *string `ddl:"parameter,no_equals,single_quotes" sql:"FROM"` } -func (input *ShowUserOptions) validate() error { +func (opts *ShowUserOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } return nil } diff --git a/pkg/sdk/users_test.go b/pkg/sdk/users_test.go index f7952d913b..82a13912a8 100644 --- a/pkg/sdk/users_test.go +++ b/pkg/sdk/users_test.go @@ -11,7 +11,7 @@ func TestUserCreate(t *testing.T) { t.Run("validation: empty options", func(t *testing.T) { opts := &CreateUserOptions{} - assertOptsInvalid(t, opts, ErrInvalidObjectIdentifier) + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) t.Run("with complete options", func(t *testing.T) { @@ -51,14 +51,14 @@ func TestUserAlter(t *testing.T) { t.Run("empty options", func(t *testing.T) { opts := &AlterUserOptions{} - assertOptsInvalid(t, opts, ErrInvalidObjectIdentifier) + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) t.Run("only name", func(t *testing.T) { opts := &AlterUserOptions{ name: id, } - assertOptsInvalid(t, opts, errExactlyOneOf("NewName", "ResetPassword", "AbortAllQueries", "AddDelegatedAuthorization", "RemoveDelegatedAuthorization", "Set", "Unset")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterUserOptions", "NewName", "ResetPassword", "AbortAllQueries", "AddDelegatedAuthorization", "RemoveDelegatedAuthorization", "Set", "Unset")) }) t.Run("with setting a policy", func(t *testing.T) { @@ -225,7 +225,7 @@ func TestUserDrop(t *testing.T) { t.Run("validation: empty options", func(t *testing.T) { opts := &DropUserOptions{} - assertOptsInvalid(t, opts, ErrInvalidObjectIdentifier) + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) t.Run("only name", func(t *testing.T) { @@ -292,7 +292,7 @@ func TestUserDescribe(t *testing.T) { t.Run("validation: empty options", func(t *testing.T) { opts := &describeUserOptions{} - assertOptsInvalid(t, opts, ErrInvalidObjectIdentifier) + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) t.Run("only name", func(t *testing.T) { diff --git a/pkg/sdk/validations.go b/pkg/sdk/validations.go index 546f7ac15f..a0011d44db 100644 --- a/pkg/sdk/validations.go +++ b/pkg/sdk/validations.go @@ -99,6 +99,10 @@ func validateIntInRange(value int, min int, max int) bool { return true } +func validateIntGreaterThan(value int, min int) bool { + return value > min +} + func validateIntGreaterThanOrEqual(value int, min int) bool { return value >= min } diff --git a/pkg/sdk/warehouses.go b/pkg/sdk/warehouses.go index a5f98fb63b..5e61ebf5b2 100644 --- a/pkg/sdk/warehouses.go +++ b/pkg/sdk/warehouses.go @@ -3,6 +3,7 @@ package sdk import ( "context" "database/sql" + "errors" "fmt" "strconv" "strings" @@ -119,16 +120,20 @@ type CreateWarehouseOptions struct { } func (opts *CreateWarehouseOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } + var errs []error if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + errs = append(errs, ErrInvalidObjectIdentifier) } if valueSet(opts.MinClusterCount) && valueSet(opts.MaxClusterCount) && !validateIntGreaterThanOrEqual(*opts.MaxClusterCount, *opts.MinClusterCount) { - return fmt.Errorf("MinClusterCount must be less than or equal to MaxClusterCount") + errs = append(errs, fmt.Errorf("MinClusterCount must be less than or equal to MaxClusterCount")) } if valueSet(opts.QueryAccelerationMaxScaleFactor) && !validateIntInRange(*opts.QueryAccelerationMaxScaleFactor, 0, 100) { - return fmt.Errorf("QueryAccelerationMaxScaleFactor must be between 0 and 100") + errs = append(errs, errIntBetween("CreateWarehouseOptions", "QueryAccelerationMaxScaleFactor", 0, 100)) } - return nil + return errors.Join(errs...) } func (c *warehouses) Create(ctx context.Context, id AccountObjectIdentifier, opts *CreateWarehouseOptions) error { @@ -165,38 +170,33 @@ type AlterWarehouseOptions struct { } func (opts *AlterWarehouseOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } + var errs []error if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + errs = append(errs, ErrInvalidObjectIdentifier) } - if ok := exactlyOneValueSet( - opts.Suspend, - opts.Resume, - opts.AbortAllQueries, - opts.NewName, - opts.Set, - opts.Unset); !ok { - return errExactlyOneOf("Suspend", "Resume", "AbortAllQueries", "NewName", "Set", "Unset") + if !exactlyOneValueSet(opts.Suspend, opts.Resume, opts.AbortAllQueries, opts.NewName, opts.Set, opts.Unset) { + errs = append(errs, errExactlyOneOf("AlterWarehouseOptions", "Suspend", "Resume", "AbortAllQueries", "NewName", "Set", "Unset")) } if everyValueSet(opts.Suspend, opts.Resume) && (*opts.Suspend && *opts.Resume) { - return fmt.Errorf("suspend and Resume cannot both be true") + errs = append(errs, errOneOf("AlterWarehouseOptions", "Suspend", "Resume")) } if (valueSet(opts.IfSuspended) && *opts.IfSuspended) && (!valueSet(opts.Resume) || !*opts.Resume) { - return fmt.Errorf(`"Resume" has to be set when using "IfSuspended"`) - } - if everyValueSet(opts.Set, opts.Unset) { - return fmt.Errorf("set and Unset cannot both be set") + errs = append(errs, fmt.Errorf(`"Resume" has to be set when using "IfSuspended"`)) } if valueSet(opts.Set) { if err := opts.Set.validate(); err != nil { - return err + errs = append(errs, err) } } if valueSet(opts.Unset) { if err := opts.Unset.validate(); err != nil { - return err + errs = append(errs, err) } } - return nil + return errors.Join(errs...) } type WarehouseSet struct { @@ -304,8 +304,11 @@ type DropWarehouseOptions struct { } func (opts *DropWarehouseOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + return errors.Join(ErrInvalidObjectIdentifier) } return nil } @@ -339,6 +342,9 @@ type ShowWarehouseOptions struct { } func (opts *ShowWarehouseOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } return nil } @@ -493,8 +499,11 @@ type describeWarehouseOptions struct { } func (opts *describeWarehouseOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } if !ValidObjectIdentifier(opts.name) { - return ErrInvalidObjectIdentifier + return errors.Join(ErrInvalidObjectIdentifier) } return nil } diff --git a/pkg/sdk/warehouses_test.go b/pkg/sdk/warehouses_test.go index 403644eac5..c13564c52c 100644 --- a/pkg/sdk/warehouses_test.go +++ b/pkg/sdk/warehouses_test.go @@ -61,7 +61,7 @@ func TestWarehouseSizing(t *testing.T) { MaxClusterCount: Int(1), MinClusterCount: Int(2), } - assertOptsInvalid(t, opts, fmt.Errorf("MinClusterCount must be less than or equal to MaxClusterCount")) + assertOptsInvalidJoinedErrors(t, opts, fmt.Errorf("MinClusterCount must be less than or equal to MaxClusterCount")) }) t.Run("Max equal Min", func(t *testing.T) { From d33a41cd0f477e8375113a546117ccd6e5dbb0eb Mon Sep 17 00:00:00 2001 From: "snowflake-release-please[bot]" <105954990+snowflake-release-please[bot]@users.noreply.github.com> Date: Thu, 26 Oct 2023 06:28:53 -0700 Subject: [PATCH 10/20] chore(main): release 0.75.0 (#2138) Co-authored-by: snowflake-release-please[bot] <105954990+snowflake-release-please[bot]@users.noreply.github.com> --- CHANGELOG.md | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e6b49afc16..469dfe244c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,27 @@ # Changelog +## [0.75.0](https://github.com/Snowflake-Labs/terraform-provider-snowflake/compare/v0.74.0...v0.75.0) (2023-10-26) + + +### Features + +* add parse_header option to file format resource ([#2132](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2132)) ([1e6e54f](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/1e6e54f828efa60edd258b316709fc4dfd370f93)) +* Use streams from the new SDK in resource / datasource ([#2129](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2129)) ([5c633be](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/5c633be461fd373d412b02b108e64b6cfc4eb856)) +* Use task from SDK in resource and data source ([#2140](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2140)) ([de23f2b](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/de23f2ba939eb368d9734217e1bb2d4ebc75eef4)) + + +### Misc + +* Return multiple errors in existing validations ([#2122](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2122)) ([4d4bcdb](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/4d4bcdbe841807da2fa08d534eaf846234934f7c)) +* Set up a single warehouse for the SDK integration tests ([#2141](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2141)) ([16022ef](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/16022ef4171e7dccf2932ae6e8d451b51c93291c)) + + +### BugFixes + +* cleanup acc tests ([#2135](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2135)) ([5db751d](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/5db751d1aa71952b1528e81cf2fdcd05d9d5d0fb)) +* provider config ([#2136](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2136)) ([07b9b4f](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/07b9b4fee800fe3f34890783cc463d4fc5904717)) +* view statement update ([#2152](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2152)) ([6de32ae](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/6de32ae6ec16ad76fb40afddfcaa7f650322cb67)) + ## [0.74.0](https://github.com/Snowflake-Labs/terraform-provider-snowflake/compare/v0.73.0...v0.74.0) (2023-10-18) From b86c4c34d05f8b982fb6218a3a3a7500a23abf72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Cie=C5=9Blak?= Date: Mon, 30 Oct 2023 13:31:17 +0100 Subject: [PATCH 11/20] feat: Poc custom error type (#2052) --- pkg/sdk/assertions_test.go | 18 ++- pkg/sdk/dynamic_table_validations.go | 49 +++---- pkg/sdk/errors.go | 114 +++++++++++++-- pkg/sdk/errors_test.go | 134 ++++++++++++++++++ pkg/sdk/integration_test_imports.go | 17 +++ pkg/sdk/masking_policy.go | 2 +- pkg/sdk/poc/generator/templates.go | 6 +- .../testint/dynamic_table_integration_test.go | 2 +- pkg/sdk/testint/tags_integration_test.go | 2 +- 9 files changed, 292 insertions(+), 52 deletions(-) create mode 100644 pkg/sdk/errors_test.go diff --git a/pkg/sdk/assertions_test.go b/pkg/sdk/assertions_test.go index e67f67963a..ddf962c360 100644 --- a/pkg/sdk/assertions_test.go +++ b/pkg/sdk/assertions_test.go @@ -1,6 +1,7 @@ package sdk import ( + "errors" "fmt" "testing" @@ -14,7 +15,14 @@ func assertOptsInvalid(t *testing.T, opts validatable, expectedError error) { t.Helper() err := opts.validate() assert.Error(t, err) - assert.Equal(t, expectedError, err) + var sdkErr *Error + if errors.As(err, &sdkErr) { + errorWithoutFileInfo := errorFileInfoRegexp.ReplaceAllString(sdkErr.Error(), "") + expectedErrorWithoutFileInfo := errorFileInfoRegexp.ReplaceAllString(expectedError.Error(), "") + assert.Equal(t, expectedErrorWithoutFileInfo, errorWithoutFileInfo) + } else { + assert.Equal(t, expectedError, err) + } } // assertOptsInvalidJoinedErrors could be reused in tests for other interfaces in sdk package. @@ -23,7 +31,13 @@ func assertOptsInvalidJoinedErrors(t *testing.T, opts validatable, expectedError err := opts.validate() assert.Error(t, err) for _, expectedError := range expectedErrors { - assert.Contains(t, err.Error(), expectedError.Error()) + var sdkErr *Error + if errors.As(expectedError, &sdkErr) { + expectedErrorWithoutFileInfo := errorFileInfoRegexp.ReplaceAllString(sdkErr.Error(), "") + assert.Contains(t, err.Error(), expectedErrorWithoutFileInfo) + } else { + assert.Contains(t, err.Error(), expectedError.Error()) + } } } diff --git a/pkg/sdk/dynamic_table_validations.go b/pkg/sdk/dynamic_table_validations.go index d5323f1d8d..e0a6c95a18 100644 --- a/pkg/sdk/dynamic_table_validations.go +++ b/pkg/sdk/dynamic_table_validations.go @@ -1,9 +1,5 @@ package sdk -import ( - "errors" -) - var ( _ validatable = new(createDynamicTableOptions) _ validatable = new(alterDynamicTableOptions) @@ -15,27 +11,26 @@ var ( func (tl *TargetLag) validate() error { if tl == nil { - return errors.Join(ErrNilOptions) + return ErrNilOptions } - var errs []error if everyValueSet(tl.MaximumDuration, tl.Downstream) { - errs = append(errs, errOneOf("TargetLag", "MaximumDuration", "Downstream")) + return errOneOf("TargetLag", "MaximumDuration", "Downstream") } - return errors.Join(errs...) + return nil } func (opts *createDynamicTableOptions) validate() error { if opts == nil { - return errors.Join(ErrNilOptions) + return ErrNilOptions } var errs []error if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } if !ValidObjectIdentifier(opts.warehouse) { - errs = append(errs, ErrInvalidObjectIdentifier) + errs = append(errs, errInvalidIdentifier("createDynamicTableOptions", "warehouse")) } - return errors.Join(errs...) + return JoinErrors(errs...) } func (dts *DynamicTableSet) validate() error { @@ -43,18 +38,15 @@ func (dts *DynamicTableSet) validate() error { if valueSet(dts.TargetLag) { errs = append(errs, dts.TargetLag.validate()) } - - if valueSet(dts.Warehouse) { - if !ValidObjectIdentifier(*dts.Warehouse) { - errs = append(errs, ErrInvalidObjectIdentifier) - } + if dts.Warehouse != nil && !ValidObjectIdentifier(*dts.Warehouse) { + errs = append(errs, errInvalidIdentifier("DynamicTableSet", "Warehouse")) } - return errors.Join(errs...) + return JoinErrors(errs...) } func (opts *alterDynamicTableOptions) validate() error { if opts == nil { - return errors.Join(ErrNilOptions) + return ErrNilOptions } var errs []error if !ValidObjectIdentifier(opts.name) { @@ -66,12 +58,12 @@ func (opts *alterDynamicTableOptions) validate() error { if valueSet(opts.Set) && valueSet(opts.Set.TargetLag) { errs = append(errs, opts.Set.TargetLag.validate()) } - return errors.Join(errs...) + return JoinErrors(errs...) } func (opts *showDynamicTableOptions) validate() error { if opts == nil { - return errors.Join(ErrNilOptions) + return ErrNilOptions } var errs []error if valueSet(opts.Like) && !valueSet(opts.Like.Pattern) { @@ -80,28 +72,25 @@ func (opts *showDynamicTableOptions) validate() error { if valueSet(opts.In) && !exactlyOneValueSet(opts.In.Account, opts.In.Database, opts.In.Schema) { errs = append(errs, errExactlyOneOf("showDynamicTableOptions.In", "Account", "Database", "Schema")) } - return errors.Join(errs...) + return JoinErrors(errs...) } func (opts *dropDynamicTableOptions) validate() error { if opts == nil { - return errors.Join(ErrNilOptions) + return ErrNilOptions } - var errs []error - if !ValidObjectIdentifier(opts.name) { - errs = append(errs, ErrInvalidObjectIdentifier) + return ErrInvalidObjectIdentifier } - return errors.Join(errs...) + return nil } func (opts *describeDynamicTableOptions) validate() error { if opts == nil { - return errors.Join(ErrNilOptions) + return ErrNilOptions } - var errs []error if !ValidObjectIdentifier(opts.name) { - errs = append(errs, ErrInvalidObjectIdentifier) + return ErrInvalidObjectIdentifier } - return errors.Join(errs...) + return nil } diff --git a/pkg/sdk/errors.go b/pkg/sdk/errors.go index 16f8d50aa5..3ba28aaa2f 100644 --- a/pkg/sdk/errors.go +++ b/pkg/sdk/errors.go @@ -4,20 +4,22 @@ import ( "errors" "fmt" "log" + "regexp" + "runtime" "strings" ) var ( - ErrNilOptions = errors.New("options cannot be nil") - ErrPatternRequiredForLikeKeyword = errors.New("pattern must be specified for like keyword") + ErrNilOptions = NewError("options cannot be nil") + ErrPatternRequiredForLikeKeyword = NewError("pattern must be specified for like keyword") // go-snowflake errors. - ErrObjectNotExistOrAuthorized = errors.New("object does not exist or not authorized") - ErrAccountIsEmpty = errors.New("account is empty") + ErrObjectNotExistOrAuthorized = NewError("object does not exist or not authorized") + ErrAccountIsEmpty = NewError("account is empty") // snowflake-sdk errors. - ErrInvalidObjectIdentifier = errors.New("invalid object identifier") - ErrDifferentDatabase = errors.New("database must be the same") + ErrInvalidObjectIdentifier = NewError("invalid object identifier") + ErrDifferentDatabase = NewError("database must be the same") ) type IntErrType string @@ -31,31 +33,31 @@ const ( ) func errIntValue(structName string, fieldName string, intErrType IntErrType, limit int) error { - return fmt.Errorf("%s field: %s must be %s %d", structName, fieldName, string(intErrType), limit) + return newError(fmt.Sprintf("%s field: %s must be %s %d", structName, fieldName, string(intErrType), limit), 2) } func errIntBetween(structName string, fieldName string, from int, to int) error { - return fmt.Errorf("%s field: %s must be between %d and %d", structName, fieldName, from, to) + return newError(fmt.Sprintf("%s field: %s must be between %d and %d", structName, fieldName, from, to), 2) } func errInvalidIdentifier(structName string, identifierField string) error { - return fmt.Errorf("invalid object identifier of %s field: %s", structName, identifierField) + return newError(fmt.Sprintf("invalid object identifier of %s field: %s", structName, identifierField), 2) } func errOneOf(structName string, fieldNames ...string) error { - return fmt.Errorf("%v fields: %v are incompatible and cannot be set at the same time", structName, fieldNames) + return newError(fmt.Sprintf("%v fields: %v are incompatible and cannot be set at the same time", structName, fieldNames), 2) } func errNotSet(structName string, fieldNames ...string) error { - return fmt.Errorf("%v fields: %v should be set", structName, fieldNames) + return newError(fmt.Sprintf("%v fields: %v should be set", structName, fieldNames), 2) } func errExactlyOneOf(structName string, fieldNames ...string) error { - return fmt.Errorf("exactly one of %s fileds %v must be set", structName, fieldNames) + return newError(fmt.Sprintf("exactly one of %s fileds %v must be set", structName, fieldNames), 2) } func errAtLeastOneOf(structName string, fieldNames ...string) error { - return fmt.Errorf("at least one of %s fields %v must be set", structName, fieldNames) + return newError(fmt.Sprintf("at least one of %s fields %v must be set", structName, fieldNames), 2) } func decodeDriverError(err error) error { @@ -75,3 +77,89 @@ func decodeDriverError(err error) error { return err } + +const errorIndentRune = '›' + +var errorFileInfoRegexp = regexp.MustCompile(`\[\w+\.\w+:\d+\] `) + +type Error struct { + file string + line int + message string + nestedErrors []error +} + +func (e *Error) Error() string { + builder := new(strings.Builder) + writeTree(e, builder, 0) + return builder.String() +} + +// NewError creates new sdk.Error with information like filename or line number (depending on where NewError was called) +func NewError(message string, nestedErrors ...error) error { + return newError(message, 2, nestedErrors...) +} + +// JoinErrors returns an error that wraps the given errors. Any nil error values are discarded. +// JoinErrors returns nil if errs contains no non-nil values, otherwise returns sdk.Error with errs as its nested errors +func JoinErrors(errs ...error) error { + notNilErrs := make([]error, 0) + for _, err := range errs { + if err != nil { + notNilErrs = append(notNilErrs, err) + } + } + if len(notNilErrs) == 0 { + return nil + } + return newError("joined error", 2, notNilErrs...) +} + +// newError is a function that is supposed to be used by other sdk.Error constructors like NewError or JoinErrors. +// First of all, it returns error implementation which is against Golang conventions, but it's convenient to use +// in other constructors, because then there's no need for casting and guessing which type of error it is. +// The second reason is that there is a mysterious skip parameter which is only useful for other constructor functions. +// It determines how many function stack calls have to be skipped to get the right filename and line information, +// which is too low-level for normal use. +func newError(message string, skip int, nested ...error) *Error { + line, filename := getCallerInfo(skip) + return &Error{ + file: filename, + line: line, + message: message, + nestedErrors: nested, + } +} + +func getCallerInfo(skip int) (int, string) { + _, file, line, _ := runtime.Caller(skip + 1) + fileSplit := strings.Split(file, "/") + var filename string + if len(fileSplit) > 1 { + filename = fileSplit[len(fileSplit)-1] + } else { + filename = fileSplit[0] + } + return line, filename +} + +func writeTree(e error, builder *strings.Builder, indent int) { + var sdkErr *Error + if joinedErr, ok := e.(interface{ Unwrap() []error }); ok { //nolint:all + errs := joinedErr.Unwrap() + for i, err := range errs { + if i > 0 { + builder.WriteByte('\n') + } + writeTree(err, builder, indent) + } + } else if errors.As(e, &sdkErr) { + builder.WriteString(strings.Repeat(fmt.Sprintf("%b ", errorIndentRune), indent) + fmt.Sprintf("[%s:%d] %s", sdkErr.file, sdkErr.line, sdkErr.message)) + for _, err := range sdkErr.nestedErrors { + builder.WriteByte('\n') + writeTree(err, builder, indent+2) + } + } else { + builder.WriteString(strings.Repeat(fmt.Sprintf("%b ", errorIndentRune), indent) + e.Error()) + } +} diff --git a/pkg/sdk/errors_test.go b/pkg/sdk/errors_test.go new file mode 100644 index 0000000000..a0533042ad --- /dev/null +++ b/pkg/sdk/errors_test.go @@ -0,0 +1,134 @@ +package sdk + +import ( + "errors" + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestWriteTree(t *testing.T) { + fileInfoRegTemplate := func(filename string) string { + return fmt.Sprintf("\\[%s:.*\\]", filename) + } + errorsTestFileInfoReg := fileInfoRegTemplate("errors_test.go") + + testCases := map[string]struct { + Error error + Indent int + Expected string + MatchContains []string + }{ + "basic error - no indent": { + Error: errors.New("some error"), + Indent: 0, + Expected: "some error", + }, + "basic error - indent": { + Error: errors.New("some error"), + Indent: 1, + Expected: fmt.Sprintf("%b some error", errorIndentRune), + }, + "basic error - double indent": { + Error: errors.New("some error"), + Indent: 2, + Expected: fmt.Sprintf("%b %b some error", errorIndentRune, errorIndentRune), + }, + "joined error - no indent": { + Error: errors.Join(errors.New("err one"), errors.New("err two")), + Indent: 0, + Expected: "err one\nerr two", + }, + "joined error - indent": { + Error: errors.Join(errors.New("err one"), errors.New("err two")), + Indent: 1, + Expected: fmt.Sprintf("%b err one\n%b err two", errorIndentRune, errorIndentRune), + }, + "joined error - double indent": { + Error: errors.Join(errors.New("err one"), errors.New("err two")), + Indent: 2, + Expected: fmt.Sprintf("%b %b err one\n%b %b err two", errorIndentRune, errorIndentRune, errorIndentRune, errorIndentRune), + }, + "custom error - no indent": { + Error: NewError("some error"), + Indent: 0, + MatchContains: []string{fmt.Sprintf("%s some error", errorsTestFileInfoReg)}, + }, + "custom error - indent": { + Error: errors.Join(NewError("err one"), NewError("err two")), + Indent: 1, + MatchContains: []string{ + fmt.Sprintf("%b %s err one", errorIndentRune, errorsTestFileInfoReg), + fmt.Sprintf("%b %s err two", errorIndentRune, errorsTestFileInfoReg), + }, + }, + "custom error - double indent": { + Error: errors.Join(NewError("err one"), NewError("err two")), + Indent: 2, + MatchContains: []string{ + fmt.Sprintf("%b %b %s err one", errorIndentRune, errorIndentRune, errorsTestFileInfoReg), + fmt.Sprintf("%b %b %s err two", errorIndentRune, errorIndentRune, errorsTestFileInfoReg), + }, + }, + "nested errors - custom errors combined with std errors": { + Error: NewError("root error", + errors.New("regular error"), + errors.Join( + errors.New("regular nested error"), + NewError("custom nested error"), + JoinErrors( + errors.New("regular nested nested error"), + NewError("custom nested nested error"), + ), + ), + NewError("custom error"), + ), + Indent: 0, + MatchContains: []string{ + fmt.Sprintf("%s root error", errorsTestFileInfoReg), + fmt.Sprintf("%b %b regular error", errorIndentRune, errorIndentRune), + // Nested errors (errors.Join-ed) are on the same level, because there's no root error there + // we could make another indent here by introducing a root error for every errors.Join-ed error + fmt.Sprintf("%b %b regular nested error", errorIndentRune, errorIndentRune), + fmt.Sprintf("%b %b %s custom nested error", errorIndentRune, errorIndentRune, errorsTestFileInfoReg), + // Here, for example, I've made a root error inside JoinErrors function + fmt.Sprintf("%b %b %s joined error", errorIndentRune, errorIndentRune, errorsTestFileInfoReg), + fmt.Sprintf("%b %b %b %b regular nested nested error", errorIndentRune, errorIndentRune, errorIndentRune, errorIndentRune), + fmt.Sprintf("%b %b %b %b %s custom nested nested error", errorIndentRune, errorIndentRune, errorIndentRune, errorIndentRune, errorsTestFileInfoReg), + fmt.Sprintf("%b %b %s custom error", errorIndentRune, errorIndentRune, errorsTestFileInfoReg), + }, + }, + "custom error - predefined errors": { + Error: errors.Join( + ErrInvalidObjectIdentifier, + errNotSet("Struct", "Field"), + ), + Indent: 2, + MatchContains: []string{ + // Predefined errors are pointing to the file where they're declared + fmt.Sprintf("%s invalid object identifier", fileInfoRegTemplate("errors.go")), + fmt.Sprintf("%s Struct fields: \\[Field\\] should be set", errorsTestFileInfoReg), + }, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + builder := new(strings.Builder) + writeTree(tc.Error, builder, tc.Indent) + if len(tc.Expected) == 0 && len(tc.MatchContains) == 0 { + t.Fatal("expected or contains should be specified on a test case") + } + if len(tc.Expected) > 0 { + require.Equal(t, tc.Expected, builder.String()) + } + if len(tc.MatchContains) > 0 { + for _, regex := range tc.MatchContains { + require.Regexpf(t, regex, builder.String(), "regex %s not in: %s", regex, builder.String()) + } + } + }) + } +} diff --git a/pkg/sdk/integration_test_imports.go b/pkg/sdk/integration_test_imports.go index d038d50765..86f7dfc715 100644 --- a/pkg/sdk/integration_test_imports.go +++ b/pkg/sdk/integration_test_imports.go @@ -3,6 +3,10 @@ package sdk import ( "context" "database/sql" + "errors" + "testing" + + "github.com/stretchr/testify/assert" ) // All the contents of this file were added to be able to use them outside the sdk package (i.e. integration tests package). @@ -16,6 +20,19 @@ func (c *Client) ExecForTests(ctx context.Context, sql string) (sql.Result, erro return result, decodeDriverError(err) } +func ErrorsEqual(t *testing.T, expected error, actual error) { + t.Helper() + var expectedErr *Error + var actualErr *Error + if errors.As(expected, &expectedErr) && errors.As(actual, &actualErr) { + expectedErrorWithoutFileInfo := errorFileInfoRegexp.ReplaceAllString(expectedErr.Error(), "") + errorWithoutFileInfo := errorFileInfoRegexp.ReplaceAllString(actualErr.Error(), "") + assert.Equal(t, expectedErrorWithoutFileInfo, errorWithoutFileInfo) + } else { + assert.Equal(t, expected, actual) + } +} + func ErrExactlyOneOf(structName string, fieldNames ...string) error { return errExactlyOneOf(structName, fieldNames...) } diff --git a/pkg/sdk/masking_policy.go b/pkg/sdk/masking_policy.go index 1bb20269cd..ac9fe60b9f 100644 --- a/pkg/sdk/masking_policy.go +++ b/pkg/sdk/masking_policy.go @@ -57,7 +57,7 @@ func (opts *CreateMaskingPolicyOptions) validate() error { } var errs []error if !ValidObjectIdentifier(opts.name) { - errs = append(errs, errors.Join(ErrInvalidObjectIdentifier)) + errs = append(errs, ErrInvalidObjectIdentifier) } if !valueSet(opts.signature) { errs = append(errs, errNotSet("CreateMaskingPolicyOptions", "signature")) diff --git a/pkg/sdk/poc/generator/templates.go b/pkg/sdk/poc/generator/templates.go index 919b8e7fa3..6a1b36f166 100644 --- a/pkg/sdk/poc/generator/templates.go +++ b/pkg/sdk/poc/generator/templates.go @@ -272,8 +272,6 @@ var ValidationsImplTemplate, _ = template.New("validationsImplTemplate").Parse(` {{- end -}} {{ end }} -import "errors" - var ( {{- range .Operations }} {{- if .OptsField }} @@ -285,11 +283,11 @@ var ( {{- if .OptsField }} func (opts *{{ .OptsField.KindNoPtr }}) validate() error { if opts == nil { - return errors.Join(ErrNilOptions) + return ErrNilOptions } var errs []error {{- template "VALIDATIONS" .OptsField }} - return errors.Join(errs...) + return JoinErrors(errs...) } {{- end }} {{ end }} diff --git a/pkg/sdk/testint/dynamic_table_integration_test.go b/pkg/sdk/testint/dynamic_table_integration_test.go index 67359cf35d..b6efbd9d59 100644 --- a/pkg/sdk/testint/dynamic_table_integration_test.go +++ b/pkg/sdk/testint/dynamic_table_integration_test.go @@ -131,7 +131,7 @@ func TestInt_DynamicTableAlter(t *testing.T) { err := client.DynamicTables.Alter(ctx, sdk.NewAlterDynamicTableRequest(dynamicTable.ID()).WithSuspend(sdk.Bool(true)).WithResume(sdk.Bool(true))) require.Error(t, err) - require.Equal(t, sdk.ErrExactlyOneOf("alterDynamicTableOptions", "Suspend", "Resume", "Refresh", "Set").Error(), err.Error()) + sdk.ErrorsEqual(t, sdk.JoinErrors(sdk.ErrExactlyOneOf("alterDynamicTableOptions", "Suspend", "Resume", "Refresh", "Set")), err) }) t.Run("alter with set", func(t *testing.T) { diff --git a/pkg/sdk/testint/tags_integration_test.go b/pkg/sdk/testint/tags_integration_test.go index 1fc675e432..0e358c4ce1 100644 --- a/pkg/sdk/testint/tags_integration_test.go +++ b/pkg/sdk/testint/tags_integration_test.go @@ -88,7 +88,7 @@ func TestInt_Tags(t *testing.T) { comment := random.Comment() values := []string{"value1", "value2"} err := client.Tags.Create(ctx, sdk.NewCreateTagRequest(id).WithOrReplace(true).WithComment(&comment).WithAllowedValues(values)) - require.Equal(t, sdk.ErrOneOf("createTagOptions", "Comment", "AllowedValues").Error(), err.Error()) + sdk.ErrorsEqual(t, sdk.ErrOneOf("createTagOptions", "Comment", "AllowedValues"), err) }) t.Run("create tag: no optionals", func(t *testing.T) { From 82c3c13b6166168e470d7cb9b2982a8979275f17 Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Mon, 30 Oct 2023 13:34:09 +0100 Subject: [PATCH 12/20] feat: Use tasks from the SDK followup (#2153) * Support cycles when getting root tasks * Test setting parameter * Use set parameter method * Add missing parameter * Fix linter * Fix session parameters handling for task * Try to fix failing test * Fix tests * Fix typo --- docs/resources/task.md | 1 + pkg/resources/task.go | 64 +++- pkg/resources/task_acceptance_test.go | 79 +++-- pkg/sdk/internal/collections/queue.go | 30 ++ pkg/sdk/internal/collections/queue_test.go | 51 +++ pkg/sdk/parameters.go | 329 +----------------- pkg/sdk/parameters_impl.go | 2 - pkg/sdk/parameters_impl_test.go | 117 +++++++ pkg/sdk/tasks_impl_gen.go | 52 +-- pkg/sdk/tasks_test.go | 3 +- pkg/sdk/testint/tasks_gen_integration_test.go | 6 + 11 files changed, 368 insertions(+), 366 deletions(-) create mode 100644 pkg/sdk/internal/collections/queue.go create mode 100644 pkg/sdk/internal/collections/queue_test.go create mode 100644 pkg/sdk/parameters_impl_test.go diff --git a/docs/resources/task.md b/docs/resources/task.md index 2889ff0a42..43e07c6ac7 100644 --- a/docs/resources/task.md +++ b/docs/resources/task.md @@ -88,6 +88,7 @@ resource "snowflake_task" "test_task" { - `error_integration` (String) Specifies the name of the notification integration used for error notifications. - `schedule` (String) The schedule for periodically running the task. This can be a cron or interval in minutes. (Conflict with after) - `session_parameters` (Map of String) Specifies session parameters to set for the session when the task runs. A task supports all session parameters. +- `suspend_task_after_num_failures` (Number) Specifies the number of consecutive failed task runs after which the current task is suspended automatically. The default is 0 (no automatic suspension). - `user_task_managed_initial_warehouse_size` (String) Specifies the size of the compute resources to provision for the first run of the task, before a task history is available for Snowflake to determine an ideal size. Once a task has successfully completed a few runs, Snowflake ignores this parameter setting. (Conflicts with warehouse) - `user_task_timeout_ms` (Number) Specifies the time limit on a single run of the task before it times out (in milliseconds). - `warehouse` (String) The warehouse the task will use. Omit this parameter to use Snowflake-managed compute resources for runs of this task. (Conflicts with user_task_managed_initial_warehouse_size) diff --git a/pkg/resources/task.go b/pkg/resources/task.go index 72e42fec09..2134398f29 100644 --- a/pkg/resources/task.go +++ b/pkg/resources/task.go @@ -16,7 +16,6 @@ import ( "golang.org/x/exp/slices" ) -// TODO [SNOW-884987]: add missing SUSPEND_TASK_AFTER_NUM_FAILURES attribute. var taskSchema = map[string]*schema.Schema{ "enabled": { Type: schema.TypeBool, @@ -67,6 +66,13 @@ var taskSchema = map[string]*schema.Schema{ ValidateFunc: validation.IntBetween(0, 86400000), Description: "Specifies the time limit on a single run of the task before it times out (in milliseconds).", }, + "suspend_task_after_num_failures": { + Type: schema.TypeInt, + Optional: true, + Default: 0, + ValidateFunc: validation.IntAtLeast(0), + Description: "Specifies the number of consecutive failed task runs after which the current task is suspended automatically. The default is 0 (no automatic suspension).", + }, "comment": { Type: schema.TypeString, Optional: true, @@ -124,6 +130,19 @@ func difference(a, b map[string]any) map[string]any { return diff } +// differentValue find keys present both in 'a' and 'b' but having different values. +func differentValue(a, b map[string]any) map[string]any { + diff := make(map[string]any) + for k, va := range a { + if vb, ok := b[k]; ok { + if vb != va { + diff[k] = vb + } + } + } + return diff +} + // Task returns a pointer to the resource representing a task. func Task() *schema.Resource { return &schema.Resource{ @@ -214,7 +233,7 @@ func ReadTask(d *schema.ResourceData, meta interface{}) error { } if len(params) > 0 { - sessionParameters := map[string]interface{}{} + sessionParameters := make(map[string]any) fieldParameters := map[string]interface{}{ "user_task_managed_initial_warehouse_size": "", } @@ -233,6 +252,13 @@ func ReadTask(d *schema.ResourceData, meta interface{}) error { } fieldParameters["user_task_timeout_ms"] = timeout + case "SUSPEND_TASK_AFTER_NUM_FAILURES": + num, err := strconv.ParseInt(param.Value, 10, 64) + if err != nil { + return err + } + + fieldParameters["suspend_task_after_num_failures"] = num default: sessionParameters[param.Key] = param.Value } @@ -299,6 +325,10 @@ func CreateTask(d *schema.ResourceData, meta interface{}) error { createRequest.WithUserTaskTimeoutMs(sdk.Int(v.(int))) } + if v, ok := d.GetOk("suspend_task_after_num_failures"); ok { + createRequest.WithSuspendTaskAfterNumFailures(sdk.Int(v.(int))) + } + if v, ok := d.GetOk("comment"); ok { createRequest.WithComment(sdk.String(v.(string))) } @@ -558,6 +588,20 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error { } } + if d.HasChange("suspend_task_after_num_failures") { + o, n := d.GetChange("suspend_task_after_num_failures") + alterRequest := sdk.NewAlterTaskRequest(taskId) + if o.(int) > 0 && n.(int) == 0 { + alterRequest.WithUnset(sdk.NewTaskUnsetRequest().WithSuspendTaskAfterNumFailures(sdk.Bool(true))) + } else { + alterRequest.WithSet(sdk.NewTaskSetRequest().WithSuspendTaskAfterNumFailures(sdk.Int(n.(int)))) + } + err := client.Tasks.Alter(ctx, alterRequest) + if err != nil { + return fmt.Errorf("error updating suspend task after num failures on task %s", taskId.FullyQualifiedName()) + } + } + if d.HasChange("comment") { newComment := d.Get("comment") alterRequest := sdk.NewAlterTaskRequest(taskId) @@ -586,7 +630,6 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error { } } - // TODO [SNOW-884987]: old implementation does not handle changing parameter value correctly (only finds for parameters to add od remove, not change) if d.HasChange("session_parameters") { o, n := d.GetChange("session_parameters") @@ -601,6 +644,7 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error { remove := difference(os, ns) add := difference(ns, os) + change := differentValue(os, ns) if len(remove) > 0 { sessionParametersUnset, err := sdk.GetSessionParametersUnsetFrom(remove) @@ -608,7 +652,7 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error { return err } if err := client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(taskId).WithUnset(sdk.NewTaskUnsetRequest().WithSessionParametersUnset(sessionParametersUnset))); err != nil { - return fmt.Errorf("error removing session_parameters on task %v", d.Id()) + return fmt.Errorf("error removing session_parameters on task %v err = %w", d.Id(), err) } } @@ -618,7 +662,17 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error { return err } if err := client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(taskId).WithSet(sdk.NewTaskSetRequest().WithSessionParameters(sessionParameters))); err != nil { - return fmt.Errorf("error adding session_parameters to task %v", d.Id()) + return fmt.Errorf("error adding session_parameters to task %v err = %w", d.Id(), err) + } + } + + if len(change) > 0 { + sessionParameters, err := sdk.GetSessionParametersFrom(change) + if err != nil { + return err + } + if err := client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(taskId).WithSet(sdk.NewTaskSetRequest().WithSessionParameters(sessionParameters))); err != nil { + return fmt.Errorf("error updating session_parameters in task %v err = %w", d.Id(), err) } } } diff --git a/pkg/resources/task_acceptance_test.go b/pkg/resources/task_acceptance_test.go index fa7af5e679..91f32876ec 100644 --- a/pkg/resources/task_acceptance_test.go +++ b/pkg/resources/task_acceptance_test.go @@ -8,6 +8,7 @@ import ( "text/template" acc "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/hashicorp/terraform-plugin-testing/helper/acctest" "github.com/hashicorp/terraform-plugin-testing/helper/resource" "github.com/hashicorp/terraform-plugin-testing/terraform" @@ -50,6 +51,10 @@ var ( Enabled: true, Schedule: "5 MINUTE", UserTaskTimeoutMs: 1800000, + SessionParams: map[string]string{ + string(sdk.SessionParameterLockTimeout): "1000", + string(sdk.SessionParameterStrictJSONOutput): "true", + }, }, ChildTask: &TaskSettings{ @@ -79,6 +84,10 @@ var ( Enabled: true, Schedule: "5 MINUTE", UserTaskTimeoutMs: 1800000, + SessionParams: map[string]string{ + string(sdk.SessionParameterLockTimeout): "1000", + string(sdk.SessionParameterStrictJSONOutput): "true", + }, }, ChildTask: &TaskSettings{ @@ -95,7 +104,7 @@ var ( When: "TRUE", Enabled: true, SessionParams: map[string]string{ - "TIMESTAMP_INPUT_FORMAT": "YYYY-MM-DD HH24", + string(sdk.SessionParameterTimestampInputFormat): "YYYY-MM-DD HH24", }, Schedule: "5 MINUTE", UserTaskTimeoutMs: 1800000, @@ -113,6 +122,10 @@ var ( Enabled: true, Schedule: "15 MINUTE", UserTaskTimeoutMs: 1800000, + SessionParams: map[string]string{ + string(sdk.SessionParameterLockTimeout): "1000", + string(sdk.SessionParameterStrictJSONOutput): "true", + }, }, ChildTask: &TaskSettings{ @@ -144,6 +157,11 @@ var ( Enabled: false, Schedule: "5 MINUTE", UserTaskTimeoutMs: 1800000, + // Changes session params: one is updated, one is removed, one is added + SessionParams: map[string]string{ + string(sdk.SessionParameterLockTimeout): "2000", + string(sdk.SessionParameterMultiStatementCount): "5", + }, }, ChildTask: &TaskSettings{ @@ -160,7 +178,7 @@ var ( When: "TRUE", Enabled: true, SessionParams: map[string]string{ - "TIMESTAMP_INPUT_FORMAT": "YYYY-MM-DD HH24", + string(sdk.SessionParameterTimestampInputFormat): "YYYY-MM-DD HH24", }, Schedule: "5 MINUTE", UserTaskTimeoutMs: 0, @@ -193,6 +211,9 @@ func TestAcc_Task(t *testing.T) { resource.TestCheckResourceAttr("snowflake_task.child_task", "schedule", initialState.ChildTask.Schedule), checkInt64("snowflake_task.root_task", "user_task_timeout_ms", initialState.RootTask.UserTaskTimeoutMs), resource.TestCheckNoResourceAttr("snowflake_task.solo_task", "user_task_timeout_ms"), + checkInt64("snowflake_task.root_task", "session_parameters.LOCK_TIMEOUT", 1000), + checkBool("snowflake_task.root_task", "session_parameters.STRICT_JSON_OUTPUT", true), + resource.TestCheckNoResourceAttr("snowflake_task.root_task", "session_parameters.MULTI_STATEMENT_COUNT"), ), }, { @@ -213,6 +234,9 @@ func TestAcc_Task(t *testing.T) { resource.TestCheckResourceAttr("snowflake_task.child_task", "schedule", stepOne.ChildTask.Schedule), checkInt64("snowflake_task.root_task", "user_task_timeout_ms", stepOne.RootTask.UserTaskTimeoutMs), checkInt64("snowflake_task.solo_task", "user_task_timeout_ms", stepOne.SoloTask.UserTaskTimeoutMs), + checkInt64("snowflake_task.root_task", "session_parameters.LOCK_TIMEOUT", 1000), + checkBool("snowflake_task.root_task", "session_parameters.STRICT_JSON_OUTPUT", true), + resource.TestCheckNoResourceAttr("snowflake_task.root_task", "session_parameters.MULTI_STATEMENT_COUNT"), ), }, { @@ -233,6 +257,9 @@ func TestAcc_Task(t *testing.T) { resource.TestCheckResourceAttr("snowflake_task.child_task", "schedule", stepTwo.ChildTask.Schedule), checkInt64("snowflake_task.root_task", "user_task_timeout_ms", stepTwo.RootTask.UserTaskTimeoutMs), checkInt64("snowflake_task.solo_task", "user_task_timeout_ms", stepTwo.SoloTask.UserTaskTimeoutMs), + checkInt64("snowflake_task.root_task", "session_parameters.LOCK_TIMEOUT", 1000), + checkBool("snowflake_task.root_task", "session_parameters.STRICT_JSON_OUTPUT", true), + resource.TestCheckNoResourceAttr("snowflake_task.root_task", "session_parameters.MULTI_STATEMENT_COUNT"), ), }, { @@ -253,6 +280,9 @@ func TestAcc_Task(t *testing.T) { resource.TestCheckResourceAttr("snowflake_task.child_task", "schedule", stepThree.ChildTask.Schedule), checkInt64("snowflake_task.root_task", "user_task_timeout_ms", stepThree.RootTask.UserTaskTimeoutMs), checkInt64("snowflake_task.solo_task", "user_task_timeout_ms", stepThree.SoloTask.UserTaskTimeoutMs), + checkInt64("snowflake_task.root_task", "session_parameters.LOCK_TIMEOUT", 2000), + resource.TestCheckNoResourceAttr("snowflake_task.root_task", "session_parameters.STRICT_JSON_OUTPUT"), + checkInt64("snowflake_task.root_task", "session_parameters.MULTI_STATEMENT_COUNT", 5), ), }, { @@ -279,6 +309,9 @@ func TestAcc_Task(t *testing.T) { // `user_task_timeout_ms` by unsetting the // USER_TASK_TIMEOUT_MS session variable. checkInt64("snowflake_task.solo_task", "user_task_timeout_ms", initialState.ChildTask.UserTaskTimeoutMs), + checkInt64("snowflake_task.root_task", "session_parameters.LOCK_TIMEOUT", 1000), + checkBool("snowflake_task.root_task", "session_parameters.STRICT_JSON_OUTPUT", true), + resource.TestCheckNoResourceAttr("snowflake_task.root_task", "session_parameters.MULTI_STATEMENT_COUNT"), ), }, }, @@ -302,12 +335,12 @@ resource "snowflake_task" "root_task" { user_task_timeout_ms = {{ .RootTask.UserTaskTimeoutMs }} {{- end }} - {{ if .ChildTask.SessionParams }} + {{ if .RootTask.SessionParams }} session_parameters = { - {{ range $key, $value := .RootTask.SessionParams}} + {{ range $key, $value := .RootTask.SessionParams}} {{ $key }} = "{{ $value }}", - } {{- end }} + } {{- end }} } resource "snowflake_task" "child_task" { @@ -325,10 +358,10 @@ resource "snowflake_task" "child_task" { {{ if .ChildTask.SessionParams }} session_parameters = { - {{ range $key, $value := .ChildTask.SessionParams}} + {{ range $key, $value := .ChildTask.SessionParams}} {{ $key }} = "{{ $value }}", - } {{- end }} + } {{- end }} } resource "snowflake_task" "solo_task" { @@ -351,8 +384,8 @@ resource "snowflake_task" "solo_task" { session_parameters = { {{ range $key, $value := .SoloTask.SessionParams}} {{ $key }} = "{{ $value }}", - } {{- end }} + } {{- end }} } `) @@ -519,6 +552,7 @@ func TestAcc_Task_SwitchScheduled(t *testing.T) { resource.TestCheckResourceAttr("snowflake_task.test_task", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_task.test_task", "sql_statement", "SELECT 1"), resource.TestCheckResourceAttr("snowflake_task.test_task", "schedule", "5 MINUTE"), + resource.TestCheckResourceAttr("snowflake_task.test_task_root", "suspend_task_after_num_failures", "1"), ), }, { @@ -529,6 +563,7 @@ func TestAcc_Task_SwitchScheduled(t *testing.T) { resource.TestCheckResourceAttr("snowflake_task.test_task", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_task.test_task", "sql_statement", "SELECT 1"), resource.TestCheckResourceAttr("snowflake_task.test_task", "schedule", ""), + resource.TestCheckResourceAttr("snowflake_task.test_task_root", "suspend_task_after_num_failures", "2"), ), }, { @@ -539,6 +574,7 @@ func TestAcc_Task_SwitchScheduled(t *testing.T) { resource.TestCheckResourceAttr("snowflake_task.test_task", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_task.test_task", "sql_statement", "SELECT 1"), resource.TestCheckResourceAttr("snowflake_task.test_task", "schedule", "5 MINUTE"), + resource.TestCheckResourceAttr("snowflake_task.test_task_root", "suspend_task_after_num_failures", "1"), ), }, { @@ -549,6 +585,7 @@ func TestAcc_Task_SwitchScheduled(t *testing.T) { resource.TestCheckResourceAttr("snowflake_task.test_task", "schema", acc.TestSchemaName), resource.TestCheckResourceAttr("snowflake_task.test_task", "sql_statement", "SELECT 1"), resource.TestCheckResourceAttr("snowflake_task.test_task", "schedule", ""), + resource.TestCheckResourceAttr("snowflake_task.test_task_root", "suspend_task_after_num_failures", "0"), ), }, }, @@ -558,12 +595,13 @@ func TestAcc_Task_SwitchScheduled(t *testing.T) { func taskConfigManagedScheduled(name string, taskRootName string, databaseName string, schemaName string) string { s := ` resource "snowflake_task" "test_task_root" { - name = "%s" - database = "%s" - schema = "%s" - sql_statement = "SELECT 1" - enabled = true - schedule = "5 MINUTE" + name = "%s" + database = "%s" + schema = "%s" + sql_statement = "SELECT 1" + enabled = true + schedule = "5 MINUTE" + suspend_task_after_num_failures = 1 } resource "snowflake_task" "test_task" { @@ -581,12 +619,13 @@ resource "snowflake_task" "test_task" { func taskConfigManagedScheduled2(name string, taskRootName string, databaseName string, schemaName string) string { s := ` resource "snowflake_task" "test_task_root" { - name = "%s" - database = "%s" - schema = "%s" - sql_statement = "SELECT 1" - enabled = true - schedule = "5 MINUTE" + name = "%s" + database = "%s" + schema = "%s" + sql_statement = "SELECT 1" + enabled = true + schedule = "5 MINUTE" + suspend_task_after_num_failures = 2 } resource "snowflake_task" "test_task" { diff --git a/pkg/sdk/internal/collections/queue.go b/pkg/sdk/internal/collections/queue.go new file mode 100644 index 0000000000..3749f1bc3e --- /dev/null +++ b/pkg/sdk/internal/collections/queue.go @@ -0,0 +1,30 @@ +package collections + +type Queue[T any] struct { + data []T +} + +func (s *Queue[T]) Head() *T { + if len(s.data) == 0 { + return nil + } + return &s.data[0] +} + +func (s *Queue[T]) Pop() *T { + elem := s.Head() + if elem != nil { + s.data = s.data[1:] + } + return elem +} + +func (s *Queue[T]) Push(elem T) { + s.data = append(s.data, elem) +} + +func NewQueue[T any]() Queue[T] { + return Queue[T]{ + data: make([]T, 0), + } +} diff --git a/pkg/sdk/internal/collections/queue_test.go b/pkg/sdk/internal/collections/queue_test.go new file mode 100644 index 0000000000..05387df658 --- /dev/null +++ b/pkg/sdk/internal/collections/queue_test.go @@ -0,0 +1,51 @@ +package collections + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestQueue(t *testing.T) { + t.Run("empty queue initialization", func(t *testing.T) { + queue := NewQueue[int]() + + require.Nil(t, queue.Head()) + require.Nil(t, queue.Pop()) + }) + + t.Run("returns head multiple times", func(t *testing.T) { + queue := NewQueue[int]() + + queue.Push(1) + + require.Equal(t, 1, *queue.Head()) + require.Equal(t, 1, *queue.Head()) + }) + + t.Run("returns empty head after pop", func(t *testing.T) { + queue := NewQueue[int]() + + queue.Pop() + + require.Nil(t, queue.Head()) + }) + + t.Run("multiple operations", func(t *testing.T) { + queue := NewQueue[int]() + + queue.Push(1) + require.Equal(t, 1, *queue.Head()) + + queue.Push(2) + require.Equal(t, 1, *queue.Head()) + + elem := queue.Pop() + require.Equal(t, 1, *elem) + require.Equal(t, 2, *queue.Head()) + + elem = queue.Pop() + require.Equal(t, 2, *elem) + require.Nil(t, queue.Head()) + }) +} diff --git a/pkg/sdk/parameters.go b/pkg/sdk/parameters.go index 6a7318bd1d..25304c5851 100644 --- a/pkg/sdk/parameters.go +++ b/pkg/sdk/parameters.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "strconv" + "strings" ) var ( @@ -150,327 +151,31 @@ func (parameters *parameters) SetAccountParameter(ctx context.Context, parameter } func (parameters *parameters) SetSessionParameterOnAccount(ctx context.Context, parameter SessionParameter, value string) error { - opts := AlterAccountOptions{Set: &AccountSet{Parameters: &AccountLevelParameters{SessionParameters: &SessionParameters{}}}} - switch parameter { - case SessionParameterAbortDetachedQuery: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.Parameters.SessionParameters.AbortDetachedQuery = b - case SessionParameterAutocommit: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.Parameters.SessionParameters.Autocommit = b - case SessionParameterBinaryInputFormat: - opts.Set.Parameters.SessionParameters.BinaryInputFormat = Pointer(BinaryInputFormat(value)) - case SessionParameterBinaryOutputFormat: - opts.Set.Parameters.SessionParameters.BinaryOutputFormat = Pointer(BinaryOutputFormat(value)) - case SessionParameterClientMetadataRequestUseConnectionCtx: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.Parameters.SessionParameters.ClientMetadataRequestUseConnectionCtx = b - case SessionParameterClientMetadataUseSessionDatabase: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.Parameters.SessionParameters.ClientMetadataUseSessionDatabase = b - case SessionParameterClientResultColumnCaseInsensitive: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.Parameters.SessionParameters.ClientResultColumnCaseInsensitive = b - case SessionParameterDateInputFormat: - opts.Set.Parameters.SessionParameters.DateInputFormat = &value - case SessionParameterDateOutputFormat: - opts.Set.Parameters.SessionParameters.DateOutputFormat = &value - case SessionParameterErrorOnNondeterministicMerge: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.Parameters.SessionParameters.ErrorOnNondeterministicMerge = b - case SessionParameterErrorOnNondeterministicUpdate: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.Parameters.SessionParameters.ErrorOnNondeterministicUpdate = b - case SessionParameterGeographyOutputFormat: - opts.Set.Parameters.SessionParameters.GeographyOutputFormat = Pointer(GeographyOutputFormat(value)) - case SessionParameterJSONIndent: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("JSON_INDENT session parameter is an integer, got %v", value) - } - opts.Set.Parameters.SessionParameters.JSONIndent = Pointer(v) - case SessionParameterLockTimeout: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("LOCK_TIMEOUT session parameter is an integer, got %v", value) - } - opts.Set.Parameters.SessionParameters.LockTimeout = Pointer(v) - case SessionParameterMultiStatementCount: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("MULTI_STATEMENT_COUNT session parameter is an integer, got %v", value) - } - opts.Set.Parameters.SessionParameters.MultiStatementCount = Pointer(v) - - case SessionParameterQueryTag: - opts.Set.Parameters.SessionParameters.QueryTag = &value - case SessionParameterQuotedIdentifiersIgnoreCase: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.Parameters.SessionParameters.QuotedIdentifiersIgnoreCase = b - case SessionParameterRowsPerResultset: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("ROWS_PER_RESULTSET session parameter is an integer, got %v", value) - } - opts.Set.Parameters.SessionParameters.RowsPerResultset = Pointer(v) - case SessionParameterSimulatedDataSharingConsumer: - opts.Set.Parameters.SessionParameters.SimulatedDataSharingConsumer = &value - case SessionParameterStatementTimeoutInSeconds: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("STATEMENT_TIMEOUT_IN_SECONDS session parameter is an integer, got %v", value) - } - opts.Set.Parameters.SessionParameters.StatementTimeoutInSeconds = Pointer(v) - case SessionParameterStrictJSONOutput: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.Parameters.SessionParameters.StrictJSONOutput = b - case SessionParameterTimestampDayIsAlways24h: - b, err := parseBooleanParameter(string(parameter), value) + sp := &SessionParameters{} + err := sp.setParam(parameter, value) + if err == nil { + opts := AlterAccountOptions{Set: &AccountSet{Parameters: &AccountLevelParameters{SessionParameters: sp}}} + err = parameters.client.Accounts.Alter(ctx, &opts) if err != nil { return err } - opts.Set.Parameters.SessionParameters.TimestampDayIsAlways24h = b - case SessionParameterTimestampInputFormat: - opts.Set.Parameters.SessionParameters.TimestampInputFormat = &value - case SessionParameterTimestampLTZOutputFormat: - opts.Set.Parameters.SessionParameters.TimestampLTZOutputFormat = &value - case SessionParameterTimestampNTZOutputFormat: - opts.Set.Parameters.SessionParameters.TimestampNTZOutputFormat = &value - case SessionParameterTimestampOutputFormat: - opts.Set.Parameters.SessionParameters.TimestampOutputFormat = &value - case SessionParameterTimestampTypeMapping: - opts.Set.Parameters.SessionParameters.TimestampTypeMapping = &value - case SessionParameterTimestampTZOutputFormat: - opts.Set.Parameters.SessionParameters.TimestampTZOutputFormat = &value - case SessionParameterTimezone: - opts.Set.Parameters.SessionParameters.Timezone = &value - case SessionParameterTimeInputFormat: - opts.Set.Parameters.SessionParameters.TimeInputFormat = &value - case SessionParameterTimeOutputFormat: - opts.Set.Parameters.SessionParameters.TimeOutputFormat = &value - case SessionParameterTransactionDefaultIsolationLevel: - opts.Set.Parameters.SessionParameters.TransactionDefaultIsolationLevel = Pointer(TransactionDefaultIsolationLevel(value)) - case SessionParameterTwoDigitCenturyStart: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("TWO_DIGIT_CENTURY_START session parameter is an integer, got %v", value) - } - opts.Set.Parameters.SessionParameters.TwoDigitCenturyStart = Pointer(v) - case SessionParameterUnsupportedDDLAction: - opts.Set.Parameters.SessionParameters.UnsupportedDDLAction = Pointer(UnsupportedDDLAction(value)) - case SessionParameterUseCachedResult: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.Parameters.SessionParameters.UseCachedResult = b - case SessionParameterWeekOfYearPolicy: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("WEEK_OF_YEAR_POLICY session parameter is an integer, got %v", value) + return nil + } else { + if strings.Contains(err.Error(), "session parameter is not supported") { + return parameters.SetObjectParameterOnAccount(ctx, ObjectParameter(parameter), value) } - opts.Set.Parameters.SessionParameters.WeekOfYearPolicy = Pointer(v) - case SessionParameterWeekStart: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("WEEK_START session parameter is an integer, got %v", value) - } - opts.Set.Parameters.SessionParameters.WeekStart = Pointer(v) - default: - return parameters.SetObjectParameterOnAccount(ctx, ObjectParameter(parameter), value) - } - err := parameters.client.Accounts.Alter(ctx, &opts) - if err != nil { return err } - return nil } func (parameters *parameters) SetSessionParameterOnUser(ctx context.Context, userId AccountObjectIdentifier, parameter SessionParameter, value string) error { - opts := AlterUserOptions{Set: &UserSet{SessionParameters: &SessionParameters{}}} - switch parameter { - case SessionParameterAbortDetachedQuery: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.SessionParameters.AbortDetachedQuery = b - case SessionParameterAutocommit: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.SessionParameters.Autocommit = b - case SessionParameterBinaryInputFormat: - opts.Set.SessionParameters.BinaryInputFormat = Pointer(BinaryInputFormat(value)) - case SessionParameterBinaryOutputFormat: - opts.Set.SessionParameters.BinaryOutputFormat = Pointer(BinaryOutputFormat(value)) - case SessionParameterClientMetadataRequestUseConnectionCtx: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.SessionParameters.ClientMetadataRequestUseConnectionCtx = b - case SessionParameterClientMetadataUseSessionDatabase: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.SessionParameters.ClientMetadataUseSessionDatabase = b - case SessionParameterClientResultColumnCaseInsensitive: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.SessionParameters.ClientResultColumnCaseInsensitive = b - case SessionParameterDateInputFormat: - opts.Set.SessionParameters.DateInputFormat = &value - case SessionParameterDateOutputFormat: - opts.Set.SessionParameters.DateOutputFormat = &value - case SessionParameterErrorOnNondeterministicMerge: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.SessionParameters.ErrorOnNondeterministicMerge = b - case SessionParameterErrorOnNondeterministicUpdate: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.SessionParameters.ErrorOnNondeterministicUpdate = b - case SessionParameterGeographyOutputFormat: - opts.Set.SessionParameters.GeographyOutputFormat = Pointer(GeographyOutputFormat(value)) - case SessionParameterJSONIndent: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("JSON_INDENT session parameter is an integer, got %v", value) - } - opts.Set.SessionParameters.JSONIndent = Pointer(v) - case SessionParameterLockTimeout: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("LOCK_TIMEOUT session parameter is an integer, got %v", value) - } - opts.Set.SessionParameters.LockTimeout = Pointer(v) - case SessionParameterMultiStatementCount: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("MULTI_STATEMENT_COUNT session parameter is an integer, got %v", value) - } - opts.Set.SessionParameters.MultiStatementCount = Pointer(v) - - case SessionParameterQueryTag: - opts.Set.SessionParameters.QueryTag = &value - case SessionParameterQuotedIdentifiersIgnoreCase: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.SessionParameters.QuotedIdentifiersIgnoreCase = b - case SessionParameterRowsPerResultset: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("ROWS_PER_RESULTSET session parameter is an integer, got %v", value) - } - opts.Set.SessionParameters.RowsPerResultset = Pointer(v) - case SessionParameterSimulatedDataSharingConsumer: - opts.Set.SessionParameters.SimulatedDataSharingConsumer = &value - case SessionParameterStatementTimeoutInSeconds: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("STATEMENT_TIMEOUT_IN_SECONDS session parameter is an integer, got %v", value) - } - opts.Set.SessionParameters.StatementTimeoutInSeconds = Pointer(v) - case SessionParameterStrictJSONOutput: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.SessionParameters.StrictJSONOutput = b - case SessionParameterTimestampDayIsAlways24h: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.SessionParameters.TimestampDayIsAlways24h = b - case SessionParameterTimestampInputFormat: - opts.Set.SessionParameters.TimestampInputFormat = &value - case SessionParameterTimestampLTZOutputFormat: - opts.Set.SessionParameters.TimestampLTZOutputFormat = &value - case SessionParameterTimestampNTZOutputFormat: - opts.Set.SessionParameters.TimestampNTZOutputFormat = &value - case SessionParameterTimestampOutputFormat: - opts.Set.SessionParameters.TimestampOutputFormat = &value - case SessionParameterTimestampTypeMapping: - opts.Set.SessionParameters.TimestampTypeMapping = &value - case SessionParameterTimestampTZOutputFormat: - opts.Set.SessionParameters.TimestampTZOutputFormat = &value - case SessionParameterTimezone: - opts.Set.SessionParameters.Timezone = &value - case SessionParameterTimeInputFormat: - opts.Set.SessionParameters.TimeInputFormat = &value - case SessionParameterTimeOutputFormat: - opts.Set.SessionParameters.TimeOutputFormat = &value - case SessionParameterTransactionDefaultIsolationLevel: - opts.Set.SessionParameters.TransactionDefaultIsolationLevel = Pointer(TransactionDefaultIsolationLevel(value)) - case SessionParameterTwoDigitCenturyStart: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("TWO_DIGIT_CENTURY_START session parameter is an integer, got %v", value) - } - opts.Set.SessionParameters.TwoDigitCenturyStart = Pointer(v) - case SessionParameterUnsupportedDDLAction: - opts.Set.SessionParameters.UnsupportedDDLAction = Pointer(UnsupportedDDLAction(value)) - case SessionParameterUseCachedResult: - b, err := parseBooleanParameter(string(parameter), value) - if err != nil { - return err - } - opts.Set.SessionParameters.UseCachedResult = b - case SessionParameterWeekOfYearPolicy: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("WEEK_OF_YEAR_POLICY session parameter is an integer, got %v", value) - } - opts.Set.SessionParameters.WeekOfYearPolicy = Pointer(v) - case SessionParameterWeekStart: - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("WEEK_START session parameter is an integer, got %v", value) - } - opts.Set.SessionParameters.WeekStart = Pointer(v) - default: - return fmt.Errorf("Invalid session parameter: %v", string(parameter)) + sp := &SessionParameters{} + err := sp.setParam(parameter, value) + if err != nil { + return err } - err := parameters.client.Users.Alter(ctx, userId, &opts) + opts := AlterUserOptions{Set: &UserSet{SessionParameters: sp}} + err = parameters.client.Users.Alter(ctx, userId, &opts) if err != nil { return err } @@ -1013,7 +718,7 @@ type SessionParametersUnset struct { } func (v *SessionParametersUnset) validate() error { - if !anyValueSet(v.AbortDetachedQuery, v.Autocommit, v.BinaryInputFormat, v.BinaryOutputFormat, v.DateInputFormat, v.DateOutputFormat, v.ErrorOnNondeterministicMerge, v.ErrorOnNondeterministicUpdate, v.GeographyOutputFormat, v.JSONIndent, v.LockTimeout, v.QueryTag, v.RowsPerResultset, v.SimulatedDataSharingConsumer, v.StatementTimeoutInSeconds, v.StrictJSONOutput, v.TimestampDayIsAlways24h, v.TimestampInputFormat, v.TimestampLTZOutputFormat, v.TimestampNTZOutputFormat, v.TimestampOutputFormat, v.TimestampTypeMapping, v.TimestampTZOutputFormat, v.Timezone, v.TimeInputFormat, v.TimeOutputFormat, v.TransactionDefaultIsolationLevel, v.TwoDigitCenturyStart, v.UnsupportedDDLAction, v.UseCachedResult, v.WeekOfYearPolicy, v.WeekStart) { + if !anyValueSet(v.AbortDetachedQuery, v.Autocommit, v.BinaryInputFormat, v.BinaryOutputFormat, v.ClientMetadataRequestUseConnectionCtx, v.ClientMetadataUseSessionDatabase, v.ClientResultColumnCaseInsensitive, v.DateInputFormat, v.DateOutputFormat, v.ErrorOnNondeterministicMerge, v.ErrorOnNondeterministicUpdate, v.GeographyOutputFormat, v.JSONIndent, v.LockTimeout, v.MultiStatementCount, v.QueryTag, v.QuotedIdentifiersIgnoreCase, v.RowsPerResultset, v.SimulatedDataSharingConsumer, v.StatementTimeoutInSeconds, v.StrictJSONOutput, v.TimestampDayIsAlways24h, v.TimestampInputFormat, v.TimestampLTZOutputFormat, v.TimestampNTZOutputFormat, v.TimestampOutputFormat, v.TimestampTypeMapping, v.TimestampTZOutputFormat, v.Timezone, v.TimeInputFormat, v.TimeOutputFormat, v.TransactionDefaultIsolationLevel, v.TwoDigitCenturyStart, v.UnsupportedDDLAction, v.UseCachedResult, v.WeekOfYearPolicy, v.WeekStart) { return errors.Join(errAtLeastOneOf("SessionParametersUnset", "AbortDetachedQuery", "Autocommit", "BinaryInputFormat", "BinaryOutputFormat", "DateInputFormat", "DateOutputFormat", "ErrorOnNondeterministicMerge", "ErrorOnNondeterministicUpdate", "GeographyOutputFormat", "JSONIndent", "LockTimeout", "QueryTag", "RowsPerResultset", "SimulatedDataSharingConsumer", "StatementTimeoutInSeconds", "StrictJSONOutput", "TimestampDayIsAlways24h", "TimestampInputFormat", "TimestampLTZOutputFormat", "TimestampNTZOutputFormat", "TimestampOutputFormat", "TimestampTypeMapping", "TimestampTZOutputFormat", "Timezone", "TimeInputFormat", "TimeOutputFormat", "TransactionDefaultIsolationLevel", "TwoDigitCenturyStart", "UnsupportedDDLAction", "UseCachedResult", "WeekOfYearPolicy", "WeekStart")) } return nil diff --git a/pkg/sdk/parameters_impl.go b/pkg/sdk/parameters_impl.go index 110efbeff3..235f88c4b8 100644 --- a/pkg/sdk/parameters_impl.go +++ b/pkg/sdk/parameters_impl.go @@ -20,8 +20,6 @@ func GetSessionParametersFrom(params map[string]any) (*SessionParameters, error) return sessionParameters, nil } -// TODO [SNOW-884987]: use this method in SetSessionParameterOnAccount and in SetSessionParameterOnUser -// TODO [SNOW-884987]: unit test this method func (sessionParameters *SessionParameters) setParam(parameter SessionParameter, value string) error { switch parameter { case SessionParameterAbortDetachedQuery: diff --git a/pkg/sdk/parameters_impl_test.go b/pkg/sdk/parameters_impl_test.go new file mode 100644 index 0000000000..59a6bdc04b --- /dev/null +++ b/pkg/sdk/parameters_impl_test.go @@ -0,0 +1,117 @@ +package sdk + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSessionParameters_setParam(t *testing.T) { + tests := []struct { + parameter SessionParameter + value string + expectedValue any + accessor func(*SessionParameters) any + }{ + {parameter: SessionParameterAbortDetachedQuery, value: "true", expectedValue: true, accessor: func(sp *SessionParameters) any { return *sp.AbortDetachedQuery }}, + {parameter: SessionParameterAutocommit, value: "true", expectedValue: true, accessor: func(sp *SessionParameters) any { return *sp.Autocommit }}, + {parameter: SessionParameterBinaryInputFormat, value: "some", expectedValue: BinaryInputFormat("some"), accessor: func(sp *SessionParameters) any { return *sp.BinaryInputFormat }}, + {parameter: SessionParameterBinaryOutputFormat, value: "some", expectedValue: BinaryOutputFormat("some"), accessor: func(sp *SessionParameters) any { return *sp.BinaryOutputFormat }}, + {parameter: SessionParameterClientMetadataRequestUseConnectionCtx, value: "true", expectedValue: true, accessor: func(sp *SessionParameters) any { return *sp.ClientMetadataRequestUseConnectionCtx }}, + {parameter: SessionParameterClientMetadataUseSessionDatabase, value: "true", expectedValue: true, accessor: func(sp *SessionParameters) any { return *sp.ClientMetadataUseSessionDatabase }}, + {parameter: SessionParameterClientResultColumnCaseInsensitive, value: "true", expectedValue: true, accessor: func(sp *SessionParameters) any { return *sp.ClientResultColumnCaseInsensitive }}, + {parameter: SessionParameterDateInputFormat, value: "some", expectedValue: "some", accessor: func(sp *SessionParameters) any { return *sp.DateInputFormat }}, + {parameter: SessionParameterGeographyOutputFormat, value: "some", expectedValue: GeographyOutputFormat("some"), accessor: func(sp *SessionParameters) any { return *sp.GeographyOutputFormat }}, + {parameter: SessionParameterDateOutputFormat, value: "some", expectedValue: "some", accessor: func(sp *SessionParameters) any { return *sp.DateOutputFormat }}, + {parameter: SessionParameterErrorOnNondeterministicMerge, value: "true", expectedValue: true, accessor: func(sp *SessionParameters) any { return *sp.ErrorOnNondeterministicMerge }}, + {parameter: SessionParameterErrorOnNondeterministicUpdate, value: "true", expectedValue: true, accessor: func(sp *SessionParameters) any { return *sp.ErrorOnNondeterministicUpdate }}, + {parameter: SessionParameterJSONIndent, value: "1", expectedValue: 1, accessor: func(sp *SessionParameters) any { return *sp.JSONIndent }}, + {parameter: SessionParameterLockTimeout, value: "1", expectedValue: 1, accessor: func(sp *SessionParameters) any { return *sp.LockTimeout }}, + {parameter: SessionParameterMultiStatementCount, value: "1", expectedValue: 1, accessor: func(sp *SessionParameters) any { return *sp.MultiStatementCount }}, + {parameter: SessionParameterQueryTag, value: "some", expectedValue: "some", accessor: func(sp *SessionParameters) any { return *sp.QueryTag }}, + {parameter: SessionParameterQuotedIdentifiersIgnoreCase, value: "true", expectedValue: true, accessor: func(sp *SessionParameters) any { return *sp.QuotedIdentifiersIgnoreCase }}, + {parameter: SessionParameterRowsPerResultset, value: "1", expectedValue: 1, accessor: func(sp *SessionParameters) any { return *sp.RowsPerResultset }}, + {parameter: SessionParameterSimulatedDataSharingConsumer, value: "some", expectedValue: "some", accessor: func(sp *SessionParameters) any { return *sp.SimulatedDataSharingConsumer }}, + {parameter: SessionParameterStatementTimeoutInSeconds, value: "1", expectedValue: 1, accessor: func(sp *SessionParameters) any { return *sp.StatementTimeoutInSeconds }}, + {parameter: SessionParameterStrictJSONOutput, value: "true", expectedValue: true, accessor: func(sp *SessionParameters) any { return *sp.StrictJSONOutput }}, + {parameter: SessionParameterTimeInputFormat, value: "some", expectedValue: "some", accessor: func(sp *SessionParameters) any { return *sp.TimeInputFormat }}, + {parameter: SessionParameterTimeOutputFormat, value: "some", expectedValue: "some", accessor: func(sp *SessionParameters) any { return *sp.TimeOutputFormat }}, + {parameter: SessionParameterTimestampDayIsAlways24h, value: "true", expectedValue: true, accessor: func(sp *SessionParameters) any { return *sp.TimestampDayIsAlways24h }}, + {parameter: SessionParameterTimestampInputFormat, value: "some", expectedValue: "some", accessor: func(sp *SessionParameters) any { return *sp.TimestampInputFormat }}, + {parameter: SessionParameterTimestampLTZOutputFormat, value: "some", expectedValue: "some", accessor: func(sp *SessionParameters) any { return *sp.TimestampLTZOutputFormat }}, + {parameter: SessionParameterTimestampNTZOutputFormat, value: "some", expectedValue: "some", accessor: func(sp *SessionParameters) any { return *sp.TimestampNTZOutputFormat }}, + {parameter: SessionParameterTimestampOutputFormat, value: "some", expectedValue: "some", accessor: func(sp *SessionParameters) any { return *sp.TimestampOutputFormat }}, + {parameter: SessionParameterTimestampTypeMapping, value: "some", expectedValue: "some", accessor: func(sp *SessionParameters) any { return *sp.TimestampTypeMapping }}, + {parameter: SessionParameterTimestampTZOutputFormat, value: "some", expectedValue: "some", accessor: func(sp *SessionParameters) any { return *sp.TimestampTZOutputFormat }}, + {parameter: SessionParameterTimezone, value: "some", expectedValue: "some", accessor: func(sp *SessionParameters) any { return *sp.Timezone }}, + {parameter: SessionParameterTransactionDefaultIsolationLevel, value: "some", expectedValue: TransactionDefaultIsolationLevel("some"), accessor: func(sp *SessionParameters) any { return *sp.TransactionDefaultIsolationLevel }}, + {parameter: SessionParameterTwoDigitCenturyStart, value: "1", expectedValue: 1, accessor: func(sp *SessionParameters) any { return *sp.TwoDigitCenturyStart }}, + {parameter: SessionParameterUnsupportedDDLAction, value: "some", expectedValue: UnsupportedDDLAction("some"), accessor: func(sp *SessionParameters) any { return *sp.UnsupportedDDLAction }}, + {parameter: SessionParameterUseCachedResult, value: "true", expectedValue: true, accessor: func(sp *SessionParameters) any { return *sp.UseCachedResult }}, + {parameter: SessionParameterWeekOfYearPolicy, value: "1", expectedValue: 1, accessor: func(sp *SessionParameters) any { return *sp.WeekOfYearPolicy }}, + {parameter: SessionParameterWeekStart, value: "1", expectedValue: 1, accessor: func(sp *SessionParameters) any { return *sp.WeekStart }}, + } + for _, tt := range tests { + t.Run(fmt.Sprintf("test valid value '%s' for parameter %s", tt.value, tt.parameter), func(t *testing.T) { + sessionParameters := &SessionParameters{} + + err := sessionParameters.setParam(tt.parameter, tt.value) + + require.NoError(t, err) + require.Equal(t, tt.expectedValue, tt.accessor(sessionParameters)) + }) + } + + invalidCases := []struct { + parameter SessionParameter + value string + }{ + {parameter: SessionParameterAbortDetachedQuery, value: "true123"}, + {parameter: SessionParameterAutocommit, value: "true123"}, + // {parameter: SessionParameterBinaryInputFormat, value: "some"}, // add validation + // {parameter: SessionParameterBinaryOutputFormat, value: "some"}, // add validation + {parameter: SessionParameterClientMetadataRequestUseConnectionCtx, value: "true123"}, + {parameter: SessionParameterClientMetadataUseSessionDatabase, value: "true123"}, + {parameter: SessionParameterClientResultColumnCaseInsensitive, value: "true123"}, + // {parameter: SessionParameterDateInputFormat, value: "some"}, // add validation + // {parameter: SessionParameterGeographyOutputFormat, value: "some"}, // add validation + // {parameter: SessionParameterDateOutputFormat, value: "some"}, // add validation + {parameter: SessionParameterErrorOnNondeterministicMerge, value: "true123"}, + {parameter: SessionParameterErrorOnNondeterministicUpdate, value: "true123"}, + {parameter: SessionParameterJSONIndent, value: "aaa"}, + {parameter: SessionParameterLockTimeout, value: "aaa"}, + {parameter: SessionParameterMultiStatementCount, value: "aaa"}, + // {parameter: SessionParameterQueryTag, value: "some"}, // add validation + {parameter: SessionParameterQuotedIdentifiersIgnoreCase, value: "true123"}, + {parameter: SessionParameterRowsPerResultset, value: "aaa"}, + // {parameter: SessionParameterSimulatedDataSharingConsumer, value: "some"}, // add validation + {parameter: SessionParameterStatementTimeoutInSeconds, value: "aaa"}, + {parameter: SessionParameterStrictJSONOutput, value: "true123"}, + // {parameter: SessionParameterTimeInputFormat, value: "some"}, // add validation + // {parameter: SessionParameterTimeOutputFormat, value: "some"}, // add validation + {parameter: SessionParameterTimestampDayIsAlways24h, value: "true123"}, + // {parameter: SessionParameterTimestampInputFormat, value: "some"}, // add validation + // {parameter: SessionParameterTimestampLTZOutputFormat, value: "some"}, // add validation + // {parameter: SessionParameterTimestampNTZOutputFormat, value: "some"}, // add validation + // {parameter: SessionParameterTimestampOutputFormat, value: "some"}, // add validation + // {parameter: SessionParameterTimestampTypeMapping, value: "some"}, // add validation + // {parameter: SessionParameterTimestampTZOutputFormat, value: "some"}, // add validation + // {parameter: SessionParameterTimezone, value: "some"}, // add validation + // {parameter: SessionParameterTransactionDefaultIsolationLevel, value: "some"}, // add validation + {parameter: SessionParameterTwoDigitCenturyStart, value: "aaa"}, + // {parameter: SessionParameterUnsupportedDDLAction, value: "some"}, // add validation + {parameter: SessionParameterUseCachedResult, value: "true123"}, + {parameter: SessionParameterWeekOfYearPolicy, value: "aaa"}, + {parameter: SessionParameterWeekStart, value: "aaa"}, + } + for _, tt := range invalidCases { + t.Run(fmt.Sprintf("test invalid value '%s' for parameter %s", tt.value, tt.parameter), func(t *testing.T) { + sessionParameters := &SessionParameters{} + + err := sessionParameters.setParam(tt.parameter, tt.value) + + require.Error(t, err) + }) + } +} diff --git a/pkg/sdk/tasks_impl_gen.go b/pkg/sdk/tasks_impl_gen.go index dd6f3070aa..6a4cb38d4c 100644 --- a/pkg/sdk/tasks_impl_gen.go +++ b/pkg/sdk/tasks_impl_gen.go @@ -3,8 +3,10 @@ package sdk import ( "context" "encoding/json" - "fmt" "strings" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/internal/collections" + "golang.org/x/exp/slices" ) var _ Tasks = (*tasks)(nil) @@ -66,40 +68,38 @@ func (v *tasks) Execute(ctx context.Context, request *ExecuteTaskRequest) error // GetRootTasks is a way to get all root tasks for the given tasks. // Snowflake does not have (yet) a method to do it without traversing the task graph manually. // Task DAG should have a single root but this is checked when the root task is being resumed; that's why we return here multiple roots. -// Cycles should not be possible in a task DAG but it is checked when the root task is being resumed; that's why this method has to be cycle-proof. -// TODO [SNOW-884987]: handle cycles +// Cycles should not be possible in a task DAG, but it is checked when the root task is being resumed; that's why this method has to be cycle-proof. func GetRootTasks(v Tasks, ctx context.Context, id SchemaObjectIdentifier) ([]Task, error) { - task, err := v.ShowByID(ctx, id) - if err != nil { - return nil, err - } + tasksToExamine := collections.NewQueue[SchemaObjectIdentifier]() + alreadyExaminedTasksNames := make([]string, 0) + rootTasks := make([]Task, 0) - predecessors := task.Predecessors - // no predecessors mean this is a root task - if len(predecessors) == 0 { - return []Task{*task}, nil - } + tasksToExamine.Push(id) + + for tasksToExamine.Head() != nil { + current := tasksToExamine.Pop() - rootTasks := make([]Task, 0, len(predecessors)) - for _, predecessor := range predecessors { - predecessorTasks, err := GetRootTasks(v, ctx, predecessor) + if slices.Contains(alreadyExaminedTasksNames, current.Name()) { + continue + } + + task, err := v.ShowByID(ctx, *current) if err != nil { - return nil, fmt.Errorf("unable to get predecessors for task %s err = %w", predecessor.FullyQualifiedName(), err) + return nil, err } - rootTasks = append(rootTasks, predecessorTasks...) - } - // TODO [SNOW-884987]: extract unique function in our collection helper (if cycle-proof algorithm still needs it) - keys := make(map[string]bool) - uniqueRootTasks := make([]Task, 0, len(rootTasks)) - for _, rootTask := range rootTasks { - if _, exists := keys[rootTask.ID().FullyQualifiedName()]; !exists { - keys[rootTask.ID().FullyQualifiedName()] = true - uniqueRootTasks = append(uniqueRootTasks, rootTask) + predecessors := task.Predecessors + if len(predecessors) == 0 { + rootTasks = append(rootTasks, *task) + } else { + for _, p := range predecessors { + tasksToExamine.Push(p) + } } + alreadyExaminedTasksNames = append(alreadyExaminedTasksNames, current.Name()) } - return uniqueRootTasks, nil + return rootTasks, nil } func (r *CreateTaskRequest) toOpts() *CreateTaskOptions { diff --git a/pkg/sdk/tasks_test.go b/pkg/sdk/tasks_test.go index 9d9f27d7d5..0faf8bf20e 100644 --- a/pkg/sdk/tasks_test.go +++ b/pkg/sdk/tasks_test.go @@ -58,7 +58,8 @@ func TestTasks_GetRootTasks(t *testing.T) { {"t1": {}, "t2": {}, "initial": {"t1"}, "expected": {"t1"}}, {"t1": {"t2", "t3", "t4"}, "t2": {}, "t3": {}, "t4": {}, "initial": {"t1"}, "expected": {"t2", "t3", "t4"}}, {"t1": {"t2", "t3", "t4"}, "t2": {}, "t3": {"t2"}, "t4": {"t3"}, "initial": {"t1"}, "expected": {"t2"}}, - // {"r": {}, "t1": {"t2", "r"}, "t2": {"t3"}, "t3": {"t1"}, "initial": {"t1"}, "expected": {"r"}}, // cycle -> failing for current (old) implementation + {"r": {}, "t1": {"t2", "r"}, "t2": {"t3"}, "t3": {"t1"}, "initial": {"t1"}, "expected": {"r"}}, // cycle -> failing for the old implementation + {"r": {}, "t1": {"t2", "r"}, "t2": {"t3"}, "t3": {"t1"}, "initial": {"t3"}, "expected": {"r"}}, // cycle -> failing for the old implementation } for i, tt := range tests { t.Run(fmt.Sprintf("test case [%v]", i), func(t *testing.T) { diff --git a/pkg/sdk/testint/tasks_gen_integration_test.go b/pkg/sdk/testint/tasks_gen_integration_test.go index f6dd6d9e5e..f68847fa47 100644 --- a/pkg/sdk/testint/tasks_gen_integration_test.go +++ b/pkg/sdk/testint/tasks_gen_integration_test.go @@ -257,6 +257,12 @@ func TestInt_Tasks(t *testing.T) { err = client.Tasks.Alter(ctx, alterRequest) require.NoError(t, err) + // can get the root task even with cycle + rootTasks, err = sdk.GetRootTasks(client.Tasks, ctx, t3Id) + require.NoError(t, err) + require.Len(t, rootTasks, 1) + require.Equal(t, rootId, rootTasks[0].ID()) + // we get an error when trying to start alterRequest = sdk.NewAlterTaskRequest(rootId).WithResume(sdk.Bool(true)) err = client.Tasks.Alter(ctx, alterRequest) From dbb7c9136c586490a0856cc07ae879be491c8150 Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Mon, 30 Oct 2023 17:30:17 +0100 Subject: [PATCH 13/20] chore: Split existing alter operations (#2156) * Extract set and unset tags * Fix warehouse validation * Fix after review * Fix warehouses --- pkg/resources/schema.go | 8 +-- pkg/sdk/accounts.go | 24 +++---- pkg/sdk/accounts_test.go | 24 +++---- pkg/sdk/errors.go | 2 +- pkg/sdk/masking_policy.go | 22 +++---- pkg/sdk/masking_policy_test.go | 2 +- pkg/sdk/pipes.go | 18 ++---- pkg/sdk/pipes_test.go | 62 ++++++------------- pkg/sdk/pipes_validations.go | 20 +++--- pkg/sdk/schemas.go | 32 +++++----- pkg/sdk/schemas_test.go | 26 ++++---- pkg/sdk/shares.go | 18 +++--- pkg/sdk/shares_test.go | 18 +++--- pkg/sdk/testint/accounts_integration_test.go | 18 +++--- .../masking_policy_integration_test.go | 8 +-- pkg/sdk/testint/pipes_integration_test.go | 16 ++--- pkg/sdk/testint/schemas_integration_test.go | 16 ++--- pkg/sdk/testint/shares_integration_test.go | 10 +-- .../system_functions_integration_test.go | 10 ++- .../testint/warehouses_integration_test.go | 44 ++++++------- pkg/sdk/users.go | 16 ++--- pkg/sdk/users_test.go | 22 +++---- pkg/sdk/warehouses.go | 28 ++++----- pkg/sdk/warehouses_test.go | 53 +++++++--------- 24 files changed, 212 insertions(+), 305 deletions(-) diff --git a/pkg/resources/schema.go b/pkg/resources/schema.go index ed8ea70be9..83b07d31ee 100644 --- a/pkg/resources/schema.go +++ b/pkg/resources/schema.go @@ -237,9 +237,7 @@ func UpdateSchema(d *schema.ResourceData, meta interface{}) error { unsetTags[i] = sdk.NewDatabaseObjectIdentifier(t.database, t.name) } err := client.Schemas.Alter(ctx, id, &sdk.AlterSchemaOptions{ - Unset: &sdk.SchemaUnset{ - Tag: unsetTags, - }, + UnsetTag: unsetTags, }) if err != nil { return fmt.Errorf("error dropping tags on %v", d.Id()) @@ -259,9 +257,7 @@ func UpdateSchema(d *schema.ResourceData, meta interface{}) error { } } err = client.Schemas.Alter(ctx, id, &sdk.AlterSchemaOptions{ - Set: &sdk.SchemaSet{ - Tag: setTags, - }, + SetTag: setTags, }) if err != nil { return fmt.Errorf("error setting tags on %v", d.Id()) diff --git a/pkg/sdk/accounts.go b/pkg/sdk/accounts.go index e3f2388058..9f08ad838e 100644 --- a/pkg/sdk/accounts.go +++ b/pkg/sdk/accounts.go @@ -90,10 +90,12 @@ type AlterAccountOptions struct { alter bool `ddl:"static" sql:"ALTER"` account bool `ddl:"static" sql:"ACCOUNT"` - Set *AccountSet `ddl:"keyword" sql:"SET"` - Unset *AccountUnset `ddl:"list,no_parentheses" sql:"UNSET"` - Rename *AccountRename `ddl:"-"` - Drop *AccountDrop `ddl:"-"` + Set *AccountSet `ddl:"keyword" sql:"SET"` + Unset *AccountUnset `ddl:"list,no_parentheses" sql:"UNSET"` + SetTag []TagAssociation `ddl:"keyword" sql:"SET TAG"` + UnsetTag []ObjectIdentifier `ddl:"keyword" sql:"UNSET TAG"` + Rename *AccountRename `ddl:"-"` + Drop *AccountDrop `ddl:"-"` } func (opts *AlterAccountOptions) validate() error { @@ -101,8 +103,8 @@ func (opts *AlterAccountOptions) validate() error { return errors.Join(ErrNilOptions) } var errs []error - if !exactlyOneValueSet(opts.Set, opts.Unset, opts.Drop, opts.Rename) { - errs = append(errs, errExactlyOneOf("CreateAccountOptions", "Set", "Unset", "Drop", "Rename")) + if !exactlyOneValueSet(opts.Set, opts.Unset, opts.SetTag, opts.UnsetTag, opts.Drop, opts.Rename) { + errs = append(errs, errExactlyOneOf("CreateAccountOptions", "Set", "Unset", "SetTag", "UnsetTag", "Drop", "Rename")) } if valueSet(opts.Set) { if err := opts.Set.validate(); err != nil { @@ -164,13 +166,12 @@ type AccountSet struct { ResourceMonitor AccountObjectIdentifier `ddl:"identifier,equals" sql:"RESOURCE_MONITOR"` PasswordPolicy SchemaObjectIdentifier `ddl:"identifier" sql:"PASSWORD POLICY"` SessionPolicy SchemaObjectIdentifier `ddl:"identifier" sql:"SESSION POLICY"` - Tag []TagAssociation `ddl:"keyword" sql:"TAG"` } func (opts *AccountSet) validate() error { var errs []error - if !exactlyOneValueSet(opts.Parameters, opts.ResourceMonitor, opts.PasswordPolicy, opts.SessionPolicy, opts.Tag) { - errs = append(errs, errExactlyOneOf("AccountSet", "Parameters", "ResourceMonitor", "PasswordPolicy", "SessionPolicy", "Tag")) + if !exactlyOneValueSet(opts.Parameters, opts.ResourceMonitor, opts.PasswordPolicy, opts.SessionPolicy) { + errs = append(errs, errExactlyOneOf("AccountSet", "Parameters", "ResourceMonitor", "PasswordPolicy", "SessionPolicy")) } if valueSet(opts.Parameters) { if err := opts.Parameters.validate(); err != nil { @@ -198,13 +199,12 @@ type AccountUnset struct { Parameters *AccountLevelParametersUnset `ddl:"list,no_parentheses"` PasswordPolicy *bool `ddl:"keyword" sql:"PASSWORD POLICY"` SessionPolicy *bool `ddl:"keyword" sql:"SESSION POLICY"` - Tag []ObjectIdentifier `ddl:"keyword" sql:"TAG"` } func (opts *AccountUnset) validate() error { var errs []error - if !exactlyOneValueSet(opts.Parameters, opts.PasswordPolicy, opts.SessionPolicy, opts.Tag) { - errs = append(errs, errExactlyOneOf("AccountUnset", "Parameters", "PasswordPolicy", "SessionPolicy", "Tag")) + if !exactlyOneValueSet(opts.Parameters, opts.PasswordPolicy, opts.SessionPolicy) { + errs = append(errs, errExactlyOneOf("AccountUnset", "Parameters", "PasswordPolicy", "SessionPolicy")) } if valueSet(opts.Parameters) { if err := opts.Parameters.validate(); err != nil { diff --git a/pkg/sdk/accounts_test.go b/pkg/sdk/accounts_test.go index f3e0b16a21..55ed1386bc 100644 --- a/pkg/sdk/accounts_test.go +++ b/pkg/sdk/accounts_test.go @@ -140,16 +140,14 @@ func TestAccountAlter(t *testing.T) { t.Run("with set tag", func(t *testing.T) { opts := &AlterAccountOptions{ - Set: &AccountSet{ - Tag: []TagAssociation{ - { - Name: NewSchemaObjectIdentifier("db", "schema", "tag1"), - Value: "v1", - }, - { - Name: NewSchemaObjectIdentifier("db", "schema", "tag2"), - Value: "v2", - }, + SetTag: []TagAssociation{ + { + Name: NewSchemaObjectIdentifier("db", "schema", "tag1"), + Value: "v1", + }, + { + Name: NewSchemaObjectIdentifier("db", "schema", "tag2"), + Value: "v2", }, }, } @@ -158,10 +156,8 @@ func TestAccountAlter(t *testing.T) { t.Run("with unset tag", func(t *testing.T) { opts := &AlterAccountOptions{ - Unset: &AccountUnset{ - Tag: []ObjectIdentifier{ - NewSchemaObjectIdentifier("db", "schema", "tag1"), - }, + UnsetTag: []ObjectIdentifier{ + NewSchemaObjectIdentifier("db", "schema", "tag1"), }, } assertOptsValidAndSQLEquals(t, opts, `ALTER ACCOUNT UNSET TAG "db"."schema"."tag1"`) diff --git a/pkg/sdk/errors.go b/pkg/sdk/errors.go index 3ba28aaa2f..4ca6475cc2 100644 --- a/pkg/sdk/errors.go +++ b/pkg/sdk/errors.go @@ -53,7 +53,7 @@ func errNotSet(structName string, fieldNames ...string) error { } func errExactlyOneOf(structName string, fieldNames ...string) error { - return newError(fmt.Sprintf("exactly one of %s fileds %v must be set", structName, fieldNames), 2) + return newError(fmt.Sprintf("exactly one of %s fields %v must be set", structName, fieldNames), 2) } func errAtLeastOneOf(structName string, fieldNames ...string) error { diff --git a/pkg/sdk/masking_policy.go b/pkg/sdk/masking_policy.go index ac9fe60b9f..40d7f34656 100644 --- a/pkg/sdk/masking_policy.go +++ b/pkg/sdk/masking_policy.go @@ -99,6 +99,8 @@ type AlterMaskingPolicyOptions struct { NewName *SchemaObjectIdentifier `ddl:"identifier" sql:"RENAME TO"` Set *MaskingPolicySet `ddl:"keyword" sql:"SET"` Unset *MaskingPolicyUnset `ddl:"keyword" sql:"UNSET"` + SetTag []TagAssociation `ddl:"keyword" sql:"SET TAG"` + UnsetTag []ObjectIdentifier `ddl:"keyword" sql:"UNSET TAG"` } func (opts *AlterMaskingPolicyOptions) validate() error { @@ -112,8 +114,8 @@ func (opts *AlterMaskingPolicyOptions) validate() error { if opts.NewName != nil && !ValidObjectIdentifier(opts.NewName) { errs = append(errs, errInvalidIdentifier("AlterMaskingPolicyOptions", "NewName")) } - if !exactlyOneValueSet(opts.NewName, opts.Set, opts.Unset) { - errs = append(errs, errExactlyOneOf("AlterMaskingPolicyOptions", "NewName", "Set", "Unset")) + if !exactlyOneValueSet(opts.Set, opts.Unset, opts.SetTag, opts.UnsetTag, opts.NewName) { + errs = append(errs, errExactlyOneOf("AlterMaskingPolicyOptions", "Set", "Unset", "SetTag", "UnsetTag", "NewName")) } if valueSet(opts.Set) { if err := opts.Set.validate(); err != nil { @@ -129,26 +131,24 @@ func (opts *AlterMaskingPolicyOptions) validate() error { } type MaskingPolicySet struct { - Body *string `ddl:"parameter,no_equals" sql:"BODY ->"` - Tag []TagAssociation `ddl:"keyword" sql:"TAG"` - Comment *string `ddl:"parameter,single_quotes" sql:"COMMENT"` + Body *string `ddl:"parameter,no_equals" sql:"BODY ->"` + Comment *string `ddl:"parameter,single_quotes" sql:"COMMENT"` } func (v *MaskingPolicySet) validate() error { - if !exactlyOneValueSet(v.Body, v.Tag, v.Comment) { - return errExactlyOneOf("MaskingPolicySet", "Body", "Tag", "Comment") + if !exactlyOneValueSet(v.Body, v.Comment) { + return errExactlyOneOf("MaskingPolicySet", "Body", "Comment") } return nil } type MaskingPolicyUnset struct { - Tag []ObjectIdentifier `ddl:"keyword" sql:"TAG"` - Comment *bool `ddl:"keyword" sql:"COMMENT"` + Comment *bool `ddl:"keyword" sql:"COMMENT"` } func (v *MaskingPolicyUnset) validate() error { - if !exactlyOneValueSet(v.Tag, v.Comment) { - return errExactlyOneOf("MaskingPolicyUnset", "Tag", "Comment") + if !exactlyOneValueSet(v.Comment) { + return errExactlyOneOf("MaskingPolicyUnset", "Comment") } return nil } diff --git a/pkg/sdk/masking_policy_test.go b/pkg/sdk/masking_policy_test.go index 23d9362387..becb6e615f 100644 --- a/pkg/sdk/masking_policy_test.go +++ b/pkg/sdk/masking_policy_test.go @@ -88,7 +88,7 @@ func TestMaskingPolicyAlter(t *testing.T) { opts := &AlterMaskingPolicyOptions{ name: id, } - assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterMaskingPolicyOptions", "NewName", "Set", "Unset")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterMaskingPolicyOptions", "Set", "Unset", "SetTag", "UnsetTag", "NewName")) }) t.Run("with set", func(t *testing.T) { diff --git a/pkg/sdk/pipes.go b/pkg/sdk/pipes.go index a0cdee18cb..928d01992a 100644 --- a/pkg/sdk/pipes.go +++ b/pkg/sdk/pipes.go @@ -42,11 +42,11 @@ type AlterPipeOptions struct { name SchemaObjectIdentifier `ddl:"identifier"` // One of - Set *PipeSet `ddl:"list,no_parentheses" sql:"SET"` - Unset *PipeUnset `ddl:"list,no_parentheses" sql:"UNSET"` - SetTags *PipeSetTags `ddl:"list,no_parentheses" sql:"SET TAG"` - UnsetTags *PipeUnsetTags `ddl:"list,no_parentheses" sql:"UNSET TAG"` - Refresh *PipeRefresh `ddl:"keyword" sql:"REFRESH"` + Set *PipeSet `ddl:"list,no_parentheses" sql:"SET"` + Unset *PipeUnset `ddl:"list,no_parentheses" sql:"UNSET"` + SetTag []TagAssociation `ddl:"keyword" sql:"SET TAG"` + UnsetTag []ObjectIdentifier `ddl:"keyword" sql:"UNSET TAG"` + Refresh *PipeRefresh `ddl:"keyword" sql:"REFRESH"` } type PipeSet struct { @@ -60,14 +60,6 @@ type PipeUnset struct { Comment *bool `ddl:"keyword" sql:"COMMENT"` } -type PipeSetTags struct { - Tag []TagAssociation `ddl:"keyword"` -} - -type PipeUnsetTags struct { - Tag []ObjectIdentifier `ddl:"keyword"` -} - type PipeRefresh struct { Prefix *string `ddl:"parameter,single_quotes" sql:"PREFIX"` ModifiedAfter *string `ddl:"parameter,single_quotes" sql:"MODIFIED_AFTER"` diff --git a/pkg/sdk/pipes_test.go b/pkg/sdk/pipes_test.go index 773603072f..f6ccd7fdeb 100644 --- a/pkg/sdk/pipes_test.go +++ b/pkg/sdk/pipes_test.go @@ -70,7 +70,7 @@ func TestPipesAlter(t *testing.T) { t.Run("validation: no alter action", func(t *testing.T) { opts := defaultOpts() - assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterPipeOptions", "Set", "Unset", "SetTags", "UnsetTags", "Refresh")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterPipeOptions", "Set", "Unset", "SetTag", "UnsetTag", "Refresh")) }) t.Run("validation: multiple alter actions", func(t *testing.T) { @@ -81,7 +81,7 @@ func TestPipesAlter(t *testing.T) { opts.Unset = &PipeUnset{ Comment: Bool(true), } - assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterPipeOptions", "Set", "Unset", "SetTags", "UnsetTags", "Refresh")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterPipeOptions", "Set", "Unset", "SetTag", "UnsetTag", "Refresh")) }) t.Run("validation: no property to set", func(t *testing.T) { @@ -90,36 +90,18 @@ func TestPipesAlter(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, errAtLeastOneOf("AlterPipeOptions.Set", "ErrorIntegration", "PipeExecutionPaused", "Comment")) }) - t.Run("validation: empty tags slice for set", func(t *testing.T) { - opts := defaultOpts() - opts.SetTags = &PipeSetTags{ - Tag: []TagAssociation{}, - } - assertOptsInvalidJoinedErrors(t, opts, errNotSet("AlterPipeOptions.SetTags", "Tag")) - }) - t.Run("validation: no property to unset", func(t *testing.T) { opts := defaultOpts() opts.Unset = &PipeUnset{} assertOptsInvalidJoinedErrors(t, opts, errAtLeastOneOf("AlterPipeOptions.Unset", "PipeExecutionPaused", "Comment")) }) - t.Run("validation: empty tags slice for unset", func(t *testing.T) { - opts := defaultOpts() - opts.UnsetTags = &PipeUnsetTags{ - Tag: []ObjectIdentifier{}, - } - assertOptsInvalidJoinedErrors(t, opts, errNotSet("AlterPipeOptions.UnsetTags", "Tag")) - }) - t.Run("set tag: single", func(t *testing.T) { opts := defaultOpts() - opts.SetTags = &PipeSetTags{ - Tag: []TagAssociation{ - { - Name: NewAccountObjectIdentifier("tag_name1"), - Value: "v1", - }, + opts.SetTag = []TagAssociation{ + { + Name: NewAccountObjectIdentifier("tag_name1"), + Value: "v1", }, } assertOptsValidAndSQLEquals(t, opts, `ALTER PIPE %s SET TAG "tag_name1" = 'v1'`, id.FullyQualifiedName()) @@ -127,16 +109,14 @@ func TestPipesAlter(t *testing.T) { t.Run("set tag: multiple", func(t *testing.T) { opts := defaultOpts() - opts.SetTags = &PipeSetTags{ - Tag: []TagAssociation{ - { - Name: NewAccountObjectIdentifier("tag_name1"), - Value: "v1", - }, - { - Name: NewAccountObjectIdentifier("tag_name2"), - Value: "v2", - }, + opts.SetTag = []TagAssociation{ + { + Name: NewAccountObjectIdentifier("tag_name1"), + Value: "v1", + }, + { + Name: NewAccountObjectIdentifier("tag_name2"), + Value: "v2", }, } assertOptsValidAndSQLEquals(t, opts, `ALTER PIPE %s SET TAG "tag_name1" = 'v1', "tag_name2" = 'v2'`, id.FullyQualifiedName()) @@ -155,21 +135,17 @@ func TestPipesAlter(t *testing.T) { t.Run("unset tag: single", func(t *testing.T) { opts := defaultOpts() - opts.UnsetTags = &PipeUnsetTags{ - Tag: []ObjectIdentifier{ - NewAccountObjectIdentifier("tag_name1"), - }, + opts.UnsetTag = []ObjectIdentifier{ + NewAccountObjectIdentifier("tag_name1"), } assertOptsValidAndSQLEquals(t, opts, `ALTER PIPE %s UNSET TAG "tag_name1"`, id.FullyQualifiedName()) }) t.Run("unset tag: multi", func(t *testing.T) { opts := defaultOpts() - opts.UnsetTags = &PipeUnsetTags{ - Tag: []ObjectIdentifier{ - NewAccountObjectIdentifier("tag_name1"), - NewAccountObjectIdentifier("tag_name2"), - }, + opts.UnsetTag = []ObjectIdentifier{ + NewAccountObjectIdentifier("tag_name1"), + NewAccountObjectIdentifier("tag_name2"), } assertOptsValidAndSQLEquals(t, opts, `ALTER PIPE %s UNSET TAG "tag_name1", "tag_name2"`, id.FullyQualifiedName()) }) diff --git a/pkg/sdk/pipes_validations.go b/pkg/sdk/pipes_validations.go index 22b50dceac..f87272f466 100644 --- a/pkg/sdk/pipes_validations.go +++ b/pkg/sdk/pipes_validations.go @@ -34,8 +34,14 @@ func (opts *AlterPipeOptions) validate() error { if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } - if ok := exactlyOneValueSet(opts.Set, opts.Unset, opts.SetTags, opts.UnsetTags, opts.Refresh); !ok { - errs = append(errs, errExactlyOneOf("AlterPipeOptions", "Set", "Unset", "SetTags", "UnsetTags", "Refresh")) + if ok := exactlyOneValueSet( + opts.Set, + opts.Unset, + opts.SetTag, + opts.UnsetTag, + opts.Refresh, + ); !ok { + errs = append(errs, errExactlyOneOf("AlterPipeOptions", "Set", "Unset", "SetTag", "UnsetTag", "Refresh")) } if set := opts.Set; valueSet(set) { if !anyValueSet(set.ErrorIntegration, set.PipeExecutionPaused, set.Comment) { @@ -47,16 +53,6 @@ func (opts *AlterPipeOptions) validate() error { errs = append(errs, errAtLeastOneOf("AlterPipeOptions.Unset", "PipeExecutionPaused", "Comment")) } } - if setTags := opts.SetTags; valueSet(setTags) { - if !valueSet(setTags.Tag) { - errs = append(errs, errNotSet("AlterPipeOptions.SetTags", "Tag")) - } - } - if unsetTags := opts.UnsetTags; valueSet(unsetTags) { - if !valueSet(unsetTags.Tag) { - errs = append(errs, errNotSet("AlterPipeOptions.UnsetTags", "Tag")) - } - } return errors.Join(errs...) } diff --git a/pkg/sdk/schemas.go b/pkg/sdk/schemas.go index 5fbe0b6fd3..8e99c6dc61 100644 --- a/pkg/sdk/schemas.go +++ b/pkg/sdk/schemas.go @@ -152,6 +152,8 @@ type AlterSchemaOptions struct { SwapWith DatabaseObjectIdentifier `ddl:"identifier" sql:"SWAP WITH"` Set *SchemaSet `ddl:"list,no_parentheses" sql:"SET"` Unset *SchemaUnset `ddl:"list,no_parentheses" sql:"UNSET"` + SetTag []TagAssociation `ddl:"keyword" sql:"SET TAG"` + UnsetTag []ObjectIdentifier `ddl:"keyword" sql:"UNSET TAG"` // One of EnableManagedAccess *bool `ddl:"keyword" sql:"ENABLE MANAGED ACCESS"` DisableManagedAccess *bool `ddl:"keyword" sql:"DISABLE MANAGED ACCESS"` @@ -165,8 +167,8 @@ func (opts *AlterSchemaOptions) validate() error { if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } - if !exactlyOneValueSet(opts.NewName, opts.SwapWith, opts.Set, opts.Unset, opts.EnableManagedAccess, opts.DisableManagedAccess) { - errs = append(errs, errOneOf("NewName", "SwapWith", "Set", "Unset", "EnableManagedAccess", "DisableManagedAccess")) + if !exactlyOneValueSet(opts.NewName, opts.SwapWith, opts.Set, opts.Unset, opts.SetTag, opts.UnsetTag, opts.EnableManagedAccess, opts.DisableManagedAccess) { + errs = append(errs, errOneOf("NewName", "SwapWith", "Set", "Unset", "SetTag", "UnsetTag", "EnableManagedAccess", "DisableManagedAccess")) } if valueSet(opts.Set) { if err := opts.Set.validate(); err != nil { @@ -182,31 +184,29 @@ func (opts *AlterSchemaOptions) validate() error { } type SchemaSet struct { - DataRetentionTimeInDays *int `ddl:"parameter" sql:"DATA_RETENTION_TIME_IN_DAYS"` - MaxDataExtensionTimeInDays *int `ddl:"parameter" sql:"MAX_DATA_EXTENSION_TIME_IN_DAYS"` - DefaultDDLCollation *string `ddl:"parameter,single_quotes" sql:"DEFAULT_DDL_COLLATION"` - Comment *string `ddl:"parameter,single_quotes" sql:"COMMENT"` - Tag []TagAssociation `ddl:"keyword" sql:"TAG"` + DataRetentionTimeInDays *int `ddl:"parameter" sql:"DATA_RETENTION_TIME_IN_DAYS"` + MaxDataExtensionTimeInDays *int `ddl:"parameter" sql:"MAX_DATA_EXTENSION_TIME_IN_DAYS"` + DefaultDDLCollation *string `ddl:"parameter,single_quotes" sql:"DEFAULT_DDL_COLLATION"` + Comment *string `ddl:"parameter,single_quotes" sql:"COMMENT"` } func (v *SchemaSet) validate() error { - if valueSet(v.Tag) && anyValueSet(v.DataRetentionTimeInDays, v.MaxDataExtensionTimeInDays, v.DefaultDDLCollation, v.Comment) { - return errors.New("tag field cannot be set with other options") + if !anyValueSet(v.DataRetentionTimeInDays, v.MaxDataExtensionTimeInDays, v.DefaultDDLCollation, v.Comment) { + return errAtLeastOneOf("SchemaSet", "DataRetentionTimeInDays", "MaxDataExtensionTimeInDays", "DefaultDDLCollation", "Comment") } return nil } type SchemaUnset struct { - DataRetentionTimeInDays *bool `ddl:"keyword" sql:"DATA_RETENTION_TIME_IN_DAYS"` - MaxDataExtensionTimeInDays *bool `ddl:"keyword" sql:"MAX_DATA_EXTENSION_TIME_IN_DAYS"` - DefaultDDLCollation *bool `ddl:"keyword" sql:"DEFAULT_DDL_COLLATION"` - Comment *bool `ddl:"keyword" sql:"COMMENT"` - Tag []ObjectIdentifier `ddl:"keyword" sql:"TAG"` + DataRetentionTimeInDays *bool `ddl:"keyword" sql:"DATA_RETENTION_TIME_IN_DAYS"` + MaxDataExtensionTimeInDays *bool `ddl:"keyword" sql:"MAX_DATA_EXTENSION_TIME_IN_DAYS"` + DefaultDDLCollation *bool `ddl:"keyword" sql:"DEFAULT_DDL_COLLATION"` + Comment *bool `ddl:"keyword" sql:"COMMENT"` } func (v *SchemaUnset) validate() error { - if valueSet(v.Tag) && anyValueSet(v.DataRetentionTimeInDays, v.MaxDataExtensionTimeInDays, v.DefaultDDLCollation, v.Comment) { - return errors.New("tag field cannot be set with other options") + if !anyValueSet(v.DataRetentionTimeInDays, v.MaxDataExtensionTimeInDays, v.DefaultDDLCollation, v.Comment) { + return errAtLeastOneOf("SchemaUnset", "DataRetentionTimeInDays", "MaxDataExtensionTimeInDays", "DefaultDDLCollation", "Comment") } return nil } diff --git a/pkg/sdk/schemas_test.go b/pkg/sdk/schemas_test.go index 26fecbbdbe..9d4080d1bc 100644 --- a/pkg/sdk/schemas_test.go +++ b/pkg/sdk/schemas_test.go @@ -78,16 +78,14 @@ func TestSchemasAlter(t *testing.T) { t.Run("set tags", func(t *testing.T) { opts := &AlterSchemaOptions{ name: NewDatabaseObjectIdentifier("database_name", "schema_name"), - Set: &SchemaSet{ - Tag: []TagAssociation{ - { - Name: NewAccountObjectIdentifier("tag1"), - Value: "value1", - }, - { - Name: NewAccountObjectIdentifier("tag2"), - Value: "value2", - }, + SetTag: []TagAssociation{ + { + Name: NewAccountObjectIdentifier("tag1"), + Value: "value1", + }, + { + Name: NewAccountObjectIdentifier("tag2"), + Value: "value2", }, }, } @@ -97,11 +95,9 @@ func TestSchemasAlter(t *testing.T) { t.Run("unset tags", func(t *testing.T) { opts := &AlterSchemaOptions{ name: NewDatabaseObjectIdentifier("database_name", "schema_name"), - Unset: &SchemaUnset{ - Tag: []ObjectIdentifier{ - NewAccountObjectIdentifier("tag1"), - NewAccountObjectIdentifier("tag2"), - }, + UnsetTag: []ObjectIdentifier{ + NewAccountObjectIdentifier("tag1"), + NewAccountObjectIdentifier("tag2"), }, } assertOptsValidAndSQLEquals(t, opts, `ALTER SCHEMA "database_name"."schema_name" UNSET TAG "tag1", "tag2"`) diff --git a/pkg/sdk/shares.go b/pkg/sdk/shares.go index a4946c875f..9f34bf2f13 100644 --- a/pkg/sdk/shares.go +++ b/pkg/sdk/shares.go @@ -179,6 +179,8 @@ type AlterShareOptions struct { Remove *ShareRemove `ddl:"keyword" sql:"REMOVE"` Set *ShareSet `ddl:"keyword" sql:"SET"` Unset *ShareUnset `ddl:"keyword" sql:"UNSET"` + SetTag []TagAssociation `ddl:"keyword" sql:"SET TAG"` + UnsetTag []ObjectIdentifier `ddl:"keyword" sql:"UNSET TAG"` } func (opts *AlterShareOptions) validate() error { @@ -189,8 +191,8 @@ func (opts *AlterShareOptions) validate() error { if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } - if !exactlyOneValueSet(opts.Add, opts.Remove, opts.Set, opts.Unset) { - errs = append(errs, errExactlyOneOf("AlterShareOptions", "Add", "Remove", "Set", "Unset")) + if !exactlyOneValueSet(opts.Add, opts.Remove, opts.Set, opts.Unset, opts.SetTag, opts.UnsetTag) { + errs = append(errs, errExactlyOneOf("AlterShareOptions", "Add", "Remove", "Set", "Unset", "SetTag", "UnsetTag")) } if valueSet(opts.Add) { if err := opts.Add.validate(); err != nil { @@ -241,24 +243,22 @@ func (v *ShareRemove) validate() error { type ShareSet struct { Accounts []AccountIdentifier `ddl:"parameter" sql:"ACCOUNTS"` Comment *string `ddl:"parameter,single_quotes" sql:"COMMENT"` - Tag []TagAssociation `ddl:"keyword" sql:"TAG"` } func (v *ShareSet) validate() error { - if valueSet(v.Tag) && anyValueSet(v.Accounts, v.Comment) { - return fmt.Errorf("accounts and comment cannot be set when tag is set") + if !anyValueSet(v.Accounts, v.Comment) { + return errAtLeastOneOf("ShareSet", "Accounts", "Comment") } return nil } type ShareUnset struct { - Tag []ObjectIdentifier `ddl:"keyword" sql:"TAG"` - Comment *bool `ddl:"keyword" sql:"COMMENT"` + Comment *bool `ddl:"keyword" sql:"COMMENT"` } func (v *ShareUnset) validate() error { - if !exactlyOneValueSet(v.Comment, v.Tag) { - return errExactlyOneOf("ShareUnset", "Comment", "Tag") + if !exactlyOneValueSet(v.Comment) { + return errExactlyOneOf("ShareUnset", "Comment") } return nil } diff --git a/pkg/sdk/shares_test.go b/pkg/sdk/shares_test.go index 94b52e95bc..ccb0fb8b0e 100644 --- a/pkg/sdk/shares_test.go +++ b/pkg/sdk/shares_test.go @@ -30,7 +30,7 @@ func TestShareAlter(t *testing.T) { opts := &AlterShareOptions{ name: NewAccountObjectIdentifier("myshare"), } - assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterShareOptions", "Add", "Remove", "Set", "Unset")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterShareOptions", "Add", "Remove", "Set", "Unset", "SetTag", "UnsetTag")) }) t.Run("with add", func(t *testing.T) { @@ -76,12 +76,10 @@ func TestShareAlter(t *testing.T) { opts := &AlterShareOptions{ IfExists: Bool(true), name: NewAccountObjectIdentifier("myshare"), - Set: &ShareSet{ - Tag: []TagAssociation{ - { - Name: NewSchemaObjectIdentifier("db", "schema", "tag"), - Value: "v1", - }, + SetTag: []TagAssociation{ + { + Name: NewSchemaObjectIdentifier("db", "schema", "tag"), + Value: "v1", }, }, } @@ -103,10 +101,8 @@ func TestShareAlter(t *testing.T) { opts := &AlterShareOptions{ IfExists: Bool(true), name: NewAccountObjectIdentifier("myshare"), - Unset: &ShareUnset{ - Tag: []ObjectIdentifier{ - NewSchemaObjectIdentifier("db", "schema", "tag"), - }, + UnsetTag: []ObjectIdentifier{ + NewSchemaObjectIdentifier("db", "schema", "tag"), }, } assertOptsValidAndSQLEquals(t, opts, `ALTER SHARE IF EXISTS "myshare" UNSET TAG "db"."schema"."tag"`) diff --git a/pkg/sdk/testint/accounts_integration_test.go b/pkg/sdk/testint/accounts_integration_test.go index 2bd7cb035f..addb97b0d6 100644 --- a/pkg/sdk/testint/accounts_integration_test.go +++ b/pkg/sdk/testint/accounts_integration_test.go @@ -282,16 +282,14 @@ func TestInt_AccountAlter(t *testing.T) { t.Cleanup(tagCleanup2) opts := &sdk.AlterAccountOptions{ - Set: &sdk.AccountSet{ - Tag: []sdk.TagAssociation{ - { - Name: tagTest1.ID(), - Value: "abc", - }, - { - Name: tagTest2.ID(), - Value: "123", - }, + SetTag: []sdk.TagAssociation{ + { + Name: tagTest1.ID(), + Value: "abc", + }, + { + Name: tagTest2.ID(), + Value: "123", }, }, } diff --git a/pkg/sdk/testint/masking_policy_integration_test.go b/pkg/sdk/testint/masking_policy_integration_test.go index 16d487d27f..f1aeb4b0c5 100644 --- a/pkg/sdk/testint/masking_policy_integration_test.go +++ b/pkg/sdk/testint/masking_policy_integration_test.go @@ -342,9 +342,7 @@ func TestInt_MaskingPolicyAlter(t *testing.T) { tagAssociations := []sdk.TagAssociation{{Name: tag.ID(), Value: "value1"}, {Name: tag2.ID(), Value: "value2"}} alterOptions := &sdk.AlterMaskingPolicyOptions{ - Set: &sdk.MaskingPolicySet{ - Tag: tagAssociations, - }, + SetTag: tagAssociations, } err := client.MaskingPolicies.Alter(ctx, id, alterOptions) require.NoError(t, err) @@ -357,9 +355,7 @@ func TestInt_MaskingPolicyAlter(t *testing.T) { // unset tag alterOptions = &sdk.AlterMaskingPolicyOptions{ - Unset: &sdk.MaskingPolicyUnset{ - Tag: []sdk.ObjectIdentifier{tag.ID()}, - }, + UnsetTag: []sdk.ObjectIdentifier{tag.ID()}, } err = client.MaskingPolicies.Alter(ctx, id, alterOptions) require.NoError(t, err) diff --git a/pkg/sdk/testint/pipes_integration_test.go b/pkg/sdk/testint/pipes_integration_test.go index 59b13f5080..a440dbba0b 100644 --- a/pkg/sdk/testint/pipes_integration_test.go +++ b/pkg/sdk/testint/pipes_integration_test.go @@ -304,12 +304,10 @@ func TestInt_PipeAlter(t *testing.T) { tagValue := "abc" alterOptions := &sdk.AlterPipeOptions{ - SetTags: &sdk.PipeSetTags{ - Tag: []sdk.TagAssociation{ - { - Name: tag.ID(), - Value: tagValue, - }, + SetTag: []sdk.TagAssociation{ + { + Name: tag.ID(), + Value: tagValue, }, }, } @@ -323,10 +321,8 @@ func TestInt_PipeAlter(t *testing.T) { assert.Equal(t, tagValue, returnedTagValue) alterOptions = &sdk.AlterPipeOptions{ - UnsetTags: &sdk.PipeUnsetTags{ - Tag: []sdk.ObjectIdentifier{ - tag.ID(), - }, + UnsetTag: []sdk.ObjectIdentifier{ + tag.ID(), }, } diff --git a/pkg/sdk/testint/schemas_integration_test.go b/pkg/sdk/testint/schemas_integration_test.go index a4a8cf351c..f17b55aefa 100644 --- a/pkg/sdk/testint/schemas_integration_test.go +++ b/pkg/sdk/testint/schemas_integration_test.go @@ -225,12 +225,10 @@ func TestInt_SchemasAlter(t *testing.T) { tagValue := "tag-value" err = client.Schemas.Alter(ctx, schemaID, &sdk.AlterSchemaOptions{ - Set: &sdk.SchemaSet{ - Tag: []sdk.TagAssociation{ - { - Name: tag.ID(), - Value: tagValue, - }, + SetTag: []sdk.TagAssociation{ + { + Name: tag.ID(), + Value: tagValue, }, }, }) @@ -268,10 +266,8 @@ func TestInt_SchemasAlter(t *testing.T) { }) err = client.Schemas.Alter(ctx, schemaID, &sdk.AlterSchemaOptions{ - Unset: &sdk.SchemaUnset{ - Tag: []sdk.ObjectIdentifier{ - tagID, - }, + UnsetTag: []sdk.ObjectIdentifier{ + tagID, }, }) require.NoError(t, err) diff --git a/pkg/sdk/testint/shares_integration_test.go b/pkg/sdk/testint/shares_integration_test.go index 1a9724511a..f835f57ad5 100644 --- a/pkg/sdk/testint/shares_integration_test.go +++ b/pkg/sdk/testint/shares_integration_test.go @@ -299,9 +299,7 @@ func TestInt_SharesAlter(t *testing.T) { } err = client.Shares.Alter(ctx, shareTest.ID(), &sdk.AlterShareOptions{ IfExists: sdk.Bool(true), - Set: &sdk.ShareSet{ - Tag: tagAssociations, - }, + SetTag: tagAssociations, }) require.NoError(t, err) tagValue, err := client.SystemFunctions.GetTag(ctx, tagTest.ID(), shareTest.ID(), sdk.ObjectTypeShare) @@ -314,10 +312,8 @@ func TestInt_SharesAlter(t *testing.T) { // unset tags err = client.Shares.Alter(ctx, shareTest.ID(), &sdk.AlterShareOptions{ IfExists: sdk.Bool(true), - Unset: &sdk.ShareUnset{ - Tag: []sdk.ObjectIdentifier{ - tagTest.ID(), - }, + UnsetTag: []sdk.ObjectIdentifier{ + tagTest.ID(), }, }) require.NoError(t, err) diff --git a/pkg/sdk/testint/system_functions_integration_test.go b/pkg/sdk/testint/system_functions_integration_test.go index 305a47ae42..161c5e1050 100644 --- a/pkg/sdk/testint/system_functions_integration_test.go +++ b/pkg/sdk/testint/system_functions_integration_test.go @@ -22,12 +22,10 @@ func TestInt_GetTag(t *testing.T) { tagValue := random.String() err := client.MaskingPolicies.Alter(ctx, maskingPolicyTest.ID(), &sdk.AlterMaskingPolicyOptions{ - Set: &sdk.MaskingPolicySet{ - Tag: []sdk.TagAssociation{ - { - Name: tagTest.ID(), - Value: tagValue, - }, + SetTag: []sdk.TagAssociation{ + { + Name: tagTest.ID(), + Value: tagValue, }, }, }) diff --git a/pkg/sdk/testint/warehouses_integration_test.go b/pkg/sdk/testint/warehouses_integration_test.go index 4b96d45283..3b167945cc 100644 --- a/pkg/sdk/testint/warehouses_integration_test.go +++ b/pkg/sdk/testint/warehouses_integration_test.go @@ -440,16 +440,14 @@ func TestInt_WarehouseAlter(t *testing.T) { t.Cleanup(warehouseCleanup) alterOptions := &sdk.AlterWarehouseOptions{ - Set: &sdk.WarehouseSet{ - Tag: []sdk.TagAssociation{ - { - Name: tag.ID(), - Value: "val", - }, - { - Name: tag2.ID(), - Value: "val2", - }, + SetTag: []sdk.TagAssociation{ + { + Name: tag.ID(), + Value: "val", + }, + { + Name: tag2.ID(), + Value: "val2", }, }, } @@ -470,16 +468,14 @@ func TestInt_WarehouseAlter(t *testing.T) { t.Cleanup(warehouseCleanup) alterOptions := &sdk.AlterWarehouseOptions{ - Set: &sdk.WarehouseSet{ - Tag: []sdk.TagAssociation{ - { - Name: tag.ID(), - Value: "val1", - }, - { - Name: tag2.ID(), - Value: "val2", - }, + SetTag: []sdk.TagAssociation{ + { + Name: tag.ID(), + Value: "val1", + }, + { + Name: tag2.ID(), + Value: "val2", }, }, } @@ -493,11 +489,9 @@ func TestInt_WarehouseAlter(t *testing.T) { require.Equal(t, "val2", val2) alterOptions = &sdk.AlterWarehouseOptions{ - Unset: &sdk.WarehouseUnset{ - Tag: []sdk.ObjectIdentifier{ - tag.ID(), - tag2.ID(), - }, + UnsetTag: []sdk.ObjectIdentifier{ + tag.ID(), + tag2.ID(), }, } err = client.Warehouses.Alter(ctx, warehouse.ID(), alterOptions) diff --git a/pkg/sdk/users.go b/pkg/sdk/users.go index 2edcf06ea3..6438833332 100644 --- a/pkg/sdk/users.go +++ b/pkg/sdk/users.go @@ -273,6 +273,8 @@ type AlterUserOptions struct { RemoveDelegatedAuthorization *RemoveDelegatedAuthorization `ddl:"keyword"` Set *UserSet `ddl:"keyword" sql:"SET"` Unset *UserUnset `ddl:"keyword" sql:"UNSET"` + SetTag []TagAssociation `ddl:"keyword" sql:"SET TAG"` + UnsetTag []ObjectIdentifier `ddl:"keyword" sql:"UNSET TAG"` } func (opts *AlterUserOptions) validate() error { @@ -283,8 +285,8 @@ func (opts *AlterUserOptions) validate() error { if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } - if !exactlyOneValueSet(opts.NewName, opts.ResetPassword, opts.AbortAllQueries, opts.AddDelegatedAuthorization, opts.RemoveDelegatedAuthorization, opts.Set, opts.Unset) { - errs = append(errs, errExactlyOneOf("AlterUserOptions", "NewName", "ResetPassword", "AbortAllQueries", "AddDelegatedAuthorization", "RemoveDelegatedAuthorization", "Set", "Unset")) + if !exactlyOneValueSet(opts.NewName, opts.ResetPassword, opts.AbortAllQueries, opts.AddDelegatedAuthorization, opts.RemoveDelegatedAuthorization, opts.Set, opts.Unset, opts.SetTag, opts.UnsetTag) { + errs = append(errs, errExactlyOneOf("AlterUserOptions", "NewName", "ResetPassword", "AbortAllQueries", "AddDelegatedAuthorization", "RemoveDelegatedAuthorization", "Set", "Unset", "SetTag", "UnsetTag")) } if valueSet(opts.RemoveDelegatedAuthorization) { if err := opts.RemoveDelegatedAuthorization.validate(); err != nil { @@ -347,15 +349,14 @@ func (opts *RemoveDelegatedAuthorization) validate() error { type UserSet struct { PasswordPolicy *string `ddl:"parameter" sql:"PASSWORD POLICY"` SessionPolicy *string `ddl:"parameter" sql:"SESSION POLICY"` - Tags []TagAssociation `ddl:"keyword,parentheses" sql:"TAG"` ObjectProperties *UserObjectProperties `ddl:"keyword"` ObjectParameters *UserObjectParameters `ddl:"keyword"` SessionParameters *SessionParameters `ddl:"keyword"` } func (opts *UserSet) validate() error { - if !exactlyOneValueSet(opts.PasswordPolicy, opts.SessionPolicy, opts.Tags, opts.ObjectProperties, opts.ObjectParameters, opts.SessionParameters) { - return errExactlyOneOf("UserSet", "PasswordPolicy", "SessionPolicy", "Tags", "ObjectProperties", "ObjectParameters", "SessionParameters") + if !exactlyOneValueSet(opts.PasswordPolicy, opts.SessionPolicy, opts.ObjectProperties, opts.ObjectParameters, opts.SessionParameters) { + return errExactlyOneOf("UserSet", "PasswordPolicy", "SessionPolicy", "ObjectProperties", "ObjectParameters", "SessionParameters") } return nil } @@ -363,15 +364,14 @@ func (opts *UserSet) validate() error { type UserUnset struct { PasswordPolicy *bool `ddl:"keyword" sql:"PASSWORD POLICY"` SessionPolicy *bool `ddl:"keyword" sql:"SESSION POLICY"` - Tags *[]string `ddl:"keyword" sql:"TAG"` ObjectProperties *UserObjectPropertiesUnset `ddl:"list"` ObjectParameters *UserObjectParametersUnset `ddl:"list"` SessionParameters *SessionParametersUnset `ddl:"list"` } func (opts *UserUnset) validate() error { - if !exactlyOneValueSet(opts.Tags, opts.PasswordPolicy, opts.SessionPolicy, opts.ObjectProperties, opts.ObjectParameters, opts.SessionParameters) { - return errExactlyOneOf("UserUnset", "Tags", "PasswordPolicy", "SessionPolicy", "ObjectProperties", "ObjectParameters", "SessionParameters") + if !exactlyOneValueSet(opts.PasswordPolicy, opts.SessionPolicy, opts.ObjectProperties, opts.ObjectParameters, opts.SessionParameters) { + return errExactlyOneOf("UserUnset", "PasswordPolicy", "SessionPolicy", "ObjectProperties", "ObjectParameters", "SessionParameters") } return nil } diff --git a/pkg/sdk/users_test.go b/pkg/sdk/users_test.go index 82a13912a8..1ece392571 100644 --- a/pkg/sdk/users_test.go +++ b/pkg/sdk/users_test.go @@ -58,7 +58,7 @@ func TestUserAlter(t *testing.T) { opts := &AlterUserOptions{ name: id, } - assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterUserOptions", "NewName", "ResetPassword", "AbortAllQueries", "AddDelegatedAuthorization", "RemoveDelegatedAuthorization", "Set", "Unset")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterUserOptions", "NewName", "ResetPassword", "AbortAllQueries", "AddDelegatedAuthorization", "RemoveDelegatedAuthorization", "Set", "Unset", "SetTag", "UnsetTag")) }) t.Run("with setting a policy", func(t *testing.T) { @@ -84,12 +84,10 @@ func TestUserAlter(t *testing.T) { }, } opts := &AlterUserOptions{ - name: id, - Set: &UserSet{ - Tags: tags, - }, + name: id, + SetTag: tags, } - assertOptsValidAndSQLEquals(t, opts, `ALTER USER %s SET TAG ("db"."schema"."tag1" = 'v1', "db"."schema"."tag2" = 'v2')`, id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `ALTER USER %s SET TAG "db"."schema"."tag1" = 'v1', "db"."schema"."tag2" = 'v2'`, id.FullyQualifiedName()) }) t.Run("with setting properties and parameters", func(t *testing.T) { @@ -171,15 +169,13 @@ func TestUserAlter(t *testing.T) { }) t.Run("with unsetting tags", func(t *testing.T) { - tag1 := "USER_TAG1" - tag2 := "USER_TAG2" + tag1 := NewSchemaObjectIdentifier("db", "schema", "USER_TAG1") + tag2 := NewSchemaObjectIdentifier("db", "schema", "USER_TAG2") opts := &AlterUserOptions{ - name: id, - Unset: &UserUnset{ - Tags: &[]string{tag1, tag2}, - }, + name: id, + UnsetTag: []ObjectIdentifier{tag1, tag2}, } - assertOptsValidAndSQLEquals(t, opts, "ALTER USER %s UNSET TAG %s, %s", id.FullyQualifiedName(), tag1, tag2) + assertOptsValidAndSQLEquals(t, opts, "ALTER USER %s UNSET TAG %s, %s", id.FullyQualifiedName(), tag1.FullyQualifiedName(), tag2.FullyQualifiedName()) }) t.Run("with unsetting properties", func(t *testing.T) { diff --git a/pkg/sdk/warehouses.go b/pkg/sdk/warehouses.go index 5e61ebf5b2..d76edddd71 100644 --- a/pkg/sdk/warehouses.go +++ b/pkg/sdk/warehouses.go @@ -165,8 +165,10 @@ type AlterWarehouseOptions struct { AbortAllQueries *bool `ddl:"keyword" sql:"ABORT ALL QUERIES"` NewName *AccountObjectIdentifier `ddl:"identifier" sql:"RENAME TO"` - Set *WarehouseSet `ddl:"keyword" sql:"SET"` - Unset *WarehouseUnset `ddl:"list,no_parentheses" sql:"UNSET"` + Set *WarehouseSet `ddl:"keyword" sql:"SET"` + Unset *WarehouseUnset `ddl:"list,no_parentheses" sql:"UNSET"` + SetTag []TagAssociation `ddl:"keyword" sql:"SET TAG"` + UnsetTag []ObjectIdentifier `ddl:"keyword" sql:"UNSET TAG"` } func (opts *AlterWarehouseOptions) validate() error { @@ -177,8 +179,8 @@ func (opts *AlterWarehouseOptions) validate() error { if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } - if !exactlyOneValueSet(opts.Suspend, opts.Resume, opts.AbortAllQueries, opts.NewName, opts.Set, opts.Unset) { - errs = append(errs, errExactlyOneOf("AlterWarehouseOptions", "Suspend", "Resume", "AbortAllQueries", "NewName", "Set", "Unset")) + if !exactlyOneValueSet(opts.Suspend, opts.Resume, opts.AbortAllQueries, opts.NewName, opts.Set, opts.Unset, opts.SetTag, opts.UnsetTag) { + errs = append(errs, errExactlyOneOf("AlterWarehouseOptions", "Suspend", "Resume", "AbortAllQueries", "NewName", "Set", "Unset", "SetTag", "UnsetTag")) } if everyValueSet(opts.Suspend, opts.Resume) && (*opts.Suspend && *opts.Resume) { errs = append(errs, errOneOf("AlterWarehouseOptions", "Suspend", "Resume")) @@ -218,8 +220,6 @@ type WarehouseSet struct { MaxConcurrencyLevel *int `ddl:"parameter" sql:"MAX_CONCURRENCY_LEVEL"` StatementQueuedTimeoutInSeconds *int `ddl:"parameter" sql:"STATEMENT_QUEUED_TIMEOUT_IN_SECONDS"` StatementTimeoutInSeconds *int `ddl:"parameter" sql:"STATEMENT_TIMEOUT_IN_SECONDS"` - - Tag []TagAssociation `ddl:"keyword" sql:"TAG"` } func (v *WarehouseSet) validate() error { @@ -244,8 +244,8 @@ func (v *WarehouseSet) validate() error { return fmt.Errorf("QueryAccelerationMaxScaleFactor must be between 0 and 100") } } - if valueSet(v.Tag) && !everyValueNil(v.AutoResume, v.EnableQueryAcceleration, v.MaxClusterCount, v.MinClusterCount, v.AutoSuspend, v.QueryAccelerationMaxScaleFactor) { - return fmt.Errorf("Tag cannot be set with any other Set parameter") + if everyValueNil(v.WarehouseType, v.WarehouseSize, v.WaitForCompletion, v.MaxClusterCount, v.MinClusterCount, v.ScalingPolicy, v.AutoSuspend, v.AutoResume, v.ResourceMonitor, v.Comment, v.EnableQueryAcceleration, v.QueryAccelerationMaxScaleFactor, v.MaxConcurrencyLevel, v.StatementQueuedTimeoutInSeconds, v.StatementTimeoutInSeconds) { + return errAtLeastOneOf("WarehouseSet", "WarehouseType", "WarehouseSize", "WaitForCompletion", "MaxClusterCount", "MinClusterCount", "ScalingPolicy", "AutoSuspend", "AutoResume", "ResourceMonitor", "Comment", "EnableQueryAcceleration", "QueryAccelerationMaxScaleFactor", "MaxConcurrencyLevel", "StatementQueuedTimeoutInSeconds", "StatementTimeoutInSeconds") } return nil } @@ -253,7 +253,6 @@ func (v *WarehouseSet) validate() error { type WarehouseUnset struct { // Object properties WarehouseType *bool `ddl:"keyword" sql:"WAREHOUSE_TYPE"` - WarehouseSize *bool `ddl:"keyword" sql:"WAREHOUSE_SIZE"` WaitForCompletion *bool `ddl:"keyword" sql:"WAIT_FOR_COMPLETION"` MaxClusterCount *bool `ddl:"keyword" sql:"MAX_CLUSTER_COUNT"` MinClusterCount *bool `ddl:"keyword" sql:"MIN_CLUSTER_COUNT"` @@ -266,15 +265,14 @@ type WarehouseUnset struct { QueryAccelerationMaxScaleFactor *bool `ddl:"keyword" sql:"QUERY_ACCELERATION_MAX_SCALE_FACTOR"` // Object params - MaxConcurrencyLevel *bool `ddl:"keyword" sql:"MAX_CONCURRENCY_LEVEL"` - StatementQueuedTimeoutInSeconds *bool `ddl:"keyword" sql:"STATEMENT_QUEUED_TIMEOUT_IN_SECONDS"` - StatementTimeoutInSeconds *bool `ddl:"keyword" sql:"STATEMENT_TIMEOUT_IN_SECONDS"` - Tag []ObjectIdentifier `ddl:"keyword" sql:"TAG"` + MaxConcurrencyLevel *bool `ddl:"keyword" sql:"MAX_CONCURRENCY_LEVEL"` + StatementQueuedTimeoutInSeconds *bool `ddl:"keyword" sql:"STATEMENT_QUEUED_TIMEOUT_IN_SECONDS"` + StatementTimeoutInSeconds *bool `ddl:"keyword" sql:"STATEMENT_TIMEOUT_IN_SECONDS"` } func (v *WarehouseUnset) validate() error { - if valueSet(v.Tag) && !everyValueNil(v.AutoResume, v.EnableQueryAcceleration, v.MaxClusterCount, v.MinClusterCount, v.AutoSuspend, v.QueryAccelerationMaxScaleFactor) { - return fmt.Errorf("Tag cannot be unset with any other Unset parameter") + if everyValueNil(v.WarehouseType, v.WaitForCompletion, v.MaxClusterCount, v.MinClusterCount, v.ScalingPolicy, v.AutoSuspend, v.AutoResume, v.ResourceMonitor, v.Comment, v.EnableQueryAcceleration, v.QueryAccelerationMaxScaleFactor, v.MaxConcurrencyLevel, v.StatementQueuedTimeoutInSeconds, v.StatementTimeoutInSeconds) { + return errAtLeastOneOf("WarehouseUnset", "WarehouseType", "WaitForCompletion", "MaxClusterCount", "MinClusterCount", "ScalingPolicy", "AutoSuspend", "AutoResume", "ResourceMonitor", "Comment", "EnableQueryAcceleration", "QueryAccelerationMaxScaleFactor", "MaxConcurrencyLevel", "StatementQueuedTimeoutInSeconds", "StatementTimeoutInSeconds") } return nil } diff --git a/pkg/sdk/warehouses_test.go b/pkg/sdk/warehouses_test.go index c13564c52c..9b0baafb02 100644 --- a/pkg/sdk/warehouses_test.go +++ b/pkg/sdk/warehouses_test.go @@ -114,16 +114,14 @@ func TestWarehouseAlter(t *testing.T) { t.Run("with set tag", func(t *testing.T) { opts := &AlterWarehouseOptions{ name: NewAccountObjectIdentifier("mywarehouse"), - Set: &WarehouseSet{ - Tag: []TagAssociation{ - { - Name: NewSchemaObjectIdentifier("db", "schema", "tag1"), - Value: "v1", - }, - { - Name: NewSchemaObjectIdentifier("db", "schema", "tag2"), - Value: "v2", - }, + SetTag: []TagAssociation{ + { + Name: NewSchemaObjectIdentifier("db", "schema", "tag1"), + Value: "v1", + }, + { + Name: NewSchemaObjectIdentifier("db", "schema", "tag2"), + Value: "v2", }, }, } @@ -133,10 +131,8 @@ func TestWarehouseAlter(t *testing.T) { t.Run("with unset tag", func(t *testing.T) { opts := &AlterWarehouseOptions{ name: NewAccountObjectIdentifier("mywarehouse"), - Unset: &WarehouseUnset{ - Tag: []ObjectIdentifier{ - NewSchemaObjectIdentifier("db", "schema", "tag1"), - }, + UnsetTag: []ObjectIdentifier{ + NewSchemaObjectIdentifier("db", "schema", "tag1"), }, } assertOptsValidAndSQLEquals(t, opts, `ALTER WAREHOUSE "mywarehouse" UNSET TAG "db"."schema"."tag1"`) @@ -146,12 +142,11 @@ func TestWarehouseAlter(t *testing.T) { opts := &AlterWarehouseOptions{ name: NewAccountObjectIdentifier("mywarehouse"), Unset: &WarehouseUnset{ - WarehouseSize: Bool(true), MaxClusterCount: Bool(true), AutoResume: Bool(true), }, } - assertOptsValidAndSQLEquals(t, opts, `ALTER WAREHOUSE "mywarehouse" UNSET WAREHOUSE_SIZE, MAX_CLUSTER_COUNT, AUTO_RESUME`) + assertOptsValidAndSQLEquals(t, opts, `ALTER WAREHOUSE "mywarehouse" UNSET MAX_CLUSTER_COUNT, AUTO_RESUME`) }) t.Run("rename", func(t *testing.T) { @@ -191,16 +186,14 @@ func TestWarehouseAlter(t *testing.T) { t.Run("with set tag", func(t *testing.T) { opts := &AlterWarehouseOptions{ name: NewAccountObjectIdentifier("mywarehouse"), - Set: &WarehouseSet{ - Tag: []TagAssociation{ - { - Name: NewSchemaObjectIdentifier("db1", "schema1", "tag1"), - Value: "v1", - }, - { - Name: NewSchemaObjectIdentifier("db2", "schema2", "tag2"), - Value: "v2", - }, + SetTag: []TagAssociation{ + { + Name: NewSchemaObjectIdentifier("db1", "schema1", "tag1"), + Value: "v1", + }, + { + Name: NewSchemaObjectIdentifier("db2", "schema2", "tag2"), + Value: "v2", }, }, } @@ -210,11 +203,9 @@ func TestWarehouseAlter(t *testing.T) { t.Run("with unset tag", func(t *testing.T) { opts := &AlterWarehouseOptions{ name: NewAccountObjectIdentifier("mywarehouse"), - Unset: &WarehouseUnset{ - Tag: []ObjectIdentifier{ - NewSchemaObjectIdentifier("db1", "schema1", "tag1"), - NewSchemaObjectIdentifier("db2", "schema2", "tag2"), - }, + UnsetTag: []ObjectIdentifier{ + NewSchemaObjectIdentifier("db1", "schema1", "tag1"), + NewSchemaObjectIdentifier("db2", "schema2", "tag2"), }, } assertOptsValidAndSQLEquals(t, opts, `ALTER WAREHOUSE "mywarehouse" UNSET TAG "db1"."schema1"."tag1", "db2"."schema2"."tag2"`) From 5148be46a57bf11cd1daeb481e1b402bafb0f52f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Cie=C5=9Blak?= Date: Tue, 31 Oct 2023 11:40:02 +0100 Subject: [PATCH 14/20] Export QueryStruct (#2160) --- pkg/sdk/network_policies_def.go | 14 ++++---- pkg/sdk/poc/example/database_role_def.go | 12 +++---- pkg/sdk/poc/generator/identifier_builders.go | 6 ++-- pkg/sdk/poc/generator/keyword_builders.go | 36 ++++++++++---------- pkg/sdk/poc/generator/operation.go | 16 ++++----- pkg/sdk/poc/generator/parameter_builders.go | 26 +++++++------- pkg/sdk/poc/generator/query_struct.go | 20 +++++------ pkg/sdk/poc/generator/static_builders.go | 12 +++---- pkg/sdk/session_policies_def.go | 14 ++++---- pkg/sdk/streams_def.go | 22 ++++++------ pkg/sdk/tasks_def.go | 20 +++++------ 11 files changed, 99 insertions(+), 99 deletions(-) diff --git a/pkg/sdk/network_policies_def.go b/pkg/sdk/network_policies_def.go index b2beaf222d..68ae2f1434 100644 --- a/pkg/sdk/network_policies_def.go +++ b/pkg/sdk/network_policies_def.go @@ -5,7 +5,7 @@ import g "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/poc/gen //go:generate go run ./poc/main.go var ( - ip = g.QueryStruct("IP"). + ip = g.NewQueryStruct("IP"). Text("IP", g.KeywordOptions().SingleQuotes().Required()) NetworkPoliciesDef = g.NewInterface( @@ -15,7 +15,7 @@ var ( ). CreateOperation( "https://docs.snowflake.com/en/sql-reference/sql/create-network-policy", - g.QueryStruct("CreateNetworkPolicies"). + g.NewQueryStruct("CreateNetworkPolicies"). Create(). OrReplace(). SQL("NETWORK POLICY"). @@ -27,14 +27,14 @@ var ( ). AlterOperation( "https://docs.snowflake.com/en/sql-reference/sql/alter-network-policy", - g.QueryStruct("AlterNetworkPolicy"). + g.NewQueryStruct("AlterNetworkPolicy"). Alter(). SQL("NETWORK POLICY"). IfExists(). Name(). OptionalQueryStructField( "Set", - g.QueryStruct("NetworkPolicySet"). + g.NewQueryStruct("NetworkPolicySet"). ListQueryStructField("AllowedIpList", ip, g.ParameterOptions().SQL("ALLOWED_IP_LIST").Parentheses()). ListQueryStructField("BlockedIpList", ip, g.ParameterOptions().SQL("BLOCKED_IP_LIST").Parentheses()). OptionalTextAssignment("COMMENT", g.ParameterOptions().SingleQuotes()). @@ -49,7 +49,7 @@ var ( ). DropOperation( "https://docs.snowflake.com/en/sql-reference/sql/drop-network-policy", - g.QueryStruct("DropNetworkPolicy"). + g.NewQueryStruct("DropNetworkPolicy"). Drop(). SQL("NETWORK POLICY"). IfExists(). @@ -70,7 +70,7 @@ var ( Field("Comment", "string"). Field("EntriesInAllowedIpList", "int"). Field("EntriesInBlockedIpList", "int"), - g.QueryStruct("ShowNetworkPolicies"). + g.NewQueryStruct("ShowNetworkPolicies"). Show(). SQL("NETWORK POLICIES"), ). @@ -83,7 +83,7 @@ var ( g.PlainStruct("NetworkPolicyDescription"). Field("Name", "string"). Field("Value", "string"), - g.QueryStruct("DescribeNetworkPolicy"). + g.NewQueryStruct("DescribeNetworkPolicy"). Describe(). SQL("NETWORK POLICY"). Name(). diff --git a/pkg/sdk/poc/example/database_role_def.go b/pkg/sdk/poc/example/database_role_def.go index af670efb71..a893a1f416 100644 --- a/pkg/sdk/poc/example/database_role_def.go +++ b/pkg/sdk/poc/example/database_role_def.go @@ -7,24 +7,24 @@ import ( //go:generate go run ../main.go var ( - dbRoleRename = g.QueryStruct("DatabaseRoleRename"). + dbRoleRename = g.NewQueryStruct("DatabaseRoleRename"). // Fields Identifier("Name", g.KindOfT[DatabaseObjectIdentifier](), g.IdentifierOptions().Required()). // Validations WithValidation(g.ValidIdentifier, "Name") - nestedThirdLevel = g.QueryStruct("NestedThirdLevel"). + nestedThirdLevel = g.NewQueryStruct("NestedThirdLevel"). // Fields Identifier("Field", g.KindOfT[DatabaseObjectIdentifier](), g.IdentifierOptions().Required()). // Validations WithValidation(g.AtLeastOneValueSet, "Field") - dbRoleSet = g.QueryStruct("DatabaseRoleSet"). + dbRoleSet = g.NewQueryStruct("DatabaseRoleSet"). // Fields TextAssignment("COMMENT", g.ParameterOptions().SingleQuotes().Required()). OptionalQueryStructField("NestedThirdLevel", nestedThirdLevel, g.ListOptions().NoParentheses().SQL("NESTED")) - dbRoleUnset = g.QueryStruct("DatabaseRoleUnset"). + dbRoleUnset = g.NewQueryStruct("DatabaseRoleUnset"). // Fields OptionalSQL("COMMENT"). // Validations @@ -37,7 +37,7 @@ var ( ). CreateOperation( "https://docs.snowflake.com/en/sql-reference/sql/create-database-role", - g.QueryStruct("CreateDatabaseRole"). + g.NewQueryStruct("CreateDatabaseRole"). // Fields Create(). OrReplace(). @@ -51,7 +51,7 @@ var ( ). AlterOperation( "https://docs.snowflake.com/en/sql-reference/sql/alter-database-role", - g.QueryStruct("AlterDatabaseRole"). + g.NewQueryStruct("AlterDatabaseRole"). // Fields Alter(). SQL("DATABASE ROLE"). diff --git a/pkg/sdk/poc/generator/identifier_builders.go b/pkg/sdk/poc/generator/identifier_builders.go index 1d697c6087..89c67da982 100644 --- a/pkg/sdk/poc/generator/identifier_builders.go +++ b/pkg/sdk/poc/generator/identifier_builders.go @@ -1,19 +1,19 @@ package generator // Name adds identifier with field name "name" and type will be inferred from interface definition -func (v *queryStruct) Name() *queryStruct { +func (v *QueryStruct) Name() *QueryStruct { identifier := NewField("name", "", Tags().Identifier(), IdentifierOptions().Required()) v.identifierField = identifier v.fields = append(v.fields, identifier) return v } -func (v *queryStruct) Identifier(fieldName string, kind string, transformer *IdentifierTransformer) *queryStruct { +func (v *QueryStruct) Identifier(fieldName string, kind string, transformer *IdentifierTransformer) *QueryStruct { v.fields = append(v.fields, NewField(fieldName, kind, Tags().Identifier(), transformer)) return v } -func (v *queryStruct) OptionalIdentifier(name string, kind string, transformer *IdentifierTransformer) *queryStruct { +func (v *QueryStruct) OptionalIdentifier(name string, kind string, transformer *IdentifierTransformer) *QueryStruct { if len(kind) > 0 && kind[0] != '*' { kind = KindOfPointer(kind) } diff --git a/pkg/sdk/poc/generator/keyword_builders.go b/pkg/sdk/poc/generator/keyword_builders.go index 73c5717232..5c37a65b24 100644 --- a/pkg/sdk/poc/generator/keyword_builders.go +++ b/pkg/sdk/poc/generator/keyword_builders.go @@ -1,87 +1,87 @@ package generator -func (v *queryStruct) OptionalSQL(sql string) *queryStruct { +func (v *QueryStruct) OptionalSQL(sql string) *QueryStruct { v.fields = append(v.fields, NewField(sqlToFieldName(sql, true), "*bool", Tags().Keyword().SQL(sql), nil)) return v } -func (v *queryStruct) OrReplace() *queryStruct { +func (v *QueryStruct) OrReplace() *QueryStruct { return v.OptionalSQL("OR REPLACE") } -func (v *queryStruct) IfNotExists() *queryStruct { +func (v *QueryStruct) IfNotExists() *QueryStruct { return v.OptionalSQL("IF NOT EXISTS") } -func (v *queryStruct) IfExists() *queryStruct { +func (v *QueryStruct) IfExists() *QueryStruct { return v.OptionalSQL("IF EXISTS") } -func (v *queryStruct) Terse() *queryStruct { +func (v *QueryStruct) Terse() *QueryStruct { return v.OptionalSQL("TERSE") } -func (v *queryStruct) Text(name string, transformer *KeywordTransformer) *queryStruct { +func (v *QueryStruct) Text(name string, transformer *KeywordTransformer) *QueryStruct { v.fields = append(v.fields, NewField(name, "string", Tags().Keyword(), transformer)) return v } -func (v *queryStruct) OptionalText(name string, transformer *KeywordTransformer) *queryStruct { +func (v *QueryStruct) OptionalText(name string, transformer *KeywordTransformer) *QueryStruct { v.fields = append(v.fields, NewField(name, "*string", Tags().Keyword(), transformer)) return v } // SessionParameters *SessionParameters `ddl:"list,no_parentheses"` -func (v *queryStruct) SessionParameters() *queryStruct { +func (v *QueryStruct) SessionParameters() *QueryStruct { v.fields = append(v.fields, NewField("SessionParameters", "*SessionParameters", Tags().List().NoParentheses(), nil).withValidations(NewValidation(ValidateValue, "SessionParameters"))) return v } -func (v *queryStruct) OptionalSessionParameters() *queryStruct { +func (v *QueryStruct) OptionalSessionParameters() *QueryStruct { v.fields = append(v.fields, NewField("SessionParameters", "*SessionParameters", Tags().List().NoParentheses(), nil).withValidations(NewValidation(ValidateValue, "SessionParameters"))) return v } -func (v *queryStruct) OptionalSessionParametersUnset() *queryStruct { +func (v *QueryStruct) OptionalSessionParametersUnset() *QueryStruct { v.fields = append(v.fields, NewField("SessionParametersUnset", "*SessionParametersUnset", Tags().List().NoParentheses(), nil).withValidations(NewValidation(ValidateValue, "SessionParametersUnset"))) return v } -func (v *queryStruct) WithTags() *queryStruct { +func (v *QueryStruct) WithTags() *QueryStruct { v.fields = append(v.fields, NewField("Tag", "[]TagAssociation", Tags().Keyword().Parentheses().SQL("TAG"), nil)) return v } -func (v *queryStruct) SetTags() *queryStruct { +func (v *QueryStruct) SetTags() *QueryStruct { v.fields = append(v.fields, NewField("SetTags", "[]TagAssociation", Tags().Keyword().SQL("SET TAG"), nil)) return v } -func (v *queryStruct) UnsetTags() *queryStruct { +func (v *QueryStruct) UnsetTags() *QueryStruct { v.fields = append(v.fields, NewField("UnsetTags", "[]ObjectIdentifier", Tags().Keyword().SQL("UNSET TAG"), nil)) return v } -func (v *queryStruct) OptionalLike() *queryStruct { +func (v *QueryStruct) OptionalLike() *QueryStruct { v.fields = append(v.fields, NewField("Like", "*Like", Tags().Keyword().SQL("LIKE"), nil)) return v } -func (v *queryStruct) OptionalIn() *queryStruct { +func (v *QueryStruct) OptionalIn() *QueryStruct { v.fields = append(v.fields, NewField("In", "*In", Tags().Keyword().SQL("IN"), nil)) return v } -func (v *queryStruct) OptionalStartsWith() *queryStruct { +func (v *QueryStruct) OptionalStartsWith() *QueryStruct { v.fields = append(v.fields, NewField("StartsWith", "*string", Tags().Parameter().NoEquals().SingleQuotes().SQL("STARTS WITH"), nil)) return v } -func (v *queryStruct) OptionalLimit() *queryStruct { +func (v *QueryStruct) OptionalLimit() *QueryStruct { v.fields = append(v.fields, NewField("Limit", "*LimitFrom", Tags().Keyword().SQL("LIMIT"), nil)) return v } -func (v *queryStruct) OptionalCopyGrants() *queryStruct { +func (v *QueryStruct) OptionalCopyGrants() *QueryStruct { return v.OptionalSQL("COPY GRANTS") } diff --git a/pkg/sdk/poc/generator/operation.go b/pkg/sdk/poc/generator/operation.go index 6dddf5d8d0..1ec861dfcd 100644 --- a/pkg/sdk/poc/generator/operation.go +++ b/pkg/sdk/poc/generator/operation.go @@ -90,7 +90,7 @@ func (i *Interface) newNoSqlOperation(kind string) *Interface { return i } -func (i *Interface) newSimpleOperation(kind string, doc string, queryStruct *queryStruct, helperStructs ...IntoField) *Interface { +func (i *Interface) newSimpleOperation(kind string, doc string, queryStruct *QueryStruct, helperStructs ...IntoField) *Interface { if queryStruct.identifierField != nil { queryStruct.identifierField.Kind = i.IdentifierKind } @@ -112,7 +112,7 @@ func (i *Interface) newOperationWithDBMapping( doc string, dbRepresentation *dbStruct, resourceRepresentation *plainStruct, - queryStruct *queryStruct, + queryStruct *QueryStruct, addMappingFunc func(op *Operation, from, to *Field), ) *Operation { db := dbRepresentation.IntoField() @@ -133,19 +133,19 @@ type IntoField interface { IntoField() *Field } -func (i *Interface) CreateOperation(doc string, queryStruct *queryStruct, helperStructs ...IntoField) *Interface { +func (i *Interface) CreateOperation(doc string, queryStruct *QueryStruct, helperStructs ...IntoField) *Interface { return i.newSimpleOperation(string(OperationKindCreate), doc, queryStruct, helperStructs...) } -func (i *Interface) AlterOperation(doc string, queryStruct *queryStruct) *Interface { +func (i *Interface) AlterOperation(doc string, queryStruct *QueryStruct) *Interface { return i.newSimpleOperation(string(OperationKindAlter), doc, queryStruct) } -func (i *Interface) DropOperation(doc string, queryStruct *queryStruct) *Interface { +func (i *Interface) DropOperation(doc string, queryStruct *QueryStruct) *Interface { return i.newSimpleOperation(string(OperationKindDrop), doc, queryStruct) } -func (i *Interface) ShowOperation(doc string, dbRepresentation *dbStruct, resourceRepresentation *plainStruct, queryStruct *queryStruct) *Interface { +func (i *Interface) ShowOperation(doc string, dbRepresentation *dbStruct, resourceRepresentation *plainStruct, queryStruct *QueryStruct) *Interface { i.newOperationWithDBMapping(string(OperationKindShow), doc, dbRepresentation, resourceRepresentation, queryStruct, addShowMapping) return i } @@ -154,12 +154,12 @@ func (i *Interface) ShowByIdOperation() *Interface { return i.newNoSqlOperation(string(OperationKindShowByID)) } -func (i *Interface) DescribeOperation(describeKind DescriptionMappingKind, doc string, dbRepresentation *dbStruct, resourceRepresentation *plainStruct, queryStruct *queryStruct) *Interface { +func (i *Interface) DescribeOperation(describeKind DescriptionMappingKind, doc string, dbRepresentation *dbStruct, resourceRepresentation *plainStruct, queryStruct *QueryStruct) *Interface { op := i.newOperationWithDBMapping(string(OperationKindDescribe), doc, dbRepresentation, resourceRepresentation, queryStruct, addDescriptionMapping) op.DescribeKind = &describeKind return i } -func (i *Interface) CustomOperation(kind string, doc string, queryStruct *queryStruct) *Interface { +func (i *Interface) CustomOperation(kind string, doc string, queryStruct *QueryStruct) *Interface { return i.newSimpleOperation(kind, doc, queryStruct) } diff --git a/pkg/sdk/poc/generator/parameter_builders.go b/pkg/sdk/poc/generator/parameter_builders.go index 8c5cf6d924..70e8fa6756 100644 --- a/pkg/sdk/poc/generator/parameter_builders.go +++ b/pkg/sdk/poc/generator/parameter_builders.go @@ -1,11 +1,11 @@ package generator -func (v *queryStruct) assignment(name string, kind string, transformer *ParameterTransformer) *queryStruct { +func (v *QueryStruct) assignment(name string, kind string, transformer *ParameterTransformer) *QueryStruct { v.fields = append(v.fields, NewField(name, kind, Tags().Parameter(), transformer)) return v } -func (v *queryStruct) Assignment(sqlPrefix string, kind string, transformer *ParameterTransformer) *queryStruct { +func (v *QueryStruct) Assignment(sqlPrefix string, kind string, transformer *ParameterTransformer) *QueryStruct { if transformer != nil { transformer = transformer.SQL(sqlPrefix) } else { @@ -14,49 +14,49 @@ func (v *queryStruct) Assignment(sqlPrefix string, kind string, transformer *Par return v.assignment(sqlToFieldName(sqlPrefix, true), kind, transformer) } -func (v *queryStruct) OptionalAssignment(sqlPrefix string, kind string, transformer *ParameterTransformer) *queryStruct { +func (v *QueryStruct) OptionalAssignment(sqlPrefix string, kind string, transformer *ParameterTransformer) *QueryStruct { if len(kind) > 0 && kind[0] != '*' { kind = KindOfPointer(kind) } return v.Assignment(sqlPrefix, kind, transformer) } -func (v *queryStruct) ListAssignment(sqlPrefix string, listItemKind string, transformer *ParameterTransformer) *queryStruct { +func (v *QueryStruct) ListAssignment(sqlPrefix string, listItemKind string, transformer *ParameterTransformer) *QueryStruct { return v.Assignment(sqlPrefix, KindOfSlice(listItemKind), transformer) } -func (v *queryStruct) NumberAssignment(sqlPrefix string, transformer *ParameterTransformer) *queryStruct { +func (v *QueryStruct) NumberAssignment(sqlPrefix string, transformer *ParameterTransformer) *QueryStruct { return v.Assignment(sqlPrefix, "int", transformer) } -func (v *queryStruct) OptionalNumberAssignment(sqlPrefix string, transformer *ParameterTransformer) *queryStruct { +func (v *QueryStruct) OptionalNumberAssignment(sqlPrefix string, transformer *ParameterTransformer) *QueryStruct { return v.Assignment(sqlPrefix, "*int", transformer) } -func (v *queryStruct) TextAssignment(sqlPrefix string, transformer *ParameterTransformer) *queryStruct { +func (v *QueryStruct) TextAssignment(sqlPrefix string, transformer *ParameterTransformer) *QueryStruct { return v.Assignment(sqlPrefix, "string", transformer) } -func (v *queryStruct) OptionalTextAssignment(sqlPrefix string, transformer *ParameterTransformer) *queryStruct { +func (v *QueryStruct) OptionalTextAssignment(sqlPrefix string, transformer *ParameterTransformer) *QueryStruct { return v.Assignment(sqlPrefix, "*string", transformer) } -func (v *queryStruct) BooleanAssignment(sqlPrefix string, transformer *ParameterTransformer) *queryStruct { +func (v *QueryStruct) BooleanAssignment(sqlPrefix string, transformer *ParameterTransformer) *QueryStruct { return v.Assignment(sqlPrefix, "bool", transformer) } -func (v *queryStruct) OptionalBooleanAssignment(sqlPrefix string, transformer *ParameterTransformer) *queryStruct { +func (v *QueryStruct) OptionalBooleanAssignment(sqlPrefix string, transformer *ParameterTransformer) *QueryStruct { return v.Assignment(sqlPrefix, "*bool", transformer) } -func (v *queryStruct) OptionalIdentifierAssignment(sqlPrefix string, identifierKind string, transformer *ParameterTransformer) *queryStruct { +func (v *QueryStruct) OptionalIdentifierAssignment(sqlPrefix string, identifierKind string, transformer *ParameterTransformer) *QueryStruct { return v.OptionalAssignment(sqlPrefix, identifierKind, transformer) } -func (v *queryStruct) OptionalComment() *queryStruct { +func (v *QueryStruct) OptionalComment() *QueryStruct { return v.OptionalTextAssignment("COMMENT", ParameterOptions().SingleQuotes()) } -func (v *queryStruct) SetComment() *queryStruct { +func (v *QueryStruct) SetComment() *QueryStruct { return v.OptionalTextAssignment("SET COMMENT", ParameterOptions().SingleQuotes()) } diff --git a/pkg/sdk/poc/generator/query_struct.go b/pkg/sdk/poc/generator/query_struct.go index 68278308bb..47dcf88487 100644 --- a/pkg/sdk/poc/generator/query_struct.go +++ b/pkg/sdk/poc/generator/query_struct.go @@ -2,49 +2,49 @@ package generator // TODO For Field abstractions use internal Field representation instead of copying only needed fields, e.g. // -// type queryStruct struct { +// type QueryStruct struct { // internalRepresentation *Field // ...additional fields that are not present in the Field // } -type queryStruct struct { +type QueryStruct struct { name string fields []*Field identifierField *Field validations []*Validation } -func QueryStruct(name string) *queryStruct { - return &queryStruct{ +func NewQueryStruct(name string) *QueryStruct { + return &QueryStruct{ name: name, fields: make([]*Field, 0), validations: make([]*Validation, 0), } } -func (v *queryStruct) IntoField() *Field { +func (v *QueryStruct) IntoField() *Field { return NewField(v.name, v.name, nil, nil). withFields(v.fields...). withValidations(v.validations...) } -func (v *queryStruct) WithValidation(validationType ValidationType, fieldNames ...string) *queryStruct { +func (v *QueryStruct) WithValidation(validationType ValidationType, fieldNames ...string) *QueryStruct { v.validations = append(v.validations, NewValidation(validationType, fieldNames...)) return v } -func (v *queryStruct) QueryStructField(name string, queryStruct *queryStruct, transformer FieldTransformer) *queryStruct { +func (v *QueryStruct) QueryStructField(name string, queryStruct *QueryStruct, transformer FieldTransformer) *QueryStruct { return v.queryStructField(name, queryStruct, "", transformer) } -func (v *queryStruct) ListQueryStructField(name string, queryStruct *queryStruct, transformer FieldTransformer) *queryStruct { +func (v *QueryStruct) ListQueryStructField(name string, queryStruct *QueryStruct, transformer FieldTransformer) *QueryStruct { return v.queryStructField(name, queryStruct, "[]", transformer) } -func (v *queryStruct) OptionalQueryStructField(name string, queryStruct *queryStruct, transformer FieldTransformer) *queryStruct { +func (v *QueryStruct) OptionalQueryStructField(name string, queryStruct *QueryStruct, transformer FieldTransformer) *QueryStruct { return v.queryStructField(name, queryStruct, "*", transformer) } -func (v *queryStruct) queryStructField(name string, queryStruct *queryStruct, kindPrefix string, transformer FieldTransformer) *queryStruct { +func (v *QueryStruct) queryStructField(name string, queryStruct *QueryStruct, kindPrefix string, transformer FieldTransformer) *QueryStruct { qs := queryStruct.IntoField() qs.Name = name qs.Kind = kindPrefix + qs.Kind diff --git a/pkg/sdk/poc/generator/static_builders.go b/pkg/sdk/poc/generator/static_builders.go index b47a7d89dc..a0635f01db 100644 --- a/pkg/sdk/poc/generator/static_builders.go +++ b/pkg/sdk/poc/generator/static_builders.go @@ -1,26 +1,26 @@ package generator -func (v *queryStruct) SQL(sql string) *queryStruct { +func (v *QueryStruct) SQL(sql string) *QueryStruct { v.fields = append(v.fields, NewField(sqlToFieldName(sql, false), "bool", Tags().Static().SQL(sql), nil)) return v } -func (v *queryStruct) Create() *queryStruct { +func (v *QueryStruct) Create() *QueryStruct { return v.SQL("CREATE") } -func (v *queryStruct) Alter() *queryStruct { +func (v *QueryStruct) Alter() *QueryStruct { return v.SQL("ALTER") } -func (v *queryStruct) Drop() *queryStruct { +func (v *QueryStruct) Drop() *QueryStruct { return v.SQL("DROP") } -func (v *queryStruct) Show() *queryStruct { +func (v *QueryStruct) Show() *QueryStruct { return v.SQL("SHOW") } -func (v *queryStruct) Describe() *queryStruct { +func (v *QueryStruct) Describe() *QueryStruct { return v.SQL("DESCRIBE") } diff --git a/pkg/sdk/session_policies_def.go b/pkg/sdk/session_policies_def.go index 86b86ae823..76c36c204d 100644 --- a/pkg/sdk/session_policies_def.go +++ b/pkg/sdk/session_policies_def.go @@ -11,7 +11,7 @@ var SessionPoliciesDef = g.NewInterface( ). CreateOperation( "https://docs.snowflake.com/en/sql-reference/sql/create-session-policy", - g.QueryStruct("CreateSessionPolicy"). + g.NewQueryStruct("CreateSessionPolicy"). Create(). OrReplace(). SQL("SESSION POLICY"). @@ -25,7 +25,7 @@ var SessionPoliciesDef = g.NewInterface( ). AlterOperation( "https://docs.snowflake.com/en/sql-reference/sql/alter-session-policy", - g.QueryStruct("AlterSessionPolicy"). + g.NewQueryStruct("AlterSessionPolicy"). Alter(). SQL("SESSION POLICY"). IfExists(). @@ -33,7 +33,7 @@ var SessionPoliciesDef = g.NewInterface( OptionalIdentifier("RenameTo", g.KindOfT[SchemaObjectIdentifier](), g.IdentifierOptions().SQL("RENAME TO")). OptionalQueryStructField( "Set", - g.QueryStruct("SessionPolicySet"). + g.NewQueryStruct("SessionPolicySet"). OptionalNumberAssignment("SESSION_IDLE_TIMEOUT_MINS", g.ParameterOptions().NoQuotes()). OptionalNumberAssignment("SESSION_UI_IDLE_TIMEOUT_MINS", g.ParameterOptions().NoQuotes()). OptionalTextAssignment("COMMENT", g.ParameterOptions().SingleQuotes()). @@ -44,7 +44,7 @@ var SessionPoliciesDef = g.NewInterface( UnsetTags(). OptionalQueryStructField( "Unset", - g.QueryStruct("SessionPolicyUnset"). + g.NewQueryStruct("SessionPolicyUnset"). OptionalSQL("SESSION_IDLE_TIMEOUT_MINS"). OptionalSQL("SESSION_UI_IDLE_TIMEOUT_MINS"). OptionalSQL("COMMENT"). @@ -56,7 +56,7 @@ var SessionPoliciesDef = g.NewInterface( ). DropOperation( "https://docs.snowflake.com/en/sql-reference/sql/drop-session-policy", - g.QueryStruct("DropSessionPolicy"). + g.NewQueryStruct("DropSessionPolicy"). Drop(). SQL("SESSION POLICY"). IfExists(). @@ -83,7 +83,7 @@ var SessionPoliciesDef = g.NewInterface( Field("Owner", "string"). Field("Comment", "string"). Field("Options", "string"), - g.QueryStruct("ShowSessionPolicies"). + g.NewQueryStruct("ShowSessionPolicies"). Show(). SQL("SESSION POLICIES"), ). @@ -102,7 +102,7 @@ var SessionPoliciesDef = g.NewInterface( Field("SessionIdleTimeoutMins", "int"). Field("SessionUIIdleTimeoutMins", "int"). Field("Comment", "string"), - g.QueryStruct("DescribeSessionPolicy"). + g.NewQueryStruct("DescribeSessionPolicy"). Describe(). SQL("SESSION POLICY"). Name(). diff --git a/pkg/sdk/streams_def.go b/pkg/sdk/streams_def.go index 68bb7913cc..db9f3a89b7 100644 --- a/pkg/sdk/streams_def.go +++ b/pkg/sdk/streams_def.go @@ -5,12 +5,12 @@ import g "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/poc/gen //go:generate go run ./poc/main.go var ( - onStreamDef = g.QueryStruct("OnStream"). + onStreamDef = g.NewQueryStruct("OnStream"). OptionalSQL("AT"). OptionalSQL("BEFORE"). QueryStructField( "Statement", - g.QueryStruct("OnStreamStatement"). + g.NewQueryStruct("OnStreamStatement"). OptionalTextAssignment("TIMESTAMP", g.ParameterOptions().ArrowEquals()). OptionalTextAssignment("OFFSET", g.ParameterOptions().ArrowEquals()). OptionalTextAssignment("STATEMENT", g.ParameterOptions().ArrowEquals()). @@ -64,7 +64,7 @@ var ( CustomOperation( "CreateOnTable", "https://docs.snowflake.com/en/sql-reference/sql/create-stream", - g.QueryStruct("CreateStreamOnTable"). + g.NewQueryStruct("CreateStreamOnTable"). Create(). OrReplace(). SQL("STREAM"). @@ -84,7 +84,7 @@ var ( CustomOperation( "CreateOnExternalTable", "https://docs.snowflake.com/en/sql-reference/sql/create-stream", - g.QueryStruct("CreateStreamOnExternalTable"). + g.NewQueryStruct("CreateStreamOnExternalTable"). Create(). OrReplace(). SQL("STREAM"). @@ -103,7 +103,7 @@ var ( CustomOperation( "CreateOnDirectoryTable", "https://docs.snowflake.com/en/sql-reference/sql/create-stream", - g.QueryStruct("CreateStreamOnDirectoryTable"). + g.NewQueryStruct("CreateStreamOnDirectoryTable"). Create(). OrReplace(). SQL("STREAM"). @@ -120,7 +120,7 @@ var ( CustomOperation( "CreateOnView", "https://docs.snowflake.com/en/sql-reference/sql/create-stream", - g.QueryStruct("CreateStreamOnView"). + g.NewQueryStruct("CreateStreamOnView"). Create(). OrReplace(). SQL("STREAM"). @@ -140,7 +140,7 @@ var ( CustomOperation( "Clone", "https://docs.snowflake.com/en/sql-reference/sql/create-stream#variant-syntax", - g.QueryStruct("CloneStream"). + g.NewQueryStruct("CloneStream"). Create(). OrReplace(). SQL("STREAM"). @@ -151,7 +151,7 @@ var ( ). AlterOperation( "https://docs.snowflake.com/en/sql-reference/sql/alter-stream", - g.QueryStruct("AlterStream"). + g.NewQueryStruct("AlterStream"). Alter(). SQL("STREAM"). IfExists(). @@ -166,7 +166,7 @@ var ( ). DropOperation( "https://docs.snowflake.com/en/sql-reference/sql/drop-stream", - g.QueryStruct("DropStream"). + g.NewQueryStruct("DropStream"). Drop(). SQL("STREAM"). IfExists(). @@ -177,7 +177,7 @@ var ( "https://docs.snowflake.com/en/sql-reference/sql/show-streams", showStreamDbRowDef, streamPlainStructDef, - g.QueryStruct("ShowStreams"). + g.NewQueryStruct("ShowStreams"). Show(). Terse(). SQL("STREAMS"). @@ -192,7 +192,7 @@ var ( "https://docs.snowflake.com/en/sql-reference/sql/desc-stream", showStreamDbRowDef, streamPlainStructDef, - g.QueryStruct("DescribeStream"). + g.NewQueryStruct("DescribeStream"). Describe(). SQL("STREAM"). Name(). diff --git a/pkg/sdk/tasks_def.go b/pkg/sdk/tasks_def.go index 11e6840967..0e80f9b2e8 100644 --- a/pkg/sdk/tasks_def.go +++ b/pkg/sdk/tasks_def.go @@ -55,7 +55,7 @@ var TasksDef = g.NewInterface( ). CreateOperation( "https://docs.snowflake.com/en/sql-reference/sql/create-task", - g.QueryStruct("CreateTask"). + g.NewQueryStruct("CreateTask"). Create(). OrReplace(). SQL("TASK"). @@ -63,7 +63,7 @@ var TasksDef = g.NewInterface( Name(). OptionalQueryStructField( "Warehouse", - g.QueryStruct("CreateTaskWarehouse"). + g.NewQueryStruct("CreateTaskWarehouse"). OptionalIdentifier("Warehouse", g.KindOfT[AccountObjectIdentifier](), g.IdentifierOptions().Equals().SQL("WAREHOUSE")). OptionalAssignment("USER_TASK_MANAGED_INITIAL_WAREHOUSE_SIZE", "WarehouseSize", g.ParameterOptions().SingleQuotes()). WithValidation(g.ExactlyOneValueSet, "Warehouse", "UserTaskManagedInitialWarehouseSize"), @@ -89,7 +89,7 @@ var TasksDef = g.NewInterface( CustomOperation( "Clone", "https://docs.snowflake.com/en/sql-reference/sql/create-task#variant-syntax", - g.QueryStruct("CloneTask"). + g.NewQueryStruct("CloneTask"). Create(). OrReplace(). SQL("TASK"). @@ -102,7 +102,7 @@ var TasksDef = g.NewInterface( ). AlterOperation( "https://docs.snowflake.com/en/sql-reference/sql/alter-task", - g.QueryStruct("AlterTask"). + g.NewQueryStruct("AlterTask"). Alter(). SQL("TASK"). IfExists(). @@ -113,7 +113,7 @@ var TasksDef = g.NewInterface( ListAssignment("ADD AFTER", "SchemaObjectIdentifier", g.ParameterOptions().NoEquals()). OptionalQueryStructField( "Set", - g.QueryStruct("TaskSet"). + g.NewQueryStruct("TaskSet"). OptionalIdentifier("Warehouse", g.KindOfT[AccountObjectIdentifier](), g.IdentifierOptions().Equals().SQL("WAREHOUSE")). OptionalAssignment("USER_TASK_MANAGED_INITIAL_WAREHOUSE_SIZE", "WarehouseSize", g.ParameterOptions().SingleQuotes()). OptionalTextAssignment("SCHEDULE", g.ParameterOptions().SingleQuotes()). @@ -130,7 +130,7 @@ var TasksDef = g.NewInterface( ). OptionalQueryStructField( "Unset", - g.QueryStruct("TaskUnset"). + g.NewQueryStruct("TaskUnset"). OptionalSQL("WAREHOUSE"). OptionalSQL("SCHEDULE"). OptionalSQL("CONFIG"). @@ -152,7 +152,7 @@ var TasksDef = g.NewInterface( ). DropOperation( "https://docs.snowflake.com/en/sql-reference/sql/drop-task", - g.QueryStruct("DropTask"). + g.NewQueryStruct("DropTask"). Drop(). SQL("TASK"). IfExists(). @@ -163,7 +163,7 @@ var TasksDef = g.NewInterface( "https://docs.snowflake.com/en/sql-reference/sql/show-tasks", taskDbRow, task, - g.QueryStruct("ShowTasks"). + g.NewQueryStruct("ShowTasks"). Show(). Terse(). SQL("TASKS"). @@ -179,7 +179,7 @@ var TasksDef = g.NewInterface( "https://docs.snowflake.com/en/sql-reference/sql/desc-task", taskDbRow, task, - g.QueryStruct("DescribeTask"). + g.NewQueryStruct("DescribeTask"). Describe(). SQL("TASK"). Name(). @@ -188,7 +188,7 @@ var TasksDef = g.NewInterface( CustomOperation( "Execute", "https://docs.snowflake.com/en/sql-reference/sql/execute-task", - g.QueryStruct("ExecuteTask"). + g.NewQueryStruct("ExecuteTask"). SQL("EXECUTE"). SQL("TASK"). Name(). From 7abb4dbb645f0eb0fcb1d39414b1ed0c322916c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Cie=C5=9Blak?= Date: Thu, 2 Nov 2023 17:02:18 +0100 Subject: [PATCH 15/20] feat: Migrate application role to new sdk (#2149) --- pkg/sdk/application_roles_def.go | 32 ++++++++ pkg/sdk/application_roles_dto_builders_gen.go | 29 +++++++ pkg/sdk/application_roles_dto_gen.go | 15 ++++ pkg/sdk/application_roles_gen.go | 41 ++++++++++ pkg/sdk/application_roles_gen_test.go | 34 ++++++++ pkg/sdk/application_roles_impl_gen.go | 54 ++++++++++++ pkg/sdk/application_roles_validations_gen.go | 16 ++++ pkg/sdk/client.go | 2 + pkg/sdk/poc/generator/keyword_builders.go | 15 ++++ pkg/sdk/poc/generator/operation.go | 10 +++ pkg/sdk/poc/generator/query_struct.go | 4 +- pkg/sdk/poc/generator/static_builders.go | 8 ++ pkg/sdk/poc/main.go | 11 +-- .../application_roles_gen_integration_test.go | 82 +++++++++++++++++++ pkg/sdk/testint/helpers_test.go | 42 ++++++++++ pkg/sdk/testint/testdata/manifest.yml | 5 ++ pkg/sdk/testint/testdata/setup.sql | 2 + 17 files changed, 396 insertions(+), 6 deletions(-) create mode 100644 pkg/sdk/application_roles_def.go create mode 100644 pkg/sdk/application_roles_dto_builders_gen.go create mode 100644 pkg/sdk/application_roles_dto_gen.go create mode 100644 pkg/sdk/application_roles_gen.go create mode 100644 pkg/sdk/application_roles_gen_test.go create mode 100644 pkg/sdk/application_roles_impl_gen.go create mode 100644 pkg/sdk/application_roles_validations_gen.go create mode 100644 pkg/sdk/testint/application_roles_gen_integration_test.go create mode 100644 pkg/sdk/testint/testdata/manifest.yml create mode 100644 pkg/sdk/testint/testdata/setup.sql diff --git a/pkg/sdk/application_roles_def.go b/pkg/sdk/application_roles_def.go new file mode 100644 index 0000000000..326ef3ab42 --- /dev/null +++ b/pkg/sdk/application_roles_def.go @@ -0,0 +1,32 @@ +package sdk + +import g "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/poc/generator" + +//go:generate go run ./poc/main.go + +var ApplicationRolesDef = g.NewInterface( + "ApplicationRoles", + "ApplicationRole", + g.KindOfT[DatabaseObjectIdentifier](), +). + ShowOperation( + "https://docs.snowflake.com/en/sql-reference/sql/show-application-roles", + g.DbStruct("applicationRoleDbRow"). + Field("created_on", "time.Time"). + Field("name", "string"). + Field("owner", "string"). + Field("comment", "string"). + Field("owner_role_type", "string"), + g.PlainStruct("ApplicationRole"). + Field("CreatedOn", "time.Time"). + Field("Name", "string"). + Field("Owner", "string"). + Field("Comment", "string"). + Field("OwnerRoleType", "string"), + g.NewQueryStruct("ShowApplicationRoles"). + Show(). + SQL("APPLICATION ROLES IN APPLICATION"). + Identifier("ApplicationName", g.KindOfT[AccountObjectIdentifier](), g.IdentifierOptions()). + OptionalLimitFrom(). + WithValidation(g.ValidIdentifier, "ApplicationName"), + ) diff --git a/pkg/sdk/application_roles_dto_builders_gen.go b/pkg/sdk/application_roles_dto_builders_gen.go new file mode 100644 index 0000000000..0b3b1411f3 --- /dev/null +++ b/pkg/sdk/application_roles_dto_builders_gen.go @@ -0,0 +1,29 @@ +// Code generated by dto builder generator; DO NOT EDIT. + +package sdk + +import () + +func NewShowApplicationRoleRequest() *ShowApplicationRoleRequest { + return &ShowApplicationRoleRequest{} +} + +func (s *ShowApplicationRoleRequest) WithApplicationName(ApplicationName AccountObjectIdentifier) *ShowApplicationRoleRequest { + s.ApplicationName = ApplicationName + return s +} + +func (s *ShowApplicationRoleRequest) WithLimit(Limit *LimitFrom) *ShowApplicationRoleRequest { + s.Limit = Limit + return s +} + +func NewShowByIDApplicationRoleRequest( + name DatabaseObjectIdentifier, + ApplicationName AccountObjectIdentifier, +) *ShowByIDApplicationRoleRequest { + s := ShowByIDApplicationRoleRequest{} + s.name = name + s.ApplicationName = ApplicationName + return &s +} diff --git a/pkg/sdk/application_roles_dto_gen.go b/pkg/sdk/application_roles_dto_gen.go new file mode 100644 index 0000000000..c5345d9a68 --- /dev/null +++ b/pkg/sdk/application_roles_dto_gen.go @@ -0,0 +1,15 @@ +package sdk + +//go:generate go run ./dto-builder-generator/main.go + +var _ optionsProvider[ShowApplicationRoleOptions] = new(ShowApplicationRoleRequest) + +type ShowApplicationRoleRequest struct { + ApplicationName AccountObjectIdentifier + Limit *LimitFrom +} + +type ShowByIDApplicationRoleRequest struct { + name DatabaseObjectIdentifier // required + ApplicationName AccountObjectIdentifier // required +} diff --git a/pkg/sdk/application_roles_gen.go b/pkg/sdk/application_roles_gen.go new file mode 100644 index 0000000000..148498054e --- /dev/null +++ b/pkg/sdk/application_roles_gen.go @@ -0,0 +1,41 @@ +package sdk + +import ( + "context" + "time" +) + +// ApplicationRoles is an interface that allows for querying application roles. +// It does not allow for other DDL queries (CREATE, ALTER, DROP, ...) to be called, because they are not possible +// to be called from the program level. Application roles are a special case where they're only usable +// inside application context (e.g. setup.sql). Right now, they can be only manipulated from the program context +// by applying debug_mode parameter to the application, but it's a hacky solution and even with that you're limited with GRANT and REVOKE options. +// That's why we're only exposing SHOW operations, because only they are the only allowed operations to be called from the program context. +type ApplicationRoles interface { + Show(ctx context.Context, request *ShowApplicationRoleRequest) ([]ApplicationRole, error) + ShowByID(ctx context.Context, request *ShowByIDApplicationRoleRequest) (*ApplicationRole, error) +} + +// ShowApplicationRoleOptions is based on https://docs.snowflake.com/en/sql-reference/sql/show-application-roles. +type ShowApplicationRoleOptions struct { + show bool `ddl:"static" sql:"SHOW"` + applicationRolesInApplication bool `ddl:"static" sql:"APPLICATION ROLES IN APPLICATION"` + ApplicationName AccountObjectIdentifier `ddl:"identifier"` + Limit *LimitFrom `ddl:"keyword" sql:"LIMIT"` +} + +type applicationRoleDbRow struct { + CreatedOn time.Time `db:"created_on"` + Name string `db:"name"` + Owner string `db:"owner"` + Comment string `db:"comment"` + OwnerRoleType string `db:"owner_role_type"` +} + +type ApplicationRole struct { + CreatedOn time.Time + Name string + Owner string + Comment string + OwnerRoleType string +} diff --git a/pkg/sdk/application_roles_gen_test.go b/pkg/sdk/application_roles_gen_test.go new file mode 100644 index 0000000000..9134d085b1 --- /dev/null +++ b/pkg/sdk/application_roles_gen_test.go @@ -0,0 +1,34 @@ +package sdk + +import "testing" + +func TestApplicationRoles_Show(t *testing.T) { + appId := RandomAccountObjectIdentifier() + + // Minimal valid ShowApplicationRoleOptions + defaultOpts := func() *ShowApplicationRoleOptions { + return &ShowApplicationRoleOptions{ + ApplicationName: appId, + } + } + + t.Run("validation: nil options", func(t *testing.T) { + var opts *ShowApplicationRoleOptions = nil + assertOptsInvalidJoinedErrors(t, opts, ErrNilOptions) + }) + + t.Run("validation: valid identifier for [opts.ApplicationName]", func(t *testing.T) { + opts := defaultOpts() + opts.ApplicationName = NewAccountObjectIdentifier("") + assertOptsInvalid(t, opts, errInvalidIdentifier("ShowApplicationRoleOptions", "ApplicationName")) + }) + + t.Run("all options", func(t *testing.T) { + opts := defaultOpts() + opts.Limit = &LimitFrom{ + Rows: Int(123), + From: String("some limit"), + } + assertOptsValidAndSQLEquals(t, opts, `SHOW APPLICATION ROLES IN APPLICATION %s LIMIT 123 FROM 'some limit'`, appId.FullyQualifiedName()) + }) +} diff --git a/pkg/sdk/application_roles_impl_gen.go b/pkg/sdk/application_roles_impl_gen.go new file mode 100644 index 0000000000..81fc6bc062 --- /dev/null +++ b/pkg/sdk/application_roles_impl_gen.go @@ -0,0 +1,54 @@ +package sdk + +import ( + "context" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/internal/collections" +) + +var _ ApplicationRoles = (*applicationRoles)(nil) + +type applicationRoles struct { + client *Client +} + +func (v *applicationRoles) Show(ctx context.Context, request *ShowApplicationRoleRequest) ([]ApplicationRole, error) { + opts := request.toOpts() + dbRows, err := validateAndQuery[applicationRoleDbRow](v.client, ctx, opts) + if err != nil { + return nil, err + } + resultList := convertRows[applicationRoleDbRow, ApplicationRole](dbRows) + return resultList, nil +} + +func (v *applicationRoles) ShowByID(ctx context.Context, request *ShowByIDApplicationRoleRequest) (*ApplicationRole, error) { + appRoles, err := v.client.ApplicationRoles.Show(ctx, NewShowApplicationRoleRequest().WithApplicationName(request.ApplicationName)) + if err != nil { + return nil, err + } + return collections.FindOne(appRoles, func(role ApplicationRole) bool { return role.Name == request.name.Name() }) +} + +func (r *ShowApplicationRoleRequest) toOpts() *ShowApplicationRoleOptions { + opts := &ShowApplicationRoleOptions{ + ApplicationName: r.ApplicationName, + } + if r.Limit != nil { + opts.Limit = &LimitFrom{ + Rows: r.Limit.Rows, + From: r.Limit.From, + } + } + return opts +} + +func (r applicationRoleDbRow) convert() *ApplicationRole { + return &ApplicationRole{ + CreatedOn: r.CreatedOn, + Name: r.Name, + Owner: r.Owner, + Comment: r.Comment, + OwnerRoleType: r.OwnerRoleType, + } +} diff --git a/pkg/sdk/application_roles_validations_gen.go b/pkg/sdk/application_roles_validations_gen.go new file mode 100644 index 0000000000..8f0be97d6a --- /dev/null +++ b/pkg/sdk/application_roles_validations_gen.go @@ -0,0 +1,16 @@ +package sdk + +import "errors" + +var _ validatable = new(ShowApplicationRoleOptions) + +func (opts *ShowApplicationRoleOptions) validate() error { + if opts == nil { + return errors.Join(ErrNilOptions) + } + var errs []error + if !ValidObjectIdentifier(opts.ApplicationName) { + errs = append(errs, errInvalidIdentifier("ShowApplicationRoleOptions", "ApplicationName")) + } + return errors.Join(errs...) +} diff --git a/pkg/sdk/client.go b/pkg/sdk/client.go index 1dc78da8cf..0e9db6dffc 100644 --- a/pkg/sdk/client.go +++ b/pkg/sdk/client.go @@ -28,6 +28,7 @@ type Client struct { // DDL Commands Accounts Accounts Alerts Alerts + ApplicationRoles ApplicationRoles Comments Comments DatabaseRoles DatabaseRoles Databases Databases @@ -141,6 +142,7 @@ func NewClientFromDB(db *sql.DB) *Client { func (c *Client) initialize() { c.Accounts = &accounts{client: c} c.Alerts = &alerts{client: c} + c.ApplicationRoles = &applicationRoles{client: c} c.Comments = &comments{client: c} c.ContextFunctions = &contextFunctions{client: c} c.ConversionFunctions = &conversionFunctions{client: c} diff --git a/pkg/sdk/poc/generator/keyword_builders.go b/pkg/sdk/poc/generator/keyword_builders.go index 5c37a65b24..4bcbb83d54 100644 --- a/pkg/sdk/poc/generator/keyword_builders.go +++ b/pkg/sdk/poc/generator/keyword_builders.go @@ -26,11 +26,26 @@ func (v *QueryStruct) Text(name string, transformer *KeywordTransformer) *QueryS return v } +func (v *QueryStruct) Number(name string, transformer *KeywordTransformer) *QueryStruct { + v.fields = append(v.fields, NewField(name, "int", Tags().Keyword(), transformer)) + return v +} + func (v *QueryStruct) OptionalText(name string, transformer *KeywordTransformer) *QueryStruct { v.fields = append(v.fields, NewField(name, "*string", Tags().Keyword(), transformer)) return v } +func (v *QueryStruct) OptionalNumber(name string, transformer *KeywordTransformer) *QueryStruct { + v.fields = append(v.fields, NewField(name, "*int", Tags().Keyword(), transformer)) + return v +} + +func (v *QueryStruct) OptionalLimitFrom() *QueryStruct { + v.fields = append(v.fields, NewField("Limit", "*LimitFrom", Tags().Keyword().SQL("LIMIT"), nil)) + return v +} + // SessionParameters *SessionParameters `ddl:"list,no_parentheses"` func (v *QueryStruct) SessionParameters() *QueryStruct { v.fields = append(v.fields, NewField("SessionParameters", "*SessionParameters", Tags().List().NoParentheses(), nil).withValidations(NewValidation(ValidateValue, "SessionParameters"))) diff --git a/pkg/sdk/poc/generator/operation.go b/pkg/sdk/poc/generator/operation.go index 1ec861dfcd..c62f99712a 100644 --- a/pkg/sdk/poc/generator/operation.go +++ b/pkg/sdk/poc/generator/operation.go @@ -9,6 +9,8 @@ const ( OperationKindShow OperationKind = "Show" OperationKindShowByID OperationKind = "ShowByID" OperationKindDescribe OperationKind = "Describe" + OperationKindGrant OperationKind = "Grant" + OperationKindRevoke OperationKind = "Revoke" ) type DescriptionMappingKind string @@ -145,6 +147,14 @@ func (i *Interface) DropOperation(doc string, queryStruct *QueryStruct) *Interfa return i.newSimpleOperation(string(OperationKindDrop), doc, queryStruct) } +func (i *Interface) GrantOperation(doc string, queryStruct *QueryStruct) *Interface { + return i.newSimpleOperation(string(OperationKindGrant), doc, queryStruct) +} + +func (i *Interface) RevokeOperation(doc string, queryStruct *QueryStruct) *Interface { + return i.newSimpleOperation(string(OperationKindRevoke), doc, queryStruct) +} + func (i *Interface) ShowOperation(doc string, dbRepresentation *dbStruct, resourceRepresentation *plainStruct, queryStruct *QueryStruct) *Interface { i.newOperationWithDBMapping(string(OperationKindShow), doc, dbRepresentation, resourceRepresentation, queryStruct, addShowMapping) return i diff --git a/pkg/sdk/poc/generator/query_struct.go b/pkg/sdk/poc/generator/query_struct.go index 47dcf88487..afe933db41 100644 --- a/pkg/sdk/poc/generator/query_struct.go +++ b/pkg/sdk/poc/generator/query_struct.go @@ -48,7 +48,9 @@ func (v *QueryStruct) queryStructField(name string, queryStruct *QueryStruct, ki qs := queryStruct.IntoField() qs.Name = name qs.Kind = kindPrefix + qs.Kind - qs = transformer.Transform(qs) + if transformer != nil { + qs = transformer.Transform(qs) + } v.fields = append(v.fields, qs) return v } diff --git a/pkg/sdk/poc/generator/static_builders.go b/pkg/sdk/poc/generator/static_builders.go index a0635f01db..2f025e56bb 100644 --- a/pkg/sdk/poc/generator/static_builders.go +++ b/pkg/sdk/poc/generator/static_builders.go @@ -24,3 +24,11 @@ func (v *QueryStruct) Show() *QueryStruct { func (v *QueryStruct) Describe() *QueryStruct { return v.SQL("DESCRIBE") } + +func (v *QueryStruct) Grant() *QueryStruct { + return v.SQL("GRANT") +} + +func (v *QueryStruct) Revoke() *QueryStruct { + return v.SQL("REVOKE") +} diff --git a/pkg/sdk/poc/main.go b/pkg/sdk/poc/main.go index a403ae1107..a634409e7e 100644 --- a/pkg/sdk/poc/main.go +++ b/pkg/sdk/poc/main.go @@ -16,11 +16,12 @@ import ( ) var definitionMapping = map[string]*generator.Interface{ - "database_role_def.go": example.DatabaseRole, - "network_policies_def.go": sdk.NetworkPoliciesDef, - "session_policies_def.go": sdk.SessionPoliciesDef, - "tasks_def.go": sdk.TasksDef, - "streams_def.go": sdk.StreamsDef, + "database_role_def.go": example.DatabaseRole, + "network_policies_def.go": sdk.NetworkPoliciesDef, + "session_policies_def.go": sdk.SessionPoliciesDef, + "tasks_def.go": sdk.TasksDef, + "streams_def.go": sdk.StreamsDef, + "application_roles_def.go": sdk.ApplicationRolesDef, } func main() { diff --git a/pkg/sdk/testint/application_roles_gen_integration_test.go b/pkg/sdk/testint/application_roles_gen_integration_test.go new file mode 100644 index 0000000000..e87f0f209d --- /dev/null +++ b/pkg/sdk/testint/application_roles_gen_integration_test.go @@ -0,0 +1,82 @@ +package testint + +import ( + "context" + "testing" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/internal/collections" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestInt_ApplicationRoles setup is a little bit different from usual integration test, because of how native apps work. +// I will try to explain it in a short form, but check out this article for more detailed description (https://docs.snowflake.com/en/developer-guide/native-apps/tutorials/getting-started-tutorial#introduction) +// - create stage - it is where we will be keeping our application files +// - put native app specific stuff onto our stage (manifest.yml and setup.sql) +// - create an application package and a new version of our application +// - create an application with the application package and the version we just created +// - while creating the application, the setup.sql script will be run in our application context (and that is where application roles for our tests are created) +// - we're ready to query application roles we have just created +func TestInt_ApplicationRoles(t *testing.T) { + client := testClient(t) + + stageName := "stage_name" + stage, cleanupStage := createStage(t, client, testDb(t), testSchema(t), stageName) + t.Cleanup(cleanupStage) + + putOnStage(t, client, stage, "manifest.yml") + putOnStage(t, client, stage, "setup.sql") + + appPackageName := "snowflake_app_pkg" + versionName := "v1" + cleanupAppPackage := createApplicationPackage(t, client, appPackageName) + t.Cleanup(cleanupAppPackage) + addApplicationPackageVersion(t, client, stage, appPackageName, versionName) + + appName := "snowflake_app" + cleanupApp := createApplication(t, client, appName, appPackageName, versionName) + t.Cleanup(cleanupApp) + + assertApplicationRole := func(t *testing.T, appRole *sdk.ApplicationRole, name string, comment string) { + t.Helper() + assert.Equal(t, name, appRole.Name) + assert.Equal(t, appName, appRole.Owner) + assert.Equal(t, comment, appRole.Comment) + assert.Equal(t, "APPLICATION", appRole.OwnerRoleType) + } + + assertApplicationRoles := func(t *testing.T, appRoles []sdk.ApplicationRole, name string, comment string) { + t.Helper() + appRole, err := collections.FindOne(appRoles, func(role sdk.ApplicationRole) bool { + return role.Name == name + }) + require.NoError(t, err) + assertApplicationRole(t, appRole, name, comment) + } + + t.Run("Show by id", func(t *testing.T) { + name := "app_role_1" + id := sdk.NewDatabaseObjectIdentifier(appName, name) + ctx := context.Background() + + appRole, err := client.ApplicationRoles.ShowByID(ctx, sdk.NewShowByIDApplicationRoleRequest(id, sdk.NewAccountObjectIdentifier(appName))) + require.NoError(t, err) + + assertApplicationRole(t, appRole, name, "some comment") + }) + + t.Run("Show", func(t *testing.T) { + ctx := context.Background() + req := sdk.NewShowApplicationRoleRequest(). + WithApplicationName(sdk.NewAccountObjectIdentifier(appName)). + WithLimit(&sdk.LimitFrom{ + Rows: sdk.Int(2), + }) + appRoles, err := client.ApplicationRoles.Show(ctx, req) + require.NoError(t, err) + + assertApplicationRoles(t, appRoles, "app_role_1", "some comment") + assertApplicationRoles(t, appRoles, "app_role_2", "some comment2") + }) +} diff --git a/pkg/sdk/testint/helpers_test.go b/pkg/sdk/testint/helpers_test.go index 97bd3fbb44..9211c7c7aa 100644 --- a/pkg/sdk/testint/helpers_test.go +++ b/pkg/sdk/testint/helpers_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "path/filepath" "testing" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" @@ -672,3 +673,44 @@ func createView(t *testing.T, client *sdk.Client, viewId sdk.SchemaObjectIdentif require.NoError(t, err) } } + +func putOnStage(t *testing.T, client *sdk.Client, stage *sdk.Stage, filename string) { + t.Helper() + ctx := context.Background() + + path, err := filepath.Abs("./testdata/" + filename) + require.NoError(t, err) + absPath := "file://" + path + + _, err = client.ExecForTests(ctx, fmt.Sprintf(`PUT '%s' @%s AUTO_COMPRESS = FALSE`, absPath, stage.ID().FullyQualifiedName())) + require.NoError(t, err) +} + +func createApplicationPackage(t *testing.T, client *sdk.Client, name string) func() { + t.Helper() + ctx := context.Background() + _, err := client.ExecForTests(ctx, fmt.Sprintf(`CREATE APPLICATION PACKAGE "%s"`, name)) + require.NoError(t, err) + return func() { + _, err := client.ExecForTests(ctx, fmt.Sprintf(`DROP APPLICATION PACKAGE "%s"`, name)) + require.NoError(t, err) + } +} + +func addApplicationPackageVersion(t *testing.T, client *sdk.Client, stage *sdk.Stage, appPackageName string, versionName string) { + t.Helper() + ctx := context.Background() + _, err := client.ExecForTests(ctx, fmt.Sprintf(`ALTER APPLICATION PACKAGE "%s" ADD VERSION %v USING '@%s'`, appPackageName, versionName, stage.ID().FullyQualifiedName())) + require.NoError(t, err) +} + +func createApplication(t *testing.T, client *sdk.Client, name string, packageName string, version string) func() { + t.Helper() + ctx := context.Background() + _, err := client.ExecForTests(ctx, fmt.Sprintf(`CREATE APPLICATION "%s" FROM APPLICATION PACKAGE "%s" USING VERSION %s`, name, packageName, version)) + require.NoError(t, err) + return func() { + _, err := client.ExecForTests(ctx, fmt.Sprintf(`DROP APPLICATION "%s"`, name)) + require.NoError(t, err) + } +} diff --git a/pkg/sdk/testint/testdata/manifest.yml b/pkg/sdk/testint/testdata/manifest.yml new file mode 100644 index 0000000000..83fb797263 --- /dev/null +++ b/pkg/sdk/testint/testdata/manifest.yml @@ -0,0 +1,5 @@ +manifest_version: 1 # required +version: + name: application_roles_test_app + label: "v1.0" + comment: "This application is used by Snowflake Terraform Provider for application role integration tests" diff --git a/pkg/sdk/testint/testdata/setup.sql b/pkg/sdk/testint/testdata/setup.sql new file mode 100644 index 0000000000..1c951878b2 --- /dev/null +++ b/pkg/sdk/testint/testdata/setup.sql @@ -0,0 +1,2 @@ +create application role "app_role_1" comment = 'some comment'; +create application role "app_role_2" comment = 'some comment2'; From ee0f6af54dbd269f8bfa2c3d73a396d98d10a6ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Cie=C5=9Blak?= Date: Fri, 3 Nov 2023 13:12:41 +0100 Subject: [PATCH 16/20] chore: Add migration guide (#2142) --- MIGRATION_GUIDE.md | 90 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 MIGRATION_GUIDE.md diff --git a/MIGRATION_GUIDE.md b/MIGRATION_GUIDE.md new file mode 100644 index 0000000000..690f3269f0 --- /dev/null +++ b/MIGRATION_GUIDE.md @@ -0,0 +1,90 @@ +# Migration guide + +This document is meant to help you migrate your Terraform config to new newest version. In migration guides we will only +describe deprecations or breaking changes and help you to change your configuration to keep the same (or similar) behaviour +across different versions. + +## v0.73.0 ➞ v0.74.0 +### Provider configuration changes + +In this change we have done a provider refactor to make it more complete and customizable by supporting more options that +were already available in Golang Snowflake driver. This lead to several attributes being added and a few deprecated. +We will focus on the deprecated ones and show you how to adapt your current configuration to the new changes. + +#### *(rename)* username ➞ user + +```terraform +provider "snowflake" { + # before + username = "username" + + # after + user = "username" +} +``` + +#### *(structural change)* OAuth API + +```terraform +provider "snowflake" { + # before + browser_auth = false + oauth_access_token = "" + oauth_refresh_token = "" + oauth_client_id = "" + oauth_client_secret = "" + oauth_endpoint = "" + oauth_redirect_url = "" + + # after + authenticator = "ExternalBrowser" + token = "" + token_accessor { + refresh_token = "" + client_id = "" + client_secret = "" + token_endpoint = "" + redirect_uri = "" + } +} +``` + +#### *(remove redundant information)* region + +Specifying a region is a legacy thing and according to https://docs.snowflake.com/en/user-guide/admin-account-identifier +you can specify a region as a part of account parameter. Specifying account parameter with the region is also considered legacy, +but with this approach it will be easier to convert only your account identifier to the new preferred way of specifying account identifier. + +```terraform +provider "snowflake" { + # before + region = "" + + # after + account = "." +} +``` + +#### *(todo)* private key path + +```terraform +provider "snowflake" { + # before + private_key_path = "" + + # after + private_key = file("") +} +``` + +#### *(rename)* session_params ➞ params + +```terraform +provider "snowflake" { + # before + session_params = {} + + # after + params = {} +} +``` From 01a774c564e4a74dea41385ddf775b7f9e9d1f7b Mon Sep 17 00:00:00 2001 From: Scott Winkler Date: Fri, 3 Nov 2023 10:36:46 -0700 Subject: [PATCH 17/20] fix private key auth (#2170) --- pkg/provider/provider.go | 6 +++++- pkg/provider/provider_helpers.go | 3 +++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/pkg/provider/provider.go b/pkg/provider/provider.go index 226ff9c5d2..3453616873 100644 --- a/pkg/provider/provider.go +++ b/pkg/provider/provider.go @@ -707,7 +707,11 @@ func ConfigureProvider(s *schema.ResourceData) (interface{}, error) { privateKeyPath := s.Get("private_key_path").(string) privateKey := s.Get("private_key").(string) privateKeyPassphrase := s.Get("private_key_passphrase").(string) - if v, err := getPrivateKey(privateKeyPath, privateKey, privateKeyPassphrase); err != nil && v != nil { + v, err := getPrivateKey(privateKeyPath, privateKey, privateKeyPassphrase) + if err != nil { + return nil, fmt.Errorf("could not retrieve private key: %w", err) + } + if v != nil { config.PrivateKey = v } diff --git a/pkg/provider/provider_helpers.go b/pkg/provider/provider_helpers.go index 480aa4c786..a7a18c6485 100644 --- a/pkg/provider/provider_helpers.go +++ b/pkg/provider/provider_helpers.go @@ -30,6 +30,9 @@ func mergeSchemas(schemaCollections ...map[string]*schema.Resource) map[string]* } func getPrivateKey(privateKeyPath, privateKeyString, privateKeyPassphrase string) (*rsa.PrivateKey, error) { + if privateKeyPath == "" && privateKeyString == "" { + return nil, nil + } privateKeyBytes := []byte(privateKeyString) var err error if len(privateKeyBytes) == 0 && privateKeyPath != "" { From 6f026f64e6e24638df2b9d4110362836a9071011 Mon Sep 17 00:00:00 2001 From: kenkoooo Date: Thu, 9 Nov 2023 23:15:55 +0900 Subject: [PATCH 18/20] feat: Add "CREATE DYNAMIC TABLE" to schema_grant (#2144) * feat: Add "CREATE DYNAMIC TABLE" to schema_grant * feat: Add CREATE DYNAMIC TABLE to privileges --- pkg/resources/privileges.go | 1 + pkg/resources/schema_grant.go | 1 + pkg/sdk/privileges.go | 1 + 3 files changed, 3 insertions(+) diff --git a/pkg/resources/privileges.go b/pkg/resources/privileges.go index 8a989b0a77..46adbfafe5 100644 --- a/pkg/resources/privileges.go +++ b/pkg/resources/privileges.go @@ -23,6 +23,7 @@ const ( privilegeCreateDatabase Privilege = "CREATE DATABASE" privilegeCreateDatabaseRole Privilege = "CREATE DATABASE ROLE" privilegeCreateDataExchangeListing Privilege = "CREATE DATA EXCHANGE LISTING" + privilegeCreateDynamicTable Privilege = "CREATE DYNAMIC TABLE" privilegeCreateExternalTable Privilege = "CREATE EXTERNAL TABLE" privilegeCreateFailoverGroup Privilege = "CREATE FAILOVER GROUP" privilegeCreateFileFormat Privilege = "CREATE FILE FORMAT" diff --git a/pkg/resources/schema_grant.go b/pkg/resources/schema_grant.go index f12960b352..4929d4d2f8 100644 --- a/pkg/resources/schema_grant.go +++ b/pkg/resources/schema_grant.go @@ -16,6 +16,7 @@ import ( var validSchemaPrivileges = NewPrivilegeSet( privilegeAddSearchOptimization, + privilegeCreateDynamicTable, privilegeCreateExternalTable, privilegeCreateFileFormat, privilegeCreateFunction, diff --git a/pkg/sdk/privileges.go b/pkg/sdk/privileges.go index 0d7a13d876..d4eb5958d9 100644 --- a/pkg/sdk/privileges.go +++ b/pkg/sdk/privileges.go @@ -127,6 +127,7 @@ const ( */ SchemaPrivilegeAddSearchOptimization SchemaPrivilege = "ADD SEARCH OPTIMIZATION" SchemaPrivilegeCreateAlert SchemaPrivilege = "CREATE ALERT" + SchemaPrivilegeCreateDynamicTable SchemaPrivilege = "CREATE DYNAMIC TABLE" SchemaPrivilegeCreateExternalTable SchemaPrivilege = "CREATE EXTERNAL TABLE" SchemaPrivilegeCreateFileFormat SchemaPrivilege = "CREATE FILE FORMAT" SchemaPrivilegeCreateFunction SchemaPrivilege = "CREATE FUNCTION" From ed079d3d06dc3af083da04ca18314c8e7b07308e Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Mon, 13 Nov 2023 09:38:24 +0100 Subject: [PATCH 19/20] feat: Add views to the SDK (#2171) * Prepare view definition (WIP) * Add column list and column masking policies (WIP) * Finish create definition (WIP) * Add test for multiple session params in task * Add alter action to views definition * Add describe operation to views definition * Define view and view row * Add generated files (without changes) * Make compile after generation * Pass create view unit tests (WIP) * Test full view create * Pass show, describe and drop tests; add alter branches tests (WIP) * Pass alter tests for view * Pass first integration test * Add more complicated create integration test * Add drop tests * Add show tests * Add describe test * Add second describe test * Add rename test * Add alter tests (WIP) * Add set and unset masking policy on column * Add set and unset tag on column * Add adding and dropping row access policies * Add row access policy to creation * Add comment * Change On to required (WIP) * Fix tests (WIP) * Add validation tests * Fix linter complaints * Adjust struct definitions to current types * Add convenience methods to db and plain structs * Rename list method and fix tags * Regenerate Options * Introduce optional variants for set and unset tags * Refactor slightly set and unset tags * Regenerate * Revert list to keyword * Add jira issue to TODO comment --- pkg/sdk/client.go | 2 + pkg/sdk/integration_test_imports.go | 7 + pkg/sdk/object_types.go | 1 + pkg/sdk/poc/generator/db_struct.go | 16 + pkg/sdk/poc/generator/field_transformers.go | 13 +- pkg/sdk/poc/generator/keyword_builders.go | 33 +- pkg/sdk/poc/generator/plain_struct.go | 16 + pkg/sdk/poc/main.go | 1 + pkg/sdk/session_policies_def.go | 4 +- pkg/sdk/streams_def.go | 4 +- pkg/sdk/tasks_def.go | 4 +- pkg/sdk/tasks_gen_test.go | 5 +- pkg/sdk/testint/helpers_test.go | 57 ++ pkg/sdk/testint/views_gen_integration_test.go | 496 ++++++++++++++++++ pkg/sdk/views_def.go | 215 ++++++++ pkg/sdk/views_dto_builders_gen.go | 332 ++++++++++++ pkg/sdk/views_dto_gen.go | 121 +++++ pkg/sdk/views_gen.go | 217 ++++++++ pkg/sdk/views_gen_test.go | 441 ++++++++++++++++ pkg/sdk/views_impl_gen.go | 249 +++++++++ pkg/sdk/views_validations_gen.go | 103 ++++ 21 files changed, 2322 insertions(+), 15 deletions(-) create mode 100644 pkg/sdk/testint/views_gen_integration_test.go create mode 100644 pkg/sdk/views_def.go create mode 100644 pkg/sdk/views_dto_builders_gen.go create mode 100644 pkg/sdk/views_dto_gen.go create mode 100644 pkg/sdk/views_gen.go create mode 100644 pkg/sdk/views_gen_test.go create mode 100644 pkg/sdk/views_impl_gen.go create mode 100644 pkg/sdk/views_validations_gen.go diff --git a/pkg/sdk/client.go b/pkg/sdk/client.go index 0e9db6dffc..eb056c0ec3 100644 --- a/pkg/sdk/client.go +++ b/pkg/sdk/client.go @@ -52,6 +52,7 @@ type Client struct { Tags Tags Tasks Tasks Users Users + Views Views Warehouses Warehouses } @@ -170,6 +171,7 @@ func (c *Client) initialize() { c.Tags = &tags{client: c} c.Tasks = &tasks{client: c} c.Users = &users{client: c} + c.Views = &views{client: c} c.Warehouses = &warehouses{client: c} } diff --git a/pkg/sdk/integration_test_imports.go b/pkg/sdk/integration_test_imports.go index 86f7dfc715..b109ec8aa6 100644 --- a/pkg/sdk/integration_test_imports.go +++ b/pkg/sdk/integration_test_imports.go @@ -20,6 +20,13 @@ func (c *Client) ExecForTests(ctx context.Context, sql string) (sql.Result, erro return result, decodeDriverError(err) } +// QueryOneForTests is an exact copy of queryOne (that is unexported), that some integration tests/helpers were using +// TODO: remove after introducing all resources using this +func (c *Client) QueryOneForTests(ctx context.Context, dest interface{}, sql string) error { + ctx = context.WithValue(ctx, snowflakeAccountLocatorContextKey, c.accountLocator) + return decodeDriverError(c.db.GetContext(ctx, dest, sql)) +} + func ErrorsEqual(t *testing.T, expected error, actual error) { t.Helper() var expectedErr *Error diff --git a/pkg/sdk/object_types.go b/pkg/sdk/object_types.go index 3f866bccbe..44363d0afb 100644 --- a/pkg/sdk/object_types.go +++ b/pkg/sdk/object_types.go @@ -58,6 +58,7 @@ const ( ObjectTypeApplicationPackage ObjectType = "APPLICATION PACKAGE" ObjectTypeApplicationRole ObjectType = "APPLICATION ROLE" ObjectTypeStreamlit ObjectType = "STREAMLIT" + ObjectTypeColumn ObjectType = "COLUMN" ) func (o ObjectType) String() string { diff --git a/pkg/sdk/poc/generator/db_struct.go b/pkg/sdk/poc/generator/db_struct.go index d42cf83d47..03490dc4ae 100644 --- a/pkg/sdk/poc/generator/db_struct.go +++ b/pkg/sdk/poc/generator/db_struct.go @@ -25,6 +25,22 @@ func (v *dbStruct) Field(dbName string, kind string) *dbStruct { return v } +func (v *dbStruct) Text(dbName string) *dbStruct { + return v.Field(dbName, "string") +} + +func (v *dbStruct) OptionalText(dbName string) *dbStruct { + return v.Field(dbName, "sql.NullString") +} + +func (v *dbStruct) Bool(dbName string) *dbStruct { + return v.Field(dbName, "bool") +} + +func (v *dbStruct) OptionalBool(dbName string) *dbStruct { + return v.Field(dbName, "sql.NullBool") +} + func (v *dbStruct) IntoField() *Field { f := NewField(v.name, v.name, nil, nil) for _, field := range v.fields { diff --git a/pkg/sdk/poc/generator/field_transformers.go b/pkg/sdk/poc/generator/field_transformers.go index 9e404814c2..cdf0299c95 100644 --- a/pkg/sdk/poc/generator/field_transformers.go +++ b/pkg/sdk/poc/generator/field_transformers.go @@ -7,9 +7,10 @@ type FieldTransformer interface { } type KeywordTransformer struct { - required bool - sqlPrefix string - quotes string + required bool + sqlPrefix string + quotes string + parentheses string } func KeywordOptions() *KeywordTransformer { @@ -41,6 +42,11 @@ func (v *KeywordTransformer) DoubleQuotes() *KeywordTransformer { return v } +func (v *KeywordTransformer) Parentheses() *KeywordTransformer { + v.parentheses = "parentheses" + return v +} + func (v *KeywordTransformer) Transform(f *Field) *Field { addTagIfMissing(f.Tags, "ddl", "keyword") if v.required { @@ -48,6 +54,7 @@ func (v *KeywordTransformer) Transform(f *Field) *Field { } addTagIfMissing(f.Tags, "sql", v.sqlPrefix) addTagIfMissing(f.Tags, "ddl", v.quotes) + addTagIfMissing(f.Tags, "ddl", v.parentheses) return f } diff --git a/pkg/sdk/poc/generator/keyword_builders.go b/pkg/sdk/poc/generator/keyword_builders.go index 4bcbb83d54..07b6f02882 100644 --- a/pkg/sdk/poc/generator/keyword_builders.go +++ b/pkg/sdk/poc/generator/keyword_builders.go @@ -62,18 +62,43 @@ func (v *QueryStruct) OptionalSessionParametersUnset() *QueryStruct { return v } -func (v *QueryStruct) WithTags() *QueryStruct { - v.fields = append(v.fields, NewField("Tag", "[]TagAssociation", Tags().Keyword().Parentheses().SQL("TAG"), nil)) +func (v *QueryStruct) NamedListWithParens(sqlPrefix string, listItemKind string, transformer *KeywordTransformer) *QueryStruct { + if transformer != nil { + transformer = transformer.Parentheses().SQL(sqlPrefix) + } else { + transformer = KeywordOptions().Parentheses().SQL(sqlPrefix) + } + v.fields = append(v.fields, NewField(sqlToFieldName(sqlPrefix, true), KindOfSlice(listItemKind), Tags().Keyword(), transformer)) return v } +func (v *QueryStruct) WithTags() *QueryStruct { + return v.NamedListWithParens("TAG", "TagAssociation", nil) +} + func (v *QueryStruct) SetTags() *QueryStruct { - v.fields = append(v.fields, NewField("SetTags", "[]TagAssociation", Tags().Keyword().SQL("SET TAG"), nil)) + return v.setTags(KeywordOptions().Required()) +} + +func (v *QueryStruct) OptionalSetTags() *QueryStruct { + return v.setTags(nil) +} + +func (v *QueryStruct) setTags(transformer *KeywordTransformer) *QueryStruct { + v.fields = append(v.fields, NewField("SetTags", "[]TagAssociation", Tags().Keyword().SQL("SET TAG"), transformer)) return v } func (v *QueryStruct) UnsetTags() *QueryStruct { - v.fields = append(v.fields, NewField("UnsetTags", "[]ObjectIdentifier", Tags().Keyword().SQL("UNSET TAG"), nil)) + return v.unsetTags(KeywordOptions().Required()) +} + +func (v *QueryStruct) OptionalUnsetTags() *QueryStruct { + return v.unsetTags(nil) +} + +func (v *QueryStruct) unsetTags(transformer *KeywordTransformer) *QueryStruct { + v.fields = append(v.fields, NewField("UnsetTags", "[]ObjectIdentifier", Tags().Keyword().SQL("UNSET TAG"), transformer)) return v } diff --git a/pkg/sdk/poc/generator/plain_struct.go b/pkg/sdk/poc/generator/plain_struct.go index c011326bbd..f6fa3eb155 100644 --- a/pkg/sdk/poc/generator/plain_struct.go +++ b/pkg/sdk/poc/generator/plain_struct.go @@ -25,6 +25,22 @@ func (v *plainStruct) Field(name string, kind string) *plainStruct { return v } +func (v *plainStruct) Text(name string) *plainStruct { + return v.Field(name, "string") +} + +func (v *plainStruct) OptionalText(name string) *plainStruct { + return v.Field(name, "*string") +} + +func (v *plainStruct) Bool(name string) *plainStruct { + return v.Field(name, "bool") +} + +func (v *plainStruct) OptionalBool(name string) *plainStruct { + return v.Field(name, "*bool") +} + func (v *plainStruct) IntoField() *Field { f := NewField(v.name, v.name, nil, nil) for _, field := range v.fields { diff --git a/pkg/sdk/poc/main.go b/pkg/sdk/poc/main.go index a634409e7e..3f445de947 100644 --- a/pkg/sdk/poc/main.go +++ b/pkg/sdk/poc/main.go @@ -22,6 +22,7 @@ var definitionMapping = map[string]*generator.Interface{ "tasks_def.go": sdk.TasksDef, "streams_def.go": sdk.StreamsDef, "application_roles_def.go": sdk.ApplicationRolesDef, + "views_def.go": sdk.ViewsDef, } func main() { diff --git a/pkg/sdk/session_policies_def.go b/pkg/sdk/session_policies_def.go index 76c36c204d..2053e48d22 100644 --- a/pkg/sdk/session_policies_def.go +++ b/pkg/sdk/session_policies_def.go @@ -40,8 +40,8 @@ var SessionPoliciesDef = g.NewInterface( WithValidation(g.AtLeastOneValueSet, "SessionIdleTimeoutMins", "SessionUiIdleTimeoutMins", "Comment"), g.KeywordOptions().SQL("SET"), ). - SetTags(). - UnsetTags(). + OptionalSetTags(). + OptionalUnsetTags(). OptionalQueryStructField( "Unset", g.NewQueryStruct("SessionPolicyUnset"). diff --git a/pkg/sdk/streams_def.go b/pkg/sdk/streams_def.go index db9f3a89b7..135ecbc767 100644 --- a/pkg/sdk/streams_def.go +++ b/pkg/sdk/streams_def.go @@ -158,8 +158,8 @@ var ( Name(). OptionalTextAssignment("SET COMMENT", g.ParameterOptions().SingleQuotes()). OptionalSQL("UNSET COMMENT"). - SetTags(). - UnsetTags(). + OptionalSetTags(). + OptionalUnsetTags(). WithValidation(g.ValidIdentifier, "name"). WithValidation(g.ConflictingFields, "IfExists", "UnsetTags"). WithValidation(g.ExactlyOneValueSet, "SetComment", "UnsetComment", "SetTags", "UnsetTags"), diff --git a/pkg/sdk/tasks_def.go b/pkg/sdk/tasks_def.go index 0e80f9b2e8..10e92697f9 100644 --- a/pkg/sdk/tasks_def.go +++ b/pkg/sdk/tasks_def.go @@ -143,8 +143,8 @@ var TasksDef = g.NewInterface( WithValidation(g.AtLeastOneValueSet, "Warehouse", "Schedule", "Config", "AllowOverlappingExecution", "UserTaskTimeoutMs", "SuspendTaskAfterNumFailures", "ErrorIntegration", "Comment", "SessionParametersUnset"), g.KeywordOptions().SQL("UNSET"), ). - SetTags(). - UnsetTags(). + OptionalSetTags(). + OptionalUnsetTags(). OptionalTextAssignment("MODIFY AS", g.ParameterOptions().NoQuotes().NoEquals()). OptionalTextAssignment("MODIFY WHEN", g.ParameterOptions().NoQuotes().NoEquals()). WithValidation(g.ValidIdentifier, "name"). diff --git a/pkg/sdk/tasks_gen_test.go b/pkg/sdk/tasks_gen_test.go index 56bb1f65e9..7d0ac3b700 100644 --- a/pkg/sdk/tasks_gen_test.go +++ b/pkg/sdk/tasks_gen_test.go @@ -71,7 +71,8 @@ func TestTasks_Create(t *testing.T) { WithConfig(String(`$${"output_dir": "/temp/test_directory/", "learning_rate": 0.1}$$`)). WithAllowOverlappingExecution(Bool(true)). WithSessionParameters(&SessionParameters{ - JSONIndent: Int(10), + JSONIndent: Int(10), + LockTimeout: Int(5), }). WithUserTaskTimeoutMs(Int(5)). WithSuspendTaskAfterNumFailures(Int(6)). @@ -85,7 +86,7 @@ func TestTasks_Create(t *testing.T) { }}). WithWhen(String(`SYSTEM$STREAM_HAS_DATA('MYSTREAM')`)) - assertOptsValidAndSQLEquals(t, req.toOpts(), "CREATE OR REPLACE TASK %s WAREHOUSE = %s SCHEDULE = '10 MINUTE' CONFIG = $${\"output_dir\": \"/temp/test_directory/\", \"learning_rate\": 0.1}$$ ALLOW_OVERLAPPING_EXECUTION = true JSON_INDENT = 10 USER_TASK_TIMEOUT_MS = 5 SUSPEND_TASK_AFTER_NUM_FAILURES = 6 ERROR_INTEGRATION = some_error_integration COPY GRANTS COMMENT = 'some comment' AFTER %s TAG (%s = 'v1') WHEN SYSTEM$STREAM_HAS_DATA('MYSTREAM') AS SELECT CURRENT_TIMESTAMP", id.FullyQualifiedName(), warehouseId.FullyQualifiedName(), otherTaskId.FullyQualifiedName(), tagId.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, req.toOpts(), "CREATE OR REPLACE TASK %s WAREHOUSE = %s SCHEDULE = '10 MINUTE' CONFIG = $${\"output_dir\": \"/temp/test_directory/\", \"learning_rate\": 0.1}$$ ALLOW_OVERLAPPING_EXECUTION = true JSON_INDENT = 10, LOCK_TIMEOUT = 5 USER_TASK_TIMEOUT_MS = 5 SUSPEND_TASK_AFTER_NUM_FAILURES = 6 ERROR_INTEGRATION = some_error_integration COPY GRANTS COMMENT = 'some comment' AFTER %s TAG (%s = 'v1') WHEN SYSTEM$STREAM_HAS_DATA('MYSTREAM') AS SELECT CURRENT_TIMESTAMP", id.FullyQualifiedName(), warehouseId.FullyQualifiedName(), otherTaskId.FullyQualifiedName(), tagId.FullyQualifiedName()) }) } diff --git a/pkg/sdk/testint/helpers_test.go b/pkg/sdk/testint/helpers_test.go index 9211c7c7aa..77c6b102ca 100644 --- a/pkg/sdk/testint/helpers_test.go +++ b/pkg/sdk/testint/helpers_test.go @@ -2,6 +2,7 @@ package testint import ( "context" + "database/sql" "errors" "fmt" "path/filepath" @@ -461,6 +462,19 @@ func createMaskingPolicy(t *testing.T, client *sdk.Client, database *sdk.Databas return createMaskingPolicyWithOptions(t, client, database, schema, signature, sdk.DataTypeVARCHAR, expression, &sdk.CreateMaskingPolicyOptions{}) } +func createMaskingPolicyIdentity(t *testing.T, client *sdk.Client, database *sdk.Database, schema *sdk.Schema, columnType sdk.DataType) (*sdk.MaskingPolicy, func()) { + t.Helper() + name := "a" + signature := []sdk.TableColumnSignature{ + { + Name: name, + Type: columnType, + }, + } + expression := "a" + return createMaskingPolicyWithOptions(t, client, database, schema, signature, columnType, expression, &sdk.CreateMaskingPolicyOptions{}) +} + func createMaskingPolicyWithOptions(t *testing.T, client *sdk.Client, database *sdk.Database, schema *sdk.Schema, signature []sdk.TableColumnSignature, returns sdk.DataType, expression string, options *sdk.CreateMaskingPolicyOptions) (*sdk.MaskingPolicy, func()) { t.Helper() var databaseCleanup func() @@ -714,3 +728,46 @@ func createApplication(t *testing.T, client *sdk.Client, name string, packageNam require.NoError(t, err) } } + +func createRowAccessPolicy(t *testing.T, client *sdk.Client, schema *sdk.Schema) (sdk.SchemaObjectIdentifier, func()) { + t.Helper() + ctx := context.Background() + id := sdk.NewSchemaObjectIdentifier(schema.DatabaseName, schema.Name, random.String()) + _, err := client.ExecForTests(ctx, fmt.Sprintf(`CREATE ROW ACCESS POLICY %s AS (A NUMBER) RETURNS BOOLEAN -> TRUE`, id.FullyQualifiedName())) + require.NoError(t, err) + + return id, func() { + _, err := client.ExecForTests(ctx, fmt.Sprintf(`DROP ROW ACCESS POLICY %s`, id.FullyQualifiedName())) + require.NoError(t, err) + } +} + +// TODO: extract getting row access policies as resource (like getting tag in system functions) +// getRowAccessPolicyFor is based on https://docs.snowflake.com/en/user-guide/security-row-intro#obtain-database-objects-with-a-row-access-policy. +func getRowAccessPolicyFor(t *testing.T, client *sdk.Client, id sdk.SchemaObjectIdentifier, objectType sdk.ObjectType) (*policyReference, error) { + t.Helper() + ctx := context.Background() + + s := &policyReference{} + policyReferencesId := sdk.NewSchemaObjectIdentifier(id.DatabaseName(), "INFORMATION_SCHEMA", "POLICY_REFERENCES") + err := client.QueryOneForTests(ctx, s, fmt.Sprintf(`SELECT * FROM TABLE(%s(REF_ENTITY_NAME => '%s', REF_ENTITY_DOMAIN => '%v'))`, policyReferencesId.FullyQualifiedName(), id.FullyQualifiedName(), objectType)) + + return s, err +} + +type policyReference struct { + PolicyDb string `db:"POLICY_DB"` + PolicySchema string `db:"POLICY_SCHEMA"` + PolicyName string `db:"POLICY_NAME"` + PolicyKind string `db:"POLICY_KIND"` + RefDatabaseName string `db:"REF_DATABASE_NAME"` + RefSchemaName string `db:"REF_SCHEMA_NAME"` + RefEntityName string `db:"REF_ENTITY_NAME"` + RefEntityDomain string `db:"REF_ENTITY_DOMAIN"` + RefColumnName sql.NullString `db:"REF_COLUMN_NAME"` + RefArgColumnNames string `db:"REF_ARG_COLUMN_NAMES"` + TagDatabase sql.NullString `db:"TAG_DATABASE"` + TagSchema sql.NullString `db:"TAG_SCHEMA"` + TagName sql.NullString `db:"TAG_NAME"` + PolicyStatus string `db:"POLICY_STATUS"` +} diff --git a/pkg/sdk/testint/views_gen_integration_test.go b/pkg/sdk/testint/views_gen_integration_test.go new file mode 100644 index 0000000000..39e250facf --- /dev/null +++ b/pkg/sdk/testint/views_gen_integration_test.go @@ -0,0 +1,496 @@ +package testint + +import ( + "fmt" + "testing" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/internal/collections" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/internal/random" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TODO: add tests for setting masking policy on creation +// TODO: add tests for setting recursive on creation +func TestInt_Views(t *testing.T) { + client := testClient(t) + ctx := testContext(t) + + table, tableCleanup := createTable(t, client, testDb(t), testSchema(t)) + t.Cleanup(tableCleanup) + + sql := fmt.Sprintf("SELECT id FROM %s", table.ID().FullyQualifiedName()) + + assertViewWithOptions := func(t *testing.T, view *sdk.View, id sdk.SchemaObjectIdentifier, isSecure bool, comment string) { + t.Helper() + assert.NotEmpty(t, view.CreatedOn) + assert.Equal(t, id.Name(), view.Name) + // Kind is filled out only in TERSE response. + assert.Empty(t, view.Kind) + assert.Empty(t, view.Reserved) + assert.Equal(t, testDb(t).Name, view.DatabaseName) + assert.Equal(t, testSchema(t).Name, view.SchemaName) + assert.Equal(t, "ACCOUNTADMIN", view.Owner) + assert.Equal(t, comment, view.Comment) + assert.NotEmpty(t, view.Text) + assert.Equal(t, isSecure, view.IsSecure) + assert.Equal(t, false, view.IsMaterialized) + assert.Equal(t, "ROLE", view.OwnerRoleType) + assert.Equal(t, "OFF", view.ChangeTracking) + } + + assertView := func(t *testing.T, view *sdk.View, id sdk.SchemaObjectIdentifier) { + t.Helper() + assertViewWithOptions(t, view, id, false, "") + } + + assertViewTerse := func(t *testing.T, view *sdk.View, id sdk.SchemaObjectIdentifier) { + t.Helper() + assert.NotEmpty(t, view.CreatedOn) + assert.Equal(t, id.Name(), view.Name) + assert.Equal(t, "VIEW", view.Kind) + assert.Equal(t, testDb(t).Name, view.DatabaseName) + assert.Equal(t, testSchema(t).Name, view.SchemaName) + + // all below are not contained in the terse response, that's why all of them we expect to be empty + assert.Empty(t, view.Reserved) + assert.Empty(t, view.Owner) + assert.Empty(t, view.Comment) + assert.Empty(t, view.Text) + assert.Empty(t, view.IsSecure) + assert.Empty(t, view.IsMaterialized) + assert.Empty(t, view.OwnerRoleType) + assert.Empty(t, view.ChangeTracking) + } + + assertViewDetailsRow := func(t *testing.T, viewDetails *sdk.ViewDetails) { + t.Helper() + assert.Equal(t, sdk.ViewDetails{ + Name: "ID", + Type: "NUMBER(38,0)", + Kind: "COLUMN", + IsNullable: true, + Default: nil, + IsPrimary: false, + IsUnique: false, + Check: nil, + Expression: nil, + Comment: nil, + PolicyName: nil, + }, *viewDetails) + } + + cleanupViewProvider := func(id sdk.SchemaObjectIdentifier) func() { + return func() { + err := client.Views.Drop(ctx, sdk.NewDropViewRequest(id)) + require.NoError(t, err) + } + } + + createViewBasicRequest := func(t *testing.T) *sdk.CreateViewRequest { + t.Helper() + name := random.String() + id := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, name) + + return sdk.NewCreateViewRequest(id, sql) + } + + createViewWithRequest := func(t *testing.T, request *sdk.CreateViewRequest) *sdk.View { + t.Helper() + id := request.GetName() + + err := client.Views.Create(ctx, request) + require.NoError(t, err) + t.Cleanup(cleanupViewProvider(id)) + + view, err := client.Views.ShowByID(ctx, id) + require.NoError(t, err) + + return view + } + + createView := func(t *testing.T) *sdk.View { + t.Helper() + return createViewWithRequest(t, createViewBasicRequest(t)) + } + + t.Run("create view: no optionals", func(t *testing.T) { + request := createViewBasicRequest(t) + + view := createViewWithRequest(t, request) + + assertView(t, view, request.GetName()) + }) + + t.Run("create view: almost complete case", func(t *testing.T) { + rowAccessPolicyId, rowAccessPolicyCleanup := createRowAccessPolicy(t, client, testSchema(t)) + t.Cleanup(rowAccessPolicyCleanup) + + tag, tagCleanup := createTag(t, client, testDb(t), testSchema(t)) + t.Cleanup(tagCleanup) + + request := createViewBasicRequest(t). + WithOrReplace(sdk.Bool(true)). + WithSecure(sdk.Bool(true)). + WithTemporary(sdk.Bool(true)). + WithColumns([]sdk.ViewColumnRequest{ + *sdk.NewViewColumnRequest("COLUMN_WITH_COMMENT").WithComment(sdk.String("column comment")), + }). + WithCopyGrants(sdk.Bool(true)). + WithComment(sdk.String("comment")). + WithRowAccessPolicy(sdk.NewViewRowAccessPolicyRequest(rowAccessPolicyId, []string{"column_with_comment"})). + WithTag([]sdk.TagAssociation{{ + Name: tag.ID(), + Value: "v2", + }}) + + id := request.GetName() + + view := createViewWithRequest(t, request) + + assertViewWithOptions(t, view, id, true, "comment") + rowAccessPolicyReference, err := getRowAccessPolicyFor(t, client, view.ID(), sdk.ObjectTypeView) + require.NoError(t, err) + assert.Equal(t, rowAccessPolicyId.Name(), rowAccessPolicyReference.PolicyName) + assert.Equal(t, "ROW_ACCESS_POLICY", rowAccessPolicyReference.PolicyKind) + assert.Equal(t, view.ID().Name(), rowAccessPolicyReference.RefEntityName) + assert.Equal(t, "VIEW", rowAccessPolicyReference.RefEntityDomain) + assert.Equal(t, "ACTIVE", rowAccessPolicyReference.PolicyStatus) + }) + + t.Run("drop view: existing", func(t *testing.T) { + request := createViewBasicRequest(t) + id := request.GetName() + + err := client.Views.Create(ctx, request) + require.NoError(t, err) + + err = client.Views.Drop(ctx, sdk.NewDropViewRequest(id)) + require.NoError(t, err) + + _, err = client.Views.ShowByID(ctx, id) + assert.ErrorIs(t, err, collections.ErrObjectNotFound) + }) + + t.Run("drop view: non-existing", func(t *testing.T) { + id := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, "does_not_exist") + + err := client.Views.Drop(ctx, sdk.NewDropViewRequest(id)) + assert.ErrorIs(t, err, sdk.ErrObjectNotExistOrAuthorized) + }) + + t.Run("alter view: rename", func(t *testing.T) { + createRequest := createViewBasicRequest(t) + id := createRequest.GetName() + + err := client.Views.Create(ctx, createRequest) + require.NoError(t, err) + + newName := random.String() + newId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, newName) + alterRequest := sdk.NewAlterViewRequest(id).WithRenameTo(&newId) + + err = client.Views.Alter(ctx, alterRequest) + if err != nil { + t.Cleanup(cleanupViewProvider(id)) + } else { + t.Cleanup(cleanupViewProvider(newId)) + } + require.NoError(t, err) + + _, err = client.Views.ShowByID(ctx, id) + assert.ErrorIs(t, err, collections.ErrObjectNotFound) + + view, err := client.Views.ShowByID(ctx, newId) + require.NoError(t, err) + + assertView(t, view, newId) + }) + + t.Run("alter view: set and unset values", func(t *testing.T) { + view := createView(t) + id := view.ID() + + alterRequest := sdk.NewAlterViewRequest(id).WithSetComment(sdk.String("new comment")) + err := client.Views.Alter(ctx, alterRequest) + require.NoError(t, err) + + alteredView, err := client.Views.ShowByID(ctx, id) + require.NoError(t, err) + + assert.Equal(t, "new comment", alteredView.Comment) + + alterRequest = sdk.NewAlterViewRequest(id).WithSetSecure(sdk.Bool(true)) + err = client.Views.Alter(ctx, alterRequest) + require.NoError(t, err) + + alteredView, err = client.Views.ShowByID(ctx, id) + require.NoError(t, err) + + assert.Equal(t, true, alteredView.IsSecure) + + alterRequest = sdk.NewAlterViewRequest(id).WithSetChangeTracking(sdk.Bool(true)) + err = client.Views.Alter(ctx, alterRequest) + require.NoError(t, err) + + alteredView, err = client.Views.ShowByID(ctx, id) + require.NoError(t, err) + + assert.Equal(t, "ON", alteredView.ChangeTracking) + + alterRequest = sdk.NewAlterViewRequest(id).WithUnsetComment(sdk.Bool(true)) + err = client.Views.Alter(ctx, alterRequest) + require.NoError(t, err) + + alteredView, err = client.Views.ShowByID(ctx, id) + require.NoError(t, err) + + assert.Equal(t, "", alteredView.Comment) + + alterRequest = sdk.NewAlterViewRequest(id).WithUnsetSecure(sdk.Bool(true)) + err = client.Views.Alter(ctx, alterRequest) + require.NoError(t, err) + + alteredView, err = client.Views.ShowByID(ctx, id) + require.NoError(t, err) + + assert.Equal(t, false, alteredView.IsSecure) + + alterRequest = sdk.NewAlterViewRequest(id).WithSetChangeTracking(sdk.Bool(false)) + err = client.Views.Alter(ctx, alterRequest) + require.NoError(t, err) + + alteredView, err = client.Views.ShowByID(ctx, id) + require.NoError(t, err) + + assert.Equal(t, "OFF", alteredView.ChangeTracking) + }) + + t.Run("alter view: set and unset tag", func(t *testing.T) { + tag, tagCleanup := createTag(t, client, testDb(t), testSchema(t)) + t.Cleanup(tagCleanup) + + view := createView(t) + id := view.ID() + + tagValue := "abc" + tags := []sdk.TagAssociation{ + { + Name: tag.ID(), + Value: tagValue, + }, + } + alterRequestSetTags := sdk.NewAlterViewRequest(id).WithSetTags(tags) + + err := client.Views.Alter(ctx, alterRequestSetTags) + require.NoError(t, err) + + // setting object type to view results in: + // SQL compilation error: Invalid value VIEW for argument OBJECT_TYPE. Please use object type TABLE for all kinds of table-like objects. + returnedTagValue, err := client.SystemFunctions.GetTag(ctx, tag.ID(), id, sdk.ObjectTypeTable) + require.NoError(t, err) + + assert.Equal(t, tagValue, returnedTagValue) + + unsetTags := []sdk.ObjectIdentifier{ + tag.ID(), + } + alterRequestUnsetTags := sdk.NewAlterViewRequest(id).WithUnsetTags(unsetTags) + + err = client.Views.Alter(ctx, alterRequestUnsetTags) + require.NoError(t, err) + + _, err = client.SystemFunctions.GetTag(ctx, tag.ID(), id, sdk.ObjectTypeTable) + require.Error(t, err) + }) + + t.Run("alter view: set and unset masking policy", func(t *testing.T) { + maskingPolicy, maskingPolicyCleanup := createMaskingPolicyIdentity(t, client, testDb(t), testSchema(t), sdk.DataTypeNumber) + t.Cleanup(maskingPolicyCleanup) + + view := createView(t) + id := view.ID() + + alterRequest := sdk.NewAlterViewRequest(id).WithSetMaskingPolicyOnColumn( + sdk.NewViewSetColumnMaskingPolicyRequest("id", maskingPolicy.ID()), + ) + err := client.Views.Alter(ctx, alterRequest) + require.NoError(t, err) + + alteredViewDetails, err := client.Views.Describe(ctx, id) + require.NoError(t, err) + + assert.Equal(t, 1, len(alteredViewDetails)) + assert.Equal(t, maskingPolicy.ID().FullyQualifiedName(), *alteredViewDetails[0].PolicyName) + + alterRequest = sdk.NewAlterViewRequest(id).WithUnsetMaskingPolicyOnColumn( + sdk.NewViewUnsetColumnMaskingPolicyRequest("id"), + ) + err = client.Views.Alter(ctx, alterRequest) + require.NoError(t, err) + + alteredViewDetails, err = client.Views.Describe(ctx, id) + require.NoError(t, err) + + assert.Equal(t, 1, len(alteredViewDetails)) + assert.Empty(t, alteredViewDetails[0].PolicyName) + }) + + t.Run("alter view: set and unset tags on column", func(t *testing.T) { + tag, tagCleanup := createTag(t, client, testDb(t), testSchema(t)) + t.Cleanup(tagCleanup) + + view := createView(t) + id := view.ID() + + tagValue := "abc" + tags := []sdk.TagAssociation{ + { + Name: tag.ID(), + Value: tagValue, + }, + } + + alterRequest := sdk.NewAlterViewRequest(id).WithSetTagsOnColumn( + sdk.NewViewSetColumnTagsRequest("id", tags), + ) + err := client.Views.Alter(ctx, alterRequest) + require.NoError(t, err) + + columnId := sdk.NewTableColumnIdentifier(id.DatabaseName(), id.SchemaName(), id.Name(), "ID") + returnedTagValue, err := client.SystemFunctions.GetTag(ctx, tag.ID(), columnId, sdk.ObjectTypeColumn) + require.NoError(t, err) + assert.Equal(t, tagValue, returnedTagValue) + + unsetTags := []sdk.ObjectIdentifier{ + tag.ID(), + } + + alterRequest = sdk.NewAlterViewRequest(id).WithUnsetTagsOnColumn( + sdk.NewViewUnsetColumnTagsRequest("id", unsetTags), + ) + err = client.Views.Alter(ctx, alterRequest) + require.NoError(t, err) + + _, err = client.SystemFunctions.GetTag(ctx, tag.ID(), columnId, sdk.ObjectTypeColumn) + require.Error(t, err) + }) + + t.Run("alter view: add and drop row access policies", func(t *testing.T) { + rowAccessPolicyId, rowAccessPolicyCleanup := createRowAccessPolicy(t, client, testSchema(t)) + t.Cleanup(rowAccessPolicyCleanup) + rowAccessPolicy2Id, rowAccessPolicy2Cleanup := createRowAccessPolicy(t, client, testSchema(t)) + t.Cleanup(rowAccessPolicy2Cleanup) + + view := createView(t) + id := view.ID() + + // add policy + alterRequest := sdk.NewAlterViewRequest(id).WithAddRowAccessPolicy(sdk.NewViewAddRowAccessPolicyRequest(rowAccessPolicyId, []string{"ID"})) + err := client.Views.Alter(ctx, alterRequest) + require.NoError(t, err) + + rowAccessPolicyReference, err := getRowAccessPolicyFor(t, client, view.ID(), sdk.ObjectTypeView) + require.NoError(t, err) + assert.Equal(t, rowAccessPolicyId.Name(), rowAccessPolicyReference.PolicyName) + assert.Equal(t, "ROW_ACCESS_POLICY", rowAccessPolicyReference.PolicyKind) + assert.Equal(t, view.ID().Name(), rowAccessPolicyReference.RefEntityName) + assert.Equal(t, "VIEW", rowAccessPolicyReference.RefEntityDomain) + assert.Equal(t, "ACTIVE", rowAccessPolicyReference.PolicyStatus) + + // remove policy + alterRequest = sdk.NewAlterViewRequest(id).WithDropRowAccessPolicy(sdk.NewViewDropRowAccessPolicyRequest(rowAccessPolicyId)) + err = client.Views.Alter(ctx, alterRequest) + require.NoError(t, err) + + _, err = getRowAccessPolicyFor(t, client, view.ID(), sdk.ObjectTypeView) + require.Error(t, err, "no rows in result set") + + // add policy again + alterRequest = sdk.NewAlterViewRequest(id).WithAddRowAccessPolicy(sdk.NewViewAddRowAccessPolicyRequest(rowAccessPolicyId, []string{"ID"})) + err = client.Views.Alter(ctx, alterRequest) + require.NoError(t, err) + + rowAccessPolicyReference, err = getRowAccessPolicyFor(t, client, view.ID(), sdk.ObjectTypeView) + require.NoError(t, err) + assert.Equal(t, rowAccessPolicyId.Name(), rowAccessPolicyReference.PolicyName) + + // drop and add other policy simultaneously + alterRequest = sdk.NewAlterViewRequest(id).WithDropAndAddRowAccessPolicy(sdk.NewViewDropAndAddRowAccessPolicyRequest( + *sdk.NewViewDropRowAccessPolicyRequest(rowAccessPolicyId), + *sdk.NewViewAddRowAccessPolicyRequest(rowAccessPolicy2Id, []string{"ID"}), + )) + err = client.Views.Alter(ctx, alterRequest) + require.NoError(t, err) + + rowAccessPolicyReference, err = getRowAccessPolicyFor(t, client, view.ID(), sdk.ObjectTypeView) + require.NoError(t, err) + assert.Equal(t, rowAccessPolicy2Id.Name(), rowAccessPolicyReference.PolicyName) + + // drop all policies + alterRequest = sdk.NewAlterViewRequest(id).WithDropAllRowAccessPolicies(sdk.Bool(true)) + err = client.Views.Alter(ctx, alterRequest) + require.NoError(t, err) + + _, err = getRowAccessPolicyFor(t, client, view.ID(), sdk.ObjectTypeView) + require.Error(t, err, "no rows in result set") + }) + + t.Run("show view: default", func(t *testing.T) { + view1 := createView(t) + view2 := createView(t) + + showRequest := sdk.NewShowViewRequest() + returnedViews, err := client.Views.Show(ctx, showRequest) + require.NoError(t, err) + + assert.Equal(t, 2, len(returnedViews)) + assert.Contains(t, returnedViews, *view1) + assert.Contains(t, returnedViews, *view2) + }) + + t.Run("show view: terse", func(t *testing.T) { + view := createView(t) + + showRequest := sdk.NewShowViewRequest().WithTerse(sdk.Bool(true)) + returnedViews, err := client.Views.Show(ctx, showRequest) + require.NoError(t, err) + + assert.Equal(t, 1, len(returnedViews)) + assertViewTerse(t, &returnedViews[0], view.ID()) + }) + + t.Run("show view: with options", func(t *testing.T) { + view1 := createView(t) + view2 := createView(t) + + showRequest := sdk.NewShowViewRequest(). + WithLike(&sdk.Like{Pattern: &view1.Name}). + WithIn(&sdk.In{Schema: sdk.NewDatabaseObjectIdentifier(testDb(t).Name, testSchema(t).Name)}). + WithLimit(&sdk.LimitFrom{Rows: sdk.Int(5)}) + returnedViews, err := client.Views.Show(ctx, showRequest) + + require.NoError(t, err) + assert.Equal(t, 1, len(returnedViews)) + assert.Contains(t, returnedViews, *view1) + assert.NotContains(t, returnedViews, *view2) + }) + + t.Run("describe view", func(t *testing.T) { + view := createView(t) + + returnedViewDetails, err := client.Views.Describe(ctx, view.ID()) + require.NoError(t, err) + + assert.Equal(t, 1, len(returnedViewDetails)) + assertViewDetailsRow(t, &returnedViewDetails[0]) + }) + + t.Run("describe view: non-existing", func(t *testing.T) { + id := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, "does_not_exist") + + _, err := client.Views.Describe(ctx, id) + assert.ErrorIs(t, err, sdk.ErrObjectNotExistOrAuthorized) + }) +} diff --git a/pkg/sdk/views_def.go b/pkg/sdk/views_def.go new file mode 100644 index 0000000000..7620e37f8a --- /dev/null +++ b/pkg/sdk/views_def.go @@ -0,0 +1,215 @@ +package sdk + +import g "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/poc/generator" + +//go:generate go run ./poc/main.go + +var viewDbRow = g.DbStruct("viewDBRow"). + Text("created_on"). + Text("name"). + OptionalText("kind"). + OptionalText("reserved"). + Text("database_name"). + Text("schema_name"). + OptionalText("owner"). + OptionalText("comment"). + OptionalText("text"). + OptionalBool("is_secure"). + OptionalBool("is_materialized"). + OptionalText("owner_role_type"). + OptionalText("change_tracking") + +var view = g.PlainStruct("View"). + Text("CreatedOn"). + Text("Name"). + Text("Kind"). + Text("Reserved"). + Text("DatabaseName"). + Text("SchemaName"). + Text("Owner"). + Text("Comment"). + Text("Text"). + Bool("IsSecure"). + Bool("IsMaterialized"). + Text("OwnerRoleType"). + Text("ChangeTracking") + +var viewDetailsDbRow = g.DbStruct("viewDetailsRow"). + Text("name"). + Field("type", "DataType"). + Text("kind"). + Text("null"). + OptionalText("default"). + Text("primary key"). + Text("unique key"). + OptionalText("check"). + OptionalText("expression"). + OptionalText("comment"). + OptionalText("policy name") + +var viewDetails = g.PlainStruct("ViewDetails"). + Text("Name"). + Field("Type", "DataType"). + Text("Kind"). + Bool("IsNullable"). + OptionalText("Default"). + Bool("IsPrimary"). + Bool("IsUnique"). + OptionalBool("Check"). + OptionalText("Expression"). + OptionalText("Comment"). + OptionalText("PolicyName") + +var viewColumn = g.NewQueryStruct("ViewColumn"). + Text("Name", g.KeywordOptions().DoubleQuotes().Required()). + OptionalTextAssignment("COMMENT", g.ParameterOptions().SingleQuotes().NoEquals()) + +var viewColumnMaskingPolicy = g.NewQueryStruct("ViewColumnMaskingPolicy"). + Text("Name", g.KeywordOptions().Required()). + Identifier("MaskingPolicy", g.KindOfT[SchemaObjectIdentifier](), g.IdentifierOptions().SQL("MASKING POLICY").Required()). + NamedListWithParens("USING", g.KindOfT[string](), nil). // TODO: double quotes here? + WithTags() + +var viewRowAccessPolicy = g.NewQueryStruct("ViewRowAccessPolicy"). + Identifier("RowAccessPolicy", g.KindOfT[SchemaObjectIdentifier](), g.IdentifierOptions().SQL("ROW ACCESS POLICY").Required()). + NamedListWithParens("ON", g.KindOfT[string](), g.KeywordOptions().Required()). // TODO: double quotes here? + WithValidation(g.ValidIdentifier, "RowAccessPolicy") + +var viewAddRowAccessPolicy = g.NewQueryStruct("ViewAddRowAccessPolicy"). + SQL("ADD"). + Identifier("RowAccessPolicy", g.KindOfT[SchemaObjectIdentifier](), g.IdentifierOptions().SQL("ROW ACCESS POLICY").Required()). + NamedListWithParens("ON", g.KindOfT[string](), g.KeywordOptions().Required()). // TODO: double quotes here? + WithValidation(g.ValidIdentifier, "RowAccessPolicy") + +var viewDropRowAccessPolicy = g.NewQueryStruct("ViewDropRowAccessPolicy"). + SQL("DROP"). + Identifier("RowAccessPolicy", g.KindOfT[SchemaObjectIdentifier](), g.IdentifierOptions().SQL("ROW ACCESS POLICY").Required()). + WithValidation(g.ValidIdentifier, "RowAccessPolicy") + +var viewDropAndAddRowAccessPolicy = g.NewQueryStruct("ViewDropAndAddRowAccessPolicy"). + QueryStructField("Drop", viewDropRowAccessPolicy, g.KeywordOptions().Required()). + QueryStructField("Add", viewAddRowAccessPolicy, g.KeywordOptions().Required()) + +var viewSetColumnMaskingPolicy = g.NewQueryStruct("ViewSetColumnMaskingPolicy"). + // In the docs there is a MODIFY alternative, but for simplicity only one is supported here. + SQL("ALTER"). + SQL("COLUMN"). + Text("Name", g.KeywordOptions().Required()). + SQL("SET"). + Identifier("MaskingPolicy", g.KindOfT[SchemaObjectIdentifier](), g.IdentifierOptions().SQL("MASKING POLICY").Required()). + NamedListWithParens("USING", g.KindOfT[string](), nil). // TODO: double quotes here? + OptionalSQL("FORCE") + +var viewUnsetColumnMaskingPolicy = g.NewQueryStruct("ViewUnsetColumnMaskingPolicy"). + // In the docs there is a MODIFY alternative, but for simplicity only one is supported here. + SQL("ALTER"). + SQL("COLUMN"). + Text("Name", g.KeywordOptions().Required()). + SQL("UNSET"). + SQL("MASKING POLICY") + +var viewSetColumnTags = g.NewQueryStruct("ViewSetColumnTags"). + // In the docs there is a MODIFY alternative, but for simplicity only one is supported here. + SQL("ALTER"). + SQL("COLUMN"). + Text("Name", g.KeywordOptions().Required()). + SetTags() + +var viewUnsetColumnTags = g.NewQueryStruct("ViewUnsetColumnTags"). + // In the docs there is a MODIFY alternative, but for simplicity only one is supported here. + SQL("ALTER"). + SQL("COLUMN"). + Text("Name", g.KeywordOptions().Required()). + UnsetTags() + +var ViewsDef = g.NewInterface( + "Views", + "View", + g.KindOfT[SchemaObjectIdentifier](), +). + CreateOperation( + "https://docs.snowflake.com/en/sql-reference/sql/create-view", + g.NewQueryStruct("CreateView"). + Create(). + OrReplace(). + OptionalSQL("SECURE"). + // There are multiple variants in the docs: { [ { LOCAL | GLOBAL } ] TEMP | TEMPORARY | VOLATILE } + // but from description they are all the same. For the sake of simplicity only one option is used here. + OptionalSQL("TEMPORARY"). + OptionalSQL("RECURSIVE"). + SQL("VIEW"). + IfNotExists(). + Name(). + ListQueryStructField("Columns", viewColumn, g.ListOptions().Parentheses()). + ListQueryStructField("ColumnsMaskingPolicies", viewColumnMaskingPolicy, g.ListOptions().NoParentheses().NoEquals()). + OptionalSQL("COPY GRANTS"). + OptionalTextAssignment("COMMENT", g.ParameterOptions().SingleQuotes()). + // In the current docs ROW ACCESS POLICY and TAG are specified twice. + // It is a mistake probably so here they are present only once. + OptionalQueryStructField("RowAccessPolicy", viewRowAccessPolicy, g.KeywordOptions()). + WithTags(). + SQL("AS"). + Text("sql", g.KeywordOptions().NoQuotes().Required()). + WithValidation(g.ValidIdentifier, "name"). + WithValidation(g.ConflictingFields, "OrReplace", "IfNotExists"), + ). + AlterOperation( + "https://docs.snowflake.com/en/sql-reference/sql/alter-view", + g.NewQueryStruct("AlterView"). + Alter(). + SQL("VIEW"). + IfExists(). + Name(). + OptionalIdentifier("RenameTo", g.KindOfT[SchemaObjectIdentifier](), g.IdentifierOptions().SQL("RENAME TO")). + OptionalTextAssignment("SET COMMENT", g.ParameterOptions().SingleQuotes()). + OptionalSQL("UNSET COMMENT"). + OptionalSQL("SET SECURE"). + OptionalBooleanAssignment("SET CHANGE_TRACKING", nil). + OptionalSQL("UNSET SECURE"). + OptionalSetTags(). + OptionalUnsetTags(). + OptionalQueryStructField("AddRowAccessPolicy", viewAddRowAccessPolicy, g.KeywordOptions()). + OptionalQueryStructField("DropRowAccessPolicy", viewDropRowAccessPolicy, g.KeywordOptions()). + OptionalQueryStructField("DropAndAddRowAccessPolicy", viewDropAndAddRowAccessPolicy, g.ListOptions().NoParentheses()). + OptionalSQL("DROP ALL ROW ACCESS POLICIES"). + OptionalQueryStructField("SetMaskingPolicyOnColumn", viewSetColumnMaskingPolicy, g.KeywordOptions()). + OptionalQueryStructField("UnsetMaskingPolicyOnColumn", viewUnsetColumnMaskingPolicy, g.KeywordOptions()). + OptionalQueryStructField("SetTagsOnColumn", viewSetColumnTags, g.KeywordOptions()). + OptionalQueryStructField("UnsetTagsOnColumn", viewUnsetColumnTags, g.KeywordOptions()). + WithValidation(g.ValidIdentifier, "name"). + WithValidation(g.ExactlyOneValueSet, "RenameTo", "SetComment", "UnsetComment", "SetSecure", "SetChangeTracking", "UnsetSecure", "SetTags", "UnsetTags", "AddRowAccessPolicy", "DropRowAccessPolicy", "DropAndAddRowAccessPolicy", "DropAllRowAccessPolicies", "SetMaskingPolicyOnColumn", "UnsetMaskingPolicyOnColumn", "SetTagsOnColumn", "UnsetTagsOnColumn"), + ). + DropOperation( + "https://docs.snowflake.com/en/sql-reference/sql/drop-view", + g.NewQueryStruct("DropView"). + Drop(). + SQL("VIEW"). + IfExists(). + Name(). + WithValidation(g.ValidIdentifier, "name"), + ). + ShowOperation( + "https://docs.snowflake.com/en/sql-reference/sql/show-views", + viewDbRow, + view, + g.NewQueryStruct("ShowViews"). + Show(). + Terse(). + SQL("VIEWS"). + OptionalLike(). + OptionalIn(). + OptionalStartsWith(). + OptionalLimit(), + ). + ShowByIdOperation(). + DescribeOperation( + g.DescriptionMappingKindSlice, + "https://docs.snowflake.com/en/sql-reference/sql/desc-view", + viewDetailsDbRow, + viewDetails, + g.NewQueryStruct("DescribeView"). + Describe(). + SQL("VIEW"). + Name(). + WithValidation(g.ValidIdentifier, "name"), + ) diff --git a/pkg/sdk/views_dto_builders_gen.go b/pkg/sdk/views_dto_builders_gen.go new file mode 100644 index 0000000000..d357ee3a27 --- /dev/null +++ b/pkg/sdk/views_dto_builders_gen.go @@ -0,0 +1,332 @@ +// Code generated by dto builder generator; DO NOT EDIT. + +package sdk + +import () + +func NewCreateViewRequest( + name SchemaObjectIdentifier, + sql string, +) *CreateViewRequest { + s := CreateViewRequest{} + s.name = name + s.sql = sql + return &s +} + +func (s *CreateViewRequest) WithOrReplace(OrReplace *bool) *CreateViewRequest { + s.OrReplace = OrReplace + return s +} + +func (s *CreateViewRequest) WithSecure(Secure *bool) *CreateViewRequest { + s.Secure = Secure + return s +} + +func (s *CreateViewRequest) WithTemporary(Temporary *bool) *CreateViewRequest { + s.Temporary = Temporary + return s +} + +func (s *CreateViewRequest) WithRecursive(Recursive *bool) *CreateViewRequest { + s.Recursive = Recursive + return s +} + +func (s *CreateViewRequest) WithIfNotExists(IfNotExists *bool) *CreateViewRequest { + s.IfNotExists = IfNotExists + return s +} + +func (s *CreateViewRequest) WithColumns(Columns []ViewColumnRequest) *CreateViewRequest { + s.Columns = Columns + return s +} + +func (s *CreateViewRequest) WithColumnsMaskingPolicies(ColumnsMaskingPolicies []ViewColumnMaskingPolicyRequest) *CreateViewRequest { + s.ColumnsMaskingPolicies = ColumnsMaskingPolicies + return s +} + +func (s *CreateViewRequest) WithCopyGrants(CopyGrants *bool) *CreateViewRequest { + s.CopyGrants = CopyGrants + return s +} + +func (s *CreateViewRequest) WithComment(Comment *string) *CreateViewRequest { + s.Comment = Comment + return s +} + +func (s *CreateViewRequest) WithRowAccessPolicy(RowAccessPolicy *ViewRowAccessPolicyRequest) *CreateViewRequest { + s.RowAccessPolicy = RowAccessPolicy + return s +} + +func (s *CreateViewRequest) WithTag(Tag []TagAssociation) *CreateViewRequest { + s.Tag = Tag + return s +} + +func NewViewColumnRequest( + Name string, +) *ViewColumnRequest { + s := ViewColumnRequest{} + s.Name = Name + return &s +} + +func (s *ViewColumnRequest) WithComment(Comment *string) *ViewColumnRequest { + s.Comment = Comment + return s +} + +func NewViewColumnMaskingPolicyRequest( + Name string, + MaskingPolicy SchemaObjectIdentifier, +) *ViewColumnMaskingPolicyRequest { + s := ViewColumnMaskingPolicyRequest{} + s.Name = Name + s.MaskingPolicy = MaskingPolicy + return &s +} + +func (s *ViewColumnMaskingPolicyRequest) WithUsing(Using []string) *ViewColumnMaskingPolicyRequest { + s.Using = Using + return s +} + +func (s *ViewColumnMaskingPolicyRequest) WithTag(Tag []TagAssociation) *ViewColumnMaskingPolicyRequest { + s.Tag = Tag + return s +} + +func NewViewRowAccessPolicyRequest( + RowAccessPolicy SchemaObjectIdentifier, + On []string, +) *ViewRowAccessPolicyRequest { + s := ViewRowAccessPolicyRequest{} + s.RowAccessPolicy = RowAccessPolicy + s.On = On + return &s +} + +func NewAlterViewRequest( + name SchemaObjectIdentifier, +) *AlterViewRequest { + s := AlterViewRequest{} + s.name = name + return &s +} + +func (s *AlterViewRequest) WithIfExists(IfExists *bool) *AlterViewRequest { + s.IfExists = IfExists + return s +} + +func (s *AlterViewRequest) WithRenameTo(RenameTo *SchemaObjectIdentifier) *AlterViewRequest { + s.RenameTo = RenameTo + return s +} + +func (s *AlterViewRequest) WithSetComment(SetComment *string) *AlterViewRequest { + s.SetComment = SetComment + return s +} + +func (s *AlterViewRequest) WithUnsetComment(UnsetComment *bool) *AlterViewRequest { + s.UnsetComment = UnsetComment + return s +} + +func (s *AlterViewRequest) WithSetSecure(SetSecure *bool) *AlterViewRequest { + s.SetSecure = SetSecure + return s +} + +func (s *AlterViewRequest) WithSetChangeTracking(SetChangeTracking *bool) *AlterViewRequest { + s.SetChangeTracking = SetChangeTracking + return s +} + +func (s *AlterViewRequest) WithUnsetSecure(UnsetSecure *bool) *AlterViewRequest { + s.UnsetSecure = UnsetSecure + return s +} + +func (s *AlterViewRequest) WithSetTags(SetTags []TagAssociation) *AlterViewRequest { + s.SetTags = SetTags + return s +} + +func (s *AlterViewRequest) WithUnsetTags(UnsetTags []ObjectIdentifier) *AlterViewRequest { + s.UnsetTags = UnsetTags + return s +} + +func (s *AlterViewRequest) WithAddRowAccessPolicy(AddRowAccessPolicy *ViewAddRowAccessPolicyRequest) *AlterViewRequest { + s.AddRowAccessPolicy = AddRowAccessPolicy + return s +} + +func (s *AlterViewRequest) WithDropRowAccessPolicy(DropRowAccessPolicy *ViewDropRowAccessPolicyRequest) *AlterViewRequest { + s.DropRowAccessPolicy = DropRowAccessPolicy + return s +} + +func (s *AlterViewRequest) WithDropAndAddRowAccessPolicy(DropAndAddRowAccessPolicy *ViewDropAndAddRowAccessPolicyRequest) *AlterViewRequest { + s.DropAndAddRowAccessPolicy = DropAndAddRowAccessPolicy + return s +} + +func (s *AlterViewRequest) WithDropAllRowAccessPolicies(DropAllRowAccessPolicies *bool) *AlterViewRequest { + s.DropAllRowAccessPolicies = DropAllRowAccessPolicies + return s +} + +func (s *AlterViewRequest) WithSetMaskingPolicyOnColumn(SetMaskingPolicyOnColumn *ViewSetColumnMaskingPolicyRequest) *AlterViewRequest { + s.SetMaskingPolicyOnColumn = SetMaskingPolicyOnColumn + return s +} + +func (s *AlterViewRequest) WithUnsetMaskingPolicyOnColumn(UnsetMaskingPolicyOnColumn *ViewUnsetColumnMaskingPolicyRequest) *AlterViewRequest { + s.UnsetMaskingPolicyOnColumn = UnsetMaskingPolicyOnColumn + return s +} + +func (s *AlterViewRequest) WithSetTagsOnColumn(SetTagsOnColumn *ViewSetColumnTagsRequest) *AlterViewRequest { + s.SetTagsOnColumn = SetTagsOnColumn + return s +} + +func (s *AlterViewRequest) WithUnsetTagsOnColumn(UnsetTagsOnColumn *ViewUnsetColumnTagsRequest) *AlterViewRequest { + s.UnsetTagsOnColumn = UnsetTagsOnColumn + return s +} + +func NewViewAddRowAccessPolicyRequest( + RowAccessPolicy SchemaObjectIdentifier, + On []string, +) *ViewAddRowAccessPolicyRequest { + s := ViewAddRowAccessPolicyRequest{} + s.RowAccessPolicy = RowAccessPolicy + s.On = On + return &s +} + +func NewViewDropRowAccessPolicyRequest( + RowAccessPolicy SchemaObjectIdentifier, +) *ViewDropRowAccessPolicyRequest { + s := ViewDropRowAccessPolicyRequest{} + s.RowAccessPolicy = RowAccessPolicy + return &s +} + +func NewViewDropAndAddRowAccessPolicyRequest( + Drop ViewDropRowAccessPolicyRequest, + Add ViewAddRowAccessPolicyRequest, +) *ViewDropAndAddRowAccessPolicyRequest { + s := ViewDropAndAddRowAccessPolicyRequest{} + s.Drop = Drop + s.Add = Add + return &s +} + +func NewViewSetColumnMaskingPolicyRequest( + Name string, + MaskingPolicy SchemaObjectIdentifier, +) *ViewSetColumnMaskingPolicyRequest { + s := ViewSetColumnMaskingPolicyRequest{} + s.Name = Name + s.MaskingPolicy = MaskingPolicy + return &s +} + +func (s *ViewSetColumnMaskingPolicyRequest) WithUsing(Using []string) *ViewSetColumnMaskingPolicyRequest { + s.Using = Using + return s +} + +func (s *ViewSetColumnMaskingPolicyRequest) WithForce(Force *bool) *ViewSetColumnMaskingPolicyRequest { + s.Force = Force + return s +} + +func NewViewUnsetColumnMaskingPolicyRequest( + Name string, +) *ViewUnsetColumnMaskingPolicyRequest { + s := ViewUnsetColumnMaskingPolicyRequest{} + s.Name = Name + return &s +} + +func NewViewSetColumnTagsRequest( + Name string, + SetTags []TagAssociation, +) *ViewSetColumnTagsRequest { + s := ViewSetColumnTagsRequest{} + s.Name = Name + s.SetTags = SetTags + return &s +} + +func NewViewUnsetColumnTagsRequest( + Name string, + UnsetTags []ObjectIdentifier, +) *ViewUnsetColumnTagsRequest { + s := ViewUnsetColumnTagsRequest{} + s.Name = Name + s.UnsetTags = UnsetTags + return &s +} + +func NewDropViewRequest( + name SchemaObjectIdentifier, +) *DropViewRequest { + s := DropViewRequest{} + s.name = name + return &s +} + +func (s *DropViewRequest) WithIfExists(IfExists *bool) *DropViewRequest { + s.IfExists = IfExists + return s +} + +func NewShowViewRequest() *ShowViewRequest { + return &ShowViewRequest{} +} + +func (s *ShowViewRequest) WithTerse(Terse *bool) *ShowViewRequest { + s.Terse = Terse + return s +} + +func (s *ShowViewRequest) WithLike(Like *Like) *ShowViewRequest { + s.Like = Like + return s +} + +func (s *ShowViewRequest) WithIn(In *In) *ShowViewRequest { + s.In = In + return s +} + +func (s *ShowViewRequest) WithStartsWith(StartsWith *string) *ShowViewRequest { + s.StartsWith = StartsWith + return s +} + +func (s *ShowViewRequest) WithLimit(Limit *LimitFrom) *ShowViewRequest { + s.Limit = Limit + return s +} + +func NewDescribeViewRequest( + name SchemaObjectIdentifier, +) *DescribeViewRequest { + s := DescribeViewRequest{} + s.name = name + return &s +} diff --git a/pkg/sdk/views_dto_gen.go b/pkg/sdk/views_dto_gen.go new file mode 100644 index 0000000000..a3b2ac4cee --- /dev/null +++ b/pkg/sdk/views_dto_gen.go @@ -0,0 +1,121 @@ +package sdk + +//go:generate go run ./dto-builder-generator/main.go + +var ( + _ optionsProvider[CreateViewOptions] = new(CreateViewRequest) + _ optionsProvider[AlterViewOptions] = new(AlterViewRequest) + _ optionsProvider[DropViewOptions] = new(DropViewRequest) + _ optionsProvider[ShowViewOptions] = new(ShowViewRequest) + _ optionsProvider[DescribeViewOptions] = new(DescribeViewRequest) +) + +type CreateViewRequest struct { + OrReplace *bool + Secure *bool + Temporary *bool + Recursive *bool + IfNotExists *bool + name SchemaObjectIdentifier // required + Columns []ViewColumnRequest + ColumnsMaskingPolicies []ViewColumnMaskingPolicyRequest + CopyGrants *bool + Comment *string + RowAccessPolicy *ViewRowAccessPolicyRequest + Tag []TagAssociation + sql string // required +} + +func (r *CreateViewRequest) GetName() SchemaObjectIdentifier { + return r.name +} + +type ViewColumnRequest struct { + Name string // required + Comment *string +} + +type ViewColumnMaskingPolicyRequest struct { + Name string // required + MaskingPolicy SchemaObjectIdentifier // required + Using []string + Tag []TagAssociation +} + +type ViewRowAccessPolicyRequest struct { + RowAccessPolicy SchemaObjectIdentifier // required + On []string // required +} + +type AlterViewRequest struct { + IfExists *bool + name SchemaObjectIdentifier // required + RenameTo *SchemaObjectIdentifier + SetComment *string + UnsetComment *bool + SetSecure *bool + SetChangeTracking *bool + UnsetSecure *bool + SetTags []TagAssociation + UnsetTags []ObjectIdentifier + AddRowAccessPolicy *ViewAddRowAccessPolicyRequest + DropRowAccessPolicy *ViewDropRowAccessPolicyRequest + DropAndAddRowAccessPolicy *ViewDropAndAddRowAccessPolicyRequest + DropAllRowAccessPolicies *bool + SetMaskingPolicyOnColumn *ViewSetColumnMaskingPolicyRequest + UnsetMaskingPolicyOnColumn *ViewUnsetColumnMaskingPolicyRequest + SetTagsOnColumn *ViewSetColumnTagsRequest + UnsetTagsOnColumn *ViewUnsetColumnTagsRequest +} + +type ViewAddRowAccessPolicyRequest struct { + RowAccessPolicy SchemaObjectIdentifier // required + On []string // required +} + +type ViewDropRowAccessPolicyRequest struct { + RowAccessPolicy SchemaObjectIdentifier // required +} + +type ViewDropAndAddRowAccessPolicyRequest struct { + Drop ViewDropRowAccessPolicyRequest // required + Add ViewAddRowAccessPolicyRequest // required +} + +type ViewSetColumnMaskingPolicyRequest struct { + Name string // required + MaskingPolicy SchemaObjectIdentifier // required + Using []string + Force *bool +} + +type ViewUnsetColumnMaskingPolicyRequest struct { + Name string // required +} + +type ViewSetColumnTagsRequest struct { + Name string // required + SetTags []TagAssociation // required +} + +type ViewUnsetColumnTagsRequest struct { + Name string // required + UnsetTags []ObjectIdentifier // required +} + +type DropViewRequest struct { + IfExists *bool + name SchemaObjectIdentifier // required +} + +type ShowViewRequest struct { + Terse *bool + Like *Like + In *In + StartsWith *string + Limit *LimitFrom +} + +type DescribeViewRequest struct { + name SchemaObjectIdentifier // required +} diff --git a/pkg/sdk/views_gen.go b/pkg/sdk/views_gen.go new file mode 100644 index 0000000000..0debdea3e4 --- /dev/null +++ b/pkg/sdk/views_gen.go @@ -0,0 +1,217 @@ +package sdk + +import ( + "context" + "database/sql" +) + +type Views interface { + Create(ctx context.Context, request *CreateViewRequest) error + Alter(ctx context.Context, request *AlterViewRequest) error + Drop(ctx context.Context, request *DropViewRequest) error + Show(ctx context.Context, request *ShowViewRequest) ([]View, error) + ShowByID(ctx context.Context, id SchemaObjectIdentifier) (*View, error) + Describe(ctx context.Context, id SchemaObjectIdentifier) ([]ViewDetails, error) +} + +// CreateViewOptions is based on https://docs.snowflake.com/en/sql-reference/sql/create-view. +type CreateViewOptions struct { + create bool `ddl:"static" sql:"CREATE"` + OrReplace *bool `ddl:"keyword" sql:"OR REPLACE"` + Secure *bool `ddl:"keyword" sql:"SECURE"` + Temporary *bool `ddl:"keyword" sql:"TEMPORARY"` + Recursive *bool `ddl:"keyword" sql:"RECURSIVE"` + view bool `ddl:"static" sql:"VIEW"` + IfNotExists *bool `ddl:"keyword" sql:"IF NOT EXISTS"` + name SchemaObjectIdentifier `ddl:"identifier"` + Columns []ViewColumn `ddl:"list,parentheses"` + ColumnsMaskingPolicies []ViewColumnMaskingPolicy `ddl:"list,no_parentheses,no_equals"` + CopyGrants *bool `ddl:"keyword" sql:"COPY GRANTS"` + Comment *string `ddl:"parameter,single_quotes" sql:"COMMENT"` + RowAccessPolicy *ViewRowAccessPolicy `ddl:"keyword"` + Tag []TagAssociation `ddl:"keyword,parentheses" sql:"TAG"` + as bool `ddl:"static" sql:"AS"` + sql string `ddl:"keyword,no_quotes"` +} + +type ViewColumn struct { + Name string `ddl:"keyword,double_quotes"` + Comment *string `ddl:"parameter,single_quotes,no_equals" sql:"COMMENT"` +} + +type ViewColumnMaskingPolicy struct { + Name string `ddl:"keyword"` + MaskingPolicy SchemaObjectIdentifier `ddl:"identifier" sql:"MASKING POLICY"` + Using []string `ddl:"keyword,parentheses" sql:"USING"` + Tag []TagAssociation `ddl:"keyword,parentheses" sql:"TAG"` +} + +type ViewRowAccessPolicy struct { + RowAccessPolicy SchemaObjectIdentifier `ddl:"identifier" sql:"ROW ACCESS POLICY"` + On []string `ddl:"keyword,parentheses" sql:"ON"` +} + +// AlterViewOptions is based on https://docs.snowflake.com/en/sql-reference/sql/alter-view. +type AlterViewOptions struct { + alter bool `ddl:"static" sql:"ALTER"` + view bool `ddl:"static" sql:"VIEW"` + IfExists *bool `ddl:"keyword" sql:"IF EXISTS"` + name SchemaObjectIdentifier `ddl:"identifier"` + RenameTo *SchemaObjectIdentifier `ddl:"identifier" sql:"RENAME TO"` + SetComment *string `ddl:"parameter,single_quotes" sql:"SET COMMENT"` + UnsetComment *bool `ddl:"keyword" sql:"UNSET COMMENT"` + SetSecure *bool `ddl:"keyword" sql:"SET SECURE"` + SetChangeTracking *bool `ddl:"parameter" sql:"SET CHANGE_TRACKING"` + UnsetSecure *bool `ddl:"keyword" sql:"UNSET SECURE"` + SetTags []TagAssociation `ddl:"keyword" sql:"SET TAG"` + UnsetTags []ObjectIdentifier `ddl:"keyword" sql:"UNSET TAG"` + AddRowAccessPolicy *ViewAddRowAccessPolicy `ddl:"keyword"` + DropRowAccessPolicy *ViewDropRowAccessPolicy `ddl:"keyword"` + DropAndAddRowAccessPolicy *ViewDropAndAddRowAccessPolicy `ddl:"list,no_parentheses"` + DropAllRowAccessPolicies *bool `ddl:"keyword" sql:"DROP ALL ROW ACCESS POLICIES"` + SetMaskingPolicyOnColumn *ViewSetColumnMaskingPolicy `ddl:"keyword"` + UnsetMaskingPolicyOnColumn *ViewUnsetColumnMaskingPolicy `ddl:"keyword"` + SetTagsOnColumn *ViewSetColumnTags `ddl:"keyword"` + UnsetTagsOnColumn *ViewUnsetColumnTags `ddl:"keyword"` +} + +type ViewAddRowAccessPolicy struct { + add bool `ddl:"static" sql:"ADD"` + RowAccessPolicy SchemaObjectIdentifier `ddl:"identifier" sql:"ROW ACCESS POLICY"` + On []string `ddl:"keyword,parentheses" sql:"ON"` +} + +type ViewDropRowAccessPolicy struct { + drop bool `ddl:"static" sql:"DROP"` + RowAccessPolicy SchemaObjectIdentifier `ddl:"identifier" sql:"ROW ACCESS POLICY"` +} + +type ViewDropAndAddRowAccessPolicy struct { + Drop ViewDropRowAccessPolicy `ddl:"keyword"` + Add ViewAddRowAccessPolicy `ddl:"keyword"` +} + +type ViewSetColumnMaskingPolicy struct { + alter bool `ddl:"static" sql:"ALTER"` + column bool `ddl:"static" sql:"COLUMN"` + Name string `ddl:"keyword"` + set bool `ddl:"static" sql:"SET"` + MaskingPolicy SchemaObjectIdentifier `ddl:"identifier" sql:"MASKING POLICY"` + Using []string `ddl:"keyword,parentheses" sql:"USING"` + Force *bool `ddl:"keyword" sql:"FORCE"` +} + +type ViewUnsetColumnMaskingPolicy struct { + alter bool `ddl:"static" sql:"ALTER"` + column bool `ddl:"static" sql:"COLUMN"` + Name string `ddl:"keyword"` + unset bool `ddl:"static" sql:"UNSET"` + maskingPolicy bool `ddl:"static" sql:"MASKING POLICY"` +} + +type ViewSetColumnTags struct { + alter bool `ddl:"static" sql:"ALTER"` + column bool `ddl:"static" sql:"COLUMN"` + Name string `ddl:"keyword"` + SetTags []TagAssociation `ddl:"keyword" sql:"SET TAG"` +} + +type ViewUnsetColumnTags struct { + alter bool `ddl:"static" sql:"ALTER"` + column bool `ddl:"static" sql:"COLUMN"` + Name string `ddl:"keyword"` + UnsetTags []ObjectIdentifier `ddl:"keyword" sql:"UNSET TAG"` +} + +// DropViewOptions is based on https://docs.snowflake.com/en/sql-reference/sql/drop-view. +type DropViewOptions struct { + drop bool `ddl:"static" sql:"DROP"` + view bool `ddl:"static" sql:"VIEW"` + IfExists *bool `ddl:"keyword" sql:"IF EXISTS"` + name SchemaObjectIdentifier `ddl:"identifier"` +} + +// ShowViewOptions is based on https://docs.snowflake.com/en/sql-reference/sql/show-views. +type ShowViewOptions struct { + show bool `ddl:"static" sql:"SHOW"` + Terse *bool `ddl:"keyword" sql:"TERSE"` + views bool `ddl:"static" sql:"VIEWS"` + Like *Like `ddl:"keyword" sql:"LIKE"` + In *In `ddl:"keyword" sql:"IN"` + StartsWith *string `ddl:"parameter,no_equals,single_quotes" sql:"STARTS WITH"` + Limit *LimitFrom `ddl:"keyword" sql:"LIMIT"` +} + +type viewDBRow struct { + CreatedOn string `db:"created_on"` + Name string `db:"name"` + Kind sql.NullString `db:"kind"` + Reserved sql.NullString `db:"reserved"` + DatabaseName string `db:"database_name"` + SchemaName string `db:"schema_name"` + Owner sql.NullString `db:"owner"` + Comment sql.NullString `db:"comment"` + Text sql.NullString `db:"text"` + IsSecure sql.NullBool `db:"is_secure"` + IsMaterialized sql.NullBool `db:"is_materialized"` + OwnerRoleType sql.NullString `db:"owner_role_type"` + ChangeTracking sql.NullString `db:"change_tracking"` +} + +type View struct { + CreatedOn string + Name string + Kind string + Reserved string + DatabaseName string + SchemaName string + Owner string + Comment string + Text string + IsSecure bool + IsMaterialized bool + OwnerRoleType string + ChangeTracking string +} + +func (v *View) ID() SchemaObjectIdentifier { + return NewSchemaObjectIdentifier(v.DatabaseName, v.SchemaName, v.Name) +} + +// DescribeViewOptions is based on https://docs.snowflake.com/en/sql-reference/sql/desc-view. +type DescribeViewOptions struct { + describe bool `ddl:"static" sql:"DESCRIBE"` + view bool `ddl:"static" sql:"VIEW"` + name SchemaObjectIdentifier `ddl:"identifier"` +} + +// TODO [SNOW-965322]: extract common type for describe +// viewDetailsRow is a copy of externalTableColumnDetailsRow. +type viewDetailsRow struct { + Name string `db:"name"` + Type DataType `db:"type"` + Kind string `db:"kind"` + IsNullable string `db:"null?"` + Default sql.NullString `db:"default"` + IsPrimary string `db:"primary key"` + IsUnique string `db:"unique key"` + Check sql.NullString `db:"check"` + Expression sql.NullString `db:"expression"` + Comment sql.NullString `db:"comment"` + PolicyName sql.NullString `db:"policy name"` +} + +// ViewDetails is a copy of ExternalTableColumnDetails. +type ViewDetails struct { + Name string + Type DataType + Kind string + IsNullable bool + Default *string + IsPrimary bool + IsUnique bool + Check *bool + Expression *string + Comment *string + PolicyName *string +} diff --git a/pkg/sdk/views_gen_test.go b/pkg/sdk/views_gen_test.go new file mode 100644 index 0000000000..e7aa732e1f --- /dev/null +++ b/pkg/sdk/views_gen_test.go @@ -0,0 +1,441 @@ +package sdk + +import ( + "testing" +) + +func TestViews_Create(t *testing.T) { + id := RandomSchemaObjectIdentifier() + sql := "SELECT id FROM t" + + // Minimal valid CreateViewOptions + defaultOpts := func() *CreateViewOptions { + return &CreateViewOptions{ + name: id, + sql: sql, + } + } + + t.Run("validation: nil options", func(t *testing.T) { + var opts *CreateViewOptions = nil + assertOptsInvalidJoinedErrors(t, opts, ErrNilOptions) + }) + + t.Run("validation: valid identifier for [opts.name]", func(t *testing.T) { + opts := defaultOpts() + opts.name = NewSchemaObjectIdentifier("", "", "") + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) + }) + + t.Run("validation: conflicting fields for [opts.OrReplace opts.IfNotExists]", func(t *testing.T) { + opts := defaultOpts() + opts.OrReplace = Bool(true) + opts.IfNotExists = Bool(true) + assertOptsInvalidJoinedErrors(t, opts, errOneOf("CreateViewOptions", "OrReplace", "IfNotExists")) + }) + + t.Run("validation: valid identifier for [opts.RowAccessPolicy.RowAccessPolicy]", func(t *testing.T) { + opts := defaultOpts() + opts.RowAccessPolicy = &ViewRowAccessPolicy{ + RowAccessPolicy: NewSchemaObjectIdentifier("", "", ""), + } + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) + }) + + t.Run("validation: empty columns for row access policy", func(t *testing.T) { + opts := defaultOpts() + opts.RowAccessPolicy = &ViewRowAccessPolicy{ + RowAccessPolicy: RandomSchemaObjectIdentifier(), + On: []string{}, + } + assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateViewOptions.RowAccessPolicy", "On")) + }) + + t.Run("basic", func(t *testing.T) { + opts := defaultOpts() + assertOptsValidAndSQLEquals(t, opts, "CREATE VIEW %s AS %s", id.FullyQualifiedName(), sql) + }) + + t.Run("all options", func(t *testing.T) { + rowAccessPolicyId := RandomSchemaObjectIdentifier() + tag1Id := RandomSchemaObjectIdentifier() + tag2Id := RandomSchemaObjectIdentifier() + maskingPolicy1Id := RandomSchemaObjectIdentifier() + maskingPolicy2Id := RandomSchemaObjectIdentifier() + + req := NewCreateViewRequest(id, sql). + WithOrReplace(Bool(true)). + WithSecure(Bool(true)). + WithTemporary(Bool(true)). + WithRecursive(Bool(true)). + WithColumns([]ViewColumnRequest{ + *NewViewColumnRequest("column_without_comment"), + *NewViewColumnRequest("column_with_comment").WithComment(String("column 2 comment")), + }). + WithColumnsMaskingPolicies([]ViewColumnMaskingPolicyRequest{ + *NewViewColumnMaskingPolicyRequest("column", maskingPolicy1Id). + WithUsing([]string{"a", "b"}). + WithTag([]TagAssociation{{ + Name: tag1Id, + Value: "v1", + }}), + *NewViewColumnMaskingPolicyRequest("column 2", maskingPolicy2Id), + }). + WithCopyGrants(Bool(true)). + WithComment(String("comment")). + WithRowAccessPolicy(NewViewRowAccessPolicyRequest(rowAccessPolicyId, []string{"c", "d"})). + WithTag([]TagAssociation{{ + Name: tag2Id, + Value: "v2", + }}) + + assertOptsValidAndSQLEquals(t, req.toOpts(), `CREATE OR REPLACE SECURE TEMPORARY RECURSIVE VIEW %s ("column_without_comment", "column_with_comment" COMMENT 'column 2 comment') column MASKING POLICY %s USING (a, b) TAG (%s = 'v1'), column 2 MASKING POLICY %s COPY GRANTS COMMENT = 'comment' ROW ACCESS POLICY %s ON (c, d) TAG (%s = 'v2') AS %s`, id.FullyQualifiedName(), maskingPolicy1Id.FullyQualifiedName(), tag1Id.FullyQualifiedName(), maskingPolicy2Id.FullyQualifiedName(), rowAccessPolicyId.FullyQualifiedName(), tag2Id.FullyQualifiedName(), sql) + }) +} + +func TestViews_Alter(t *testing.T) { + id := RandomSchemaObjectIdentifier() + + // Minimal valid AlterViewOptions + defaultOpts := func() *AlterViewOptions { + return &AlterViewOptions{ + name: id, + } + } + + t.Run("validation: nil options", func(t *testing.T) { + var opts *AlterViewOptions = nil + assertOptsInvalidJoinedErrors(t, opts, ErrNilOptions) + }) + + t.Run("validation: valid identifier for [opts.name]", func(t *testing.T) { + opts := defaultOpts() + opts.name = NewSchemaObjectIdentifier("", "", "") + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) + }) + + t.Run("validation: exactly one field from [opts.RenameTo opts.SetComment opts.UnsetComment opts.SetSecure opts.SetChangeTracking opts.UnsetSecure opts.SetTags opts.UnsetTags opts.AddRowAccessPolicy opts.DropRowAccessPolicy opts.DropAndAddRowAccessPolicy opts.DropAllRowAccessPolicies opts.SetMaskingPolicyOnColumn opts.UnsetMaskingPolicyOnColumn opts.SetTagsOnColumn opts.UnsetTagsOnColumn] should be present", func(t *testing.T) { + opts := defaultOpts() + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterViewOptions", "RenameTo", "SetComment", "UnsetComment", "SetSecure", "SetChangeTracking", "UnsetSecure", "SetTags", "UnsetTags", "AddRowAccessPolicy", "DropRowAccessPolicy", "DropAndAddRowAccessPolicy", "DropAllRowAccessPolicies", "SetMaskingPolicyOnColumn", "UnsetMaskingPolicyOnColumn", "SetTagsOnColumn", "UnsetTagsOnColumn")) + }) + + t.Run("validation: exactly one field from [opts.RenameTo opts.SetComment opts.UnsetComment opts.SetSecure opts.SetChangeTracking opts.UnsetSecure opts.SetTags opts.UnsetTags opts.AddRowAccessPolicy opts.DropRowAccessPolicy opts.DropAndAddRowAccessPolicy opts.DropAllRowAccessPolicies opts.SetMaskingPolicyOnColumn opts.UnsetMaskingPolicyOnColumn opts.SetTagsOnColumn opts.UnsetTagsOnColumn] should be present - more present", func(t *testing.T) { + opts := defaultOpts() + opts.SetChangeTracking = Bool(true) + opts.DropAllRowAccessPolicies = Bool(true) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterViewOptions", "RenameTo", "SetComment", "UnsetComment", "SetSecure", "SetChangeTracking", "UnsetSecure", "SetTags", "UnsetTags", "AddRowAccessPolicy", "DropRowAccessPolicy", "DropAndAddRowAccessPolicy", "DropAllRowAccessPolicies", "SetMaskingPolicyOnColumn", "UnsetMaskingPolicyOnColumn", "SetTagsOnColumn", "UnsetTagsOnColumn")) + }) + + t.Run("validation: valid identifier for [opts.DropRowAccessPolicy.RowAccessPolicy]", func(t *testing.T) { + opts := defaultOpts() + opts.DropRowAccessPolicy = &ViewDropRowAccessPolicy{ + RowAccessPolicy: NewSchemaObjectIdentifier("", "", ""), + } + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) + }) + + t.Run("validation: valid identifier for [opts.AddRowAccessPolicy.RowAccessPolicy]", func(t *testing.T) { + opts := defaultOpts() + opts.AddRowAccessPolicy = &ViewAddRowAccessPolicy{ + RowAccessPolicy: NewSchemaObjectIdentifier("", "", ""), + } + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) + }) + + t.Run("validation: empty columns for row access policy (add)", func(t *testing.T) { + opts := defaultOpts() + opts.AddRowAccessPolicy = &ViewAddRowAccessPolicy{ + RowAccessPolicy: RandomSchemaObjectIdentifier(), + On: []string{}, + } + assertOptsInvalidJoinedErrors(t, opts, errNotSet("AlterViewOptions.AddRowAccessPolicy", "On")) + }) + + t.Run("validation: valid identifier for [opts.DropAndAddRowAccessPolicy.Drop.RowAccessPolicy]", func(t *testing.T) { + opts := defaultOpts() + opts.DropAndAddRowAccessPolicy = &ViewDropAndAddRowAccessPolicy{ + Drop: ViewDropRowAccessPolicy{ + RowAccessPolicy: NewSchemaObjectIdentifier("", "", ""), + }, + Add: ViewAddRowAccessPolicy{ + RowAccessPolicy: RandomSchemaObjectIdentifier(), + }, + } + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) + }) + + t.Run("validation: valid identifier for [opts.DropAndAddRowAccessPolicy.Add.RowAccessPolicy]", func(t *testing.T) { + opts := defaultOpts() + opts.DropAndAddRowAccessPolicy = &ViewDropAndAddRowAccessPolicy{ + Drop: ViewDropRowAccessPolicy{ + RowAccessPolicy: RandomSchemaObjectIdentifier(), + }, + Add: ViewAddRowAccessPolicy{ + RowAccessPolicy: NewSchemaObjectIdentifier("", "", ""), + }, + } + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) + }) + + t.Run("validation: empty columns for row access policy (drop and add)", func(t *testing.T) { + opts := defaultOpts() + opts.DropAndAddRowAccessPolicy = &ViewDropAndAddRowAccessPolicy{ + Drop: ViewDropRowAccessPolicy{ + RowAccessPolicy: RandomSchemaObjectIdentifier(), + }, + Add: ViewAddRowAccessPolicy{ + RowAccessPolicy: RandomSchemaObjectIdentifier(), + On: []string{}, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errNotSet("AlterViewOptions.DropAndAddRowAccessPolicy.Add", "On")) + }) + + t.Run("rename", func(t *testing.T) { + newId := RandomSchemaObjectIdentifier() + + opts := defaultOpts() + opts.RenameTo = &newId + assertOptsValidAndSQLEquals(t, opts, "ALTER VIEW %s RENAME TO %s", id.FullyQualifiedName(), newId.FullyQualifiedName()) + }) + + t.Run("set comment", func(t *testing.T) { + opts := defaultOpts() + opts.SetComment = String("comment") + assertOptsValidAndSQLEquals(t, opts, "ALTER VIEW %s SET COMMENT = 'comment'", id.FullyQualifiedName()) + }) + + t.Run("unset comment", func(t *testing.T) { + opts := defaultOpts() + opts.UnsetComment = Bool(true) + assertOptsValidAndSQLEquals(t, opts, "ALTER VIEW %s UNSET COMMENT", id.FullyQualifiedName()) + }) + + t.Run("set secure", func(t *testing.T) { + opts := defaultOpts() + opts.SetSecure = Bool(true) + assertOptsValidAndSQLEquals(t, opts, "ALTER VIEW %s SET SECURE", id.FullyQualifiedName()) + }) + + t.Run("set change tracking: true", func(t *testing.T) { + opts := defaultOpts() + opts.SetChangeTracking = Bool(true) + assertOptsValidAndSQLEquals(t, opts, "ALTER VIEW %s SET CHANGE_TRACKING = true", id.FullyQualifiedName()) + }) + + t.Run("set change tracking: false", func(t *testing.T) { + opts := defaultOpts() + opts.SetChangeTracking = Bool(false) + assertOptsValidAndSQLEquals(t, opts, "ALTER VIEW %s SET CHANGE_TRACKING = false", id.FullyQualifiedName()) + }) + + t.Run("unset secure", func(t *testing.T) { + opts := defaultOpts() + opts.UnsetSecure = Bool(true) + assertOptsValidAndSQLEquals(t, opts, "ALTER VIEW %s UNSET SECURE", id.FullyQualifiedName()) + }) + + t.Run("set tags", func(t *testing.T) { + opts := defaultOpts() + opts.SetTags = []TagAssociation{ + { + Name: NewAccountObjectIdentifier("tag1"), + Value: "value1", + }, + { + Name: NewAccountObjectIdentifier("tag2"), + Value: "value2", + }, + } + assertOptsValidAndSQLEquals(t, opts, `ALTER VIEW %s SET TAG "tag1" = 'value1', "tag2" = 'value2'`, id.FullyQualifiedName()) + }) + + t.Run("unset tags", func(t *testing.T) { + opts := defaultOpts() + opts.UnsetTags = []ObjectIdentifier{ + NewAccountObjectIdentifier("tag1"), + NewAccountObjectIdentifier("tag2"), + } + assertOptsValidAndSQLEquals(t, opts, `ALTER VIEW %s UNSET TAG "tag1", "tag2"`, id.FullyQualifiedName()) + }) + + t.Run("add row access policy", func(t *testing.T) { + rowAccessPolicyId := RandomSchemaObjectIdentifier() + + opts := defaultOpts() + opts.AddRowAccessPolicy = &ViewAddRowAccessPolicy{ + RowAccessPolicy: rowAccessPolicyId, + On: []string{"a", "b"}, + } + assertOptsValidAndSQLEquals(t, opts, "ALTER VIEW %s ADD ROW ACCESS POLICY %s ON (a, b)", id.FullyQualifiedName(), rowAccessPolicyId.FullyQualifiedName()) + }) + + t.Run("drop row access policy", func(t *testing.T) { + rowAccessPolicyId := RandomSchemaObjectIdentifier() + + opts := defaultOpts() + opts.DropRowAccessPolicy = &ViewDropRowAccessPolicy{ + RowAccessPolicy: rowAccessPolicyId, + } + assertOptsValidAndSQLEquals(t, opts, "ALTER VIEW %s DROP ROW ACCESS POLICY %s", id.FullyQualifiedName(), rowAccessPolicyId.FullyQualifiedName()) + }) + + t.Run("drop and add row access policy", func(t *testing.T) { + rowAccessPolicy1Id := RandomSchemaObjectIdentifier() + rowAccessPolicy2Id := RandomSchemaObjectIdentifier() + + opts := defaultOpts() + opts.DropAndAddRowAccessPolicy = &ViewDropAndAddRowAccessPolicy{ + Drop: ViewDropRowAccessPolicy{ + RowAccessPolicy: rowAccessPolicy1Id, + }, + Add: ViewAddRowAccessPolicy{ + RowAccessPolicy: rowAccessPolicy2Id, + On: []string{"a", "b"}, + }, + } + assertOptsValidAndSQLEquals(t, opts, "ALTER VIEW %s DROP ROW ACCESS POLICY %s, ADD ROW ACCESS POLICY %s ON (a, b)", id.FullyQualifiedName(), rowAccessPolicy1Id.FullyQualifiedName(), rowAccessPolicy2Id.FullyQualifiedName()) + }) + + t.Run("drop all row access policies", func(t *testing.T) { + opts := defaultOpts() + opts.DropAllRowAccessPolicies = Bool(true) + assertOptsValidAndSQLEquals(t, opts, "ALTER VIEW %s DROP ALL ROW ACCESS POLICIES", id.FullyQualifiedName()) + }) + + t.Run("set masking policy on column", func(t *testing.T) { + maskingPolicyId := RandomSchemaObjectIdentifier() + + opts := defaultOpts() + opts.SetMaskingPolicyOnColumn = &ViewSetColumnMaskingPolicy{ + Name: "column", + MaskingPolicy: maskingPolicyId, + Using: []string{"a", "b"}, + Force: Bool(true), + } + assertOptsValidAndSQLEquals(t, opts, "ALTER VIEW %s ALTER COLUMN column SET MASKING POLICY %s USING (a, b) FORCE", id.FullyQualifiedName(), maskingPolicyId.FullyQualifiedName()) + }) + + t.Run("unset masking policy on column", func(t *testing.T) { + opts := defaultOpts() + opts.UnsetMaskingPolicyOnColumn = &ViewUnsetColumnMaskingPolicy{ + Name: "column", + } + assertOptsValidAndSQLEquals(t, opts, "ALTER VIEW %s ALTER COLUMN column UNSET MASKING POLICY", id.FullyQualifiedName()) + }) + + t.Run("set tags on column", func(t *testing.T) { + opts := defaultOpts() + opts.SetTagsOnColumn = &ViewSetColumnTags{ + Name: "column", + SetTags: []TagAssociation{ + { + Name: NewAccountObjectIdentifier("tag1"), + Value: "value1", + }, + { + Name: NewAccountObjectIdentifier("tag2"), + Value: "value2", + }, + }, + } + assertOptsValidAndSQLEquals(t, opts, `ALTER VIEW %s ALTER COLUMN column SET TAG "tag1" = 'value1', "tag2" = 'value2'`, id.FullyQualifiedName()) + }) + + t.Run("unset tags on column", func(t *testing.T) { + opts := defaultOpts() + opts.UnsetTagsOnColumn = &ViewUnsetColumnTags{ + Name: "column", + UnsetTags: []ObjectIdentifier{ + NewAccountObjectIdentifier("tag1"), + NewAccountObjectIdentifier("tag2"), + }, + } + assertOptsValidAndSQLEquals(t, opts, `ALTER VIEW %s ALTER COLUMN column UNSET TAG "tag1", "tag2"`, id.FullyQualifiedName()) + }) +} + +func TestViews_Drop(t *testing.T) { + id := RandomSchemaObjectIdentifier() + + // Minimal valid DropViewOptions + defaultOpts := func() *DropViewOptions { + return &DropViewOptions{ + name: id, + } + } + + t.Run("validation: nil options", func(t *testing.T) { + var opts *DropViewOptions = nil + assertOptsInvalidJoinedErrors(t, opts, ErrNilOptions) + }) + + t.Run("validation: valid identifier for [opts.name]", func(t *testing.T) { + opts := defaultOpts() + opts.name = NewSchemaObjectIdentifier("", "", "") + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) + }) + + t.Run("basic", func(t *testing.T) { + opts := defaultOpts() + assertOptsValidAndSQLEquals(t, opts, "DROP VIEW %s", id.FullyQualifiedName()) + }) +} + +func TestViews_Show(t *testing.T) { + // Minimal valid ShowViewOptions + defaultOpts := func() *ShowViewOptions { + return &ShowViewOptions{} + } + + t.Run("validation: nil options", func(t *testing.T) { + var opts *ShowViewOptions = nil + assertOptsInvalidJoinedErrors(t, opts, ErrNilOptions) + }) + + t.Run("basic", func(t *testing.T) { + opts := defaultOpts() + assertOptsValidAndSQLEquals(t, opts, "SHOW VIEWS") + }) + + t.Run("all options", func(t *testing.T) { + opts := defaultOpts() + opts.Terse = Bool(true) + opts.Like = &Like{ + Pattern: String("myaccount"), + } + opts.In = &In{ + Account: Bool(true), + } + opts.StartsWith = String("abc") + opts.Limit = &LimitFrom{Rows: Int(10)} + assertOptsValidAndSQLEquals(t, opts, "SHOW TERSE VIEWS LIKE 'myaccount' IN ACCOUNT STARTS WITH 'abc' LIMIT 10") + }) +} + +func TestViews_Describe(t *testing.T) { + id := RandomSchemaObjectIdentifier() + + // Minimal valid DescribeViewOptions + defaultOpts := func() *DescribeViewOptions { + return &DescribeViewOptions{ + name: id, + } + } + + t.Run("validation: nil options", func(t *testing.T) { + var opts *DescribeViewOptions = nil + assertOptsInvalidJoinedErrors(t, opts, ErrNilOptions) + }) + + t.Run("validation: valid identifier for [opts.name]", func(t *testing.T) { + opts := defaultOpts() + opts.name = NewSchemaObjectIdentifier("", "", "") + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) + }) + + t.Run("basic", func(t *testing.T) { + opts := defaultOpts() + assertOptsValidAndSQLEquals(t, opts, "DESCRIBE VIEW %s", id.FullyQualifiedName()) + }) +} diff --git a/pkg/sdk/views_impl_gen.go b/pkg/sdk/views_impl_gen.go new file mode 100644 index 0000000000..c77278b471 --- /dev/null +++ b/pkg/sdk/views_impl_gen.go @@ -0,0 +1,249 @@ +package sdk + +import ( + "context" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/internal/collections" +) + +var _ Views = (*views)(nil) + +type views struct { + client *Client +} + +func (v *views) Create(ctx context.Context, request *CreateViewRequest) error { + opts := request.toOpts() + return validateAndExec(v.client, ctx, opts) +} + +func (v *views) Alter(ctx context.Context, request *AlterViewRequest) error { + opts := request.toOpts() + return validateAndExec(v.client, ctx, opts) +} + +func (v *views) Drop(ctx context.Context, request *DropViewRequest) error { + opts := request.toOpts() + return validateAndExec(v.client, ctx, opts) +} + +func (v *views) Show(ctx context.Context, request *ShowViewRequest) ([]View, error) { + opts := request.toOpts() + dbRows, err := validateAndQuery[viewDBRow](v.client, ctx, opts) + if err != nil { + return nil, err + } + resultList := convertRows[viewDBRow, View](dbRows) + return resultList, nil +} + +func (v *views) ShowByID(ctx context.Context, id SchemaObjectIdentifier) (*View, error) { + request := NewShowViewRequest().WithIn(&In{Database: NewAccountObjectIdentifier(id.DatabaseName())}).WithLike(&Like{String(id.Name())}) + views, err := v.Show(ctx, request) + if err != nil { + return nil, err + } + return collections.FindOne(views, func(r View) bool { return r.Name == id.Name() }) +} + +func (v *views) Describe(ctx context.Context, id SchemaObjectIdentifier) ([]ViewDetails, error) { + opts := &DescribeViewOptions{ + name: id, + } + rows, err := validateAndQuery[viewDetailsRow](v.client, ctx, opts) + if err != nil { + return nil, err + } + return convertRows[viewDetailsRow, ViewDetails](rows), nil +} + +func (r *CreateViewRequest) toOpts() *CreateViewOptions { + opts := &CreateViewOptions{ + OrReplace: r.OrReplace, + Secure: r.Secure, + Temporary: r.Temporary, + Recursive: r.Recursive, + IfNotExists: r.IfNotExists, + name: r.name, + + CopyGrants: r.CopyGrants, + Comment: r.Comment, + + Tag: r.Tag, + sql: r.sql, + } + if r.Columns != nil { + s := make([]ViewColumn, len(r.Columns)) + for i, v := range r.Columns { + s[i] = ViewColumn(v) + } + opts.Columns = s + } + if r.ColumnsMaskingPolicies != nil { + s := make([]ViewColumnMaskingPolicy, len(r.ColumnsMaskingPolicies)) + for i, v := range r.ColumnsMaskingPolicies { + s[i] = ViewColumnMaskingPolicy(v) + } + opts.ColumnsMaskingPolicies = s + } + if r.RowAccessPolicy != nil { + opts.RowAccessPolicy = &ViewRowAccessPolicy{ + RowAccessPolicy: r.RowAccessPolicy.RowAccessPolicy, + On: r.RowAccessPolicy.On, + } + } + return opts +} + +func (r *AlterViewRequest) toOpts() *AlterViewOptions { + opts := &AlterViewOptions{ + IfExists: r.IfExists, + name: r.name, + RenameTo: r.RenameTo, + SetComment: r.SetComment, + UnsetComment: r.UnsetComment, + SetSecure: r.SetSecure, + SetChangeTracking: r.SetChangeTracking, + UnsetSecure: r.UnsetSecure, + SetTags: r.SetTags, + UnsetTags: r.UnsetTags, + DropAllRowAccessPolicies: r.DropAllRowAccessPolicies, + } + if r.AddRowAccessPolicy != nil { + opts.AddRowAccessPolicy = &ViewAddRowAccessPolicy{ + RowAccessPolicy: r.AddRowAccessPolicy.RowAccessPolicy, + On: r.AddRowAccessPolicy.On, + } + } + if r.DropRowAccessPolicy != nil { + opts.DropRowAccessPolicy = &ViewDropRowAccessPolicy{ + RowAccessPolicy: r.DropRowAccessPolicy.RowAccessPolicy, + } + } + if r.DropAndAddRowAccessPolicy != nil { + opts.DropAndAddRowAccessPolicy = &ViewDropAndAddRowAccessPolicy{} + opts.DropAndAddRowAccessPolicy.Drop = ViewDropRowAccessPolicy{ + RowAccessPolicy: r.DropAndAddRowAccessPolicy.Drop.RowAccessPolicy, + } + opts.DropAndAddRowAccessPolicy.Add = ViewAddRowAccessPolicy{ + RowAccessPolicy: r.DropAndAddRowAccessPolicy.Add.RowAccessPolicy, + On: r.DropAndAddRowAccessPolicy.Add.On, + } + } + if r.SetMaskingPolicyOnColumn != nil { + opts.SetMaskingPolicyOnColumn = &ViewSetColumnMaskingPolicy{ + Name: r.SetMaskingPolicyOnColumn.Name, + MaskingPolicy: r.SetMaskingPolicyOnColumn.MaskingPolicy, + Using: r.SetMaskingPolicyOnColumn.Using, + Force: r.SetMaskingPolicyOnColumn.Force, + } + } + if r.UnsetMaskingPolicyOnColumn != nil { + opts.UnsetMaskingPolicyOnColumn = &ViewUnsetColumnMaskingPolicy{ + Name: r.UnsetMaskingPolicyOnColumn.Name, + } + } + if r.SetTagsOnColumn != nil { + opts.SetTagsOnColumn = &ViewSetColumnTags{ + Name: r.SetTagsOnColumn.Name, + SetTags: r.SetTagsOnColumn.SetTags, + } + } + if r.UnsetTagsOnColumn != nil { + opts.UnsetTagsOnColumn = &ViewUnsetColumnTags{ + Name: r.UnsetTagsOnColumn.Name, + UnsetTags: r.UnsetTagsOnColumn.UnsetTags, + } + } + return opts +} + +func (r *DropViewRequest) toOpts() *DropViewOptions { + opts := &DropViewOptions{ + IfExists: r.IfExists, + name: r.name, + } + return opts +} + +func (r *ShowViewRequest) toOpts() *ShowViewOptions { + opts := &ShowViewOptions{ + Terse: r.Terse, + Like: r.Like, + In: r.In, + StartsWith: r.StartsWith, + Limit: r.Limit, + } + return opts +} + +func (r viewDBRow) convert() *View { + view := View{ + CreatedOn: r.CreatedOn, + Name: r.Name, + DatabaseName: r.DatabaseName, + SchemaName: r.SchemaName, + } + if r.Kind.Valid { + view.Kind = r.Kind.String + } + if r.Reserved.Valid { + view.Reserved = r.Reserved.String + } + if r.Owner.Valid { + view.Owner = r.Owner.String + } + if r.Comment.Valid { + view.Comment = r.Comment.String + } + if r.Text.Valid { + view.Text = r.Text.String + } + if r.IsSecure.Valid { + view.IsSecure = r.IsSecure.Bool + } + if r.IsMaterialized.Valid { + view.IsMaterialized = r.IsMaterialized.Bool + } + if r.OwnerRoleType.Valid { + view.OwnerRoleType = r.OwnerRoleType.String + } + if r.ChangeTracking.Valid { + view.ChangeTracking = r.ChangeTracking.String + } + return &view +} + +func (r *DescribeViewRequest) toOpts() *DescribeViewOptions { + opts := &DescribeViewOptions{ + name: r.name, + } + return opts +} + +func (r viewDetailsRow) convert() *ViewDetails { + details := &ViewDetails{ + Name: r.Name, + Type: r.Type, + Kind: r.Kind, + IsNullable: r.IsNullable == "Y", + IsPrimary: r.IsPrimary == "Y", + IsUnique: r.IsUnique == "Y", + } + if r.Default.Valid { + details.Default = String(r.Default.String) + } + if r.Check.Valid { + details.Check = Bool(r.Check.String == "Y") + } + if r.Expression.Valid { + details.Expression = String(r.Expression.String) + } + if r.Comment.Valid { + details.Comment = String(r.Comment.String) + } + if r.PolicyName.Valid { + details.PolicyName = String(r.PolicyName.String) + } + return details +} diff --git a/pkg/sdk/views_validations_gen.go b/pkg/sdk/views_validations_gen.go new file mode 100644 index 0000000000..cbc47adb61 --- /dev/null +++ b/pkg/sdk/views_validations_gen.go @@ -0,0 +1,103 @@ +package sdk + +var ( + _ validatable = new(CreateViewOptions) + _ validatable = new(AlterViewOptions) + _ validatable = new(DropViewOptions) + _ validatable = new(ShowViewOptions) + _ validatable = new(DescribeViewOptions) +) + +func (opts *CreateViewOptions) validate() error { + if opts == nil { + return ErrNilOptions + } + var errs []error + if !ValidObjectIdentifier(opts.name) { + errs = append(errs, ErrInvalidObjectIdentifier) + } + if everyValueSet(opts.OrReplace, opts.IfNotExists) { + errs = append(errs, errOneOf("CreateViewOptions", "OrReplace", "IfNotExists")) + } + if valueSet(opts.RowAccessPolicy) { + if !ValidObjectIdentifier(opts.RowAccessPolicy.RowAccessPolicy) { + errs = append(errs, ErrInvalidObjectIdentifier) + } + if !valueSet(opts.RowAccessPolicy.On) { + errs = append(errs, errNotSet("CreateViewOptions.RowAccessPolicy", "On")) + } + } + return JoinErrors(errs...) +} + +func (opts *AlterViewOptions) validate() error { + if opts == nil { + return ErrNilOptions + } + var errs []error + if !ValidObjectIdentifier(opts.name) { + errs = append(errs, ErrInvalidObjectIdentifier) + } + if !exactlyOneValueSet(opts.RenameTo, opts.SetComment, opts.UnsetComment, opts.SetSecure, opts.SetChangeTracking, opts.UnsetSecure, opts.SetTags, opts.UnsetTags, opts.AddRowAccessPolicy, opts.DropRowAccessPolicy, opts.DropAndAddRowAccessPolicy, opts.DropAllRowAccessPolicies, opts.SetMaskingPolicyOnColumn, opts.UnsetMaskingPolicyOnColumn, opts.SetTagsOnColumn, opts.UnsetTagsOnColumn) { + errs = append(errs, errExactlyOneOf("AlterViewOptions", "RenameTo", "SetComment", "UnsetComment", "SetSecure", "SetChangeTracking", "UnsetSecure", "SetTags", "UnsetTags", "AddRowAccessPolicy", "DropRowAccessPolicy", "DropAndAddRowAccessPolicy", "DropAllRowAccessPolicies", "SetMaskingPolicyOnColumn", "UnsetMaskingPolicyOnColumn", "SetTagsOnColumn", "UnsetTagsOnColumn")) + } + if valueSet(opts.AddRowAccessPolicy) { + if !ValidObjectIdentifier(opts.AddRowAccessPolicy.RowAccessPolicy) { + errs = append(errs, ErrInvalidObjectIdentifier) + } + if !valueSet(opts.AddRowAccessPolicy.On) { + errs = append(errs, errNotSet("AlterViewOptions.AddRowAccessPolicy", "On")) + } + } + if valueSet(opts.DropRowAccessPolicy) { + if !ValidObjectIdentifier(opts.DropRowAccessPolicy.RowAccessPolicy) { + errs = append(errs, ErrInvalidObjectIdentifier) + } + } + if valueSet(opts.DropAndAddRowAccessPolicy) { + if valueSet(opts.DropAndAddRowAccessPolicy.Drop) { + if !ValidObjectIdentifier(opts.DropAndAddRowAccessPolicy.Drop.RowAccessPolicy) { + errs = append(errs, ErrInvalidObjectIdentifier) + } + } + if valueSet(opts.DropAndAddRowAccessPolicy.Add) { + if !ValidObjectIdentifier(opts.DropAndAddRowAccessPolicy.Add.RowAccessPolicy) { + errs = append(errs, ErrInvalidObjectIdentifier) + } + if !valueSet(opts.DropAndAddRowAccessPolicy.Add.On) { + errs = append(errs, errNotSet("AlterViewOptions.DropAndAddRowAccessPolicy.Add", "On")) + } + } + } + return JoinErrors(errs...) +} + +func (opts *DropViewOptions) validate() error { + if opts == nil { + return ErrNilOptions + } + var errs []error + if !ValidObjectIdentifier(opts.name) { + errs = append(errs, ErrInvalidObjectIdentifier) + } + return JoinErrors(errs...) +} + +func (opts *ShowViewOptions) validate() error { + if opts == nil { + return ErrNilOptions + } + var errs []error + return JoinErrors(errs...) +} + +func (opts *DescribeViewOptions) validate() error { + if opts == nil { + return ErrNilOptions + } + var errs []error + if !ValidObjectIdentifier(opts.name) { + errs = append(errs, ErrInvalidObjectIdentifier) + } + return JoinErrors(errs...) +} From 4765410d8c0d7a1a848e383a8faa97f51c5b1c71 Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Mon, 13 Nov 2023 09:41:05 +0100 Subject: [PATCH 20/20] Fix the typo (#2183) --- docs/resources/role_ownership_grant.md | 35 +++++++++++++++++++ .../import.sh | 0 .../resource.tf | 0 3 files changed, 35 insertions(+) rename examples/resources/{snowflake_role_ownership_grants => snowflake_role_ownership_grant}/import.sh (100%) rename examples/resources/{snowflake_role_ownership_grants => snowflake_role_ownership_grant}/resource.tf (100%) diff --git a/docs/resources/role_ownership_grant.md b/docs/resources/role_ownership_grant.md index fca231a8dc..954ab9b212 100644 --- a/docs/resources/role_ownership_grant.md +++ b/docs/resources/role_ownership_grant.md @@ -10,7 +10,34 @@ description: |- +## Example Usage +```terraform +resource "snowflake_role" "role" { + name = "rking_test_role" + comment = "for testing" +} + +resource "snowflake_role" "other_role" { + name = "rking_test_role2" +} + +# ensure the Terraform user inherits ownership privileges for the rking_test_role role +# otherwise Terraform will fail to destroy the rking_test_role2 role due to insufficient privileges +resource "snowflake_role_grants" "grants" { + role_name = snowflake_role.role.name + + roles = [ + "ACCOUNTADMIN", + ] +} + +resource "snowflake_role_ownership_grant" "grant" { + on_role_name = snowflake_role.role.name + to_role_name = snowflake_role.other_role.name + current_grants = "COPY" +} +``` ## Schema @@ -28,3 +55,11 @@ description: |- ### Read-Only - `id` (String) The ID of this resource. + +## Import + +Import is supported using the following syntax: + +```shell +terraform import snowflake_role_ownership_grant.example rolename +``` diff --git a/examples/resources/snowflake_role_ownership_grants/import.sh b/examples/resources/snowflake_role_ownership_grant/import.sh similarity index 100% rename from examples/resources/snowflake_role_ownership_grants/import.sh rename to examples/resources/snowflake_role_ownership_grant/import.sh diff --git a/examples/resources/snowflake_role_ownership_grants/resource.tf b/examples/resources/snowflake_role_ownership_grant/resource.tf similarity index 100% rename from examples/resources/snowflake_role_ownership_grants/resource.tf rename to examples/resources/snowflake_role_ownership_grant/resource.tf