From 01c97c85a0dac2a9eb4a5f5944d1e2be2a4cfb70 Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Tue, 3 Dec 2024 14:15:43 +0100 Subject: [PATCH 01/29] use data types in sql starts here From 8301db3b49b7356c5e1063c09643ff935ff2399d Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Tue, 3 Dec 2024 15:19:31 +0100 Subject: [PATCH 02/29] Add special handling for DataTYpe interface in sql builder (PoC) --- pkg/sdk/sql_builder.go | 8 ++++++- pkg/sdk/sql_builder_test.go | 45 +++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/pkg/sdk/sql_builder.go b/pkg/sdk/sql_builder.go index bae8ef485c..6385781640 100644 --- a/pkg/sdk/sql_builder.go +++ b/pkg/sdk/sql_builder.go @@ -6,6 +6,8 @@ import ( "strings" "time" "unsafe" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" ) type modifierType string @@ -642,7 +644,11 @@ func (v sqlParameterClause) String() string { if v.value == nil { return s } + var value = v.value + if dataType, ok := value.(datatypes.DataType); ok { + value = dataType.ToSql() + } // key = "value" - s += v.qm.Modify(v.value) + s += v.qm.Modify(value) return s } diff --git a/pkg/sdk/sql_builder_test.go b/pkg/sdk/sql_builder_test.go index b5eac3de58..a7cfd04075 100644 --- a/pkg/sdk/sql_builder_test.go +++ b/pkg/sdk/sql_builder_test.go @@ -1,9 +1,11 @@ package sdk import ( + "fmt" "reflect" "testing" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -472,3 +474,46 @@ func TestBuilder_sql(t *testing.T) { assert.Equal(t, "EXAMPLE_STATIC EXAMPLE_KEYWORD = example", s) }) } + +func TestBuilder_DataType(t *testing.T) { + + type dataTypeTestHelper struct { + DataType datatypes.DataType `ddl:"parameter,no_quotes,no_equals"` + } + + dataTypes := []struct { + dataType string + expectedSql string + }{ + {dataType: "VARCHAR(20)", expectedSql: "VARCHAR(20)"}, + {dataType: "VARCHAR", expectedSql: "VARCHAR(16777216)"}, + {dataType: "CHAR", expectedSql: "CHAR(1)"}, + {dataType: "NUMBER", expectedSql: "NUMBER(38, 0)"}, + } + + t.Run("test data type empty", func(t *testing.T) { + opts := dataTypeTestHelper{} + + s, err := structToSQL(opts) + + require.NoError(t, err) + assert.Equal(t, "", s) + }) + + for _, tc := range dataTypes { + tc := tc + t.Run(fmt.Sprintf(`cheking building SQL for data type "%s, expecting "%s"`, tc.dataType, tc.expectedSql), func(t *testing.T) { + dataType, err := datatypes.ParseDataType(tc.dataType) + require.NoError(t, err) + + opts := dataTypeTestHelper{ + DataType: dataType, + } + + s, err := structToSQL(opts) + + require.NoError(t, err) + assert.Equal(t, tc.expectedSql, s) + }) + } +} From f157cfb8887bbb446b0ecd3b1819d768e07080cf Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Wed, 4 Dec 2024 00:10:51 +0100 Subject: [PATCH 03/29] Add TODOs --- pkg/sdk/sql_builder_test.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pkg/sdk/sql_builder_test.go b/pkg/sdk/sql_builder_test.go index a7cfd04075..de9e16ef70 100644 --- a/pkg/sdk/sql_builder_test.go +++ b/pkg/sdk/sql_builder_test.go @@ -475,6 +475,11 @@ func TestBuilder_sql(t *testing.T) { }) } +// TODO [this PR]: add constructors for each data type? +// TODO [this PR]: test printing all data types +// TODO [this PR]: add optional alternatives to functions and procedures (arguments and return types) +// TODO [this PR]: integration tests for both options +// TODO [this PR]: integration test to check all data types in a new way + reading from snowflake? func TestBuilder_DataType(t *testing.T) { type dataTypeTestHelper struct { From d29bb7271f34518f53819b8889b34f73ea92f765 Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Wed, 4 Dec 2024 11:15:22 +0100 Subject: [PATCH 04/29] Test all data types in sql builder --- pkg/sdk/datatypes/number.go | 3 +++ pkg/sdk/sql_builder_test.go | 51 ++++++++++++++++++++++++++++++++++--- 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/pkg/sdk/datatypes/number.go b/pkg/sdk/datatypes/number.go index 14ac2696fc..378c9e3205 100644 --- a/pkg/sdk/datatypes/number.go +++ b/pkg/sdk/datatypes/number.go @@ -24,6 +24,9 @@ type NumberDataType struct { } func (t *NumberDataType) ToSql() string { + if slices.Contains(NumberDataTypeSubTypes, t.underlyingType) { + return t.underlyingType + } return fmt.Sprintf("%s(%d, %d)", t.underlyingType, t.precision, t.scale) } diff --git a/pkg/sdk/sql_builder_test.go b/pkg/sdk/sql_builder_test.go index de9e16ef70..f708799159 100644 --- a/pkg/sdk/sql_builder_test.go +++ b/pkg/sdk/sql_builder_test.go @@ -475,8 +475,6 @@ func TestBuilder_sql(t *testing.T) { }) } -// TODO [this PR]: add constructors for each data type? -// TODO [this PR]: test printing all data types // TODO [this PR]: add optional alternatives to functions and procedures (arguments and return types) // TODO [this PR]: integration tests for both options // TODO [this PR]: integration test to check all data types in a new way + reading from snowflake? @@ -490,10 +488,57 @@ func TestBuilder_DataType(t *testing.T) { dataType string expectedSql string }{ + {dataType: "ARRAY", expectedSql: "ARRAY"}, + {dataType: "array", expectedSql: "ARRAY"}, + {dataType: "BINARY", expectedSql: "BINARY(8388608)"}, + {dataType: "binary(120)", expectedSql: "BINARY(120)"}, + {dataType: "BOOLEAN", expectedSql: "BOOLEAN"}, + {dataType: "boolean", expectedSql: "BOOLEAN"}, + {dataType: "DATE", expectedSql: "DATE"}, + {dataType: "date", expectedSql: "DATE"}, + {dataType: "FLOAT", expectedSql: "FLOAT"}, + {dataType: "float4", expectedSql: "FLOAT4"}, + {dataType: "real", expectedSql: "REAL"}, + {dataType: "GEOGRAPHY", expectedSql: "GEOGRAPHY"}, + {dataType: "geography", expectedSql: "GEOGRAPHY"}, + {dataType: "GEOMETRY", expectedSql: "GEOMETRY"}, + {dataType: "geometry", expectedSql: "GEOMETRY"}, + {dataType: "NUMBER", expectedSql: "NUMBER(38, 0)"}, + {dataType: "NUMBER(36)", expectedSql: "NUMBER(36, 0)"}, + {dataType: "NUMBER(36, 2)", expectedSql: "NUMBER(36, 2)"}, + {dataType: "number(36, 2)", expectedSql: "NUMBER(36, 2)"}, + {dataType: "INT", expectedSql: "INT"}, + {dataType: "integer", expectedSql: "INTEGER"}, + {dataType: "OBJECT", expectedSql: "OBJECT"}, + {dataType: "object", expectedSql: "OBJECT"}, {dataType: "VARCHAR(20)", expectedSql: "VARCHAR(20)"}, {dataType: "VARCHAR", expectedSql: "VARCHAR(16777216)"}, + {dataType: "varchar", expectedSql: "VARCHAR(16777216)"}, {dataType: "CHAR", expectedSql: "CHAR(1)"}, - {dataType: "NUMBER", expectedSql: "NUMBER(38, 0)"}, + {dataType: "char(34)", expectedSql: "CHAR(34)"}, + {dataType: "TIME", expectedSql: "TIME(9)"}, + {dataType: "time", expectedSql: "TIME(9)"}, + {dataType: "time(5)", expectedSql: "TIME(5)"}, + {dataType: "TIMESTAMP_LTZ", expectedSql: "TIMESTAMP_LTZ(9)"}, + {dataType: "timestamp_ltz", expectedSql: "TIMESTAMP_LTZ(9)"}, + {dataType: "timestampltz", expectedSql: "TIMESTAMPLTZ(9)"}, + {dataType: "timestampltz(5)", expectedSql: "TIMESTAMPLTZ(5)"}, + {dataType: "TIMESTAMP_NTZ", expectedSql: "TIMESTAMP_NTZ(9)"}, + {dataType: "timestamp_ntz", expectedSql: "TIMESTAMP_NTZ(9)"}, + {dataType: "timestamp_ntz(5)", expectedSql: "TIMESTAMP_NTZ(5)"}, + {dataType: "timestampntz", expectedSql: "TIMESTAMPNTZ(9)"}, + {dataType: "timestampntz(5)", expectedSql: "TIMESTAMPNTZ(5)"}, + {dataType: "TIMESTAMP_TZ", expectedSql: "TIMESTAMP_TZ(9)"}, + {dataType: "timestamp_tz", expectedSql: "TIMESTAMP_TZ(9)"}, + {dataType: "timestamp_tz(5)", expectedSql: "TIMESTAMP_TZ(5)"}, + {dataType: "timestamptz", expectedSql: "TIMESTAMPTZ(9)"}, + {dataType: "timestamptz(5)", expectedSql: "TIMESTAMPTZ(5)"}, + {dataType: "VARIANT", expectedSql: "VARIANT"}, + {dataType: "variant", expectedSql: "VARIANT"}, + {dataType: "VECTOR(INT, 20)", expectedSql: "VECTOR(INT, 20)"}, + {dataType: "VECTOR(FLOAT, 20)", expectedSql: "VECTOR(FLOAT, 20)"}, + {dataType: "VECTOR(int, 20)", expectedSql: "VECTOR(INT, 20)"}, + {dataType: "VECTOR(float, 20)", expectedSql: "VECTOR(FLOAT, 20)"}, } t.Run("test data type empty", func(t *testing.T) { From bb5e6717476a07750b1cfd7cc589c0f6b3350f6b Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Wed, 4 Dec 2024 11:19:27 +0100 Subject: [PATCH 05/29] Alter functions def (old data type) --- pkg/sdk/functions_def.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/sdk/functions_def.go b/pkg/sdk/functions_def.go index 017ccd8335..5e2f289415 100644 --- a/pkg/sdk/functions_def.go +++ b/pkg/sdk/functions_def.go @@ -6,18 +6,18 @@ import g "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/poc/gen var functionArgument = g.NewQueryStruct("FunctionArgument"). Text("ArgName", g.KeywordOptions().NoQuotes().Required()). - PredefinedQueryStructField("ArgDataType", "DataType", g.KeywordOptions().NoQuotes().Required()). + PredefinedQueryStructField("ArgDataTypeOld", "DataType", g.KeywordOptions().NoQuotes().Required()). PredefinedQueryStructField("DefaultValue", "*string", g.ParameterOptions().NoEquals().SQL("DEFAULT")) var functionColumn = g.NewQueryStruct("FunctionColumn"). Text("ColumnName", g.KeywordOptions().NoQuotes().Required()). - PredefinedQueryStructField("ColumnDataType", "DataType", g.KeywordOptions().NoQuotes().Required()) + PredefinedQueryStructField("ColumnDataTypeOld", "DataType", g.KeywordOptions().NoQuotes().Required()) var functionReturns = g.NewQueryStruct("FunctionReturns"). OptionalQueryStructField( "ResultDataType", g.NewQueryStruct("FunctionReturnsResultDataType"). - PredefinedQueryStructField("ResultDataType", "DataType", g.KeywordOptions().NoQuotes().Required()), + PredefinedQueryStructField("ResultDataTypeOld", "DataType", g.KeywordOptions().NoQuotes().Required()), g.KeywordOptions(), ). OptionalQueryStructField( @@ -174,7 +174,7 @@ var FunctionsDef = g.NewInterface( functionArgument, g.ListOptions().MustParentheses()). OptionalSQL("COPY GRANTS"). - PredefinedQueryStructField("ResultDataType", "DataType", g.ParameterOptions().NoEquals().SQL("RETURNS").Required()). + PredefinedQueryStructField("ResultDataTypeOld", "DataType", g.ParameterOptions().NoEquals().SQL("RETURNS").Required()). PredefinedQueryStructField("ReturnNullValues", "*ReturnNullValues", g.KeywordOptions()). SQL("LANGUAGE SCALA"). PredefinedQueryStructField("NullInputBehavior", "*NullInputBehavior", g.KeywordOptions()). From 9a50470238bb9dae8d8ceb535d58e58bcee046fb Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Wed, 4 Dec 2024 11:40:00 +0100 Subject: [PATCH 06/29] Regenerate --- pkg/resources/function.go | 62 +++++++++---------- pkg/schemas/function_gen.go | 2 +- pkg/sdk/functions_def.go | 2 +- pkg/sdk/functions_dto_builders_gen.go | 16 ++--- pkg/sdk/functions_dto_gen.go | 14 ++--- pkg/sdk/functions_ext.go | 5 ++ pkg/sdk/functions_gen.go | 20 +++--- pkg/sdk/functions_impl_gen.go | 12 ++-- pkg/sdk/functions_validations_gen.go | 3 + pkg/sdk/testint/functions_integration_test.go | 6 +- 10 files changed, 73 insertions(+), 69 deletions(-) create mode 100644 pkg/sdk/functions_ext.go diff --git a/pkg/resources/function.go b/pkg/resources/function.go index 19415f91f1..c41a4e8977 100644 --- a/pkg/resources/function.go +++ b/pkg/resources/function.go @@ -227,9 +227,9 @@ func CreateContextFunction(ctx context.Context, d *schema.ResourceData, meta int func createJavaFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client name := d.Get("name").(string) - schema := d.Get("schema").(string) + sc := d.Get("schema").(string) database := d.Get("database").(string) - id := sdk.NewSchemaObjectIdentifier(database, schema, name) + id := sdk.NewSchemaObjectIdentifier(database, sc, name) // Set required returns, diags := parseFunctionReturnsRequest(d.Get("return_type").(string)) @@ -266,14 +266,14 @@ func createJavaFunction(ctx context.Context, d *schema.ResourceData, meta interf request.WithComment(v.(string)) } if _, ok := d.GetOk("imports"); ok { - imports := []sdk.FunctionImportRequest{} + var imports []sdk.FunctionImportRequest for _, item := range d.Get("imports").([]interface{}) { imports = append(imports, *sdk.NewFunctionImportRequest().WithImport(item.(string))) } request.WithImports(imports) } if _, ok := d.GetOk("packages"); ok { - packages := []sdk.FunctionPackageRequest{} + var packages []sdk.FunctionPackageRequest for _, item := range d.Get("packages").([]interface{}) { packages = append(packages, *sdk.NewFunctionPackageRequest().WithPackage(item.(string))) } @@ -288,9 +288,9 @@ func createJavaFunction(ctx context.Context, d *schema.ResourceData, meta interf } argumentTypes := make([]sdk.DataType, 0, len(arguments)) for _, item := range arguments { - argumentTypes = append(argumentTypes, item.ArgDataType) + argumentTypes = append(argumentTypes, item.ArgDataTypeOld) } - nid := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argumentTypes...) + nid := sdk.NewSchemaObjectIdentifierWithArguments(database, sc, name, argumentTypes...) d.SetId(nid.FullyQualifiedName()) return ReadContextFunction(ctx, d, meta) } @@ -298,9 +298,9 @@ func createJavaFunction(ctx context.Context, d *schema.ResourceData, meta interf func createScalaFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client name := d.Get("name").(string) - schema := d.Get("schema").(string) + sc := d.Get("schema").(string) database := d.Get("database").(string) - id := sdk.NewSchemaObjectIdentifier(database, schema, name) + id := sdk.NewSchemaObjectIdentifier(database, sc, name) // Set required returnType := d.Get("return_type").(string) @@ -338,14 +338,14 @@ func createScalaFunction(ctx context.Context, d *schema.ResourceData, meta inter request.WithComment(v.(string)) } if _, ok := d.GetOk("imports"); ok { - imports := []sdk.FunctionImportRequest{} + var imports []sdk.FunctionImportRequest for _, item := range d.Get("imports").([]interface{}) { imports = append(imports, *sdk.NewFunctionImportRequest().WithImport(item.(string))) } request.WithImports(imports) } if _, ok := d.GetOk("packages"); ok { - packages := []sdk.FunctionPackageRequest{} + var packages []sdk.FunctionPackageRequest for _, item := range d.Get("packages").([]interface{}) { packages = append(packages, *sdk.NewFunctionPackageRequest().WithPackage(item.(string))) } @@ -360,9 +360,9 @@ func createScalaFunction(ctx context.Context, d *schema.ResourceData, meta inter } argumentTypes := make([]sdk.DataType, 0, len(arguments)) for _, item := range arguments { - argumentTypes = append(argumentTypes, item.ArgDataType) + argumentTypes = append(argumentTypes, item.ArgDataTypeOld) } - nid := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argumentTypes...) + nid := sdk.NewSchemaObjectIdentifierWithArguments(database, sc, name, argumentTypes...) d.SetId(nid.FullyQualifiedName()) return ReadContextFunction(ctx, d, meta) } @@ -370,9 +370,9 @@ func createScalaFunction(ctx context.Context, d *schema.ResourceData, meta inter func createSQLFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client name := d.Get("name").(string) - schema := d.Get("schema").(string) + sc := d.Get("schema").(string) database := d.Get("database").(string) - id := sdk.NewSchemaObjectIdentifier(database, schema, name) + id := sdk.NewSchemaObjectIdentifier(database, sc, name) // Set required returns, diags := parseFunctionReturnsRequest(d.Get("return_type").(string)) @@ -406,9 +406,9 @@ func createSQLFunction(ctx context.Context, d *schema.ResourceData, meta interfa } argumentTypes := make([]sdk.DataType, 0, len(arguments)) for _, item := range arguments { - argumentTypes = append(argumentTypes, item.ArgDataType) + argumentTypes = append(argumentTypes, item.ArgDataTypeOld) } - nid := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argumentTypes...) + nid := sdk.NewSchemaObjectIdentifierWithArguments(database, sc, name, argumentTypes...) d.SetId(nid.FullyQualifiedName()) return ReadContextFunction(ctx, d, meta) } @@ -416,9 +416,9 @@ func createSQLFunction(ctx context.Context, d *schema.ResourceData, meta interfa func createPythonFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client name := d.Get("name").(string) - schema := d.Get("schema").(string) + sc := d.Get("schema").(string) database := d.Get("database").(string) - id := sdk.NewSchemaObjectIdentifier(database, schema, name) + id := sdk.NewSchemaObjectIdentifier(database, sc, name) // Set required returns, diags := parseFunctionReturnsRequest(d.Get("return_type").(string)) @@ -454,14 +454,14 @@ func createPythonFunction(ctx context.Context, d *schema.ResourceData, meta inte request.WithComment(v.(string)) } if _, ok := d.GetOk("imports"); ok { - imports := []sdk.FunctionImportRequest{} + var imports []sdk.FunctionImportRequest for _, item := range d.Get("imports").([]interface{}) { imports = append(imports, *sdk.NewFunctionImportRequest().WithImport(item.(string))) } request.WithImports(imports) } if _, ok := d.GetOk("packages"); ok { - packages := []sdk.FunctionPackageRequest{} + var packages []sdk.FunctionPackageRequest for _, item := range d.Get("packages").([]interface{}) { packages = append(packages, *sdk.NewFunctionPackageRequest().WithPackage(item.(string))) } @@ -473,9 +473,9 @@ func createPythonFunction(ctx context.Context, d *schema.ResourceData, meta inte } argumentTypes := make([]sdk.DataType, 0, len(arguments)) for _, item := range arguments { - argumentTypes = append(argumentTypes, item.ArgDataType) + argumentTypes = append(argumentTypes, item.ArgDataTypeOld) } - nid := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argumentTypes...) + nid := sdk.NewSchemaObjectIdentifierWithArguments(database, sc, name, argumentTypes...) d.SetId(nid.FullyQualifiedName()) return ReadContextFunction(ctx, d, meta) } @@ -483,9 +483,9 @@ func createPythonFunction(ctx context.Context, d *schema.ResourceData, meta inte func createJavascriptFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client name := d.Get("name").(string) - schema := d.Get("schema").(string) + sc := d.Get("schema").(string) database := d.Get("database").(string) - id := sdk.NewSchemaObjectIdentifier(database, schema, name) + id := sdk.NewSchemaObjectIdentifier(database, sc, name) // Set required returns, diags := parseFunctionReturnsRequest(d.Get("return_type").(string)) @@ -522,9 +522,9 @@ func createJavascriptFunction(ctx context.Context, d *schema.ResourceData, meta } argumentTypes := make([]sdk.DataType, 0, len(arguments)) for _, item := range arguments { - argumentTypes = append(argumentTypes, item.ArgDataType) + argumentTypes = append(argumentTypes, item.ArgDataTypeOld) } - nid := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argumentTypes...) + nid := sdk.NewSchemaObjectIdentifierWithArguments(database, sc, name, argumentTypes...) d.SetId(nid.FullyQualifiedName()) return ReadContextFunction(ctx, d, meta) } @@ -575,7 +575,7 @@ func ReadContextFunction(ctx context.Context, d *schema.ResourceData, meta inter if value != "" { // Do nothing for functions without arguments pairs := strings.Split(value, ", ") - arguments := []interface{}{} + var arguments []interface{} for _, pair := range pairs { item := strings.Split(pair, " ") argument := map[string]interface{}{} @@ -739,7 +739,7 @@ func parseFunctionArguments(d *schema.ResourceData) ([]sdk.FunctionArgumentReque if diags != nil { return nil, diags } - args = append(args, sdk.FunctionArgumentRequest{ArgName: argName, ArgDataType: sdk.LegacyDataTypeFrom(argDataType)}) + args = append(args, sdk.FunctionArgumentRequest{ArgName: argName, ArgDataTypeOld: sdk.LegacyDataTypeFrom(argDataType)}) } } return args, nil @@ -764,8 +764,8 @@ func convertFunctionColumns(s string) ([]sdk.FunctionColumn, diag.Diagnostics) { return nil, diag.FromErr(err) } columns = append(columns, sdk.FunctionColumn{ - ColumnName: match[1], - ColumnDataType: sdk.LegacyDataTypeFrom(dataType), + ColumnName: match[1], + ColumnDataTypeOld: sdk.LegacyDataTypeFrom(dataType), }) } } @@ -781,7 +781,7 @@ func parseFunctionReturnsRequest(s string) (*sdk.FunctionReturnsRequest, diag.Di } var cr []sdk.FunctionColumnRequest for _, item := range columns { - cr = append(cr, *sdk.NewFunctionColumnRequest(item.ColumnName, item.ColumnDataType)) + cr = append(cr, *sdk.NewFunctionColumnRequest(item.ColumnName, item.ColumnDataTypeOld)) } returns.WithTable(*sdk.NewFunctionReturnsTableRequest().WithColumns(cr)) } else { diff --git a/pkg/schemas/function_gen.go b/pkg/schemas/function_gen.go index ecd7d71f53..a211866ea5 100644 --- a/pkg/schemas/function_gen.go +++ b/pkg/schemas/function_gen.go @@ -95,7 +95,7 @@ func FunctionToSchema(function *sdk.Function) map[string]any { functionSchema["is_ansi"] = function.IsAnsi functionSchema["min_num_arguments"] = function.MinNumArguments functionSchema["max_num_arguments"] = function.MaxNumArguments - functionSchema["arguments"] = function.Arguments + functionSchema["arguments"] = function.ArgumentsOld functionSchema["arguments_raw"] = function.ArgumentsRaw functionSchema["description"] = function.Description functionSchema["catalog_name"] = function.CatalogName diff --git a/pkg/sdk/functions_def.go b/pkg/sdk/functions_def.go index 5e2f289415..0f2774f800 100644 --- a/pkg/sdk/functions_def.go +++ b/pkg/sdk/functions_def.go @@ -282,7 +282,7 @@ var FunctionsDef = g.NewInterface( Field("IsAnsi", "bool"). Field("MinNumArguments", "int"). Field("MaxNumArguments", "int"). - Field("Arguments", "string"). + Field("ArgumentsRaw", "string"). Field("Description", "string"). Field("CatalogName", "string"). Field("IsTableFunction", "bool"). diff --git a/pkg/sdk/functions_dto_builders_gen.go b/pkg/sdk/functions_dto_builders_gen.go index 0aef014932..cc74344faf 100644 --- a/pkg/sdk/functions_dto_builders_gen.go +++ b/pkg/sdk/functions_dto_builders_gen.go @@ -103,11 +103,11 @@ func (s *CreateForJavaFunctionRequest) WithFunctionDefinition(FunctionDefinition func NewFunctionArgumentRequest( ArgName string, - ArgDataType DataType, + ArgDataTypeOld DataType, ) *FunctionArgumentRequest { s := FunctionArgumentRequest{} s.ArgName = ArgName - s.ArgDataType = ArgDataType + s.ArgDataTypeOld = ArgDataTypeOld return &s } @@ -131,10 +131,10 @@ func (s *FunctionReturnsRequest) WithTable(Table FunctionReturnsTableRequest) *F } func NewFunctionReturnsResultDataTypeRequest( - ResultDataType DataType, + ResultDataTypeOld DataType, ) *FunctionReturnsResultDataTypeRequest { s := FunctionReturnsResultDataTypeRequest{} - s.ResultDataType = ResultDataType + s.ResultDataTypeOld = ResultDataTypeOld return &s } @@ -149,11 +149,11 @@ func (s *FunctionReturnsTableRequest) WithColumns(Columns []FunctionColumnReques func NewFunctionColumnRequest( ColumnName string, - ColumnDataType DataType, + ColumnDataTypeOld DataType, ) *FunctionColumnRequest { s := FunctionColumnRequest{} s.ColumnName = ColumnName - s.ColumnDataType = ColumnDataType + s.ColumnDataTypeOld = ColumnDataTypeOld return &s } @@ -323,12 +323,12 @@ func (s *CreateForPythonFunctionRequest) WithFunctionDefinition(FunctionDefiniti func NewCreateForScalaFunctionRequest( name SchemaObjectIdentifier, - ResultDataType DataType, + ResultDataTypeOld DataType, Handler string, ) *CreateForScalaFunctionRequest { s := CreateForScalaFunctionRequest{} s.name = name - s.ResultDataType = ResultDataType + s.ResultDataTypeOld = ResultDataTypeOld s.Handler = Handler return &s } diff --git a/pkg/sdk/functions_dto_gen.go b/pkg/sdk/functions_dto_gen.go index 3fa0ff387b..0cc20bd792 100644 --- a/pkg/sdk/functions_dto_gen.go +++ b/pkg/sdk/functions_dto_gen.go @@ -38,9 +38,9 @@ type CreateForJavaFunctionRequest struct { } type FunctionArgumentRequest struct { - ArgName string // required - ArgDataType DataType // required - DefaultValue *string + ArgName string // required + ArgDataTypeOld DataType // required + DefaultValue *string } type FunctionReturnsRequest struct { @@ -49,7 +49,7 @@ type FunctionReturnsRequest struct { } type FunctionReturnsResultDataTypeRequest struct { - ResultDataType DataType // required + ResultDataTypeOld DataType // required } type FunctionReturnsTableRequest struct { @@ -57,8 +57,8 @@ type FunctionReturnsTableRequest struct { } type FunctionColumnRequest struct { - ColumnName string // required - ColumnDataType DataType // required + ColumnName string // required + ColumnDataTypeOld DataType // required } type FunctionImportRequest struct { @@ -114,7 +114,7 @@ type CreateForScalaFunctionRequest struct { name SchemaObjectIdentifier // required Arguments []FunctionArgumentRequest CopyGrants *bool - ResultDataType DataType // required + ResultDataTypeOld DataType // required ReturnNullValues *ReturnNullValues NullInputBehavior *NullInputBehavior ReturnResultsBehavior *ReturnResultsBehavior diff --git a/pkg/sdk/functions_ext.go b/pkg/sdk/functions_ext.go new file mode 100644 index 0000000000..4fe8a9524d --- /dev/null +++ b/pkg/sdk/functions_ext.go @@ -0,0 +1,5 @@ +package sdk + +func (v *Function) ID() SchemaObjectIdentifierWithArguments { + return NewSchemaObjectIdentifierWithArguments(v.CatalogName, v.SchemaName, v.Name, v.ArgumentsOld...) +} diff --git a/pkg/sdk/functions_gen.go b/pkg/sdk/functions_gen.go index 85f2b8378d..29e642d799 100644 --- a/pkg/sdk/functions_gen.go +++ b/pkg/sdk/functions_gen.go @@ -46,9 +46,9 @@ type CreateForJavaFunctionOptions struct { } type FunctionArgument struct { - ArgName string `ddl:"keyword,no_quotes"` - ArgDataType DataType `ddl:"keyword,no_quotes"` - DefaultValue *string `ddl:"parameter,no_equals" sql:"DEFAULT"` + ArgName string `ddl:"keyword,no_quotes"` + ArgDataTypeOld DataType `ddl:"keyword,no_quotes"` + DefaultValue *string `ddl:"parameter,no_equals" sql:"DEFAULT"` } type FunctionReturns struct { @@ -57,7 +57,7 @@ type FunctionReturns struct { } type FunctionReturnsResultDataType struct { - ResultDataType DataType `ddl:"keyword,no_quotes"` + ResultDataTypeOld DataType `ddl:"keyword,no_quotes"` } type FunctionReturnsTable struct { @@ -65,8 +65,8 @@ type FunctionReturnsTable struct { } type FunctionColumn struct { - ColumnName string `ddl:"keyword,no_quotes"` - ColumnDataType DataType `ddl:"keyword,no_quotes"` + ColumnName string `ddl:"keyword,no_quotes"` + ColumnDataTypeOld DataType `ddl:"keyword,no_quotes"` } type FunctionImport struct { @@ -133,7 +133,7 @@ type CreateForScalaFunctionOptions struct { name SchemaObjectIdentifier `ddl:"identifier"` Arguments []FunctionArgument `ddl:"list,must_parentheses"` CopyGrants *bool `ddl:"keyword" sql:"COPY GRANTS"` - ResultDataType DataType `ddl:"parameter,no_equals" sql:"RETURNS"` + ResultDataTypeOld DataType `ddl:"parameter,no_equals" sql:"RETURNS"` ReturnNullValues *ReturnNullValues `ddl:"keyword"` languageScala bool `ddl:"static" sql:"LANGUAGE SCALA"` NullInputBehavior *NullInputBehavior `ddl:"keyword"` @@ -229,7 +229,7 @@ type Function struct { IsAnsi bool MinNumArguments int MaxNumArguments int - Arguments []DataType + ArgumentsOld []DataType ArgumentsRaw string Description string CatalogName string @@ -241,10 +241,6 @@ type Function struct { IsMemoizable bool } -func (v *Function) ID() SchemaObjectIdentifierWithArguments { - return NewSchemaObjectIdentifierWithArguments(v.CatalogName, v.SchemaName, v.Name, v.Arguments...) -} - // DescribeFunctionOptions is based on https://docs.snowflake.com/en/sql-reference/sql/desc-function. type DescribeFunctionOptions struct { describe bool `ddl:"static" sql:"DESCRIBE"` diff --git a/pkg/sdk/functions_impl_gen.go b/pkg/sdk/functions_impl_gen.go index 2abf41c1e6..721d405e99 100644 --- a/pkg/sdk/functions_impl_gen.go +++ b/pkg/sdk/functions_impl_gen.go @@ -110,7 +110,7 @@ func (r *CreateForJavaFunctionRequest) toOpts() *CreateForJavaFunctionOptions { opts.Returns = FunctionReturns{} if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &FunctionReturnsResultDataType{ - ResultDataType: r.Returns.ResultDataType.ResultDataType, + ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, } } if r.Returns.Table != nil { @@ -165,7 +165,7 @@ func (r *CreateForJavascriptFunctionRequest) toOpts() *CreateForJavascriptFuncti opts.Returns = FunctionReturns{} if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &FunctionReturnsResultDataType{ - ResultDataType: r.Returns.ResultDataType.ResultDataType, + ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, } } if r.Returns.Table != nil { @@ -212,7 +212,7 @@ func (r *CreateForPythonFunctionRequest) toOpts() *CreateForPythonFunctionOption opts.Returns = FunctionReturns{} if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &FunctionReturnsResultDataType{ - ResultDataType: r.Returns.ResultDataType.ResultDataType, + ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, } } if r.Returns.Table != nil { @@ -251,7 +251,7 @@ func (r *CreateForScalaFunctionRequest) toOpts() *CreateForScalaFunctionOptions name: r.name, CopyGrants: r.CopyGrants, - ResultDataType: r.ResultDataType, + ResultDataTypeOld: r.ResultDataTypeOld, ReturnNullValues: r.ReturnNullValues, NullInputBehavior: r.NullInputBehavior, ReturnResultsBehavior: r.ReturnResultsBehavior, @@ -311,7 +311,7 @@ func (r *CreateForSQLFunctionRequest) toOpts() *CreateForSQLFunctionOptions { opts.Returns = FunctionReturns{} if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &FunctionReturnsResultDataType{ - ResultDataType: r.Returns.ResultDataType.ResultDataType, + ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, } } if r.Returns.Table != nil { @@ -386,7 +386,7 @@ func (r functionRow) convert() *Function { if err != nil { log.Printf("[DEBUG] failed to parse function arguments, err = %s", err) } else { - e.Arguments = dataTypes + e.ArgumentsOld = dataTypes } if r.IsSecure.Valid { diff --git a/pkg/sdk/functions_validations_gen.go b/pkg/sdk/functions_validations_gen.go index 3bf1a29ff9..52863802a8 100644 --- a/pkg/sdk/functions_validations_gen.go +++ b/pkg/sdk/functions_validations_gen.go @@ -31,6 +31,7 @@ func (opts *CreateForJavaFunctionOptions) validate() error { errs = append(errs, errExactlyOneOf("CreateForJavaFunctionOptions.Returns", "ResultDataType", "Table")) } } + // added manually if opts.FunctionDefinition == nil { if opts.TargetPath != nil { errs = append(errs, NewError("TARGET_PATH must be nil when AS is nil")) @@ -86,6 +87,7 @@ func (opts *CreateForPythonFunctionOptions) validate() error { errs = append(errs, errExactlyOneOf("CreateForPythonFunctionOptions.Returns", "ResultDataType", "Table")) } } + // added manually if opts.FunctionDefinition == nil { if len(opts.Imports) == 0 { errs = append(errs, NewError("IMPORTS must not be empty when AS is nil")) @@ -108,6 +110,7 @@ func (opts *CreateForScalaFunctionOptions) validate() error { if everyValueSet(opts.OrReplace, opts.IfNotExists) { errs = append(errs, errOneOf("CreateForScalaFunctionOptions", "OrReplace", "IfNotExists")) } + // added manually if opts.FunctionDefinition == nil { if opts.TargetPath != nil { errs = append(errs, NewError("TARGET_PATH must be nil when AS is nil")) diff --git a/pkg/sdk/testint/functions_integration_test.go b/pkg/sdk/testint/functions_integration_test.go index 44bb8b898a..b61afbb037 100644 --- a/pkg/sdk/testint/functions_integration_test.go +++ b/pkg/sdk/testint/functions_integration_test.go @@ -209,7 +209,7 @@ func TestInt_OtherFunctions(t *testing.T) { assert.Equal(t, 0, function.MaxNumArguments) } assert.NotEmpty(t, function.ArgumentsRaw) - assert.NotEmpty(t, function.Arguments) + assert.NotEmpty(t, function.ArgumentsOld) assert.NotEmpty(t, function.Description) assert.NotEmpty(t, function.CatalogName) assert.Equal(t, false, function.IsTableFunction) @@ -542,7 +542,7 @@ func TestInt_FunctionsShowByID(t *testing.T) { dataTypes := make([]sdk.DataType, len(args)) for i, arg := range args { - dataType, err := datatypes.ParseDataType(string(arg.ArgDataType)) + dataType, err := datatypes.ParseDataType(string(arg.ArgDataTypeOld)) require.NoError(t, err) dataTypes[i] = sdk.LegacyDataTypeFrom(dataType) } @@ -550,6 +550,6 @@ func TestInt_FunctionsShowByID(t *testing.T) { function, err := client.Functions.ShowByID(ctx, idWithArguments) require.NoError(t, err) - require.Equal(t, dataTypes, function.Arguments) + require.Equal(t, dataTypes, function.ArgumentsOld) }) } From 33982f4ebf90b530c4ee41dc60252c8ecd663b54 Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Wed, 4 Dec 2024 11:45:18 +0100 Subject: [PATCH 07/29] Update function def (old data type optional, new - required) --- pkg/sdk/functions_def.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pkg/sdk/functions_def.go b/pkg/sdk/functions_def.go index 0f2774f800..af3dbcdbdb 100644 --- a/pkg/sdk/functions_def.go +++ b/pkg/sdk/functions_def.go @@ -6,18 +6,21 @@ import g "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/poc/gen var functionArgument = g.NewQueryStruct("FunctionArgument"). Text("ArgName", g.KeywordOptions().NoQuotes().Required()). - PredefinedQueryStructField("ArgDataTypeOld", "DataType", g.KeywordOptions().NoQuotes().Required()). + PredefinedQueryStructField("ArgDataTypeOld", "DataType", g.KeywordOptions().NoQuotes()). + PredefinedQueryStructField("ArgDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()). PredefinedQueryStructField("DefaultValue", "*string", g.ParameterOptions().NoEquals().SQL("DEFAULT")) var functionColumn = g.NewQueryStruct("FunctionColumn"). Text("ColumnName", g.KeywordOptions().NoQuotes().Required()). - PredefinedQueryStructField("ColumnDataTypeOld", "DataType", g.KeywordOptions().NoQuotes().Required()) + PredefinedQueryStructField("ColumnDataTypeOld", "DataType", g.KeywordOptions().NoQuotes()). + PredefinedQueryStructField("ColumnDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()) var functionReturns = g.NewQueryStruct("FunctionReturns"). OptionalQueryStructField( "ResultDataType", g.NewQueryStruct("FunctionReturnsResultDataType"). - PredefinedQueryStructField("ResultDataTypeOld", "DataType", g.KeywordOptions().NoQuotes().Required()), + PredefinedQueryStructField("ResultDataTypeOld", "DataType", g.KeywordOptions().NoQuotes()). + PredefinedQueryStructField("ResultDataType", "datatypes.DataType", g.KeywordOptions().NoQuotes().Required()), g.KeywordOptions(), ). OptionalQueryStructField( @@ -174,7 +177,8 @@ var FunctionsDef = g.NewInterface( functionArgument, g.ListOptions().MustParentheses()). OptionalSQL("COPY GRANTS"). - PredefinedQueryStructField("ResultDataTypeOld", "DataType", g.ParameterOptions().NoEquals().SQL("RETURNS").Required()). + PredefinedQueryStructField("ResultDataTypeOld", "DataType", g.ParameterOptions().NoEquals().SQL("RETURNS")). + PredefinedQueryStructField("ResultDataType", "datatypes.DataType", g.ParameterOptions().NoEquals().SQL("RETURNS").Required()). PredefinedQueryStructField("ReturnNullValues", "*ReturnNullValues", g.KeywordOptions()). SQL("LANGUAGE SCALA"). PredefinedQueryStructField("NullInputBehavior", "*NullInputBehavior", g.KeywordOptions()). From d7d6ebce653d9e38867380dbabf58906052fd68e Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Wed, 4 Dec 2024 12:31:37 +0100 Subject: [PATCH 08/29] Regenerate function with new data type (WIP) --- pkg/acceptance/helpers/function_client.go | 6 +- pkg/resources/function.go | 6 +- pkg/sdk/datatypes/boolean.go | 3 + pkg/sdk/functions_dto_builders_gen.go | 41 +++++++--- pkg/sdk/functions_dto_gen.go | 18 +++-- pkg/sdk/functions_gen.go | 19 +++-- pkg/sdk/functions_gen_test.go | 62 +++++++------- pkg/sdk/functions_impl_gen.go | 5 ++ pkg/sdk/poc/README.md | 1 + pkg/sdk/sql_builder_test.go | 13 +++ pkg/sdk/testint/functions_integration_test.go | 81 ++++++++++--------- 11 files changed, 157 insertions(+), 98 deletions(-) diff --git a/pkg/acceptance/helpers/function_client.go b/pkg/acceptance/helpers/function_client.go index 5b23afaf9f..9a904de51a 100644 --- a/pkg/acceptance/helpers/function_client.go +++ b/pkg/acceptance/helpers/function_client.go @@ -35,7 +35,7 @@ func (c *FunctionClient) CreateWithIdentifier(t *testing.T, id sdk.SchemaObjectI return c.CreateWithRequest(t, id, sdk.NewCreateForSQLFunctionRequest( id.SchemaObjectId(), - *sdk.NewFunctionReturnsRequest().WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeInt)), + *sdk.NewFunctionReturnsRequest().WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeVariant)), "SELECT 1", ), ) @@ -48,7 +48,7 @@ func (c *FunctionClient) CreateSecure(t *testing.T, arguments ...sdk.DataType) * return c.CreateWithRequest(t, id, sdk.NewCreateForSQLFunctionRequest( id.SchemaObjectId(), - *sdk.NewFunctionReturnsRequest().WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeInt)), + *sdk.NewFunctionReturnsRequest().WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeInt)), "SELECT 1", ).WithSecure(true), ) @@ -59,7 +59,7 @@ func (c *FunctionClient) CreateWithRequest(t *testing.T, id sdk.SchemaObjectIden ctx := context.Background() argumentRequests := make([]sdk.FunctionArgumentRequest, len(id.ArgumentDataTypes())) for i, argumentDataType := range id.ArgumentDataTypes() { - argumentRequests[i] = *sdk.NewFunctionArgumentRequest(c.ids.Alpha(), argumentDataType) + argumentRequests[i] = *sdk.NewFunctionArgumentRequest(c.ids.Alpha(), nil).WithArgDataTypeOld(argumentDataType) } err := c.client().CreateForSQL(ctx, req.WithArguments(argumentRequests)) require.NoError(t, err) diff --git a/pkg/resources/function.go b/pkg/resources/function.go index c41a4e8977..ba91184217 100644 --- a/pkg/resources/function.go +++ b/pkg/resources/function.go @@ -311,7 +311,7 @@ func createScalaFunction(ctx context.Context, d *schema.ResourceData, meta inter functionDefinition := d.Get("statement").(string) handler := d.Get("handler").(string) // create request with required - request := sdk.NewCreateForScalaFunctionRequest(id, sdk.LegacyDataTypeFrom(returnDataType), handler) + request := sdk.NewCreateForScalaFunctionRequest(id, nil, handler).WithResultDataTypeOld(sdk.LegacyDataTypeFrom(returnDataType)) request.WithFunctionDefinition(functionDefinition) // Set optionals @@ -781,7 +781,7 @@ func parseFunctionReturnsRequest(s string) (*sdk.FunctionReturnsRequest, diag.Di } var cr []sdk.FunctionColumnRequest for _, item := range columns { - cr = append(cr, *sdk.NewFunctionColumnRequest(item.ColumnName, item.ColumnDataTypeOld)) + cr = append(cr, *sdk.NewFunctionColumnRequest(item.ColumnName, nil).WithColumnDataTypeOld(item.ColumnDataTypeOld)) } returns.WithTable(*sdk.NewFunctionReturnsTableRequest().WithColumns(cr)) } else { @@ -789,7 +789,7 @@ func parseFunctionReturnsRequest(s string) (*sdk.FunctionReturnsRequest, diag.Di if diags != nil { return nil, diags } - returns.WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(sdk.LegacyDataTypeFrom(returnDataType))) + returns.WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.LegacyDataTypeFrom(returnDataType))) } return returns, nil } diff --git a/pkg/sdk/datatypes/boolean.go b/pkg/sdk/datatypes/boolean.go index 4e84979f40..4a6c617bcf 100644 --- a/pkg/sdk/datatypes/boolean.go +++ b/pkg/sdk/datatypes/boolean.go @@ -7,6 +7,9 @@ type BooleanDataType struct { } func (t *BooleanDataType) ToSql() string { + if t == nil { + return "" + } return t.underlyingType } diff --git a/pkg/sdk/functions_dto_builders_gen.go b/pkg/sdk/functions_dto_builders_gen.go index cc74344faf..3bb40dfd0e 100644 --- a/pkg/sdk/functions_dto_builders_gen.go +++ b/pkg/sdk/functions_dto_builders_gen.go @@ -2,7 +2,10 @@ package sdk -import () +// imports edited manually +import ( + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" +) func NewCreateForJavaFunctionRequest( name SchemaObjectIdentifier, @@ -103,14 +106,19 @@ func (s *CreateForJavaFunctionRequest) WithFunctionDefinition(FunctionDefinition func NewFunctionArgumentRequest( ArgName string, - ArgDataTypeOld DataType, + ArgDataType datatypes.DataType, ) *FunctionArgumentRequest { s := FunctionArgumentRequest{} s.ArgName = ArgName - s.ArgDataTypeOld = ArgDataTypeOld + s.ArgDataType = ArgDataType return &s } +func (s *FunctionArgumentRequest) WithArgDataTypeOld(ArgDataTypeOld DataType) *FunctionArgumentRequest { + s.ArgDataTypeOld = ArgDataTypeOld + return s +} + func (s *FunctionArgumentRequest) WithDefaultValue(DefaultValue string) *FunctionArgumentRequest { s.DefaultValue = &DefaultValue return s @@ -131,13 +139,18 @@ func (s *FunctionReturnsRequest) WithTable(Table FunctionReturnsTableRequest) *F } func NewFunctionReturnsResultDataTypeRequest( - ResultDataTypeOld DataType, + ResultDataType datatypes.DataType, ) *FunctionReturnsResultDataTypeRequest { s := FunctionReturnsResultDataTypeRequest{} - s.ResultDataTypeOld = ResultDataTypeOld + s.ResultDataType = ResultDataType return &s } +func (s *FunctionReturnsResultDataTypeRequest) WithResultDataTypeOld(ResultDataTypeOld DataType) *FunctionReturnsResultDataTypeRequest { + s.ResultDataTypeOld = ResultDataTypeOld + return s +} + func NewFunctionReturnsTableRequest() *FunctionReturnsTableRequest { return &FunctionReturnsTableRequest{} } @@ -149,14 +162,19 @@ func (s *FunctionReturnsTableRequest) WithColumns(Columns []FunctionColumnReques func NewFunctionColumnRequest( ColumnName string, - ColumnDataTypeOld DataType, + ColumnDataType datatypes.DataType, ) *FunctionColumnRequest { s := FunctionColumnRequest{} s.ColumnName = ColumnName - s.ColumnDataTypeOld = ColumnDataTypeOld + s.ColumnDataType = ColumnDataType return &s } +func (s *FunctionColumnRequest) WithColumnDataTypeOld(ColumnDataTypeOld DataType) *FunctionColumnRequest { + s.ColumnDataTypeOld = ColumnDataTypeOld + return s +} + func NewFunctionImportRequest() *FunctionImportRequest { return &FunctionImportRequest{} } @@ -323,12 +341,12 @@ func (s *CreateForPythonFunctionRequest) WithFunctionDefinition(FunctionDefiniti func NewCreateForScalaFunctionRequest( name SchemaObjectIdentifier, - ResultDataTypeOld DataType, + ResultDataType datatypes.DataType, Handler string, ) *CreateForScalaFunctionRequest { s := CreateForScalaFunctionRequest{} s.name = name - s.ResultDataTypeOld = ResultDataTypeOld + s.ResultDataType = ResultDataType s.Handler = Handler return &s } @@ -363,6 +381,11 @@ func (s *CreateForScalaFunctionRequest) WithCopyGrants(CopyGrants bool) *CreateF return s } +func (s *CreateForScalaFunctionRequest) WithResultDataTypeOld(ResultDataTypeOld DataType) *CreateForScalaFunctionRequest { + s.ResultDataTypeOld = ResultDataTypeOld + return s +} + func (s *CreateForScalaFunctionRequest) WithReturnNullValues(ReturnNullValues ReturnNullValues) *CreateForScalaFunctionRequest { s.ReturnNullValues = &ReturnNullValues return s diff --git a/pkg/sdk/functions_dto_gen.go b/pkg/sdk/functions_dto_gen.go index 0cc20bd792..4ff74dcd73 100644 --- a/pkg/sdk/functions_dto_gen.go +++ b/pkg/sdk/functions_dto_gen.go @@ -1,5 +1,7 @@ package sdk +import "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" + //go:generate go run ./dto-builder-generator/main.go var ( @@ -38,8 +40,9 @@ type CreateForJavaFunctionRequest struct { } type FunctionArgumentRequest struct { - ArgName string // required - ArgDataTypeOld DataType // required + ArgName string // required + ArgDataTypeOld DataType + ArgDataType datatypes.DataType // required DefaultValue *string } @@ -49,7 +52,8 @@ type FunctionReturnsRequest struct { } type FunctionReturnsResultDataTypeRequest struct { - ResultDataTypeOld DataType // required + ResultDataTypeOld DataType + ResultDataType datatypes.DataType // required } type FunctionReturnsTableRequest struct { @@ -57,8 +61,9 @@ type FunctionReturnsTableRequest struct { } type FunctionColumnRequest struct { - ColumnName string // required - ColumnDataTypeOld DataType // required + ColumnName string // required + ColumnDataTypeOld DataType + ColumnDataType datatypes.DataType // required } type FunctionImportRequest struct { @@ -114,7 +119,8 @@ type CreateForScalaFunctionRequest struct { name SchemaObjectIdentifier // required Arguments []FunctionArgumentRequest CopyGrants *bool - ResultDataTypeOld DataType // required + ResultDataTypeOld DataType + ResultDataType datatypes.DataType // required ReturnNullValues *ReturnNullValues NullInputBehavior *NullInputBehavior ReturnResultsBehavior *ReturnResultsBehavior diff --git a/pkg/sdk/functions_gen.go b/pkg/sdk/functions_gen.go index 29e642d799..78ab2e1bd5 100644 --- a/pkg/sdk/functions_gen.go +++ b/pkg/sdk/functions_gen.go @@ -1,8 +1,11 @@ package sdk +// imports edited manually import ( "context" "database/sql" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" ) type Functions interface { @@ -46,9 +49,10 @@ type CreateForJavaFunctionOptions struct { } type FunctionArgument struct { - ArgName string `ddl:"keyword,no_quotes"` - ArgDataTypeOld DataType `ddl:"keyword,no_quotes"` - DefaultValue *string `ddl:"parameter,no_equals" sql:"DEFAULT"` + ArgName string `ddl:"keyword,no_quotes"` + ArgDataTypeOld DataType `ddl:"keyword,no_quotes"` + ArgDataType datatypes.DataType `ddl:"parameter,no_quotes,no_equals"` + DefaultValue *string `ddl:"parameter,no_equals" sql:"DEFAULT"` } type FunctionReturns struct { @@ -57,7 +61,8 @@ type FunctionReturns struct { } type FunctionReturnsResultDataType struct { - ResultDataTypeOld DataType `ddl:"keyword,no_quotes"` + ResultDataTypeOld DataType `ddl:"keyword,no_quotes"` + ResultDataType datatypes.DataType `ddl:"parameter,no_quotes,no_equals"` } type FunctionReturnsTable struct { @@ -65,8 +70,9 @@ type FunctionReturnsTable struct { } type FunctionColumn struct { - ColumnName string `ddl:"keyword,no_quotes"` - ColumnDataTypeOld DataType `ddl:"keyword,no_quotes"` + ColumnName string `ddl:"keyword,no_quotes"` + ColumnDataTypeOld DataType `ddl:"keyword,no_quotes"` + ColumnDataType datatypes.DataType `ddl:"parameter,no_quotes,no_equals"` } type FunctionImport struct { @@ -134,6 +140,7 @@ type CreateForScalaFunctionOptions struct { Arguments []FunctionArgument `ddl:"list,must_parentheses"` CopyGrants *bool `ddl:"keyword" sql:"COPY GRANTS"` ResultDataTypeOld DataType `ddl:"parameter,no_equals" sql:"RETURNS"` + ResultDataType datatypes.DataType `ddl:"parameter,no_quotes,no_equals" sql:"RETURNS"` ReturnNullValues *ReturnNullValues `ddl:"keyword"` languageScala bool `ddl:"static" sql:"LANGUAGE SCALA"` NullInputBehavior *NullInputBehavior `ddl:"keyword"` diff --git a/pkg/sdk/functions_gen_test.go b/pkg/sdk/functions_gen_test.go index b0c1c5b0b5..36f8a82892 100644 --- a/pkg/sdk/functions_gen_test.go +++ b/pkg/sdk/functions_gen_test.go @@ -47,7 +47,7 @@ func TestFunctions_CreateForJava(t *testing.T) { opts := defaultOpts() opts.Returns = FunctionReturns{ ResultDataType: &FunctionReturnsResultDataType{ - ResultDataType: DataTypeVARCHAR, + ResultDataTypeOld: DataTypeVARCHAR, }, } assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForJavaFunctionOptions", "Handler")) @@ -60,13 +60,13 @@ func TestFunctions_CreateForJava(t *testing.T) { opts.Secure = Bool(true) opts.Arguments = []FunctionArgument{ { - ArgName: "id", - ArgDataType: DataTypeNumber, + ArgName: "id", + ArgDataTypeOld: DataTypeNumber, }, { - ArgName: "name", - ArgDataType: DataTypeVARCHAR, - DefaultValue: String("'test'"), + ArgName: "name", + ArgDataTypeOld: DataTypeVARCHAR, + DefaultValue: String("'test'"), }, } opts.CopyGrants = Bool(true) @@ -74,12 +74,12 @@ func TestFunctions_CreateForJava(t *testing.T) { Table: &FunctionReturnsTable{ Columns: []FunctionColumn{ { - ColumnName: "country_code", - ColumnDataType: DataTypeVARCHAR, + ColumnName: "country_code", + ColumnDataTypeOld: DataTypeVARCHAR, }, { - ColumnName: "country_name", - ColumnDataType: DataTypeVARCHAR, + ColumnName: "country_name", + ColumnDataTypeOld: DataTypeVARCHAR, }, }, }, @@ -149,7 +149,7 @@ func TestFunctions_CreateForJavascript(t *testing.T) { opts := defaultOpts() opts.Returns = FunctionReturns{ ResultDataType: &FunctionReturnsResultDataType{ - ResultDataType: DataTypeVARCHAR, + ResultDataTypeOld: DataTypeVARCHAR, }, } assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForJavascriptFunctionOptions", "FunctionDefinition")) @@ -162,15 +162,15 @@ func TestFunctions_CreateForJavascript(t *testing.T) { opts.Secure = Bool(true) opts.Arguments = []FunctionArgument{ { - ArgName: "d", - ArgDataType: DataTypeFloat, - DefaultValue: String("1.0"), + ArgName: "d", + ArgDataTypeOld: DataTypeFloat, + DefaultValue: String("1.0"), }, } opts.CopyGrants = Bool(true) opts.Returns = FunctionReturns{ ResultDataType: &FunctionReturnsResultDataType{ - ResultDataType: DataTypeFloat, + ResultDataTypeOld: DataTypeFloat, }, } opts.ReturnNullValues = ReturnNullValuesPointer(ReturnNullValuesNotNull) @@ -212,7 +212,7 @@ func TestFunctions_CreateForPython(t *testing.T) { opts := defaultOpts() opts.Returns = FunctionReturns{ ResultDataType: &FunctionReturnsResultDataType{ - ResultDataType: DataTypeVARCHAR, + ResultDataTypeOld: DataTypeVARCHAR, }, } assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForPythonFunctionOptions", "RuntimeVersion")) @@ -236,15 +236,15 @@ func TestFunctions_CreateForPython(t *testing.T) { opts.Secure = Bool(true) opts.Arguments = []FunctionArgument{ { - ArgName: "i", - ArgDataType: DataTypeNumber, - DefaultValue: String("1"), + ArgName: "i", + ArgDataTypeOld: DataTypeNumber, + DefaultValue: String("1"), }, } opts.CopyGrants = Bool(true) opts.Returns = FunctionReturns{ ResultDataType: &FunctionReturnsResultDataType{ - ResultDataType: DataTypeVariant, + ResultDataTypeOld: DataTypeVariant, }, } opts.ReturnNullValues = ReturnNullValuesPointer(ReturnNullValuesNotNull) @@ -322,7 +322,7 @@ func TestFunctions_CreateForScala(t *testing.T) { t.Run("validation: options are missing", func(t *testing.T) { opts := defaultOpts() - opts.ResultDataType = DataTypeVARCHAR + opts.ResultDataTypeOld = DataTypeVARCHAR assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForScalaFunctionOptions", "Handler")) }) @@ -333,13 +333,13 @@ func TestFunctions_CreateForScala(t *testing.T) { opts.Secure = Bool(true) opts.Arguments = []FunctionArgument{ { - ArgName: "x", - ArgDataType: DataTypeVARCHAR, - DefaultValue: String("'test'"), + ArgName: "x", + ArgDataTypeOld: DataTypeVARCHAR, + DefaultValue: String("'test'"), }, } opts.CopyGrants = Bool(true) - opts.ResultDataType = DataTypeVARCHAR + opts.ResultDataTypeOld = DataTypeVARCHAR opts.ReturnNullValues = ReturnNullValuesPointer(ReturnNullValuesNotNull) opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorCalledOnNullInput) opts.ReturnResultsBehavior = ReturnResultsBehaviorPointer(ReturnResultsBehaviorImmutable) @@ -386,7 +386,7 @@ func TestFunctions_CreateForSQL(t *testing.T) { opts := defaultOpts() opts.Returns = FunctionReturns{ ResultDataType: &FunctionReturnsResultDataType{ - ResultDataType: DataTypeVARCHAR, + ResultDataTypeOld: DataTypeVARCHAR, }, } assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForSQLFunctionOptions", "FunctionDefinition")) @@ -396,7 +396,7 @@ func TestFunctions_CreateForSQL(t *testing.T) { opts := defaultOpts() opts.Returns = FunctionReturns{ ResultDataType: &FunctionReturnsResultDataType{ - ResultDataType: DataTypeFloat, + ResultDataTypeOld: DataTypeFloat, }, } opts.FunctionDefinition = "3.141592654::FLOAT" @@ -410,15 +410,15 @@ func TestFunctions_CreateForSQL(t *testing.T) { opts.Secure = Bool(true) opts.Arguments = []FunctionArgument{ { - ArgName: "message", - ArgDataType: "VARCHAR", - DefaultValue: String("'test'"), + ArgName: "message", + ArgDataTypeOld: "VARCHAR", + DefaultValue: String("'test'"), }, } opts.CopyGrants = Bool(true) opts.Returns = FunctionReturns{ ResultDataType: &FunctionReturnsResultDataType{ - ResultDataType: DataTypeFloat, + ResultDataTypeOld: DataTypeFloat, }, } opts.ReturnNullValues = ReturnNullValuesPointer(ReturnNullValuesNotNull) diff --git a/pkg/sdk/functions_impl_gen.go b/pkg/sdk/functions_impl_gen.go index 721d405e99..ca17781139 100644 --- a/pkg/sdk/functions_impl_gen.go +++ b/pkg/sdk/functions_impl_gen.go @@ -111,6 +111,7 @@ func (r *CreateForJavaFunctionRequest) toOpts() *CreateForJavaFunctionOptions { if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &FunctionReturnsResultDataType{ ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + ResultDataType: r.Returns.ResultDataType.ResultDataType, } } if r.Returns.Table != nil { @@ -166,6 +167,7 @@ func (r *CreateForJavascriptFunctionRequest) toOpts() *CreateForJavascriptFuncti if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &FunctionReturnsResultDataType{ ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + ResultDataType: r.Returns.ResultDataType.ResultDataType, } } if r.Returns.Table != nil { @@ -213,6 +215,7 @@ func (r *CreateForPythonFunctionRequest) toOpts() *CreateForPythonFunctionOption if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &FunctionReturnsResultDataType{ ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + ResultDataType: r.Returns.ResultDataType.ResultDataType, } } if r.Returns.Table != nil { @@ -252,6 +255,7 @@ func (r *CreateForScalaFunctionRequest) toOpts() *CreateForScalaFunctionOptions CopyGrants: r.CopyGrants, ResultDataTypeOld: r.ResultDataTypeOld, + ResultDataType: r.ResultDataType, ReturnNullValues: r.ReturnNullValues, NullInputBehavior: r.NullInputBehavior, ReturnResultsBehavior: r.ReturnResultsBehavior, @@ -312,6 +316,7 @@ func (r *CreateForSQLFunctionRequest) toOpts() *CreateForSQLFunctionOptions { if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &FunctionReturnsResultDataType{ ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + ResultDataType: r.Returns.ResultDataType.ResultDataType, } } if r.Returns.Table != nil { diff --git a/pkg/sdk/poc/README.md b/pkg/sdk/poc/README.md index 44af1e130b..46cb6e16b9 100644 --- a/pkg/sdk/poc/README.md +++ b/pkg/sdk/poc/README.md @@ -109,6 +109,7 @@ find a better solution to solve the issue (add more logic to the templates ?) - there should be no need to define custom types every time - more clear definition of lists that can be empty vs cannot be empty - add empty ids in generated tests (TODO in random_test.go) +- add optional imports (currently they have to be added manually, e.g. `datatypes.DataType`) ##### Known issues - generating two converts when Show and Desc use the same data structure diff --git a/pkg/sdk/sql_builder_test.go b/pkg/sdk/sql_builder_test.go index f708799159..5ac8d5a8a1 100644 --- a/pkg/sdk/sql_builder_test.go +++ b/pkg/sdk/sql_builder_test.go @@ -550,6 +550,19 @@ func TestBuilder_DataType(t *testing.T) { assert.Equal(t, "", s) }) + // TODO [this PR]: test all types as nil + t.Run("test data type nil", func(t *testing.T) { + var a *datatypes.BooleanDataType + opts := dataTypeTestHelper{ + DataType: a, + } + + s, err := structToSQL(opts) + + require.NoError(t, err) + assert.Equal(t, "", s) + }) + for _, tc := range dataTypes { tc := tc t.Run(fmt.Sprintf(`cheking building SQL for data type "%s, expecting "%s"`, tc.dataType, tc.expectedSql), func(t *testing.T) { diff --git a/pkg/sdk/testint/functions_integration_test.go b/pkg/sdk/testint/functions_integration_test.go index b61afbb037..a4ac7fa502 100644 --- a/pkg/sdk/testint/functions_integration_test.go +++ b/pkg/sdk/testint/functions_integration_test.go @@ -44,9 +44,9 @@ func TestInt_CreateFunctions(t *testing.T) { } }` target := fmt.Sprintf("@~/tf-%d.jar", time.Now().Unix()) - dt := sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeVARCHAR) + dt := sdk.NewFunctionReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeVARCHAR) returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt) - argument := sdk.NewFunctionArgumentRequest("x", sdk.DataTypeVARCHAR).WithDefaultValue("'abc'") + argument := sdk.NewFunctionArgumentRequest("x", nil).WithDefaultValue("'abc'").WithArgDataTypeOld(sdk.DataTypeVARCHAR) request := sdk.NewCreateForJavaFunctionRequest(id.SchemaObjectId(), *returns, "TestFunc.echoVarchar"). WithOrReplace(true). WithArguments([]sdk.FunctionArgumentRequest{*argument}). @@ -77,9 +77,9 @@ func TestInt_CreateFunctions(t *testing.T) { return result; }` - dt := sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeFloat) + dt := sdk.NewFunctionReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeFloat) returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt) - argument := sdk.NewFunctionArgumentRequest("d", sdk.DataTypeFloat) + argument := sdk.NewFunctionArgumentRequest("d", nil).WithArgDataTypeOld(sdk.DataTypeFloat) request := sdk.NewCreateForJavascriptFunctionRequest(id.SchemaObjectId(), *returns, definition). WithOrReplace(true). WithArguments([]sdk.FunctionArgumentRequest{*argument}). @@ -100,9 +100,9 @@ func TestInt_CreateFunctions(t *testing.T) { definition := ` def dump(i): print("Hello World!")` - dt := sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeVariant) + dt := sdk.NewFunctionReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeVariant) returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt) - argument := sdk.NewFunctionArgumentRequest("i", sdk.DataTypeNumber) + argument := sdk.NewFunctionArgumentRequest("i", nil).WithArgDataTypeOld(sdk.DataTypeNumber) request := sdk.NewCreateForPythonFunctionRequest(id.SchemaObjectId(), *returns, "3.8", "dump"). WithOrReplace(true). WithArguments([]sdk.FunctionArgumentRequest{*argument}). @@ -127,8 +127,9 @@ def dump(i): } }` - argument := sdk.NewFunctionArgumentRequest("x", sdk.DataTypeVARCHAR) - request := sdk.NewCreateForScalaFunctionRequest(id.SchemaObjectId(), sdk.DataTypeVARCHAR, "Echo.echoVarchar"). + argument := sdk.NewFunctionArgumentRequest("x", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) + request := sdk.NewCreateForScalaFunctionRequest(id.SchemaObjectId(), nil, "Echo.echoVarchar"). + WithResultDataTypeOld(sdk.DataTypeVARCHAR). WithOrReplace(true). WithArguments([]sdk.FunctionArgumentRequest{*argument}). WithRuntimeVersion("2.12"). @@ -148,9 +149,9 @@ def dump(i): definition := "3.141592654::FLOAT" - dt := sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeFloat) + dt := sdk.NewFunctionReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeFloat) returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt) - argument := sdk.NewFunctionArgumentRequest("x", sdk.DataTypeFloat) + argument := sdk.NewFunctionArgumentRequest("x", nil).WithArgDataTypeOld(sdk.DataTypeFloat) request := sdk.NewCreateForSQLFunctionRequest(id.SchemaObjectId(), *returns, definition). WithArguments([]sdk.FunctionArgumentRequest{*argument}). WithOrReplace(true). @@ -170,7 +171,7 @@ def dump(i): definition := "3.141592654::FLOAT" - dt := sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeFloat) + dt := sdk.NewFunctionReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeFloat) returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt) request := sdk.NewCreateForSQLFunctionRequest(id.SchemaObjectId(), *returns, definition). WithOrReplace(true). @@ -241,12 +242,12 @@ func TestInt_OtherFunctions(t *testing.T) { definition := "3.141592654::FLOAT" - dt := sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeFloat) + dt := sdk.NewFunctionReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeFloat) returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt) request := sdk.NewCreateForSQLFunctionRequest(id.SchemaObjectId(), *returns, definition). WithOrReplace(true) if withArguments { - argument := sdk.NewFunctionArgumentRequest("x", sdk.DataTypeFloat) + argument := sdk.NewFunctionArgumentRequest("x", nil).WithArgDataTypeOld(sdk.DataTypeFloat) request = request.WithArguments([]sdk.FunctionArgumentRequest{*argument}) } err := client.Functions.CreateForSQL(ctx, request) @@ -438,11 +439,11 @@ func TestInt_FunctionsShowByID(t *testing.T) { t.Helper() definition := "3.141592654::FLOAT" - dt := sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeFloat) + dt := sdk.NewFunctionReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeFloat) returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt) request := sdk.NewCreateForSQLFunctionRequest(id.SchemaObjectId(), *returns, definition).WithOrReplace(true) - argument := sdk.NewFunctionArgumentRequest("x", sdk.DataTypeFloat) + argument := sdk.NewFunctionArgumentRequest("x", nil).WithArgDataTypeOld(sdk.DataTypeFloat) request = request.WithArguments([]sdk.FunctionArgumentRequest{*argument}) err := client.Functions.CreateForSQL(ctx, request) require.NoError(t, err) @@ -504,34 +505,34 @@ func TestInt_FunctionsShowByID(t *testing.T) { id := testClientHelper().Ids.RandomSchemaObjectIdentifier() args := []sdk.FunctionArgumentRequest{ - *sdk.NewFunctionArgumentRequest("A", "NUMBER(2, 0)"), - *sdk.NewFunctionArgumentRequest("B", "DECIMAL"), - *sdk.NewFunctionArgumentRequest("C", "INTEGER"), - *sdk.NewFunctionArgumentRequest("D", sdk.DataTypeFloat), - *sdk.NewFunctionArgumentRequest("E", "DOUBLE"), - *sdk.NewFunctionArgumentRequest("F", "VARCHAR(20)"), - *sdk.NewFunctionArgumentRequest("G", "CHAR"), - *sdk.NewFunctionArgumentRequest("H", sdk.DataTypeString), - *sdk.NewFunctionArgumentRequest("I", "TEXT"), - *sdk.NewFunctionArgumentRequest("J", sdk.DataTypeBinary), - *sdk.NewFunctionArgumentRequest("K", "VARBINARY"), - *sdk.NewFunctionArgumentRequest("L", sdk.DataTypeBoolean), - *sdk.NewFunctionArgumentRequest("M", sdk.DataTypeDate), - *sdk.NewFunctionArgumentRequest("N", "DATETIME"), - *sdk.NewFunctionArgumentRequest("O", sdk.DataTypeTime), - *sdk.NewFunctionArgumentRequest("R", sdk.DataTypeTimestampLTZ), - *sdk.NewFunctionArgumentRequest("S", sdk.DataTypeTimestampNTZ), - *sdk.NewFunctionArgumentRequest("T", sdk.DataTypeTimestampTZ), - *sdk.NewFunctionArgumentRequest("U", sdk.DataTypeVariant), - *sdk.NewFunctionArgumentRequest("V", sdk.DataTypeObject), - *sdk.NewFunctionArgumentRequest("W", sdk.DataTypeArray), - *sdk.NewFunctionArgumentRequest("X", sdk.DataTypeGeography), - *sdk.NewFunctionArgumentRequest("Y", sdk.DataTypeGeometry), - *sdk.NewFunctionArgumentRequest("Z", "VECTOR(INT, 16)"), + *sdk.NewFunctionArgumentRequest("A", nil).WithArgDataTypeOld("NUMBER(2, 0)"), + *sdk.NewFunctionArgumentRequest("B", nil).WithArgDataTypeOld("DECIMAL"), + *sdk.NewFunctionArgumentRequest("C", nil).WithArgDataTypeOld("INTEGER"), + *sdk.NewFunctionArgumentRequest("D", nil).WithArgDataTypeOld(sdk.DataTypeFloat), + *sdk.NewFunctionArgumentRequest("E", nil).WithArgDataTypeOld("DOUBLE"), + *sdk.NewFunctionArgumentRequest("F", nil).WithArgDataTypeOld("VARCHAR(20)"), + *sdk.NewFunctionArgumentRequest("G", nil).WithArgDataTypeOld("CHAR"), + *sdk.NewFunctionArgumentRequest("H", nil).WithArgDataTypeOld(sdk.DataTypeString), + *sdk.NewFunctionArgumentRequest("I", nil).WithArgDataTypeOld("TEXT"), + *sdk.NewFunctionArgumentRequest("J", nil).WithArgDataTypeOld(sdk.DataTypeBinary), + *sdk.NewFunctionArgumentRequest("K", nil).WithArgDataTypeOld("VARBINARY"), + *sdk.NewFunctionArgumentRequest("L", nil).WithArgDataTypeOld(sdk.DataTypeBoolean), + *sdk.NewFunctionArgumentRequest("M", nil).WithArgDataTypeOld(sdk.DataTypeDate), + *sdk.NewFunctionArgumentRequest("N", nil).WithArgDataTypeOld("DATETIME"), + *sdk.NewFunctionArgumentRequest("O", nil).WithArgDataTypeOld(sdk.DataTypeTime), + *sdk.NewFunctionArgumentRequest("R", nil).WithArgDataTypeOld(sdk.DataTypeTimestampLTZ), + *sdk.NewFunctionArgumentRequest("S", nil).WithArgDataTypeOld(sdk.DataTypeTimestampNTZ), + *sdk.NewFunctionArgumentRequest("T", nil).WithArgDataTypeOld(sdk.DataTypeTimestampTZ), + *sdk.NewFunctionArgumentRequest("U", nil).WithArgDataTypeOld(sdk.DataTypeVariant), + *sdk.NewFunctionArgumentRequest("V", nil).WithArgDataTypeOld(sdk.DataTypeObject), + *sdk.NewFunctionArgumentRequest("W", nil).WithArgDataTypeOld(sdk.DataTypeArray), + *sdk.NewFunctionArgumentRequest("X", nil).WithArgDataTypeOld(sdk.DataTypeGeography), + *sdk.NewFunctionArgumentRequest("Y", nil).WithArgDataTypeOld(sdk.DataTypeGeometry), + *sdk.NewFunctionArgumentRequest("Z", nil).WithArgDataTypeOld("VECTOR(INT, 16)"), } err := client.Functions.CreateForPython(ctx, sdk.NewCreateForPythonFunctionRequest( id, - *sdk.NewFunctionReturnsRequest().WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeVariant)), + *sdk.NewFunctionReturnsRequest().WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeVariant)), "3.8", "add", ). From c215f981e5e38853e44a71a18b6d4ac0f7de185a Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Wed, 4 Dec 2024 12:37:41 +0100 Subject: [PATCH 09/29] Pass all function integration tests --- pkg/acceptance/helpers/function_client.go | 2 +- pkg/sdk/functions_def.go | 7 ++++--- pkg/sdk/functions_gen.go | 5 +++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/pkg/acceptance/helpers/function_client.go b/pkg/acceptance/helpers/function_client.go index 9a904de51a..3e6fe5a294 100644 --- a/pkg/acceptance/helpers/function_client.go +++ b/pkg/acceptance/helpers/function_client.go @@ -35,7 +35,7 @@ func (c *FunctionClient) CreateWithIdentifier(t *testing.T, id sdk.SchemaObjectI return c.CreateWithRequest(t, id, sdk.NewCreateForSQLFunctionRequest( id.SchemaObjectId(), - *sdk.NewFunctionReturnsRequest().WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeVariant)), + *sdk.NewFunctionReturnsRequest().WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeInt)), "SELECT 1", ), ) diff --git a/pkg/sdk/functions_def.go b/pkg/sdk/functions_def.go index af3dbcdbdb..7f1aa0be59 100644 --- a/pkg/sdk/functions_def.go +++ b/pkg/sdk/functions_def.go @@ -20,7 +20,7 @@ var functionReturns = g.NewQueryStruct("FunctionReturns"). "ResultDataType", g.NewQueryStruct("FunctionReturnsResultDataType"). PredefinedQueryStructField("ResultDataTypeOld", "DataType", g.KeywordOptions().NoQuotes()). - PredefinedQueryStructField("ResultDataType", "datatypes.DataType", g.KeywordOptions().NoQuotes().Required()), + PredefinedQueryStructField("ResultDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()), g.KeywordOptions(), ). OptionalQueryStructField( @@ -177,8 +177,9 @@ var FunctionsDef = g.NewInterface( functionArgument, g.ListOptions().MustParentheses()). OptionalSQL("COPY GRANTS"). - PredefinedQueryStructField("ResultDataTypeOld", "DataType", g.ParameterOptions().NoEquals().SQL("RETURNS")). - PredefinedQueryStructField("ResultDataType", "datatypes.DataType", g.ParameterOptions().NoEquals().SQL("RETURNS").Required()). + SQL("RETURNS"). + PredefinedQueryStructField("ResultDataTypeOld", "DataType", g.ParameterOptions().NoEquals()). + PredefinedQueryStructField("ResultDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()). PredefinedQueryStructField("ReturnNullValues", "*ReturnNullValues", g.KeywordOptions()). SQL("LANGUAGE SCALA"). PredefinedQueryStructField("NullInputBehavior", "*NullInputBehavior", g.KeywordOptions()). diff --git a/pkg/sdk/functions_gen.go b/pkg/sdk/functions_gen.go index 78ab2e1bd5..ab7ca62170 100644 --- a/pkg/sdk/functions_gen.go +++ b/pkg/sdk/functions_gen.go @@ -139,8 +139,9 @@ type CreateForScalaFunctionOptions struct { name SchemaObjectIdentifier `ddl:"identifier"` Arguments []FunctionArgument `ddl:"list,must_parentheses"` CopyGrants *bool `ddl:"keyword" sql:"COPY GRANTS"` - ResultDataTypeOld DataType `ddl:"parameter,no_equals" sql:"RETURNS"` - ResultDataType datatypes.DataType `ddl:"parameter,no_quotes,no_equals" sql:"RETURNS"` + returns bool `ddl:"static" sql:"RETURNS"` + ResultDataTypeOld DataType `ddl:"parameter,no_equals"` + ResultDataType datatypes.DataType `ddl:"parameter,no_quotes,no_equals"` ReturnNullValues *ReturnNullValues `ddl:"keyword"` languageScala bool `ddl:"static" sql:"LANGUAGE SCALA"` NullInputBehavior *NullInputBehavior `ddl:"keyword"` From 1605e9d47dc3552326b51a5b7c8ce1fd85c861b8 Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Wed, 4 Dec 2024 12:38:45 +0100 Subject: [PATCH 10/29] Add TODO --- pkg/sdk/functions_gen_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/sdk/functions_gen_test.go b/pkg/sdk/functions_gen_test.go index 36f8a82892..6456719f32 100644 --- a/pkg/sdk/functions_gen_test.go +++ b/pkg/sdk/functions_gen_test.go @@ -4,6 +4,7 @@ import ( "testing" ) +// TODO [this PR]: unit test new data types func TestFunctions_CreateForJava(t *testing.T) { id := randomSchemaObjectIdentifier() From ba0a3dcfc76f8f7a16c61c645951b42977cd28d2 Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Wed, 4 Dec 2024 13:00:29 +0100 Subject: [PATCH 11/29] Test all nil data types in SQL builder --- pkg/sdk/datatypes/boolean.go | 3 --- pkg/sdk/sql_builder.go | 4 ++++ pkg/sdk/sql_builder_test.go | 41 +++++++++++++++++++++++++++--------- 3 files changed, 35 insertions(+), 13 deletions(-) diff --git a/pkg/sdk/datatypes/boolean.go b/pkg/sdk/datatypes/boolean.go index 4a6c617bcf..4e84979f40 100644 --- a/pkg/sdk/datatypes/boolean.go +++ b/pkg/sdk/datatypes/boolean.go @@ -7,9 +7,6 @@ type BooleanDataType struct { } func (t *BooleanDataType) ToSql() string { - if t == nil { - return "" - } return t.underlyingType } diff --git a/pkg/sdk/sql_builder.go b/pkg/sdk/sql_builder.go index 6385781640..b6d9f7469b 100644 --- a/pkg/sdk/sql_builder.go +++ b/pkg/sdk/sql_builder.go @@ -646,6 +646,10 @@ func (v sqlParameterClause) String() string { } var value = v.value if dataType, ok := value.(datatypes.DataType); ok { + // We check like this and not by `dataType == nil` because for e.g. `var *datatypes.ArrayDataType` return false in a normal nil check + if reflect.ValueOf(dataType).IsZero() { + return s + } value = dataType.ToSql() } // key = "value" diff --git a/pkg/sdk/sql_builder_test.go b/pkg/sdk/sql_builder_test.go index 5ac8d5a8a1..62aaa239d9 100644 --- a/pkg/sdk/sql_builder_test.go +++ b/pkg/sdk/sql_builder_test.go @@ -541,6 +541,26 @@ func TestBuilder_DataType(t *testing.T) { {dataType: "VECTOR(float, 20)", expectedSql: "VECTOR(FLOAT, 20)"}, } + nilTestCases := func() []datatypes.DataType { + var a *datatypes.ArrayDataType + var b *datatypes.BinaryDataType + var c *datatypes.BooleanDataType + var d *datatypes.DateDataType + var e *datatypes.FloatDataType + var f *datatypes.GeographyDataType + var g *datatypes.GeometryDataType + var h *datatypes.NumberDataType + var i *datatypes.ObjectDataType + var j *datatypes.TextDataType + var k *datatypes.TimeDataType + var l *datatypes.TimestampLtzDataType + var m *datatypes.TimestampNtzDataType + var n *datatypes.TimestampTzDataType + var o *datatypes.VariantDataType + var p *datatypes.VectorDataType + + return []datatypes.DataType{a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p} + }() t.Run("test data type empty", func(t *testing.T) { opts := dataTypeTestHelper{} @@ -550,18 +570,19 @@ func TestBuilder_DataType(t *testing.T) { assert.Equal(t, "", s) }) - // TODO [this PR]: test all types as nil - t.Run("test data type nil", func(t *testing.T) { - var a *datatypes.BooleanDataType - opts := dataTypeTestHelper{ - DataType: a, - } + for _, tc := range nilTestCases { + tc := tc + t.Run(fmt.Sprintf(`test for nil data type "%s"`, reflect.TypeOf(tc)), func(t *testing.T) { + opts := dataTypeTestHelper{ + DataType: tc, + } - s, err := structToSQL(opts) + s, err := structToSQL(opts) - require.NoError(t, err) - assert.Equal(t, "", s) - }) + require.NoError(t, err) + assert.Equal(t, "", s) + }) + } for _, tc := range dataTypes { tc := tc From c7471b2b99eb0ca910a1c602ad511a403d13c804 Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Wed, 4 Dec 2024 14:28:14 +0100 Subject: [PATCH 12/29] Update function creation unit tests --- pkg/sdk/functions_gen_test.go | 232 ++++++++++++++++++++++++++++++++-- pkg/sdk/random_test.go | 7 + 2 files changed, 227 insertions(+), 12 deletions(-) diff --git a/pkg/sdk/functions_gen_test.go b/pkg/sdk/functions_gen_test.go index 6456719f32..865a2911bc 100644 --- a/pkg/sdk/functions_gen_test.go +++ b/pkg/sdk/functions_gen_test.go @@ -4,7 +4,6 @@ import ( "testing" ) -// TODO [this PR]: unit test new data types func TestFunctions_CreateForJava(t *testing.T) { id := randomSchemaObjectIdentifier() @@ -48,13 +47,14 @@ func TestFunctions_CreateForJava(t *testing.T) { opts := defaultOpts() opts.Returns = FunctionReturns{ ResultDataType: &FunctionReturnsResultDataType{ - ResultDataTypeOld: DataTypeVARCHAR, + ResultDataType: dataTypeVarchar, }, } assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForJavaFunctionOptions", "Handler")) }) - t.Run("all options", func(t *testing.T) { + // TODO [next PR]: remove with old function removal for V1 + t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.OrReplace = Bool(true) opts.Temporary = Bool(true) @@ -118,6 +118,71 @@ func TestFunctions_CreateForJava(t *testing.T) { opts.FunctionDefinition = String("return id + name;") assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE TEMPORARY SECURE FUNCTION %s (id NUMBER, name VARCHAR DEFAULT 'test') COPY GRANTS RETURNS TABLE (country_code VARCHAR, country_name VARCHAR) NOT NULL LANGUAGE JAVA CALLED ON NULL INPUT IMMUTABLE RUNTIME_VERSION = '2.0' COMMENT = 'comment' IMPORTS = ('@~/my_decrement_udf_package_dir/my_decrement_udf_jar.jar') PACKAGES = ('com.snowflake:snowpark:1.2.0') HANDLER = 'TestFunc.echoVarchar' EXTERNAL_ACCESS_INTEGRATIONS = ("ext_integration") SECRETS = ('variable1' = name1, 'variable2' = name2) TARGET_PATH = '@~/testfunc.jar' AS 'return id + name;'`, id.FullyQualifiedName()) }) + + t.Run("all options", func(t *testing.T) { + opts := defaultOpts() + opts.OrReplace = Bool(true) + opts.Temporary = Bool(true) + opts.Secure = Bool(true) + opts.Arguments = []FunctionArgument{ + { + ArgName: "id", + ArgDataType: dataTypeNumber, + }, + { + ArgName: "name", + ArgDataType: dataTypeVarchar, + DefaultValue: String("'test'"), + }, + } + opts.CopyGrants = Bool(true) + opts.Returns = FunctionReturns{ + Table: &FunctionReturnsTable{ + Columns: []FunctionColumn{ + { + ColumnName: "country_code", + ColumnDataType: dataTypeVarchar, + }, + { + ColumnName: "country_name", + ColumnDataType: dataTypeVarchar, + }, + }, + }, + } + opts.ReturnNullValues = ReturnNullValuesPointer(ReturnNullValuesNotNull) + opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorCalledOnNullInput) + opts.ReturnResultsBehavior = ReturnResultsBehaviorPointer(ReturnResultsBehaviorImmutable) + opts.RuntimeVersion = String("2.0") + opts.Comment = String("comment") + opts.Imports = []FunctionImport{ + { + Import: "@~/my_decrement_udf_package_dir/my_decrement_udf_jar.jar", + }, + } + opts.Packages = []FunctionPackage{ + { + Package: "com.snowflake:snowpark:1.2.0", + }, + } + opts.Handler = "TestFunc.echoVarchar" + opts.ExternalAccessIntegrations = []AccountObjectIdentifier{ + NewAccountObjectIdentifier("ext_integration"), + } + opts.Secrets = []SecretReference{ + { + VariableName: "variable1", + Name: "name1", + }, + { + VariableName: "variable2", + Name: "name2", + }, + } + opts.TargetPath = String("@~/testfunc.jar") + opts.FunctionDefinition = String("return id + name;") + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE TEMPORARY SECURE FUNCTION %s (id NUMBER(36, 2), name VARCHAR(100) DEFAULT 'test') COPY GRANTS RETURNS TABLE (country_code VARCHAR(100), country_name VARCHAR(100)) NOT NULL LANGUAGE JAVA CALLED ON NULL INPUT IMMUTABLE RUNTIME_VERSION = '2.0' COMMENT = 'comment' IMPORTS = ('@~/my_decrement_udf_package_dir/my_decrement_udf_jar.jar') PACKAGES = ('com.snowflake:snowpark:1.2.0') HANDLER = 'TestFunc.echoVarchar' EXTERNAL_ACCESS_INTEGRATIONS = ("ext_integration") SECRETS = ('variable1' = name1, 'variable2' = name2) TARGET_PATH = '@~/testfunc.jar' AS 'return id + name;'`, id.FullyQualifiedName()) + }) } func TestFunctions_CreateForJavascript(t *testing.T) { @@ -150,13 +215,14 @@ func TestFunctions_CreateForJavascript(t *testing.T) { opts := defaultOpts() opts.Returns = FunctionReturns{ ResultDataType: &FunctionReturnsResultDataType{ - ResultDataTypeOld: DataTypeVARCHAR, + ResultDataType: dataTypeVarchar, }, } assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForJavascriptFunctionOptions", "FunctionDefinition")) }) - t.Run("all options", func(t *testing.T) { + // TODO [next PR]: remove with old function removal for V1 + t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.OrReplace = Bool(true) opts.Temporary = Bool(true) @@ -181,6 +247,33 @@ func TestFunctions_CreateForJavascript(t *testing.T) { opts.FunctionDefinition = "return 1;" assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE TEMPORARY SECURE FUNCTION %s (d FLOAT DEFAULT 1.0) COPY GRANTS RETURNS FLOAT NOT NULL LANGUAGE JAVASCRIPT CALLED ON NULL INPUT IMMUTABLE COMMENT = 'comment' AS 'return 1;'`, id.FullyQualifiedName()) }) + + t.Run("all options", func(t *testing.T) { + opts := defaultOpts() + opts.OrReplace = Bool(true) + opts.Temporary = Bool(true) + opts.Secure = Bool(true) + opts.Arguments = []FunctionArgument{ + { + ArgName: "d", + ArgDataType: dataTypeFloat, + DefaultValue: String("1.0"), + }, + } + opts.CopyGrants = Bool(true) + opts.Returns = FunctionReturns{ + ResultDataType: &FunctionReturnsResultDataType{ + ResultDataType: dataTypeFloat, + }, + } + opts.ReturnNullValues = ReturnNullValuesPointer(ReturnNullValuesNotNull) + opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorCalledOnNullInput) + opts.ReturnResultsBehavior = ReturnResultsBehaviorPointer(ReturnResultsBehaviorImmutable) + opts.Comment = String("comment") + opts.FunctionDefinition = "return 1;" + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE TEMPORARY SECURE FUNCTION %s (d FLOAT DEFAULT 1.0) COPY GRANTS RETURNS FLOAT NOT NULL LANGUAGE JAVASCRIPT CALLED ON NULL INPUT IMMUTABLE COMMENT = 'comment' AS 'return 1;'`, id.FullyQualifiedName()) + }) + } func TestFunctions_CreateForPython(t *testing.T) { @@ -213,7 +306,7 @@ func TestFunctions_CreateForPython(t *testing.T) { opts := defaultOpts() opts.Returns = FunctionReturns{ ResultDataType: &FunctionReturnsResultDataType{ - ResultDataTypeOld: DataTypeVARCHAR, + ResultDataType: dataTypeVarchar, }, } assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForPythonFunctionOptions", "RuntimeVersion")) @@ -230,7 +323,8 @@ func TestFunctions_CreateForPython(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, NewError("IMPORTS must not be empty when AS is nil")) }) - t.Run("all options", func(t *testing.T) { + // TODO [next PR]: remove with old function removal for V1 + t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.OrReplace = Bool(true) opts.Temporary = Bool(true) @@ -286,6 +380,63 @@ func TestFunctions_CreateForPython(t *testing.T) { opts.FunctionDefinition = String("import numpy as np") assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE TEMPORARY SECURE FUNCTION %s (i NUMBER DEFAULT 1) COPY GRANTS RETURNS VARIANT NOT NULL LANGUAGE PYTHON CALLED ON NULL INPUT IMMUTABLE RUNTIME_VERSION = '3.8' COMMENT = 'comment' IMPORTS = ('numpy', 'pandas') PACKAGES = ('numpy', 'pandas') HANDLER = 'udf' EXTERNAL_ACCESS_INTEGRATIONS = ("ext_integration") SECRETS = ('variable1' = name1, 'variable2' = name2) AS 'import numpy as np'`, id.FullyQualifiedName()) }) + + t.Run("all options", func(t *testing.T) { + opts := defaultOpts() + opts.OrReplace = Bool(true) + opts.Temporary = Bool(true) + opts.Secure = Bool(true) + opts.Arguments = []FunctionArgument{ + { + ArgName: "i", + ArgDataType: dataTypeNumber, + DefaultValue: String("1"), + }, + } + opts.CopyGrants = Bool(true) + opts.Returns = FunctionReturns{ + ResultDataType: &FunctionReturnsResultDataType{ + ResultDataType: dataTypeVariant, + }, + } + opts.ReturnNullValues = ReturnNullValuesPointer(ReturnNullValuesNotNull) + opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorCalledOnNullInput) + opts.ReturnResultsBehavior = ReturnResultsBehaviorPointer(ReturnResultsBehaviorImmutable) + opts.RuntimeVersion = "3.8" + opts.Comment = String("comment") + opts.Imports = []FunctionImport{ + { + Import: "numpy", + }, + { + Import: "pandas", + }, + } + opts.Packages = []FunctionPackage{ + { + Package: "numpy", + }, + { + Package: "pandas", + }, + } + opts.Handler = "udf" + opts.ExternalAccessIntegrations = []AccountObjectIdentifier{ + NewAccountObjectIdentifier("ext_integration"), + } + opts.Secrets = []SecretReference{ + { + VariableName: "variable1", + Name: "name1", + }, + { + VariableName: "variable2", + Name: "name2", + }, + } + opts.FunctionDefinition = String("import numpy as np") + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE TEMPORARY SECURE FUNCTION %s (i NUMBER(36, 2) DEFAULT 1) COPY GRANTS RETURNS VARIANT NOT NULL LANGUAGE PYTHON CALLED ON NULL INPUT IMMUTABLE RUNTIME_VERSION = '3.8' COMMENT = 'comment' IMPORTS = ('numpy', 'pandas') PACKAGES = ('numpy', 'pandas') HANDLER = 'udf' EXTERNAL_ACCESS_INTEGRATIONS = ("ext_integration") SECRETS = ('variable1' = name1, 'variable2' = name2) AS 'import numpy as np'`, id.FullyQualifiedName()) + }) } func TestFunctions_CreateForScala(t *testing.T) { @@ -323,11 +474,12 @@ func TestFunctions_CreateForScala(t *testing.T) { t.Run("validation: options are missing", func(t *testing.T) { opts := defaultOpts() - opts.ResultDataTypeOld = DataTypeVARCHAR + opts.ResultDataType = dataTypeVarchar assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForScalaFunctionOptions", "Handler")) }) - t.Run("all options", func(t *testing.T) { + // TODO [next PR]: remove with old function removal for V1 + t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.OrReplace = Bool(true) opts.Temporary = Bool(true) @@ -355,6 +507,35 @@ func TestFunctions_CreateForScala(t *testing.T) { opts.FunctionDefinition = String("return x") assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE TEMPORARY SECURE FUNCTION %s (x VARCHAR DEFAULT 'test') COPY GRANTS RETURNS VARCHAR NOT NULL LANGUAGE SCALA CALLED ON NULL INPUT IMMUTABLE RUNTIME_VERSION = '2.0' COMMENT = 'comment' IMPORTS = ('@udf_libs/echohandler.jar') HANDLER = 'Echo.echoVarchar' AS 'return x'`, id.FullyQualifiedName()) }) + + t.Run("all options", func(t *testing.T) { + opts := defaultOpts() + opts.OrReplace = Bool(true) + opts.Temporary = Bool(true) + opts.Secure = Bool(true) + opts.Arguments = []FunctionArgument{ + { + ArgName: "x", + ArgDataType: dataTypeVarchar, + DefaultValue: String("'test'"), + }, + } + opts.CopyGrants = Bool(true) + opts.ResultDataType = dataTypeVarchar + opts.ReturnNullValues = ReturnNullValuesPointer(ReturnNullValuesNotNull) + opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorCalledOnNullInput) + opts.ReturnResultsBehavior = ReturnResultsBehaviorPointer(ReturnResultsBehaviorImmutable) + opts.RuntimeVersion = String("2.0") + opts.Comment = String("comment") + opts.Imports = []FunctionImport{ + { + Import: "@udf_libs/echohandler.jar", + }, + } + opts.Handler = "Echo.echoVarchar" + opts.FunctionDefinition = String("return x") + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE TEMPORARY SECURE FUNCTION %s (x VARCHAR(100) DEFAULT 'test') COPY GRANTS RETURNS VARCHAR(100) NOT NULL LANGUAGE SCALA CALLED ON NULL INPUT IMMUTABLE RUNTIME_VERSION = '2.0' COMMENT = 'comment' IMPORTS = ('@udf_libs/echohandler.jar') HANDLER = 'Echo.echoVarchar' AS 'return x'`, id.FullyQualifiedName()) + }) } func TestFunctions_CreateForSQL(t *testing.T) { @@ -387,7 +568,7 @@ func TestFunctions_CreateForSQL(t *testing.T) { opts := defaultOpts() opts.Returns = FunctionReturns{ ResultDataType: &FunctionReturnsResultDataType{ - ResultDataTypeOld: DataTypeVARCHAR, + ResultDataType: dataTypeVarchar, }, } assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForSQLFunctionOptions", "FunctionDefinition")) @@ -397,14 +578,15 @@ func TestFunctions_CreateForSQL(t *testing.T) { opts := defaultOpts() opts.Returns = FunctionReturns{ ResultDataType: &FunctionReturnsResultDataType{ - ResultDataTypeOld: DataTypeFloat, + ResultDataType: dataTypeFloat, }, } opts.FunctionDefinition = "3.141592654::FLOAT" assertOptsValidAndSQLEquals(t, opts, `CREATE FUNCTION %s () RETURNS FLOAT AS '3.141592654::FLOAT'`, id.FullyQualifiedName()) }) - t.Run("all options", func(t *testing.T) { + // TODO [next PR]: remove with old function removal for V1 + t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.OrReplace = Bool(true) opts.Temporary = Bool(true) @@ -429,6 +611,32 @@ func TestFunctions_CreateForSQL(t *testing.T) { opts.FunctionDefinition = "3.141592654::FLOAT" assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE TEMPORARY SECURE FUNCTION %s (message VARCHAR DEFAULT 'test') COPY GRANTS RETURNS FLOAT NOT NULL IMMUTABLE MEMOIZABLE COMMENT = 'comment' AS '3.141592654::FLOAT'`, id.FullyQualifiedName()) }) + + t.Run("all options", func(t *testing.T) { + opts := defaultOpts() + opts.OrReplace = Bool(true) + opts.Temporary = Bool(true) + opts.Secure = Bool(true) + opts.Arguments = []FunctionArgument{ + { + ArgName: "message", + ArgDataType: dataTypeVarchar, + DefaultValue: String("'test'"), + }, + } + opts.CopyGrants = Bool(true) + opts.Returns = FunctionReturns{ + ResultDataType: &FunctionReturnsResultDataType{ + ResultDataType: dataTypeFloat, + }, + } + opts.ReturnNullValues = ReturnNullValuesPointer(ReturnNullValuesNotNull) + opts.ReturnResultsBehavior = ReturnResultsBehaviorPointer(ReturnResultsBehaviorImmutable) + opts.Memoizable = Bool(true) + opts.Comment = String("comment") + opts.FunctionDefinition = "3.141592654::FLOAT" + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE TEMPORARY SECURE FUNCTION %s (message VARCHAR(100) DEFAULT 'test') COPY GRANTS RETURNS FLOAT NOT NULL IMMUTABLE MEMOIZABLE COMMENT = 'comment' AS '3.141592654::FLOAT'`, id.FullyQualifiedName()) + }) } func TestFunctions_Drop(t *testing.T) { diff --git a/pkg/sdk/random_test.go b/pkg/sdk/random_test.go index 552eb68c15..83880167df 100644 --- a/pkg/sdk/random_test.go +++ b/pkg/sdk/random_test.go @@ -2,6 +2,7 @@ package sdk import ( "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/helpers/random" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" ) var ( @@ -14,6 +15,12 @@ var ( emptyDatabaseObjectIdentifier = NewDatabaseObjectIdentifier("", "") emptySchemaObjectIdentifier = NewSchemaObjectIdentifier("", "", "") emptySchemaObjectIdentifierWithArguments = NewSchemaObjectIdentifierWithArguments("", "", "") + + // TODO [SNOW-1843440]: create using constructors (when we add them)? + dataTypeNumber, _ = datatypes.ParseDataType("NUMBER(36, 2)") + dataTypeVarchar, _ = datatypes.ParseDataType("VARCHAR(100)") + dataTypeFloat, _ = datatypes.ParseDataType("FLOAT") + dataTypeVariant, _ = datatypes.ParseDataType("VARIANT") ) func randomSchemaObjectIdentifierWithArguments(argumentDataTypes ...DataType) SchemaObjectIdentifierWithArguments { From aa02897fd1b565be256094f6af3d8fb3d51079ec Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Wed, 4 Dec 2024 15:30:40 +0100 Subject: [PATCH 13/29] Pass functions integration test with different data types --- pkg/sdk/datatypes/array.go | 4 + pkg/sdk/datatypes/binary.go | 4 + pkg/sdk/datatypes/boolean.go | 4 + pkg/sdk/datatypes/data_types.go | 1 + pkg/sdk/datatypes/date.go | 4 + pkg/sdk/datatypes/float.go | 4 + pkg/sdk/datatypes/geography.go | 4 + pkg/sdk/datatypes/geometry.go | 4 + pkg/sdk/datatypes/number.go | 4 + pkg/sdk/datatypes/object.go | 4 + pkg/sdk/datatypes/text.go | 4 + pkg/sdk/datatypes/time.go | 4 + pkg/sdk/datatypes/timestamp_ltz.go | 4 + pkg/sdk/datatypes/timestamp_ntz.go | 4 + pkg/sdk/datatypes/timestamp_tz.go | 4 + pkg/sdk/datatypes/variant.go | 4 + pkg/sdk/datatypes/vector.go | 4 + pkg/sdk/sql_builder_test.go | 4 +- pkg/sdk/testint/functions_integration_test.go | 77 ++++++++++++++++++- 19 files changed, 143 insertions(+), 3 deletions(-) diff --git a/pkg/sdk/datatypes/array.go b/pkg/sdk/datatypes/array.go index eb7247f6e6..835f48ae29 100644 --- a/pkg/sdk/datatypes/array.go +++ b/pkg/sdk/datatypes/array.go @@ -14,6 +14,10 @@ func (t *ArrayDataType) ToLegacyDataTypeSql() string { return ArrayLegacyDataType } +func (t *ArrayDataType) Canonical() string { + return ArrayLegacyDataType +} + var ArrayDataTypeSynonyms = []string{ArrayLegacyDataType} func parseArrayDataTypeRaw(raw sanitizedDataTypeRaw) (*ArrayDataType, error) { diff --git a/pkg/sdk/datatypes/binary.go b/pkg/sdk/datatypes/binary.go index c50dba0570..07181e3aaf 100644 --- a/pkg/sdk/datatypes/binary.go +++ b/pkg/sdk/datatypes/binary.go @@ -25,6 +25,10 @@ func (t *BinaryDataType) ToLegacyDataTypeSql() string { return BinaryLegacyDataType } +func (t *BinaryDataType) Canonical() string { + return fmt.Sprintf("%s(%d)", BinaryLegacyDataType, t.size) +} + var BinaryDataTypeSynonyms = []string{BinaryLegacyDataType, "VARBINARY"} func parseBinaryDataTypeRaw(raw sanitizedDataTypeRaw) (*BinaryDataType, error) { diff --git a/pkg/sdk/datatypes/boolean.go b/pkg/sdk/datatypes/boolean.go index 4e84979f40..56f84a4064 100644 --- a/pkg/sdk/datatypes/boolean.go +++ b/pkg/sdk/datatypes/boolean.go @@ -14,6 +14,10 @@ func (t *BooleanDataType) ToLegacyDataTypeSql() string { return BooleanLegacyDataType } +func (t *BooleanDataType) Canonical() string { + return BooleanLegacyDataType +} + var BooleanDataTypeSynonyms = []string{BooleanLegacyDataType} func parseBooleanDataTypeRaw(raw sanitizedDataTypeRaw) (*BooleanDataType, error) { diff --git a/pkg/sdk/datatypes/data_types.go b/pkg/sdk/datatypes/data_types.go index e1c0065855..1449c91f75 100644 --- a/pkg/sdk/datatypes/data_types.go +++ b/pkg/sdk/datatypes/data_types.go @@ -16,6 +16,7 @@ import ( type DataType interface { ToSql() string ToLegacyDataTypeSql() string + Canonical() string } type sanitizedDataTypeRaw struct { diff --git a/pkg/sdk/datatypes/date.go b/pkg/sdk/datatypes/date.go index 92ee7c27bc..c962a4a831 100644 --- a/pkg/sdk/datatypes/date.go +++ b/pkg/sdk/datatypes/date.go @@ -14,6 +14,10 @@ func (t *DateDataType) ToLegacyDataTypeSql() string { return DateLegacyDataType } +func (t *DateDataType) Canonical() string { + return DateLegacyDataType +} + var DateDataTypeSynonyms = []string{DateLegacyDataType} func parseDateDataTypeRaw(raw sanitizedDataTypeRaw) (*DateDataType, error) { diff --git a/pkg/sdk/datatypes/float.go b/pkg/sdk/datatypes/float.go index a0ca84863b..36fe0d9be0 100644 --- a/pkg/sdk/datatypes/float.go +++ b/pkg/sdk/datatypes/float.go @@ -14,6 +14,10 @@ func (t *FloatDataType) ToLegacyDataTypeSql() string { return FloatLegacyDataType } +func (t *FloatDataType) Canonical() string { + return FloatLegacyDataType +} + var FloatDataTypeSynonyms = []string{"FLOAT8", "FLOAT4", FloatLegacyDataType, "DOUBLE PRECISION", "DOUBLE", "REAL"} func parseFloatDataTypeRaw(raw sanitizedDataTypeRaw) (*FloatDataType, error) { diff --git a/pkg/sdk/datatypes/geography.go b/pkg/sdk/datatypes/geography.go index 4a024a20b0..43ee148212 100644 --- a/pkg/sdk/datatypes/geography.go +++ b/pkg/sdk/datatypes/geography.go @@ -14,6 +14,10 @@ func (t *GeographyDataType) ToLegacyDataTypeSql() string { return GeographyLegacyDataType } +func (t *GeographyDataType) Canonical() string { + return GeographyLegacyDataType +} + var GeographyDataTypeSynonyms = []string{GeographyLegacyDataType} func parseGeographyDataTypeRaw(raw sanitizedDataTypeRaw) (*GeographyDataType, error) { diff --git a/pkg/sdk/datatypes/geometry.go b/pkg/sdk/datatypes/geometry.go index d09ebd3eea..8ab62e817b 100644 --- a/pkg/sdk/datatypes/geometry.go +++ b/pkg/sdk/datatypes/geometry.go @@ -14,6 +14,10 @@ func (t *GeometryDataType) ToLegacyDataTypeSql() string { return GeometryLegacyDataType } +func (t *GeometryDataType) Canonical() string { + return GeometryLegacyDataType +} + var GeometryDataTypeSynonyms = []string{GeometryLegacyDataType} func parseGeometryDataTypeRaw(raw sanitizedDataTypeRaw) (*GeometryDataType, error) { diff --git a/pkg/sdk/datatypes/number.go b/pkg/sdk/datatypes/number.go index 378c9e3205..cd11467717 100644 --- a/pkg/sdk/datatypes/number.go +++ b/pkg/sdk/datatypes/number.go @@ -34,6 +34,10 @@ func (t *NumberDataType) ToLegacyDataTypeSql() string { return NumberLegacyDataType } +func (t *NumberDataType) Canonical() string { + return fmt.Sprintf("%s(%d,%d)", NumberLegacyDataType, t.precision, t.scale) +} + var ( NumberDataTypeSynonyms = []string{NumberLegacyDataType, "DECIMAL", "DEC", "NUMERIC"} NumberDataTypeSubTypes = []string{"INTEGER", "INT", "BIGINT", "SMALLINT", "TINYINT", "BYTEINT"} diff --git a/pkg/sdk/datatypes/object.go b/pkg/sdk/datatypes/object.go index fe333aa7b0..098b04b0be 100644 --- a/pkg/sdk/datatypes/object.go +++ b/pkg/sdk/datatypes/object.go @@ -14,6 +14,10 @@ func (t *ObjectDataType) ToLegacyDataTypeSql() string { return ObjectLegacyDataType } +func (t *ObjectDataType) Canonical() string { + return ObjectLegacyDataType +} + var ObjectDataTypeSynonyms = []string{ObjectLegacyDataType} func parseObjectDataTypeRaw(raw sanitizedDataTypeRaw) (*ObjectDataType, error) { diff --git a/pkg/sdk/datatypes/text.go b/pkg/sdk/datatypes/text.go index 2598253101..c05d64f18c 100644 --- a/pkg/sdk/datatypes/text.go +++ b/pkg/sdk/datatypes/text.go @@ -30,6 +30,10 @@ func (t *TextDataType) ToLegacyDataTypeSql() string { return VarcharLegacyDataType } +func (t *TextDataType) Canonical() string { + return fmt.Sprintf("%s(%d)", VarcharLegacyDataType, t.length) +} + var ( TextDataTypeSynonyms = []string{VarcharLegacyDataType, "STRING", "TEXT", "NVARCHAR2", "NVARCHAR", "CHAR VARYING", "NCHAR VARYING"} TextDataTypeSubtypes = []string{"CHARACTER", "CHAR", "NCHAR"} diff --git a/pkg/sdk/datatypes/time.go b/pkg/sdk/datatypes/time.go index ee79421122..e33223c104 100644 --- a/pkg/sdk/datatypes/time.go +++ b/pkg/sdk/datatypes/time.go @@ -25,6 +25,10 @@ func (t *TimeDataType) ToLegacyDataTypeSql() string { return TimeLegacyDataType } +func (t *TimeDataType) Canonical() string { + return fmt.Sprintf("%s(%d)", TimeLegacyDataType, t.precision) +} + var TimeDataTypeSynonyms = []string{TimeLegacyDataType} func parseTimeDataTypeRaw(raw sanitizedDataTypeRaw) (*TimeDataType, error) { diff --git a/pkg/sdk/datatypes/timestamp_ltz.go b/pkg/sdk/datatypes/timestamp_ltz.go index f844ec537f..41961bfdb7 100644 --- a/pkg/sdk/datatypes/timestamp_ltz.go +++ b/pkg/sdk/datatypes/timestamp_ltz.go @@ -23,6 +23,10 @@ func (t *TimestampLtzDataType) ToLegacyDataTypeSql() string { return TimestampLtzLegacyDataType } +func (t *TimestampLtzDataType) Canonical() string { + return fmt.Sprintf("%s(%d)", TimestampLtzLegacyDataType, t.precision) +} + var TimestampLtzDataTypeSynonyms = []string{TimestampLtzLegacyDataType, "TIMESTAMPLTZ", "TIMESTAMP WITH LOCAL TIME ZONE"} func parseTimestampLtzDataTypeRaw(raw sanitizedDataTypeRaw) (*TimestampLtzDataType, error) { diff --git a/pkg/sdk/datatypes/timestamp_ntz.go b/pkg/sdk/datatypes/timestamp_ntz.go index 86aa5f0a0c..e11ed41b08 100644 --- a/pkg/sdk/datatypes/timestamp_ntz.go +++ b/pkg/sdk/datatypes/timestamp_ntz.go @@ -23,6 +23,10 @@ func (t *TimestampNtzDataType) ToLegacyDataTypeSql() string { return TimestampNtzLegacyDataType } +func (t *TimestampNtzDataType) Canonical() string { + return fmt.Sprintf("%s(%d)", TimestampNtzLegacyDataType, t.precision) +} + var TimestampNtzDataTypeSynonyms = []string{TimestampNtzLegacyDataType, "TIMESTAMPNTZ", "TIMESTAMP WITHOUT TIME ZONE", "DATETIME"} func parseTimestampNtzDataTypeRaw(raw sanitizedDataTypeRaw) (*TimestampNtzDataType, error) { diff --git a/pkg/sdk/datatypes/timestamp_tz.go b/pkg/sdk/datatypes/timestamp_tz.go index 44e6cafeb6..0c99944bf8 100644 --- a/pkg/sdk/datatypes/timestamp_tz.go +++ b/pkg/sdk/datatypes/timestamp_tz.go @@ -23,6 +23,10 @@ func (t *TimestampTzDataType) ToLegacyDataTypeSql() string { return TimestampTzLegacyDataType } +func (t *TimestampTzDataType) Canonical() string { + return fmt.Sprintf("%s(%d)", TimestampTzLegacyDataType, t.precision) +} + var TimestampTzDataTypeSynonyms = []string{TimestampTzLegacyDataType, "TIMESTAMPTZ", "TIMESTAMP WITH TIME ZONE"} func parseTimestampTzDataTypeRaw(raw sanitizedDataTypeRaw) (*TimestampTzDataType, error) { diff --git a/pkg/sdk/datatypes/variant.go b/pkg/sdk/datatypes/variant.go index b096084934..538ca2921d 100644 --- a/pkg/sdk/datatypes/variant.go +++ b/pkg/sdk/datatypes/variant.go @@ -14,6 +14,10 @@ func (t *VariantDataType) ToLegacyDataTypeSql() string { return VariantLegacyDataType } +func (t *VariantDataType) Canonical() string { + return VariantLegacyDataType +} + var VariantDataTypeSynonyms = []string{VariantLegacyDataType} func parseVariantDataTypeRaw(raw sanitizedDataTypeRaw) (*VariantDataType, error) { diff --git a/pkg/sdk/datatypes/vector.go b/pkg/sdk/datatypes/vector.go index a535ca2b58..1fbe420c9f 100644 --- a/pkg/sdk/datatypes/vector.go +++ b/pkg/sdk/datatypes/vector.go @@ -26,6 +26,10 @@ func (t *VectorDataType) ToLegacyDataTypeSql() string { return t.ToSql() } +func (t *VectorDataType) Canonical() string { + return fmt.Sprintf("%s(%s, %d)", t.underlyingType, t.innerType, t.dimension) +} + var ( VectorDataTypeSynonyms = []string{"VECTOR"} VectorAllowedInnerTypes = []string{"INT", "FLOAT"} diff --git a/pkg/sdk/sql_builder_test.go b/pkg/sdk/sql_builder_test.go index 62aaa239d9..5f46e364ba 100644 --- a/pkg/sdk/sql_builder_test.go +++ b/pkg/sdk/sql_builder_test.go @@ -475,8 +475,8 @@ func TestBuilder_sql(t *testing.T) { }) } -// TODO [this PR]: add optional alternatives to functions and procedures (arguments and return types) -// TODO [this PR]: integration tests for both options +// TODO [this PR]: add optional alternatives to procedures (arguments and return types) +// TODO [this PR]: integration tests for procedures // TODO [this PR]: integration test to check all data types in a new way + reading from snowflake? func TestBuilder_DataType(t *testing.T) { diff --git a/pkg/sdk/testint/functions_integration_test.go b/pkg/sdk/testint/functions_integration_test.go index a4ac7fa502..cbfbfe8514 100644 --- a/pkg/sdk/testint/functions_integration_test.go +++ b/pkg/sdk/testint/functions_integration_test.go @@ -498,7 +498,8 @@ func TestInt_FunctionsShowByID(t *testing.T) { require.Equal(t, *e, *es) }) - t.Run("function returns non detailed data types of arguments", func(t *testing.T) { + // TODO [next PR]: remove with old function removal for V1 + t.Run("function returns non detailed data types of arguments - old data types", func(t *testing.T) { // This test proves that every detailed data types (e.g. VARCHAR(20) and NUMBER(10, 0)) are generalized // on Snowflake side (to e.g. VARCHAR and NUMBER) and that sdk.ToDataType mapping function maps detailed types // correctly to their generalized counterparts (same as in Snowflake). @@ -553,4 +554,78 @@ func TestInt_FunctionsShowByID(t *testing.T) { require.NoError(t, err) require.Equal(t, dataTypes, function.ArgumentsOld) }) + + // This test shows behavior of detailed types (e.g. VARCHAR(20) and NUMBER(10, 0)) on Snowflake side. + // For SHOW, data type is generalized both for argument and return type (to e.g. VARCHAR and NUMBER). + // FOR DESCRIBE, data type is generalized for argument and works weirdly for the return type: type is generalized to the canonical one, but we get also the attributes. + for _, tc := range []string{ + "NUMBER(36, 5)", + "NUMBER(36)", + "NUMBER", + "DECIMAL", + "INTEGER", + "FLOAT", + "DOUBLE", + "VARCHAR", + "VARCHAR(20)", + "CHAR", + "CHAR(10)", + "TEXT", + "BINARY", + "BINARY(1000)", + "VARBINARY", + "BOOLEAN", + "DATE", + "DATETIME", + "TIME", + "TIMESTAMP_LTZ", + "TIMESTAMP_NTZ", + "TIMESTAMP_TZ", + "VARIANT", + "OBJECT", + "ARRAY", + "GEOGRAPHY", + "GEOMETRY", + "VECTOR(INT, 16)", + "VECTOR(FLOAT, 8)", + } { + tc := tc + t.Run(fmt.Sprintf("function returns non detailed data types of arguments for %s", tc), func(t *testing.T) { + id := testClientHelper().Ids.RandomSchemaObjectIdentifier() + argName := "A" + dataType, err := datatypes.ParseDataType(tc) + require.NoError(t, err) + args := []sdk.FunctionArgumentRequest{ + *sdk.NewFunctionArgumentRequest(argName, dataType), + } + + err = client.Functions.CreateForPython(ctx, sdk.NewCreateForPythonFunctionRequest( + id, + *sdk.NewFunctionReturnsRequest().WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(dataType)), + "3.8", + "add", + ). + WithArguments(args). + WithFunctionDefinition(fmt.Sprintf("def add(%[1]s): %[1]s", argName)), + ) + require.NoError(t, err) + + oldDataType := sdk.LegacyDataTypeFrom(dataType) + idWithArguments := sdk.NewSchemaObjectIdentifierWithArguments(id.DatabaseName(), id.SchemaName(), id.Name(), oldDataType) + + function, err := client.Functions.ShowByID(ctx, idWithArguments) + require.NoError(t, err) + assert.Equal(t, []sdk.DataType{oldDataType}, function.ArgumentsOld) + assert.Equal(t, fmt.Sprintf("%[1]s(%[2]s) RETURN %[2]s", id.Name(), oldDataType), function.ArgumentsRaw) + + details, err := client.Functions.Describe(ctx, idWithArguments) + require.NoError(t, err) + pairs := make(map[string]string) + for _, detail := range details { + pairs[detail.Property] = detail.Value + } + assert.Equal(t, fmt.Sprintf("(%s %s)", argName, oldDataType), pairs["signature"]) + assert.Equal(t, dataType.Canonical(), pairs["returns"]) + }) + } } From 1b2615e2aaf3ead078fd7cf5f0b07d822d93ab1e Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Wed, 4 Dec 2024 15:42:02 +0100 Subject: [PATCH 14/29] Test canonical method on all data types --- pkg/sdk/datatypes/data_types_test.go | 23 ++++++++++++++++++++++- pkg/sdk/datatypes/vector.go | 2 +- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/pkg/sdk/datatypes/data_types_test.go b/pkg/sdk/datatypes/data_types_test.go index 21525fded8..cfb3845ef1 100644 --- a/pkg/sdk/datatypes/data_types_test.go +++ b/pkg/sdk/datatypes/data_types_test.go @@ -2,6 +2,7 @@ package datatypes import ( "fmt" + "slices" "strings" "testing" @@ -91,7 +92,12 @@ func Test_ParseDataType_Number(t *testing.T) { assert.Equal(t, tc.expectedUnderlyingType, parsed.(*NumberDataType).underlyingType) assert.Equal(t, NumberLegacyDataType, parsed.ToLegacyDataTypeSql()) - assert.Equal(t, fmt.Sprintf("%s(%d, %d)", parsed.(*NumberDataType).underlyingType, parsed.(*NumberDataType).precision, parsed.(*NumberDataType).scale), parsed.ToSql()) + if slices.Contains(NumberDataTypeSubTypes, parsed.(*NumberDataType).underlyingType) { + assert.Equal(t, parsed.(*NumberDataType).underlyingType, parsed.ToSql()) + } else { + assert.Equal(t, fmt.Sprintf("%s(%d, %d)", parsed.(*NumberDataType).underlyingType, parsed.(*NumberDataType).precision, parsed.(*NumberDataType).scale), parsed.ToSql()) + } + assert.Equal(t, fmt.Sprintf("%s(%d,%d)", NumberLegacyDataType, parsed.(*NumberDataType).precision, parsed.(*NumberDataType).scale), parsed.Canonical()) }) } @@ -158,6 +164,7 @@ func Test_ParseDataType_Float(t *testing.T) { assert.Equal(t, FloatLegacyDataType, parsed.ToLegacyDataTypeSql()) assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + assert.Equal(t, FloatLegacyDataType, parsed.Canonical()) }) } @@ -267,6 +274,7 @@ func Test_ParseDataType_Text(t *testing.T) { assert.Equal(t, VarcharLegacyDataType, parsed.ToLegacyDataTypeSql()) assert.Equal(t, fmt.Sprintf("%s(%d)", parsed.(*TextDataType).underlyingType, parsed.(*TextDataType).length), parsed.ToSql()) + assert.Equal(t, fmt.Sprintf("%s(%d)", VarcharLegacyDataType, parsed.(*TextDataType).length), parsed.Canonical()) }) } @@ -338,6 +346,7 @@ func Test_ParseDataType_Binary(t *testing.T) { assert.Equal(t, BinaryLegacyDataType, parsed.ToLegacyDataTypeSql()) assert.Equal(t, fmt.Sprintf("%s(%d)", parsed.(*BinaryDataType).underlyingType, parsed.(*BinaryDataType).size), parsed.ToSql()) + assert.Equal(t, fmt.Sprintf("%s(%d)", BinaryLegacyDataType, parsed.(*BinaryDataType).size), parsed.Canonical()) }) } @@ -396,6 +405,7 @@ func Test_ParseDataType_Boolean(t *testing.T) { assert.Equal(t, BooleanLegacyDataType, parsed.ToLegacyDataTypeSql()) assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + assert.Equal(t, BooleanLegacyDataType, parsed.Canonical()) }) } @@ -452,6 +462,7 @@ func Test_ParseDataType_Date(t *testing.T) { assert.Equal(t, DateLegacyDataType, parsed.ToLegacyDataTypeSql()) assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + assert.Equal(t, DateLegacyDataType, parsed.Canonical()) }) } @@ -512,6 +523,7 @@ func Test_ParseDataType_Time(t *testing.T) { assert.Equal(t, TimeLegacyDataType, parsed.ToLegacyDataTypeSql()) assert.Equal(t, fmt.Sprintf("%s(%d)", tc.expectedUnderlyingType, tc.expectedPrecision), parsed.ToSql()) + assert.Equal(t, fmt.Sprintf("%s(%d)", TimeLegacyDataType, tc.expectedPrecision), parsed.Canonical()) }) } @@ -581,6 +593,7 @@ func Test_ParseDataType_TimestampLtz(t *testing.T) { assert.Equal(t, TimestampLtzLegacyDataType, parsed.ToLegacyDataTypeSql()) assert.Equal(t, fmt.Sprintf("%s(%d)", parsed.(*TimestampLtzDataType).underlyingType, parsed.(*TimestampLtzDataType).precision), parsed.ToSql()) + assert.Equal(t, fmt.Sprintf("%s(%d)", TimestampLtzLegacyDataType, parsed.(*TimestampLtzDataType).precision), parsed.Canonical()) }) } @@ -652,6 +665,7 @@ func Test_ParseDataType_TimestampNtz(t *testing.T) { assert.Equal(t, TimestampNtzLegacyDataType, parsed.ToLegacyDataTypeSql()) assert.Equal(t, fmt.Sprintf("%s(%d)", parsed.(*TimestampNtzDataType).underlyingType, parsed.(*TimestampNtzDataType).precision), parsed.ToSql()) + assert.Equal(t, fmt.Sprintf("%s(%d)", TimestampNtzLegacyDataType, parsed.(*TimestampNtzDataType).precision), parsed.Canonical()) }) } @@ -721,6 +735,7 @@ func Test_ParseDataType_TimestampTz(t *testing.T) { assert.Equal(t, TimestampTzLegacyDataType, parsed.ToLegacyDataTypeSql()) assert.Equal(t, fmt.Sprintf("%s(%d)", parsed.(*TimestampTzDataType).underlyingType, parsed.(*TimestampTzDataType).precision), parsed.ToSql()) + assert.Equal(t, fmt.Sprintf("%s(%d)", TimestampTzLegacyDataType, parsed.(*TimestampTzDataType).precision), parsed.Canonical()) }) } @@ -777,6 +792,7 @@ func Test_ParseDataType_Variant(t *testing.T) { assert.Equal(t, VariantLegacyDataType, parsed.ToLegacyDataTypeSql()) assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + assert.Equal(t, VariantLegacyDataType, parsed.Canonical()) }) } @@ -833,6 +849,7 @@ func Test_ParseDataType_Object(t *testing.T) { assert.Equal(t, ObjectLegacyDataType, parsed.ToLegacyDataTypeSql()) assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + assert.Equal(t, ObjectLegacyDataType, parsed.Canonical()) }) } @@ -889,6 +906,7 @@ func Test_ParseDataType_Array(t *testing.T) { assert.Equal(t, ArrayLegacyDataType, parsed.ToLegacyDataTypeSql()) assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + assert.Equal(t, ArrayLegacyDataType, parsed.Canonical()) }) } @@ -945,6 +963,7 @@ func Test_ParseDataType_Geography(t *testing.T) { assert.Equal(t, GeographyLegacyDataType, parsed.ToLegacyDataTypeSql()) assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + assert.Equal(t, GeographyLegacyDataType, parsed.Canonical()) }) } @@ -1001,6 +1020,7 @@ func Test_ParseDataType_Geometry(t *testing.T) { assert.Equal(t, GeometryLegacyDataType, parsed.ToLegacyDataTypeSql()) assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + assert.Equal(t, GeometryLegacyDataType, parsed.Canonical()) }) } @@ -1060,6 +1080,7 @@ func Test_ParseDataType_Vector(t *testing.T) { assert.Equal(t, fmt.Sprintf("%s(%s, %d)", parsed.(*VectorDataType).underlyingType, parsed.(*VectorDataType).innerType, parsed.(*VectorDataType).dimension), parsed.ToLegacyDataTypeSql()) assert.Equal(t, fmt.Sprintf("%s(%s, %d)", parsed.(*VectorDataType).underlyingType, parsed.(*VectorDataType).innerType, parsed.(*VectorDataType).dimension), parsed.ToSql()) + assert.Equal(t, fmt.Sprintf("%s(%s, %d)", parsed.(*VectorDataType).underlyingType, parsed.(*VectorDataType).innerType, parsed.(*VectorDataType).dimension), parsed.Canonical()) }) } diff --git a/pkg/sdk/datatypes/vector.go b/pkg/sdk/datatypes/vector.go index 1fbe420c9f..035249af64 100644 --- a/pkg/sdk/datatypes/vector.go +++ b/pkg/sdk/datatypes/vector.go @@ -27,7 +27,7 @@ func (t *VectorDataType) ToLegacyDataTypeSql() string { } func (t *VectorDataType) Canonical() string { - return fmt.Sprintf("%s(%s, %d)", t.underlyingType, t.innerType, t.dimension) + return t.ToSql() } var ( From 5dc629d9dad66634c7baaa8a79761c4498eca083 Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Wed, 4 Dec 2024 17:01:02 +0100 Subject: [PATCH 15/29] Regenerate procedures with renaming old data types --- pkg/sdk/procedures_def.go | 16 ++++----- pkg/sdk/procedures_dto_builders_gen.go | 20 +++++------ pkg/sdk/procedures_dto_gen.go | 20 +++++------ pkg/sdk/procedures_ext.go | 5 +++ pkg/sdk/procedures_gen.go | 24 ++++++------- pkg/sdk/procedures_impl_gen.go | 48 +++++++++++++------------- pkg/sdk/procedures_validations_gen.go | 11 ++++-- 7 files changed, 76 insertions(+), 68 deletions(-) create mode 100644 pkg/sdk/procedures_ext.go diff --git a/pkg/sdk/procedures_def.go b/pkg/sdk/procedures_def.go index 3b5eb69882..9870e20694 100644 --- a/pkg/sdk/procedures_def.go +++ b/pkg/sdk/procedures_def.go @@ -6,18 +6,18 @@ import g "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/poc/gen var procedureArgument = g.NewQueryStruct("ProcedureArgument"). Text("ArgName", g.KeywordOptions().NoQuotes().Required()). - PredefinedQueryStructField("ArgDataType", "DataType", g.KeywordOptions().NoQuotes().Required()). + PredefinedQueryStructField("ArgDataTypeOld", "DataType", g.KeywordOptions().NoQuotes().Required()). PredefinedQueryStructField("DefaultValue", "*string", g.ParameterOptions().NoEquals().SQL("DEFAULT")) var procedureColumn = g.NewQueryStruct("ProcedureColumn"). Text("ColumnName", g.KeywordOptions().NoQuotes().Required()). - PredefinedQueryStructField("ColumnDataType", "DataType", g.KeywordOptions().NoQuotes().Required()) + PredefinedQueryStructField("ColumnDataTypeOld", "DataType", g.KeywordOptions().NoQuotes().Required()) var procedureReturns = g.NewQueryStruct("ProcedureReturns"). OptionalQueryStructField( "ResultDataType", g.NewQueryStruct("ProcedureReturnsResultDataType"). - PredefinedQueryStructField("ResultDataType", "DataType", g.KeywordOptions().NoQuotes().Required()). + PredefinedQueryStructField("ResultDataTypeOld", "DataType", g.KeywordOptions().NoQuotes().Required()). OptionalSQL("NULL").OptionalSQL("NOT NULL"), g.KeywordOptions(), ). @@ -36,7 +36,7 @@ var procedureSQLReturns = g.NewQueryStruct("ProcedureSQLReturns"). OptionalQueryStructField( "ResultDataType", g.NewQueryStruct("ProcedureReturnsResultDataType"). - PredefinedQueryStructField("ResultDataType", "DataType", g.KeywordOptions().NoQuotes().Required()), + PredefinedQueryStructField("ResultDataTypeOld", "DataType", g.KeywordOptions().NoQuotes().Required()), g.KeywordOptions(), ). OptionalQueryStructField( @@ -126,7 +126,7 @@ var ProceduresDef = g.NewInterface( g.ListOptions().MustParentheses(), ). OptionalSQL("COPY GRANTS"). - PredefinedQueryStructField("ResultDataType", "DataType", g.ParameterOptions().NoEquals().SQL("RETURNS").Required()). + PredefinedQueryStructField("ResultDataTypeOld", "DataType", g.ParameterOptions().NoEquals().SQL("RETURNS").Required()). OptionalSQL("NOT NULL"). SQL("LANGUAGE JAVASCRIPT"). PredefinedQueryStructField("NullInputBehavior", "*NullInputBehavior", g.KeywordOptions()). @@ -299,7 +299,7 @@ var ProceduresDef = g.NewInterface( Field("IsAnsi", "bool"). Field("MinNumArguments", "int"). Field("MaxNumArguments", "int"). - Field("Arguments", "string"). + Field("ArgumentsRaw", "string"). Field("Description", "string"). Field("CatalogName", "string"). Field("IsTableFunction", "bool"). @@ -437,7 +437,7 @@ var ProceduresDef = g.NewInterface( procedureArgument, g.ListOptions().MustParentheses(), ). - PredefinedQueryStructField("ResultDataType", "DataType", g.ParameterOptions().NoEquals().SQL("RETURNS").Required()). + PredefinedQueryStructField("ResultDataTypeOld", "DataType", g.ParameterOptions().NoEquals().SQL("RETURNS").Required()). OptionalSQL("NOT NULL"). SQL("LANGUAGE JAVASCRIPT"). PredefinedQueryStructField("NullInputBehavior", "*NullInputBehavior", g.KeywordOptions()). @@ -452,7 +452,7 @@ var ProceduresDef = g.NewInterface( PredefinedQueryStructField("CallArguments", "[]string", g.KeywordOptions().MustParentheses()). PredefinedQueryStructField("ScriptingVariable", "*string", g.ParameterOptions().NoEquals().NoQuotes().SQL("INTO")). WithValidation(g.ValidateValueSet, "ProcedureDefinition"). - WithValidation(g.ValidateValueSet, "ResultDataType"). + WithValidation(g.ValidateValueSet, "ResultDataTypeOld"). WithValidation(g.ValidIdentifier, "ProcedureName"). WithValidation(g.ValidIdentifier, "Name"), ).CustomOperation( diff --git a/pkg/sdk/procedures_dto_builders_gen.go b/pkg/sdk/procedures_dto_builders_gen.go index c88c8bd9ac..8170230f4a 100644 --- a/pkg/sdk/procedures_dto_builders_gen.go +++ b/pkg/sdk/procedures_dto_builders_gen.go @@ -82,11 +82,11 @@ func (s *CreateForJavaProcedureRequest) WithProcedureDefinition(ProcedureDefinit func NewProcedureArgumentRequest( ArgName string, - ArgDataType DataType, + ArgDataTypeOld DataType, ) *ProcedureArgumentRequest { s := ProcedureArgumentRequest{} s.ArgName = ArgName - s.ArgDataType = ArgDataType + s.ArgDataTypeOld = ArgDataTypeOld return &s } @@ -110,10 +110,10 @@ func (s *ProcedureReturnsRequest) WithTable(Table ProcedureReturnsTableRequest) } func NewProcedureReturnsResultDataTypeRequest( - ResultDataType DataType, + ResultDataTypeOld DataType, ) *ProcedureReturnsResultDataTypeRequest { s := ProcedureReturnsResultDataTypeRequest{} - s.ResultDataType = ResultDataType + s.ResultDataTypeOld = ResultDataTypeOld return &s } @@ -138,11 +138,11 @@ func (s *ProcedureReturnsTableRequest) WithColumns(Columns []ProcedureColumnRequ func NewProcedureColumnRequest( ColumnName string, - ColumnDataType DataType, + ColumnDataTypeOld DataType, ) *ProcedureColumnRequest { s := ProcedureColumnRequest{} s.ColumnName = ColumnName - s.ColumnDataType = ColumnDataType + s.ColumnDataTypeOld = ColumnDataTypeOld return &s } @@ -164,12 +164,12 @@ func NewProcedureImportRequest( func NewCreateForJavaScriptProcedureRequest( name SchemaObjectIdentifier, - ResultDataType DataType, + ResultDataTypeOld DataType, ProcedureDefinition string, ) *CreateForJavaScriptProcedureRequest { s := CreateForJavaScriptProcedureRequest{} s.name = name - s.ResultDataType = ResultDataType + s.ResultDataTypeOld = ResultDataTypeOld s.ProcedureDefinition = ProcedureDefinition return &s } @@ -646,13 +646,13 @@ func (s *CreateAndCallForScalaProcedureRequest) WithScriptingVariable(ScriptingV func NewCreateAndCallForJavaScriptProcedureRequest( Name AccountObjectIdentifier, - ResultDataType DataType, + ResultDataTypeOld DataType, ProcedureDefinition string, ProcedureName AccountObjectIdentifier, ) *CreateAndCallForJavaScriptProcedureRequest { s := CreateAndCallForJavaScriptProcedureRequest{} s.Name = Name - s.ResultDataType = ResultDataType + s.ResultDataTypeOld = ResultDataTypeOld s.ProcedureDefinition = ProcedureDefinition s.ProcedureName = ProcedureName return &s diff --git a/pkg/sdk/procedures_dto_gen.go b/pkg/sdk/procedures_dto_gen.go index 8ad24b86e6..339fbff4b7 100644 --- a/pkg/sdk/procedures_dto_gen.go +++ b/pkg/sdk/procedures_dto_gen.go @@ -41,9 +41,9 @@ type CreateForJavaProcedureRequest struct { } type ProcedureArgumentRequest struct { - ArgName string // required - ArgDataType DataType // required - DefaultValue *string + ArgName string // required + ArgDataTypeOld DataType // required + DefaultValue *string } type ProcedureReturnsRequest struct { @@ -52,9 +52,9 @@ type ProcedureReturnsRequest struct { } type ProcedureReturnsResultDataTypeRequest struct { - ResultDataType DataType // required - Null *bool - NotNull *bool + ResultDataTypeOld DataType // required + Null *bool + NotNull *bool } type ProcedureReturnsTableRequest struct { @@ -62,8 +62,8 @@ type ProcedureReturnsTableRequest struct { } type ProcedureColumnRequest struct { - ColumnName string // required - ColumnDataType DataType // required + ColumnName string // required + ColumnDataTypeOld DataType // required } type ProcedurePackageRequest struct { @@ -80,7 +80,7 @@ type CreateForJavaScriptProcedureRequest struct { name SchemaObjectIdentifier // required Arguments []ProcedureArgumentRequest CopyGrants *bool - ResultDataType DataType // required + ResultDataTypeOld DataType // required NotNull *bool NullInputBehavior *NullInputBehavior Comment *string @@ -218,7 +218,7 @@ type CreateAndCallForScalaProcedureRequest struct { type CreateAndCallForJavaScriptProcedureRequest struct { Name AccountObjectIdentifier // required Arguments []ProcedureArgumentRequest - ResultDataType DataType // required + ResultDataTypeOld DataType // required NotNull *bool NullInputBehavior *NullInputBehavior ProcedureDefinition string // required diff --git a/pkg/sdk/procedures_ext.go b/pkg/sdk/procedures_ext.go new file mode 100644 index 0000000000..055e422501 --- /dev/null +++ b/pkg/sdk/procedures_ext.go @@ -0,0 +1,5 @@ +package sdk + +func (v *Procedure) ID() SchemaObjectIdentifierWithArguments { + return NewSchemaObjectIdentifierWithArguments(v.CatalogName, v.SchemaName, v.Name, v.Arguments...) +} diff --git a/pkg/sdk/procedures_gen.go b/pkg/sdk/procedures_gen.go index e265558f70..fbe2bc8861 100644 --- a/pkg/sdk/procedures_gen.go +++ b/pkg/sdk/procedures_gen.go @@ -49,9 +49,9 @@ type CreateForJavaProcedureOptions struct { } type ProcedureArgument struct { - ArgName string `ddl:"keyword,no_quotes"` - ArgDataType DataType `ddl:"keyword,no_quotes"` - DefaultValue *string `ddl:"parameter,no_equals" sql:"DEFAULT"` + ArgName string `ddl:"keyword,no_quotes"` + ArgDataTypeOld DataType `ddl:"keyword,no_quotes"` + DefaultValue *string `ddl:"parameter,no_equals" sql:"DEFAULT"` } type ProcedureReturns struct { @@ -60,9 +60,9 @@ type ProcedureReturns struct { } type ProcedureReturnsResultDataType struct { - ResultDataType DataType `ddl:"keyword,no_quotes"` - Null *bool `ddl:"keyword" sql:"NULL"` - NotNull *bool `ddl:"keyword" sql:"NOT NULL"` + ResultDataTypeOld DataType `ddl:"keyword,no_quotes"` + Null *bool `ddl:"keyword" sql:"NULL"` + NotNull *bool `ddl:"keyword" sql:"NOT NULL"` } type ProcedureReturnsTable struct { @@ -70,8 +70,8 @@ type ProcedureReturnsTable struct { } type ProcedureColumn struct { - ColumnName string `ddl:"keyword,no_quotes"` - ColumnDataType DataType `ddl:"keyword,no_quotes"` + ColumnName string `ddl:"keyword,no_quotes"` + ColumnDataTypeOld DataType `ddl:"keyword,no_quotes"` } type ProcedurePackage struct { @@ -91,7 +91,7 @@ type CreateForJavaScriptProcedureOptions struct { name SchemaObjectIdentifier `ddl:"identifier"` Arguments []ProcedureArgument `ddl:"list,must_parentheses"` CopyGrants *bool `ddl:"keyword" sql:"COPY GRANTS"` - ResultDataType DataType `ddl:"parameter,no_equals" sql:"RETURNS"` + ResultDataTypeOld DataType `ddl:"parameter,no_equals" sql:"RETURNS"` NotNull *bool `ddl:"keyword" sql:"NOT NULL"` languageJavascript bool `ddl:"static" sql:"LANGUAGE JAVASCRIPT"` NullInputBehavior *NullInputBehavior `ddl:"keyword"` @@ -235,10 +235,6 @@ type Procedure struct { IsSecure bool } -func (v *Procedure) ID() SchemaObjectIdentifierWithArguments { - return NewSchemaObjectIdentifierWithArguments(v.CatalogName, v.SchemaName, v.Name, v.Arguments...) -} - // DescribeProcedureOptions is based on https://docs.snowflake.com/en/sql-reference/sql/desc-procedure. type DescribeProcedureOptions struct { describe bool `ddl:"static" sql:"DESCRIBE"` @@ -318,7 +314,7 @@ type CreateAndCallForJavaScriptProcedureOptions struct { Name AccountObjectIdentifier `ddl:"identifier"` asProcedure bool `ddl:"static" sql:"AS PROCEDURE"` Arguments []ProcedureArgument `ddl:"list,must_parentheses"` - ResultDataType DataType `ddl:"parameter,no_equals" sql:"RETURNS"` + ResultDataTypeOld DataType `ddl:"parameter,no_equals" sql:"RETURNS"` NotNull *bool `ddl:"keyword" sql:"NOT NULL"` languageJavascript bool `ddl:"static" sql:"LANGUAGE JAVASCRIPT"` NullInputBehavior *NullInputBehavior `ddl:"keyword"` diff --git a/pkg/sdk/procedures_impl_gen.go b/pkg/sdk/procedures_impl_gen.go index 80cb096373..ebb24b874c 100644 --- a/pkg/sdk/procedures_impl_gen.go +++ b/pkg/sdk/procedures_impl_gen.go @@ -137,9 +137,9 @@ func (r *CreateForJavaProcedureRequest) toOpts() *CreateForJavaProcedureOptions opts.Returns = ProcedureReturns{} if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &ProcedureReturnsResultDataType{ - ResultDataType: r.Returns.ResultDataType.ResultDataType, - Null: r.Returns.ResultDataType.Null, - NotNull: r.Returns.ResultDataType.NotNull, + ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + Null: r.Returns.ResultDataType.Null, + NotNull: r.Returns.ResultDataType.NotNull, } } if r.Returns.Table != nil { @@ -176,7 +176,7 @@ func (r *CreateForJavaScriptProcedureRequest) toOpts() *CreateForJavaScriptProce name: r.name, CopyGrants: r.CopyGrants, - ResultDataType: r.ResultDataType, + ResultDataTypeOld: r.ResultDataTypeOld, NotNull: r.NotNull, NullInputBehavior: r.NullInputBehavior, Comment: r.Comment, @@ -221,9 +221,9 @@ func (r *CreateForPythonProcedureRequest) toOpts() *CreateForPythonProcedureOpti opts.Returns = ProcedureReturns{} if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &ProcedureReturnsResultDataType{ - ResultDataType: r.Returns.ResultDataType.ResultDataType, - Null: r.Returns.ResultDataType.Null, - NotNull: r.Returns.ResultDataType.NotNull, + ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + Null: r.Returns.ResultDataType.Null, + NotNull: r.Returns.ResultDataType.NotNull, } } if r.Returns.Table != nil { @@ -280,9 +280,9 @@ func (r *CreateForScalaProcedureRequest) toOpts() *CreateForScalaProcedureOption opts.Returns = ProcedureReturns{} if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &ProcedureReturnsResultDataType{ - ResultDataType: r.Returns.ResultDataType.ResultDataType, - Null: r.Returns.ResultDataType.Null, - NotNull: r.Returns.ResultDataType.NotNull, + ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + Null: r.Returns.ResultDataType.Null, + NotNull: r.Returns.ResultDataType.NotNull, } } if r.Returns.Table != nil { @@ -337,7 +337,7 @@ func (r *CreateForSQLProcedureRequest) toOpts() *CreateForSQLProcedureOptions { } if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &ProcedureReturnsResultDataType{ - ResultDataType: r.Returns.ResultDataType.ResultDataType, + ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, } } if r.Returns.Table != nil { @@ -466,9 +466,9 @@ func (r *CreateAndCallForJavaProcedureRequest) toOpts() *CreateAndCallForJavaPro opts.Returns = ProcedureReturns{} if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &ProcedureReturnsResultDataType{ - ResultDataType: r.Returns.ResultDataType.ResultDataType, - Null: r.Returns.ResultDataType.Null, - NotNull: r.Returns.ResultDataType.NotNull, + ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + Null: r.Returns.ResultDataType.Null, + NotNull: r.Returns.ResultDataType.NotNull, } } if r.Returns.Table != nil { @@ -529,9 +529,9 @@ func (r *CreateAndCallForScalaProcedureRequest) toOpts() *CreateAndCallForScalaP opts.Returns = ProcedureReturns{} if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &ProcedureReturnsResultDataType{ - ResultDataType: r.Returns.ResultDataType.ResultDataType, - Null: r.Returns.ResultDataType.Null, - NotNull: r.Returns.ResultDataType.NotNull, + ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + Null: r.Returns.ResultDataType.Null, + NotNull: r.Returns.ResultDataType.NotNull, } } if r.Returns.Table != nil { @@ -576,7 +576,7 @@ func (r *CreateAndCallForJavaScriptProcedureRequest) toOpts() *CreateAndCallForJ opts := &CreateAndCallForJavaScriptProcedureOptions{ Name: r.Name, - ResultDataType: r.ResultDataType, + ResultDataTypeOld: r.ResultDataTypeOld, NotNull: r.NotNull, NullInputBehavior: r.NullInputBehavior, ProcedureDefinition: r.ProcedureDefinition, @@ -630,9 +630,9 @@ func (r *CreateAndCallForPythonProcedureRequest) toOpts() *CreateAndCallForPytho opts.Returns = ProcedureReturns{} if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &ProcedureReturnsResultDataType{ - ResultDataType: r.Returns.ResultDataType.ResultDataType, - Null: r.Returns.ResultDataType.Null, - NotNull: r.Returns.ResultDataType.NotNull, + ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + Null: r.Returns.ResultDataType.Null, + NotNull: r.Returns.ResultDataType.NotNull, } } if r.Returns.Table != nil { @@ -694,9 +694,9 @@ func (r *CreateAndCallForSQLProcedureRequest) toOpts() *CreateAndCallForSQLProce opts.Returns = ProcedureReturns{} if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &ProcedureReturnsResultDataType{ - ResultDataType: r.Returns.ResultDataType.ResultDataType, - Null: r.Returns.ResultDataType.Null, - NotNull: r.Returns.ResultDataType.NotNull, + ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + Null: r.Returns.ResultDataType.Null, + NotNull: r.Returns.ResultDataType.NotNull, } } if r.Returns.Table != nil { diff --git a/pkg/sdk/procedures_validations_gen.go b/pkg/sdk/procedures_validations_gen.go index c397a41500..a153e05b36 100644 --- a/pkg/sdk/procedures_validations_gen.go +++ b/pkg/sdk/procedures_validations_gen.go @@ -40,6 +40,7 @@ func (opts *CreateForJavaProcedureOptions) validate() error { errs = append(errs, errExactlyOneOf("CreateForJavaProcedureOptions.Returns", "ResultDataType", "Table")) } } + // added manually if opts.ProcedureDefinition == nil && opts.TargetPath != nil { errs = append(errs, NewError("TARGET_PATH must be nil when AS is nil")) } @@ -107,6 +108,7 @@ func (opts *CreateForScalaProcedureOptions) validate() error { errs = append(errs, errExactlyOneOf("CreateForScalaProcedureOptions.Returns", "ResultDataType", "Table")) } } + // added manually if opts.ProcedureDefinition == nil && opts.TargetPath != nil { errs = append(errs, NewError("TARGET_PATH must be nil when AS is nil")) } @@ -205,6 +207,7 @@ func (opts *CreateAndCallForJavaProcedureOptions) validate() error { errs = append(errs, errNotSet("CreateAndCallForJavaProcedureOptions", "Handler")) } if !ValidObjectIdentifier(opts.ProcedureName) { + // altered manually errs = append(errs, errInvalidIdentifier("CreateAndCallForJavaProcedureOptions", "ProcedureName")) } if !ValidObjectIdentifier(opts.Name) { @@ -233,6 +236,7 @@ func (opts *CreateAndCallForScalaProcedureOptions) validate() error { errs = append(errs, errNotSet("CreateAndCallForScalaProcedureOptions", "Handler")) } if !ValidObjectIdentifier(opts.ProcedureName) { + // altered manually errs = append(errs, errInvalidIdentifier("CreateAndCallForScalaProcedureOptions", "ProcedureName")) } if !ValidObjectIdentifier(opts.Name) { @@ -254,10 +258,11 @@ func (opts *CreateAndCallForJavaScriptProcedureOptions) validate() error { if !valueSet(opts.ProcedureDefinition) { errs = append(errs, errNotSet("CreateAndCallForJavaScriptProcedureOptions", "ProcedureDefinition")) } - if !valueSet(opts.ResultDataType) { - errs = append(errs, errNotSet("CreateAndCallForJavaScriptProcedureOptions", "ResultDataType")) + if !valueSet(opts.ResultDataTypeOld) { + errs = append(errs, errNotSet("CreateAndCallForJavaScriptProcedureOptions", "ResultDataTypeOld")) } if !ValidObjectIdentifier(opts.ProcedureName) { + // altered manually errs = append(errs, errInvalidIdentifier("CreateAndCallForJavaScriptProcedureOptions", "ProcedureName")) } if !ValidObjectIdentifier(opts.Name) { @@ -281,6 +286,7 @@ func (opts *CreateAndCallForPythonProcedureOptions) validate() error { errs = append(errs, errNotSet("CreateAndCallForPythonProcedureOptions", "Handler")) } if !ValidObjectIdentifier(opts.ProcedureName) { + // altered manually errs = append(errs, errInvalidIdentifier("CreateAndCallForPythonProcedureOptions", "ProcedureName")) } if !ValidObjectIdentifier(opts.Name) { @@ -303,6 +309,7 @@ func (opts *CreateAndCallForSQLProcedureOptions) validate() error { errs = append(errs, errNotSet("CreateAndCallForSQLProcedureOptions", "ProcedureDefinition")) } if !ValidObjectIdentifier(opts.ProcedureName) { + // altered manually errs = append(errs, errInvalidIdentifier("CreateAndCallForSQLProcedureOptions", "ProcedureName")) } if !ValidObjectIdentifier(opts.Name) { From 27b75b24bcd977863217aab1cb2405f9ccc4ba46 Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Wed, 4 Dec 2024 17:13:06 +0100 Subject: [PATCH 16/29] Adjust tests --- pkg/resources/procedure.go | 54 +++++------ pkg/sdk/procedures_gen_test.go | 170 +++++++++++++++++---------------- 2 files changed, 117 insertions(+), 107 deletions(-) diff --git a/pkg/resources/procedure.go b/pkg/resources/procedure.go index aa8b557250..c16f19ceef 100644 --- a/pkg/resources/procedure.go +++ b/pkg/resources/procedure.go @@ -243,7 +243,7 @@ func CreateContextProcedure(ctx context.Context, d *schema.ResourceData, meta in func createJavaProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client name := d.Get("name").(string) - schema := d.Get("schema").(string) + sc := d.Get("schema").(string) database := d.Get("database").(string) args, diags := getProcedureArguments(d) if diags != nil { @@ -251,9 +251,9 @@ func createJavaProcedure(ctx context.Context, d *schema.ResourceData, meta inter } argDataTypes := make([]sdk.DataType, len(args)) for i, arg := range args { - argDataTypes[i] = arg.ArgDataType + argDataTypes[i] = arg.ArgDataTypeOld } - id := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argDataTypes...) + id := sdk.NewSchemaObjectIdentifierWithArguments(database, sc, name, argDataTypes...) returns, diags := parseProcedureReturnsRequest(d.Get("return_type").(string)) if diags != nil { @@ -261,7 +261,7 @@ func createJavaProcedure(ctx context.Context, d *schema.ResourceData, meta inter } procedureDefinition := d.Get("statement").(string) runtimeVersion := d.Get("runtime_version").(string) - packages := []sdk.ProcedurePackageRequest{} + var packages []sdk.ProcedurePackageRequest for _, item := range d.Get("packages").([]interface{}) { packages = append(packages, *sdk.NewProcedurePackageRequest(item.(string))) } @@ -287,7 +287,7 @@ func createJavaProcedure(ctx context.Context, d *schema.ResourceData, meta inter req.WithSecure(v.(bool)) } if _, ok := d.GetOk("imports"); ok { - imports := []sdk.ProcedureImportRequest{} + var imports []sdk.ProcedureImportRequest for _, item := range d.Get("imports").([]interface{}) { imports = append(imports, *sdk.NewProcedureImportRequest(item.(string))) } @@ -304,7 +304,7 @@ func createJavaProcedure(ctx context.Context, d *schema.ResourceData, meta inter func createJavaScriptProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client name := d.Get("name").(string) - schema := d.Get("schema").(string) + sc := d.Get("schema").(string) database := d.Get("database").(string) args, diags := getProcedureArguments(d) if diags != nil { @@ -312,9 +312,9 @@ func createJavaScriptProcedure(ctx context.Context, d *schema.ResourceData, meta } argDataTypes := make([]sdk.DataType, len(args)) for i, arg := range args { - argDataTypes[i] = arg.ArgDataType + argDataTypes[i] = arg.ArgDataTypeOld } - id := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argDataTypes...) + id := sdk.NewSchemaObjectIdentifierWithArguments(database, sc, name, argDataTypes...) returnType := d.Get("return_type").(string) returnDataType, diags := convertProcedureDataType(returnType) @@ -355,7 +355,7 @@ func createJavaScriptProcedure(ctx context.Context, d *schema.ResourceData, meta func createScalaProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client name := d.Get("name").(string) - schema := d.Get("schema").(string) + sc := d.Get("schema").(string) database := d.Get("database").(string) args, diags := getProcedureArguments(d) if diags != nil { @@ -363,9 +363,9 @@ func createScalaProcedure(ctx context.Context, d *schema.ResourceData, meta inte } argDataTypes := make([]sdk.DataType, len(args)) for i, arg := range args { - argDataTypes[i] = arg.ArgDataType + argDataTypes[i] = arg.ArgDataTypeOld } - id := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argDataTypes...) + id := sdk.NewSchemaObjectIdentifierWithArguments(database, sc, name, argDataTypes...) returns, diags := parseProcedureReturnsRequest(d.Get("return_type").(string)) if diags != nil { @@ -373,7 +373,7 @@ func createScalaProcedure(ctx context.Context, d *schema.ResourceData, meta inte } procedureDefinition := d.Get("statement").(string) runtimeVersion := d.Get("runtime_version").(string) - packages := []sdk.ProcedurePackageRequest{} + var packages []sdk.ProcedurePackageRequest for _, item := range d.Get("packages").([]interface{}) { packages = append(packages, *sdk.NewProcedurePackageRequest(item.(string))) } @@ -399,7 +399,7 @@ func createScalaProcedure(ctx context.Context, d *schema.ResourceData, meta inte req.WithSecure(v.(bool)) } if _, ok := d.GetOk("imports"); ok { - imports := []sdk.ProcedureImportRequest{} + var imports []sdk.ProcedureImportRequest for _, item := range d.Get("imports").([]interface{}) { imports = append(imports, *sdk.NewProcedureImportRequest(item.(string))) } @@ -416,7 +416,7 @@ func createScalaProcedure(ctx context.Context, d *schema.ResourceData, meta inte func createSQLProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client name := d.Get("name").(string) - schema := d.Get("schema").(string) + sc := d.Get("schema").(string) database := d.Get("database").(string) args, diags := getProcedureArguments(d) if diags != nil { @@ -424,9 +424,9 @@ func createSQLProcedure(ctx context.Context, d *schema.ResourceData, meta interf } argDataTypes := make([]sdk.DataType, len(args)) for i, arg := range args { - argDataTypes[i] = arg.ArgDataType + argDataTypes[i] = arg.ArgDataTypeOld } - id := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argDataTypes...) + id := sdk.NewSchemaObjectIdentifierWithArguments(database, sc, name, argDataTypes...) returns, diags := parseProcedureSQLReturnsRequest(d.Get("return_type").(string)) if diags != nil { @@ -466,7 +466,7 @@ func createSQLProcedure(ctx context.Context, d *schema.ResourceData, meta interf func createPythonProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client name := d.Get("name").(string) - schema := d.Get("schema").(string) + sc := d.Get("schema").(string) database := d.Get("database").(string) args, diags := getProcedureArguments(d) if diags != nil { @@ -474,9 +474,9 @@ func createPythonProcedure(ctx context.Context, d *schema.ResourceData, meta int } argDataTypes := make([]sdk.DataType, len(args)) for i, arg := range args { - argDataTypes[i] = arg.ArgDataType + argDataTypes[i] = arg.ArgDataTypeOld } - id := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argDataTypes...) + id := sdk.NewSchemaObjectIdentifierWithArguments(database, sc, name, argDataTypes...) returns, diags := parseProcedureReturnsRequest(d.Get("return_type").(string)) if diags != nil { @@ -484,7 +484,7 @@ func createPythonProcedure(ctx context.Context, d *schema.ResourceData, meta int } procedureDefinition := d.Get("statement").(string) runtimeVersion := d.Get("runtime_version").(string) - packages := []sdk.ProcedurePackageRequest{} + var packages []sdk.ProcedurePackageRequest for _, item := range d.Get("packages").([]interface{}) { packages = append(packages, *sdk.NewProcedurePackageRequest(item.(string))) } @@ -518,7 +518,7 @@ func createPythonProcedure(ctx context.Context, d *schema.ResourceData, meta int req.WithSecure(v.(bool)) } if _, ok := d.GetOk("imports"); ok { - imports := []sdk.ProcedureImportRequest{} + var imports []sdk.ProcedureImportRequest for _, item := range d.Get("imports").([]interface{}) { imports = append(imports, *sdk.NewProcedureImportRequest(item.(string))) } @@ -577,7 +577,7 @@ func ReadContextProcedure(ctx context.Context, d *schema.ResourceData, meta inte if args != "" { // Do nothing for functions without arguments argPairs := strings.Split(args, ", ") - args := []interface{}{} + var args []any for _, argPair := range argPairs { argItem := strings.Split(argPair, " ") @@ -735,7 +735,7 @@ func getProcedureArguments(d *schema.ResourceData) ([]sdk.ProcedureArgumentReque if diags != nil { return nil, diags } - args = append(args, sdk.ProcedureArgumentRequest{ArgName: argName, ArgDataType: sdk.LegacyDataTypeFrom(argDataType)}) + args = append(args, sdk.ProcedureArgumentRequest{ArgName: argName, ArgDataTypeOld: sdk.LegacyDataTypeFrom(argDataType)}) } } return args, nil @@ -760,8 +760,8 @@ func convertProcedureColumns(s string) ([]sdk.ProcedureColumn, diag.Diagnostics) return nil, diag.FromErr(err) } columns = append(columns, sdk.ProcedureColumn{ - ColumnName: match[1], - ColumnDataType: sdk.LegacyDataTypeFrom(dataType), + ColumnName: match[1], + ColumnDataTypeOld: sdk.LegacyDataTypeFrom(dataType), }) } } @@ -777,7 +777,7 @@ func parseProcedureReturnsRequest(s string) (*sdk.ProcedureReturnsRequest, diag. } var cr []sdk.ProcedureColumnRequest for _, item := range columns { - cr = append(cr, *sdk.NewProcedureColumnRequest(item.ColumnName, item.ColumnDataType)) + cr = append(cr, *sdk.NewProcedureColumnRequest(item.ColumnName, item.ColumnDataTypeOld)) } returns.WithTable(*sdk.NewProcedureReturnsTableRequest().WithColumns(cr)) } else { @@ -799,7 +799,7 @@ func parseProcedureSQLReturnsRequest(s string) (*sdk.ProcedureSQLReturnsRequest, } var cr []sdk.ProcedureColumnRequest for _, item := range columns { - cr = append(cr, *sdk.NewProcedureColumnRequest(item.ColumnName, item.ColumnDataType)) + cr = append(cr, *sdk.NewProcedureColumnRequest(item.ColumnName, item.ColumnDataTypeOld)) } returns.WithTable(*sdk.NewProcedureReturnsTableRequest().WithColumns(cr)) } else { diff --git a/pkg/sdk/procedures_gen_test.go b/pkg/sdk/procedures_gen_test.go index 7717308d51..9ae7ce3915 100644 --- a/pkg/sdk/procedures_gen_test.go +++ b/pkg/sdk/procedures_gen_test.go @@ -45,26 +45,27 @@ func TestProcedures_CreateForJava(t *testing.T) { opts := defaultOpts() opts.Returns = ProcedureReturns{ ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: DataTypeVARCHAR, + ResultDataTypeOld: DataTypeVARCHAR, }, } assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForJavaProcedureOptions", "Handler")) assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForJavaProcedureOptions", "RuntimeVersion")) }) - t.Run("all options", func(t *testing.T) { + // TODO [next PR]: remove with old procedure removal for V1 + t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.OrReplace = Bool(true) opts.Secure = Bool(true) opts.Arguments = []ProcedureArgument{ { - ArgName: "id", - ArgDataType: DataTypeNumber, + ArgName: "id", + ArgDataTypeOld: DataTypeNumber, }, { - ArgName: "name", - ArgDataType: DataTypeVARCHAR, - DefaultValue: String("'test'"), + ArgName: "name", + ArgDataTypeOld: DataTypeVARCHAR, + DefaultValue: String("'test'"), }, } opts.CopyGrants = Bool(true) @@ -72,8 +73,8 @@ func TestProcedures_CreateForJava(t *testing.T) { Table: &ProcedureReturnsTable{ Columns: []ProcedureColumn{ { - ColumnName: "country_code", - ColumnDataType: DataTypeVARCHAR, + ColumnName: "country_code", + ColumnDataTypeOld: DataTypeVARCHAR, }, }, }, @@ -137,19 +138,20 @@ func TestProcedures_CreateForJavaScript(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForJavaScriptProcedureOptions", "ProcedureDefinition")) }) - t.Run("all options", func(t *testing.T) { + // TODO [next PR]: remove with old procedure removal for V1 + t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.OrReplace = Bool(true) opts.Secure = Bool(true) opts.Arguments = []ProcedureArgument{ { - ArgName: "d", - ArgDataType: "DOUBLE", - DefaultValue: String("1.0"), + ArgName: "d", + ArgDataTypeOld: "DOUBLE", + DefaultValue: String("1.0"), }, } opts.CopyGrants = Bool(true) - opts.ResultDataType = "DOUBLE" + opts.ResultDataTypeOld = "DOUBLE" opts.NotNull = Bool(true) opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorStrict) opts.Comment = String("test comment") @@ -189,29 +191,30 @@ func TestProcedures_CreateForPython(t *testing.T) { opts := defaultOpts() opts.Returns = ProcedureReturns{ ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: DataTypeVARCHAR, + ResultDataTypeOld: DataTypeVARCHAR, }, } assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForPythonProcedureOptions", "Handler")) assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForPythonProcedureOptions", "RuntimeVersion")) }) - t.Run("all options", func(t *testing.T) { + // TODO [next PR]: remove with old procedure removal for V1 + t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.OrReplace = Bool(true) opts.Secure = Bool(true) opts.Arguments = []ProcedureArgument{ { - ArgName: "i", - ArgDataType: "int", - DefaultValue: String("1"), + ArgName: "i", + ArgDataTypeOld: "int", + DefaultValue: String("1"), }, } opts.CopyGrants = Bool(true) opts.Returns = ProcedureReturns{ ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: "VARIANT", - Null: Bool(true), + ResultDataTypeOld: "VARIANT", + Null: Bool(true), }, } opts.RuntimeVersion = "3.8" @@ -294,29 +297,30 @@ func TestProcedures_CreateForScala(t *testing.T) { opts := defaultOpts() opts.Returns = ProcedureReturns{ ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: DataTypeVARCHAR, + ResultDataTypeOld: DataTypeVARCHAR, }, } assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForScalaProcedureOptions", "Handler")) assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForScalaProcedureOptions", "RuntimeVersion")) }) - t.Run("all options", func(t *testing.T) { + // TODO [next PR]: remove with old procedure removal for V1 + t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.OrReplace = Bool(true) opts.Secure = Bool(true) opts.Arguments = []ProcedureArgument{ { - ArgName: "x", - ArgDataType: "VARCHAR", - DefaultValue: String("'test'"), + ArgName: "x", + ArgDataTypeOld: "VARCHAR", + DefaultValue: String("'test'"), }, } opts.CopyGrants = Bool(true) opts.Returns = ProcedureReturns{ ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: "VARCHAR", - NotNull: Bool(true), + ResultDataTypeOld: "VARCHAR", + NotNull: Bool(true), }, } opts.RuntimeVersion = "2.0" @@ -369,28 +373,29 @@ func TestProcedures_CreateForSQL(t *testing.T) { opts := defaultOpts() opts.Returns = ProcedureSQLReturns{ ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: DataTypeFloat, + ResultDataTypeOld: DataTypeFloat, }, } opts.ProcedureDefinition = "3.141592654::FLOAT" assertOptsValidAndSQLEquals(t, opts, `CREATE PROCEDURE %s () RETURNS FLOAT LANGUAGE SQL AS '3.141592654::FLOAT'`, id.FullyQualifiedName()) }) - t.Run("all options", func(t *testing.T) { + // TODO [next PR]: remove with old procedure removal for V1 + t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.OrReplace = Bool(true) opts.Secure = Bool(true) opts.Arguments = []ProcedureArgument{ { - ArgName: "message", - ArgDataType: "VARCHAR", - DefaultValue: String("'test'"), + ArgName: "message", + ArgDataTypeOld: "VARCHAR", + DefaultValue: String("'test'"), }, } opts.CopyGrants = Bool(true) opts.Returns = ProcedureSQLReturns{ ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: "VARCHAR", + ResultDataTypeOld: "VARCHAR", }, NotNull: Bool(true), } @@ -666,13 +671,13 @@ func TestProcedures_CreateAndCallForJava(t *testing.T) { Table: &ProcedureReturnsTable{ Columns: []ProcedureColumn{ { - ColumnName: "name", - ColumnDataType: DataTypeVARCHAR, + ColumnName: "name", + ColumnDataTypeOld: DataTypeVARCHAR, }, }, }, ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: DataTypeFloat, + ResultDataTypeOld: DataTypeFloat, }, } assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForJavaProcedureOptions.Returns", "ResultDataType", "Table")) @@ -682,7 +687,7 @@ func TestProcedures_CreateAndCallForJava(t *testing.T) { opts := defaultOpts() opts.Returns = ProcedureReturns{ ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: DataTypeVARCHAR, + ResultDataTypeOld: DataTypeVARCHAR, }, } assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateAndCallForJavaProcedureOptions", "Handler")) @@ -706,24 +711,25 @@ func TestProcedures_CreateAndCallForJava(t *testing.T) { assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE () RETURNS TABLE () LANGUAGE JAVA RUNTIME_VERSION = '1.8' PACKAGES = ('com.snowflake:snowpark:latest') HANDLER = 'TestFunc.echoVarchar' CALL %s ()`, id.FullyQualifiedName(), id.FullyQualifiedName()) }) - t.Run("all options", func(t *testing.T) { + // TODO [next PR]: remove with old procedure removal for V1 + t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.Arguments = []ProcedureArgument{ { - ArgName: "id", - ArgDataType: DataTypeNumber, + ArgName: "id", + ArgDataTypeOld: DataTypeNumber, }, { - ArgName: "name", - ArgDataType: DataTypeVARCHAR, + ArgName: "name", + ArgDataTypeOld: DataTypeVARCHAR, }, } opts.Returns = ProcedureReturns{ Table: &ProcedureReturnsTable{ Columns: []ProcedureColumn{ { - ColumnName: "country_code", - ColumnDataType: DataTypeVARCHAR, + ColumnName: "country_code", + ColumnDataTypeOld: DataTypeVARCHAR, }, }, }, @@ -787,13 +793,13 @@ func TestProcedures_CreateAndCallForScala(t *testing.T) { Table: &ProcedureReturnsTable{ Columns: []ProcedureColumn{ { - ColumnName: "name", - ColumnDataType: DataTypeVARCHAR, + ColumnName: "name", + ColumnDataTypeOld: DataTypeVARCHAR, }, }, }, ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: DataTypeFloat, + ResultDataTypeOld: DataTypeFloat, }, } assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForScalaProcedureOptions.Returns", "ResultDataType", "Table")) @@ -803,7 +809,7 @@ func TestProcedures_CreateAndCallForScala(t *testing.T) { opts := defaultOpts() opts.Returns = ProcedureReturns{ ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: DataTypeVARCHAR, + ResultDataTypeOld: DataTypeVARCHAR, }, } assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateAndCallForScalaProcedureOptions", "Handler")) @@ -827,24 +833,25 @@ func TestProcedures_CreateAndCallForScala(t *testing.T) { assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE () RETURNS TABLE () LANGUAGE SCALA RUNTIME_VERSION = '2.12' PACKAGES = ('com.snowflake:snowpark:1.2.0') HANDLER = 'TestFunc.echoVarchar' CALL %s ()`, id.FullyQualifiedName(), id.FullyQualifiedName()) }) - t.Run("all options", func(t *testing.T) { + // TODO [next PR]: remove with old procedure removal for V1 + t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.Arguments = []ProcedureArgument{ { - ArgName: "id", - ArgDataType: DataTypeNumber, + ArgName: "id", + ArgDataTypeOld: DataTypeNumber, }, { - ArgName: "name", - ArgDataType: DataTypeVARCHAR, + ArgName: "name", + ArgDataTypeOld: DataTypeVARCHAR, }, } opts.Returns = ProcedureReturns{ Table: &ProcedureReturnsTable{ Columns: []ProcedureColumn{ { - ColumnName: "country_code", - ColumnDataType: DataTypeVARCHAR, + ColumnName: "country_code", + ColumnDataTypeOld: DataTypeVARCHAR, }, }, }, @@ -910,13 +917,13 @@ func TestProcedures_CreateAndCallForPython(t *testing.T) { Table: &ProcedureReturnsTable{ Columns: []ProcedureColumn{ { - ColumnName: "name", - ColumnDataType: DataTypeVARCHAR, + ColumnName: "name", + ColumnDataTypeOld: DataTypeVARCHAR, }, }, }, ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: DataTypeFloat, + ResultDataTypeOld: DataTypeFloat, }, } assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForPythonProcedureOptions.Returns", "ResultDataType", "Table")) @@ -926,7 +933,7 @@ func TestProcedures_CreateAndCallForPython(t *testing.T) { opts := defaultOpts() opts.Returns = ProcedureReturns{ ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: DataTypeVARCHAR, + ResultDataTypeOld: DataTypeVARCHAR, }, } assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateAndCallForPythonProcedureOptions", "Handler")) @@ -950,19 +957,20 @@ func TestProcedures_CreateAndCallForPython(t *testing.T) { assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE () RETURNS TABLE () LANGUAGE PYTHON RUNTIME_VERSION = '3.8' PACKAGES = ('snowflake-snowpark-python') HANDLER = 'udf' CALL %s ()`, id.FullyQualifiedName(), id.FullyQualifiedName()) }) - t.Run("all options", func(t *testing.T) { + // TODO [next PR]: remove with old procedure removal for V1 + t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.Arguments = []ProcedureArgument{ { - ArgName: "i", - ArgDataType: "int", - DefaultValue: String("1"), + ArgName: "i", + ArgDataTypeOld: "int", + DefaultValue: String("1"), }, } opts.Returns = ProcedureReturns{ ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: "VARIANT", - Null: Bool(true), + ResultDataTypeOld: "VARIANT", + Null: Bool(true), }, } opts.RuntimeVersion = "3.8" @@ -1028,22 +1036,23 @@ func TestProcedures_CreateAndCallForJavaScript(t *testing.T) { t.Run("no arguments", func(t *testing.T) { opts := defaultOpts() - opts.ResultDataType = "DOUBLE" + opts.ResultDataTypeOld = "DOUBLE" opts.ProcedureDefinition = "return 1;" opts.ProcedureName = id assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE () RETURNS DOUBLE LANGUAGE JAVASCRIPT AS 'return 1;' CALL %s ()`, id.FullyQualifiedName(), id.FullyQualifiedName()) }) - t.Run("all options", func(t *testing.T) { + // TODO [next PR]: remove with old procedure removal for V1 + t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.Arguments = []ProcedureArgument{ { - ArgName: "d", - ArgDataType: "DOUBLE", - DefaultValue: String("1.0"), + ArgName: "d", + ArgDataTypeOld: "DOUBLE", + DefaultValue: String("1.0"), }, } - opts.ResultDataType = "DOUBLE" + opts.ResultDataTypeOld = "DOUBLE" opts.NotNull = Bool(true) opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorStrict) opts.ProcedureDefinition = "return 1;" @@ -1094,13 +1103,13 @@ func TestProcedures_CreateAndCallForSQL(t *testing.T) { Table: &ProcedureReturnsTable{ Columns: []ProcedureColumn{ { - ColumnName: "name", - ColumnDataType: DataTypeVARCHAR, + ColumnName: "name", + ColumnDataTypeOld: DataTypeVARCHAR, }, }, }, ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: DataTypeFloat, + ResultDataTypeOld: DataTypeFloat, }, } assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns", "ResultDataType", "Table")) @@ -1122,18 +1131,19 @@ func TestProcedures_CreateAndCallForSQL(t *testing.T) { assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE () RETURNS TABLE () LANGUAGE SQL AS '3.141592654::FLOAT' CALL %s ()`, id.FullyQualifiedName(), id.FullyQualifiedName()) }) - t.Run("all options", func(t *testing.T) { + // TODO [next PR]: remove with old procedure removal for V1 + t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.Arguments = []ProcedureArgument{ { - ArgName: "message", - ArgDataType: "VARCHAR", - DefaultValue: String("'test'"), + ArgName: "message", + ArgDataTypeOld: "VARCHAR", + DefaultValue: String("'test'"), }, } opts.Returns = ProcedureReturns{ ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: DataTypeFloat, + ResultDataTypeOld: DataTypeFloat, }, } opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorStrict) From 0339f5ee5cacaf03ab8b31305ca65409bd0253da Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Wed, 4 Dec 2024 17:21:32 +0100 Subject: [PATCH 17/29] Add new data type to procedure definition --- pkg/sdk/procedures_def.go | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/pkg/sdk/procedures_def.go b/pkg/sdk/procedures_def.go index 9870e20694..d30ff67e94 100644 --- a/pkg/sdk/procedures_def.go +++ b/pkg/sdk/procedures_def.go @@ -6,18 +6,21 @@ import g "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/poc/gen var procedureArgument = g.NewQueryStruct("ProcedureArgument"). Text("ArgName", g.KeywordOptions().NoQuotes().Required()). - PredefinedQueryStructField("ArgDataTypeOld", "DataType", g.KeywordOptions().NoQuotes().Required()). + PredefinedQueryStructField("ArgDataTypeOld", "DataType", g.KeywordOptions().NoQuotes()). + PredefinedQueryStructField("ArgDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()). PredefinedQueryStructField("DefaultValue", "*string", g.ParameterOptions().NoEquals().SQL("DEFAULT")) var procedureColumn = g.NewQueryStruct("ProcedureColumn"). Text("ColumnName", g.KeywordOptions().NoQuotes().Required()). - PredefinedQueryStructField("ColumnDataTypeOld", "DataType", g.KeywordOptions().NoQuotes().Required()) + PredefinedQueryStructField("ColumnDataTypeOld", "DataType", g.KeywordOptions().NoQuotes()). + PredefinedQueryStructField("ColumnDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()) var procedureReturns = g.NewQueryStruct("ProcedureReturns"). OptionalQueryStructField( "ResultDataType", g.NewQueryStruct("ProcedureReturnsResultDataType"). - PredefinedQueryStructField("ResultDataTypeOld", "DataType", g.KeywordOptions().NoQuotes().Required()). + PredefinedQueryStructField("ResultDataTypeOld", "DataType", g.KeywordOptions().NoQuotes()). + PredefinedQueryStructField("ResultDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()). OptionalSQL("NULL").OptionalSQL("NOT NULL"), g.KeywordOptions(), ). @@ -36,7 +39,8 @@ var procedureSQLReturns = g.NewQueryStruct("ProcedureSQLReturns"). OptionalQueryStructField( "ResultDataType", g.NewQueryStruct("ProcedureReturnsResultDataType"). - PredefinedQueryStructField("ResultDataTypeOld", "DataType", g.KeywordOptions().NoQuotes().Required()), + PredefinedQueryStructField("ResultDataTypeOld", "DataType", g.KeywordOptions().NoQuotes()). + PredefinedQueryStructField("ResultDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()), g.KeywordOptions(), ). OptionalQueryStructField( @@ -126,7 +130,9 @@ var ProceduresDef = g.NewInterface( g.ListOptions().MustParentheses(), ). OptionalSQL("COPY GRANTS"). - PredefinedQueryStructField("ResultDataTypeOld", "DataType", g.ParameterOptions().NoEquals().SQL("RETURNS").Required()). + SQL("RETURNS"). + PredefinedQueryStructField("ResultDataTypeOld", "DataType", g.ParameterOptions().NoEquals()). + PredefinedQueryStructField("ResultDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()). OptionalSQL("NOT NULL"). SQL("LANGUAGE JAVASCRIPT"). PredefinedQueryStructField("NullInputBehavior", "*NullInputBehavior", g.KeywordOptions()). @@ -437,7 +443,9 @@ var ProceduresDef = g.NewInterface( procedureArgument, g.ListOptions().MustParentheses(), ). - PredefinedQueryStructField("ResultDataTypeOld", "DataType", g.ParameterOptions().NoEquals().SQL("RETURNS").Required()). + SQL("RETURNS"). + PredefinedQueryStructField("ResultDataTypeOld", "DataType", g.ParameterOptions().NoEquals()). + PredefinedQueryStructField("ResultDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()). OptionalSQL("NOT NULL"). SQL("LANGUAGE JAVASCRIPT"). PredefinedQueryStructField("NullInputBehavior", "*NullInputBehavior", g.KeywordOptions()). @@ -452,7 +460,7 @@ var ProceduresDef = g.NewInterface( PredefinedQueryStructField("CallArguments", "[]string", g.KeywordOptions().MustParentheses()). PredefinedQueryStructField("ScriptingVariable", "*string", g.ParameterOptions().NoEquals().NoQuotes().SQL("INTO")). WithValidation(g.ValidateValueSet, "ProcedureDefinition"). - WithValidation(g.ValidateValueSet, "ResultDataTypeOld"). + WithValidation(g.AtLeastOneValueSet, "ResultDataTypeOld", "ResultDataType"). WithValidation(g.ValidIdentifier, "ProcedureName"). WithValidation(g.ValidIdentifier, "Name"), ).CustomOperation( From 8b440c68a53b347b43848611f697ddc4b7fd28d3 Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Wed, 4 Dec 2024 17:45:51 +0100 Subject: [PATCH 18/29] Regenerate procedures with new data types --- pkg/sdk/procedures_dto_builders_gen.go | 50 ++++++++--- pkg/sdk/procedures_dto_gen.go | 22 +++-- pkg/sdk/procedures_gen.go | 30 ++++--- pkg/sdk/procedures_gen_test.go | 116 +++++++++++++++++++++++-- pkg/sdk/procedures_impl_gen.go | 10 +++ pkg/sdk/procedures_validations_gen.go | 4 +- 6 files changed, 194 insertions(+), 38 deletions(-) diff --git a/pkg/sdk/procedures_dto_builders_gen.go b/pkg/sdk/procedures_dto_builders_gen.go index 8170230f4a..373852a62a 100644 --- a/pkg/sdk/procedures_dto_builders_gen.go +++ b/pkg/sdk/procedures_dto_builders_gen.go @@ -2,7 +2,10 @@ package sdk -import () +// imports added manually +import ( + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" +) func NewCreateForJavaProcedureRequest( name SchemaObjectIdentifier, @@ -82,14 +85,19 @@ func (s *CreateForJavaProcedureRequest) WithProcedureDefinition(ProcedureDefinit func NewProcedureArgumentRequest( ArgName string, - ArgDataTypeOld DataType, + ArgDataType datatypes.DataType, ) *ProcedureArgumentRequest { s := ProcedureArgumentRequest{} s.ArgName = ArgName - s.ArgDataTypeOld = ArgDataTypeOld + s.ArgDataType = ArgDataType return &s } +func (s *ProcedureArgumentRequest) WithArgDataTypeOld(ArgDataTypeOld DataType) *ProcedureArgumentRequest { + s.ArgDataTypeOld = ArgDataTypeOld + return s +} + func (s *ProcedureArgumentRequest) WithDefaultValue(DefaultValue string) *ProcedureArgumentRequest { s.DefaultValue = &DefaultValue return s @@ -110,13 +118,18 @@ func (s *ProcedureReturnsRequest) WithTable(Table ProcedureReturnsTableRequest) } func NewProcedureReturnsResultDataTypeRequest( - ResultDataTypeOld DataType, + ResultDataType datatypes.DataType, ) *ProcedureReturnsResultDataTypeRequest { s := ProcedureReturnsResultDataTypeRequest{} - s.ResultDataTypeOld = ResultDataTypeOld + s.ResultDataType = ResultDataType return &s } +func (s *ProcedureReturnsResultDataTypeRequest) WithResultDataTypeOld(ResultDataTypeOld DataType) *ProcedureReturnsResultDataTypeRequest { + s.ResultDataTypeOld = ResultDataTypeOld + return s +} + func (s *ProcedureReturnsResultDataTypeRequest) WithNull(Null bool) *ProcedureReturnsResultDataTypeRequest { s.Null = &Null return s @@ -138,14 +151,19 @@ func (s *ProcedureReturnsTableRequest) WithColumns(Columns []ProcedureColumnRequ func NewProcedureColumnRequest( ColumnName string, - ColumnDataTypeOld DataType, + ColumnDataType datatypes.DataType, ) *ProcedureColumnRequest { s := ProcedureColumnRequest{} s.ColumnName = ColumnName - s.ColumnDataTypeOld = ColumnDataTypeOld + s.ColumnDataType = ColumnDataType return &s } +func (s *ProcedureColumnRequest) WithColumnDataTypeOld(ColumnDataTypeOld DataType) *ProcedureColumnRequest { + s.ColumnDataTypeOld = ColumnDataTypeOld + return s +} + func NewProcedurePackageRequest( Package string, ) *ProcedurePackageRequest { @@ -164,12 +182,12 @@ func NewProcedureImportRequest( func NewCreateForJavaScriptProcedureRequest( name SchemaObjectIdentifier, - ResultDataTypeOld DataType, + ResultDataType datatypes.DataType, ProcedureDefinition string, ) *CreateForJavaScriptProcedureRequest { s := CreateForJavaScriptProcedureRequest{} s.name = name - s.ResultDataTypeOld = ResultDataTypeOld + s.ResultDataType = ResultDataType s.ProcedureDefinition = ProcedureDefinition return &s } @@ -194,6 +212,11 @@ func (s *CreateForJavaScriptProcedureRequest) WithCopyGrants(CopyGrants bool) *C return s } +func (s *CreateForJavaScriptProcedureRequest) WithResultDataTypeOld(ResultDataTypeOld DataType) *CreateForJavaScriptProcedureRequest { + s.ResultDataTypeOld = ResultDataTypeOld + return s +} + func (s *CreateForJavaScriptProcedureRequest) WithNotNull(NotNull bool) *CreateForJavaScriptProcedureRequest { s.NotNull = &NotNull return s @@ -646,13 +669,13 @@ func (s *CreateAndCallForScalaProcedureRequest) WithScriptingVariable(ScriptingV func NewCreateAndCallForJavaScriptProcedureRequest( Name AccountObjectIdentifier, - ResultDataTypeOld DataType, + ResultDataType datatypes.DataType, ProcedureDefinition string, ProcedureName AccountObjectIdentifier, ) *CreateAndCallForJavaScriptProcedureRequest { s := CreateAndCallForJavaScriptProcedureRequest{} s.Name = Name - s.ResultDataTypeOld = ResultDataTypeOld + s.ResultDataType = ResultDataType s.ProcedureDefinition = ProcedureDefinition s.ProcedureName = ProcedureName return &s @@ -663,6 +686,11 @@ func (s *CreateAndCallForJavaScriptProcedureRequest) WithArguments(Arguments []P return s } +func (s *CreateAndCallForJavaScriptProcedureRequest) WithResultDataTypeOld(ResultDataTypeOld DataType) *CreateAndCallForJavaScriptProcedureRequest { + s.ResultDataTypeOld = ResultDataTypeOld + return s +} + func (s *CreateAndCallForJavaScriptProcedureRequest) WithNotNull(NotNull bool) *CreateAndCallForJavaScriptProcedureRequest { s.NotNull = &NotNull return s diff --git a/pkg/sdk/procedures_dto_gen.go b/pkg/sdk/procedures_dto_gen.go index 339fbff4b7..bf3e0a8d72 100644 --- a/pkg/sdk/procedures_dto_gen.go +++ b/pkg/sdk/procedures_dto_gen.go @@ -1,5 +1,8 @@ package sdk +// imports added manually +import "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" + //go:generate go run ./dto-builder-generator/main.go var ( @@ -41,8 +44,9 @@ type CreateForJavaProcedureRequest struct { } type ProcedureArgumentRequest struct { - ArgName string // required - ArgDataTypeOld DataType // required + ArgName string // required + ArgDataTypeOld DataType + ArgDataType datatypes.DataType // required DefaultValue *string } @@ -52,7 +56,8 @@ type ProcedureReturnsRequest struct { } type ProcedureReturnsResultDataTypeRequest struct { - ResultDataTypeOld DataType // required + ResultDataTypeOld DataType + ResultDataType datatypes.DataType // required Null *bool NotNull *bool } @@ -62,8 +67,9 @@ type ProcedureReturnsTableRequest struct { } type ProcedureColumnRequest struct { - ColumnName string // required - ColumnDataTypeOld DataType // required + ColumnName string // required + ColumnDataTypeOld DataType + ColumnDataType datatypes.DataType // required } type ProcedurePackageRequest struct { @@ -80,7 +86,8 @@ type CreateForJavaScriptProcedureRequest struct { name SchemaObjectIdentifier // required Arguments []ProcedureArgumentRequest CopyGrants *bool - ResultDataTypeOld DataType // required + ResultDataTypeOld DataType + ResultDataType datatypes.DataType // required NotNull *bool NullInputBehavior *NullInputBehavior Comment *string @@ -218,7 +225,8 @@ type CreateAndCallForScalaProcedureRequest struct { type CreateAndCallForJavaScriptProcedureRequest struct { Name AccountObjectIdentifier // required Arguments []ProcedureArgumentRequest - ResultDataTypeOld DataType // required + ResultDataTypeOld DataType + ResultDataType datatypes.DataType // required NotNull *bool NullInputBehavior *NullInputBehavior ProcedureDefinition string // required diff --git a/pkg/sdk/procedures_gen.go b/pkg/sdk/procedures_gen.go index fbe2bc8861..93ecbbc02c 100644 --- a/pkg/sdk/procedures_gen.go +++ b/pkg/sdk/procedures_gen.go @@ -3,6 +3,9 @@ package sdk import ( "context" "database/sql" + + // import added manually + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" ) type Procedures interface { @@ -49,9 +52,10 @@ type CreateForJavaProcedureOptions struct { } type ProcedureArgument struct { - ArgName string `ddl:"keyword,no_quotes"` - ArgDataTypeOld DataType `ddl:"keyword,no_quotes"` - DefaultValue *string `ddl:"parameter,no_equals" sql:"DEFAULT"` + ArgName string `ddl:"keyword,no_quotes"` + ArgDataTypeOld DataType `ddl:"keyword,no_quotes"` + ArgDataType datatypes.DataType `ddl:"parameter,no_quotes,no_equals"` + DefaultValue *string `ddl:"parameter,no_equals" sql:"DEFAULT"` } type ProcedureReturns struct { @@ -60,9 +64,10 @@ type ProcedureReturns struct { } type ProcedureReturnsResultDataType struct { - ResultDataTypeOld DataType `ddl:"keyword,no_quotes"` - Null *bool `ddl:"keyword" sql:"NULL"` - NotNull *bool `ddl:"keyword" sql:"NOT NULL"` + ResultDataTypeOld DataType `ddl:"keyword,no_quotes"` + ResultDataType datatypes.DataType `ddl:"parameter,no_quotes,no_equals"` + Null *bool `ddl:"keyword" sql:"NULL"` + NotNull *bool `ddl:"keyword" sql:"NOT NULL"` } type ProcedureReturnsTable struct { @@ -70,8 +75,9 @@ type ProcedureReturnsTable struct { } type ProcedureColumn struct { - ColumnName string `ddl:"keyword,no_quotes"` - ColumnDataTypeOld DataType `ddl:"keyword,no_quotes"` + ColumnName string `ddl:"keyword,no_quotes"` + ColumnDataTypeOld DataType `ddl:"keyword,no_quotes"` + ColumnDataType datatypes.DataType `ddl:"parameter,no_quotes,no_equals"` } type ProcedurePackage struct { @@ -91,7 +97,9 @@ type CreateForJavaScriptProcedureOptions struct { name SchemaObjectIdentifier `ddl:"identifier"` Arguments []ProcedureArgument `ddl:"list,must_parentheses"` CopyGrants *bool `ddl:"keyword" sql:"COPY GRANTS"` - ResultDataTypeOld DataType `ddl:"parameter,no_equals" sql:"RETURNS"` + returns bool `ddl:"static" sql:"RETURNS"` + ResultDataTypeOld DataType `ddl:"parameter,no_equals"` + ResultDataType datatypes.DataType `ddl:"parameter,no_quotes,no_equals"` NotNull *bool `ddl:"keyword" sql:"NOT NULL"` languageJavascript bool `ddl:"static" sql:"LANGUAGE JAVASCRIPT"` NullInputBehavior *NullInputBehavior `ddl:"keyword"` @@ -314,7 +322,9 @@ type CreateAndCallForJavaScriptProcedureOptions struct { Name AccountObjectIdentifier `ddl:"identifier"` asProcedure bool `ddl:"static" sql:"AS PROCEDURE"` Arguments []ProcedureArgument `ddl:"list,must_parentheses"` - ResultDataTypeOld DataType `ddl:"parameter,no_equals" sql:"RETURNS"` + returns bool `ddl:"static" sql:"RETURNS"` + ResultDataTypeOld DataType `ddl:"parameter,no_equals"` + ResultDataType datatypes.DataType `ddl:"parameter,no_quotes,no_equals"` NotNull *bool `ddl:"keyword" sql:"NOT NULL"` languageJavascript bool `ddl:"static" sql:"LANGUAGE JAVASCRIPT"` NullInputBehavior *NullInputBehavior `ddl:"keyword"` diff --git a/pkg/sdk/procedures_gen_test.go b/pkg/sdk/procedures_gen_test.go index 9ae7ce3915..d68cc90a0c 100644 --- a/pkg/sdk/procedures_gen_test.go +++ b/pkg/sdk/procedures_gen_test.go @@ -9,7 +9,14 @@ func TestProcedures_CreateForJava(t *testing.T) { defaultOpts := func() *CreateForJavaProcedureOptions { return &CreateForJavaProcedureOptions{ - name: id, + name: id, + Handler: "TestFunc.echoVarchar", + Packages: []ProcedurePackage{ + { + Package: "com.snowflake:snowpark:1.2.0", + }, + }, + RuntimeVersion: "1.8", } } @@ -24,7 +31,25 @@ func TestProcedures_CreateForJava(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) - t.Run("validation: returns", func(t *testing.T) { + t.Run("validation: [opts.RuntimeVersion] should be set", func(t *testing.T) { + opts := defaultOpts() + opts.RuntimeVersion = "" + assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForJavaProcedureOptions", "RuntimeVersion")) + }) + + t.Run("validation: [opts.Packages] should be set", func(t *testing.T) { + opts := defaultOpts() + opts.Packages = nil + assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForJavaProcedureOptions", "Packages")) + }) + + t.Run("validation: [opts.Handler] should be set", func(t *testing.T) { + opts := defaultOpts() + opts.Handler = "" + assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForJavaProcedureOptions", "Handler")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType opts.Returns.Table] should be present", func(t *testing.T) { opts := defaultOpts() opts.Returns = ProcedureReturns{} assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavaProcedureOptions.Returns", "ResultDataType", "Table")) @@ -118,7 +143,8 @@ func TestProcedures_CreateForJavaScript(t *testing.T) { defaultOpts := func() *CreateForJavaScriptProcedureOptions { return &CreateForJavaScriptProcedureOptions{ - name: id, + name: id, + ProcedureDefinition: "return 1;", } } @@ -127,6 +153,12 @@ func TestProcedures_CreateForJavaScript(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, ErrNilOptions) }) + t.Run("validation: [opts.ProcedureDefinition] should be set", func(t *testing.T) { + opts := defaultOpts() + opts.ProcedureDefinition = "" + assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForJavaScriptProcedureOptions", "ProcedureDefinition")) + }) + t.Run("validation: incorrect identifier", func(t *testing.T) { opts := defaultOpts() opts.name = emptySchemaObjectIdentifier @@ -166,7 +198,14 @@ func TestProcedures_CreateForPython(t *testing.T) { defaultOpts := func() *CreateForPythonProcedureOptions { return &CreateForPythonProcedureOptions{ - name: id, + name: id, + RuntimeVersion: "3.8", + Packages: []ProcedurePackage{ + { + Package: "numpy", + }, + }, + Handler: "udf", } } @@ -175,13 +214,31 @@ func TestProcedures_CreateForPython(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, ErrNilOptions) }) + t.Run("validation: [opts.RuntimeVersion] should be set", func(t *testing.T) { + opts := defaultOpts() + opts.RuntimeVersion = "" + assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForPythonProcedureOptions", "RuntimeVersion")) + }) + + t.Run("validation: [opts.Packages] should be set", func(t *testing.T) { + opts := defaultOpts() + opts.Packages = nil + assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForPythonProcedureOptions", "Packages")) + }) + + t.Run("validation: [opts.Handler] should be set", func(t *testing.T) { + opts := defaultOpts() + opts.Handler = "" + assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForPythonProcedureOptions", "Handler")) + }) + t.Run("validation: incorrect identifier", func(t *testing.T) { opts := defaultOpts() opts.name = emptySchemaObjectIdentifier assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) - t.Run("validation: returns", func(t *testing.T) { + t.Run("validation: exactly one field from [opts.Returns.ResultDataType opts.Returns.Table] should be present", func(t *testing.T) { opts := defaultOpts() opts.Returns = ProcedureReturns{} assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForPythonProcedureOptions.Returns", "ResultDataType", "Table")) @@ -261,7 +318,14 @@ func TestProcedures_CreateForScala(t *testing.T) { defaultOpts := func() *CreateForScalaProcedureOptions { return &CreateForScalaProcedureOptions{ - name: id, + name: id, + RuntimeVersion: "2.0", + Packages: []ProcedurePackage{ + { + Package: "com.snowflake:snowpark:1.2.0", + }, + }, + Handler: "Echo.echoVarchar", } } @@ -270,13 +334,31 @@ func TestProcedures_CreateForScala(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, ErrNilOptions) }) + t.Run("validation: [opts.RuntimeVersion] should be set", func(t *testing.T) { + opts := defaultOpts() + opts.RuntimeVersion = "" + assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForScalaProcedureOptions", "RuntimeVersion")) + }) + + t.Run("validation: [opts.Packages] should be set", func(t *testing.T) { + opts := defaultOpts() + opts.Packages = nil + assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForScalaProcedureOptions", "Packages")) + }) + + t.Run("validation: [opts.Handler] should be set", func(t *testing.T) { + opts := defaultOpts() + opts.Handler = "" + assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForScalaProcedureOptions", "Handler")) + }) + t.Run("validation: incorrect identifier", func(t *testing.T) { opts := defaultOpts() opts.name = emptySchemaObjectIdentifier assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) - t.Run("validation: returns", func(t *testing.T) { + t.Run("validation: exactly one field from [opts.Returns.ResultDataType opts.Returns.Table] should be present", func(t *testing.T) { opts := defaultOpts() opts.Returns = ProcedureReturns{} assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForScalaProcedureOptions.Returns", "ResultDataType", "Table")) @@ -349,7 +431,13 @@ func TestProcedures_CreateForSQL(t *testing.T) { defaultOpts := func() *CreateForSQLProcedureOptions { return &CreateForSQLProcedureOptions{ - name: id, + name: id, + ProcedureDefinition: "3.141592654::FLOAT", + Returns: ProcedureSQLReturns{ + ResultDataType: &ProcedureReturnsResultDataType{ + ResultDataTypeOld: "VARCHAR", + }, + }, } } @@ -358,6 +446,12 @@ func TestProcedures_CreateForSQL(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, ErrNilOptions) }) + t.Run("validation: [opts.ProcedureDefinition] should be set", func(t *testing.T) { + opts := defaultOpts() + opts.ProcedureDefinition = "" + assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForSQLProcedureOptions", "ProcedureDefinition")) + }) + t.Run("validation: incorrect identifier", func(t *testing.T) { opts := defaultOpts() opts.name = emptySchemaObjectIdentifier @@ -380,6 +474,12 @@ func TestProcedures_CreateForSQL(t *testing.T) { assertOptsValidAndSQLEquals(t, opts, `CREATE PROCEDURE %s () RETURNS FLOAT LANGUAGE SQL AS '3.141592654::FLOAT'`, id.FullyQualifiedName()) }) + t.Run("validation: exactly one field from [opts.Returns.ResultDataType opts.Returns.Table] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureSQLReturns{} + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLProcedureOptions.Returns", "ResultDataType", "Table")) + }) + // TODO [next PR]: remove with old procedure removal for V1 t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() diff --git a/pkg/sdk/procedures_impl_gen.go b/pkg/sdk/procedures_impl_gen.go index ebb24b874c..6f627d9f2e 100644 --- a/pkg/sdk/procedures_impl_gen.go +++ b/pkg/sdk/procedures_impl_gen.go @@ -138,6 +138,7 @@ func (r *CreateForJavaProcedureRequest) toOpts() *CreateForJavaProcedureOptions if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &ProcedureReturnsResultDataType{ ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + ResultDataType: r.Returns.ResultDataType.ResultDataType, Null: r.Returns.ResultDataType.Null, NotNull: r.Returns.ResultDataType.NotNull, } @@ -177,6 +178,7 @@ func (r *CreateForJavaScriptProcedureRequest) toOpts() *CreateForJavaScriptProce CopyGrants: r.CopyGrants, ResultDataTypeOld: r.ResultDataTypeOld, + ResultDataType: r.ResultDataType, NotNull: r.NotNull, NullInputBehavior: r.NullInputBehavior, Comment: r.Comment, @@ -222,6 +224,7 @@ func (r *CreateForPythonProcedureRequest) toOpts() *CreateForPythonProcedureOpti if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &ProcedureReturnsResultDataType{ ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + ResultDataType: r.Returns.ResultDataType.ResultDataType, Null: r.Returns.ResultDataType.Null, NotNull: r.Returns.ResultDataType.NotNull, } @@ -281,6 +284,7 @@ func (r *CreateForScalaProcedureRequest) toOpts() *CreateForScalaProcedureOption if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &ProcedureReturnsResultDataType{ ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + ResultDataType: r.Returns.ResultDataType.ResultDataType, Null: r.Returns.ResultDataType.Null, NotNull: r.Returns.ResultDataType.NotNull, } @@ -338,6 +342,7 @@ func (r *CreateForSQLProcedureRequest) toOpts() *CreateForSQLProcedureOptions { if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &ProcedureReturnsResultDataType{ ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + ResultDataType: r.Returns.ResultDataType.ResultDataType, } } if r.Returns.Table != nil { @@ -467,6 +472,7 @@ func (r *CreateAndCallForJavaProcedureRequest) toOpts() *CreateAndCallForJavaPro if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &ProcedureReturnsResultDataType{ ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + ResultDataType: r.Returns.ResultDataType.ResultDataType, Null: r.Returns.ResultDataType.Null, NotNull: r.Returns.ResultDataType.NotNull, } @@ -530,6 +536,7 @@ func (r *CreateAndCallForScalaProcedureRequest) toOpts() *CreateAndCallForScalaP if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &ProcedureReturnsResultDataType{ ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + ResultDataType: r.Returns.ResultDataType.ResultDataType, Null: r.Returns.ResultDataType.Null, NotNull: r.Returns.ResultDataType.NotNull, } @@ -577,6 +584,7 @@ func (r *CreateAndCallForJavaScriptProcedureRequest) toOpts() *CreateAndCallForJ Name: r.Name, ResultDataTypeOld: r.ResultDataTypeOld, + ResultDataType: r.ResultDataType, NotNull: r.NotNull, NullInputBehavior: r.NullInputBehavior, ProcedureDefinition: r.ProcedureDefinition, @@ -631,6 +639,7 @@ func (r *CreateAndCallForPythonProcedureRequest) toOpts() *CreateAndCallForPytho if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &ProcedureReturnsResultDataType{ ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + ResultDataType: r.Returns.ResultDataType.ResultDataType, Null: r.Returns.ResultDataType.Null, NotNull: r.Returns.ResultDataType.NotNull, } @@ -695,6 +704,7 @@ func (r *CreateAndCallForSQLProcedureRequest) toOpts() *CreateAndCallForSQLProce if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &ProcedureReturnsResultDataType{ ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + ResultDataType: r.Returns.ResultDataType.ResultDataType, Null: r.Returns.ResultDataType.Null, NotNull: r.Returns.ResultDataType.NotNull, } diff --git a/pkg/sdk/procedures_validations_gen.go b/pkg/sdk/procedures_validations_gen.go index a153e05b36..1630153487 100644 --- a/pkg/sdk/procedures_validations_gen.go +++ b/pkg/sdk/procedures_validations_gen.go @@ -258,8 +258,8 @@ func (opts *CreateAndCallForJavaScriptProcedureOptions) validate() error { if !valueSet(opts.ProcedureDefinition) { errs = append(errs, errNotSet("CreateAndCallForJavaScriptProcedureOptions", "ProcedureDefinition")) } - if !valueSet(opts.ResultDataTypeOld) { - errs = append(errs, errNotSet("CreateAndCallForJavaScriptProcedureOptions", "ResultDataTypeOld")) + if !anyValueSet(opts.ResultDataTypeOld, opts.ResultDataType) { + errs = append(errs, errAtLeastOneOf("CreateAndCallForJavaScriptProcedureOptions", "ResultDataTypeOld", "ResultDataType")) } if !ValidObjectIdentifier(opts.ProcedureName) { // altered manually From cef1eef3f063502ca8e473facb551fed6766b38c Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Wed, 4 Dec 2024 18:11:32 +0100 Subject: [PATCH 19/29] Add unit tests for new data types in procedures --- pkg/sdk/procedures_gen_test.go | 447 +++++++++++++++++++++++++++++---- 1 file changed, 404 insertions(+), 43 deletions(-) diff --git a/pkg/sdk/procedures_gen_test.go b/pkg/sdk/procedures_gen_test.go index d68cc90a0c..ceac49a462 100644 --- a/pkg/sdk/procedures_gen_test.go +++ b/pkg/sdk/procedures_gen_test.go @@ -66,17 +66,6 @@ func TestProcedures_CreateForJava(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, NewError("TARGET_PATH must be nil when AS is nil")) }) - t.Run("validation: options are missing", func(t *testing.T) { - opts := defaultOpts() - opts.Returns = ProcedureReturns{ - ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataTypeOld: DataTypeVARCHAR, - }, - } - assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForJavaProcedureOptions", "Handler")) - assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForJavaProcedureOptions", "RuntimeVersion")) - }) - // TODO [next PR]: remove with old procedure removal for V1 t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() @@ -136,6 +125,65 @@ func TestProcedures_CreateForJava(t *testing.T) { opts.ProcedureDefinition = String("return id + name;") assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE SECURE PROCEDURE %s (id NUMBER, name VARCHAR DEFAULT 'test') COPY GRANTS RETURNS TABLE (country_code VARCHAR) LANGUAGE JAVA RUNTIME_VERSION = '1.8' PACKAGES = ('com.snowflake:snowpark:1.2.0') IMPORTS = ('test_jar.jar') HANDLER = 'TestFunc.echoVarchar' EXTERNAL_ACCESS_INTEGRATIONS = ("ext_integration") SECRETS = ('variable1' = name1, 'variable2' = name2) TARGET_PATH = '@~/testfunc.jar' STRICT COMMENT = 'test comment' EXECUTE AS CALLER AS 'return id + name;'`, id.FullyQualifiedName()) }) + + t.Run("all options", func(t *testing.T) { + opts := defaultOpts() + opts.OrReplace = Bool(true) + opts.Secure = Bool(true) + opts.Arguments = []ProcedureArgument{ + { + ArgName: "id", + ArgDataType: dataTypeNumber, + }, + { + ArgName: "name", + ArgDataType: dataTypeVarchar, + DefaultValue: String("'test'"), + }, + } + opts.CopyGrants = Bool(true) + opts.Returns = ProcedureReturns{ + Table: &ProcedureReturnsTable{ + Columns: []ProcedureColumn{ + { + ColumnName: "country_code", + ColumnDataType: dataTypeVarchar, + }, + }, + }, + } + opts.RuntimeVersion = "1.8" + opts.Packages = []ProcedurePackage{ + { + Package: "com.snowflake:snowpark:1.2.0", + }, + } + opts.Imports = []ProcedureImport{ + { + Import: "test_jar.jar", + }, + } + opts.Handler = "TestFunc.echoVarchar" + opts.ExternalAccessIntegrations = []AccountObjectIdentifier{ + NewAccountObjectIdentifier("ext_integration"), + } + opts.Secrets = []SecretReference{ + { + VariableName: "variable1", + Name: "name1", + }, + { + VariableName: "variable2", + Name: "name2", + }, + } + opts.TargetPath = String("@~/testfunc.jar") + opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorStrict) + opts.Comment = String("test comment") + opts.ExecuteAs = ExecuteAsPointer(ExecuteAsCaller) + opts.ProcedureDefinition = String("return id + name;") + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE SECURE PROCEDURE %s (id NUMBER(36, 2), name VARCHAR(100) DEFAULT 'test') COPY GRANTS RETURNS TABLE (country_code VARCHAR(100)) LANGUAGE JAVA RUNTIME_VERSION = '1.8' PACKAGES = ('com.snowflake:snowpark:1.2.0') IMPORTS = ('test_jar.jar') HANDLER = 'TestFunc.echoVarchar' EXTERNAL_ACCESS_INTEGRATIONS = ("ext_integration") SECRETS = ('variable1' = name1, 'variable2' = name2) TARGET_PATH = '@~/testfunc.jar' STRICT COMMENT = 'test comment' EXECUTE AS CALLER AS 'return id + name;'`, id.FullyQualifiedName()) + }) } func TestProcedures_CreateForJavaScript(t *testing.T) { @@ -165,11 +213,6 @@ func TestProcedures_CreateForJavaScript(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) - t.Run("validation: options are missing", func(t *testing.T) { - opts := defaultOpts() - assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForJavaScriptProcedureOptions", "ProcedureDefinition")) - }) - // TODO [next PR]: remove with old procedure removal for V1 t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() @@ -191,6 +234,28 @@ func TestProcedures_CreateForJavaScript(t *testing.T) { opts.ProcedureDefinition = "return 1;" assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE SECURE PROCEDURE %s (d DOUBLE DEFAULT 1.0) COPY GRANTS RETURNS DOUBLE NOT NULL LANGUAGE JAVASCRIPT STRICT COMMENT = 'test comment' EXECUTE AS CALLER AS 'return 1;'`, id.FullyQualifiedName()) }) + + t.Run("all options", func(t *testing.T) { + opts := defaultOpts() + opts.OrReplace = Bool(true) + opts.Secure = Bool(true) + opts.Arguments = []ProcedureArgument{ + { + ArgName: "d", + ArgDataType: dataTypeFloat, + DefaultValue: String("1.0"), + }, + } + opts.CopyGrants = Bool(true) + opts.ResultDataType = dataTypeFloat + opts.NotNull = Bool(true) + opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorStrict) + opts.Comment = String("test comment") + opts.ExecuteAs = ExecuteAsPointer(ExecuteAsCaller) + opts.ProcedureDefinition = "return 1;" + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE SECURE PROCEDURE %s (d FLOAT DEFAULT 1.0) COPY GRANTS RETURNS FLOAT NOT NULL LANGUAGE JAVASCRIPT STRICT COMMENT = 'test comment' EXECUTE AS CALLER AS 'return 1;'`, id.FullyQualifiedName()) + }) + } func TestProcedures_CreateForPython(t *testing.T) { @@ -244,17 +309,6 @@ func TestProcedures_CreateForPython(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForPythonProcedureOptions.Returns", "ResultDataType", "Table")) }) - t.Run("validation: options are missing", func(t *testing.T) { - opts := defaultOpts() - opts.Returns = ProcedureReturns{ - ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataTypeOld: DataTypeVARCHAR, - }, - } - assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForPythonProcedureOptions", "Handler")) - assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForPythonProcedureOptions", "RuntimeVersion")) - }) - // TODO [next PR]: remove with old procedure removal for V1 t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() @@ -311,6 +365,62 @@ func TestProcedures_CreateForPython(t *testing.T) { opts.ProcedureDefinition = String("import numpy as np") assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE SECURE PROCEDURE %s (i int DEFAULT 1) COPY GRANTS RETURNS VARIANT NULL LANGUAGE PYTHON RUNTIME_VERSION = '3.8' PACKAGES = ('numpy', 'pandas') IMPORTS = ('numpy', 'pandas') HANDLER = 'udf' EXTERNAL_ACCESS_INTEGRATIONS = ("ext_integration") SECRETS = ('variable1' = name1, 'variable2' = name2) STRICT COMMENT = 'test comment' EXECUTE AS CALLER AS 'import numpy as np'`, id.FullyQualifiedName()) }) + + t.Run("all options", func(t *testing.T) { + opts := defaultOpts() + opts.OrReplace = Bool(true) + opts.Secure = Bool(true) + opts.Arguments = []ProcedureArgument{ + { + ArgName: "i", + ArgDataType: dataTypeNumber, + DefaultValue: String("1"), + }, + } + opts.CopyGrants = Bool(true) + opts.Returns = ProcedureReturns{ + ResultDataType: &ProcedureReturnsResultDataType{ + ResultDataType: dataTypeVariant, + Null: Bool(true), + }, + } + opts.RuntimeVersion = "3.8" + opts.Packages = []ProcedurePackage{ + { + Package: "numpy", + }, + { + Package: "pandas", + }, + } + opts.Imports = []ProcedureImport{ + { + Import: "numpy", + }, + { + Import: "pandas", + }, + } + opts.Handler = "udf" + opts.ExternalAccessIntegrations = []AccountObjectIdentifier{ + NewAccountObjectIdentifier("ext_integration"), + } + opts.Secrets = []SecretReference{ + { + VariableName: "variable1", + Name: "name1", + }, + { + VariableName: "variable2", + Name: "name2", + }, + } + opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorStrict) + opts.Comment = String("test comment") + opts.ExecuteAs = ExecuteAsPointer(ExecuteAsCaller) + opts.ProcedureDefinition = String("import numpy as np") + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE SECURE PROCEDURE %s (i NUMBER(36, 2) DEFAULT 1) COPY GRANTS RETURNS VARIANT NULL LANGUAGE PYTHON RUNTIME_VERSION = '3.8' PACKAGES = ('numpy', 'pandas') IMPORTS = ('numpy', 'pandas') HANDLER = 'udf' EXTERNAL_ACCESS_INTEGRATIONS = ("ext_integration") SECRETS = ('variable1' = name1, 'variable2' = name2) STRICT COMMENT = 'test comment' EXECUTE AS CALLER AS 'import numpy as np'`, id.FullyQualifiedName()) + }) } func TestProcedures_CreateForScala(t *testing.T) { @@ -375,17 +485,6 @@ func TestProcedures_CreateForScala(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, NewError("TARGET_PATH must be nil when AS is nil")) }) - t.Run("validation: options are missing", func(t *testing.T) { - opts := defaultOpts() - opts.Returns = ProcedureReturns{ - ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataTypeOld: DataTypeVARCHAR, - }, - } - assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForScalaProcedureOptions", "Handler")) - assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForScalaProcedureOptions", "RuntimeVersion")) - }) - // TODO [next PR]: remove with old procedure removal for V1 t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() @@ -424,6 +523,44 @@ func TestProcedures_CreateForScala(t *testing.T) { opts.ProcedureDefinition = String("return x") assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE SECURE PROCEDURE %s (x VARCHAR DEFAULT 'test') COPY GRANTS RETURNS VARCHAR NOT NULL LANGUAGE SCALA RUNTIME_VERSION = '2.0' PACKAGES = ('com.snowflake:snowpark:1.2.0') IMPORTS = ('@udf_libs/echohandler.jar') HANDLER = 'Echo.echoVarchar' TARGET_PATH = '@~/testfunc.jar' STRICT COMMENT = 'test comment' EXECUTE AS CALLER AS 'return x'`, id.FullyQualifiedName()) }) + + t.Run("all options", func(t *testing.T) { + opts := defaultOpts() + opts.OrReplace = Bool(true) + opts.Secure = Bool(true) + opts.Arguments = []ProcedureArgument{ + { + ArgName: "x", + ArgDataType: dataTypeVarchar, + DefaultValue: String("'test'"), + }, + } + opts.CopyGrants = Bool(true) + opts.Returns = ProcedureReturns{ + ResultDataType: &ProcedureReturnsResultDataType{ + ResultDataType: dataTypeVarchar, + NotNull: Bool(true), + }, + } + opts.RuntimeVersion = "2.0" + opts.Packages = []ProcedurePackage{ + { + Package: "com.snowflake:snowpark:1.2.0", + }, + } + opts.Imports = []ProcedureImport{ + { + Import: "@udf_libs/echohandler.jar", + }, + } + opts.Handler = "Echo.echoVarchar" + opts.TargetPath = String("@~/testfunc.jar") + opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorStrict) + opts.Comment = String("test comment") + opts.ExecuteAs = ExecuteAsPointer(ExecuteAsCaller) + opts.ProcedureDefinition = String("return x") + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE SECURE PROCEDURE %s (x VARCHAR(100) DEFAULT 'test') COPY GRANTS RETURNS VARCHAR(100) NOT NULL LANGUAGE SCALA RUNTIME_VERSION = '2.0' PACKAGES = ('com.snowflake:snowpark:1.2.0') IMPORTS = ('@udf_libs/echohandler.jar') HANDLER = 'Echo.echoVarchar' TARGET_PATH = '@~/testfunc.jar' STRICT COMMENT = 'test comment' EXECUTE AS CALLER AS 'return x'`, id.FullyQualifiedName()) + }) } func TestProcedures_CreateForSQL(t *testing.T) { @@ -458,11 +595,6 @@ func TestProcedures_CreateForSQL(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) - t.Run("validation: options are missing", func(t *testing.T) { - opts := defaultOpts() - assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForSQLProcedureOptions", "ProcedureDefinition")) - }) - t.Run("create with no arguments", func(t *testing.T) { opts := defaultOpts() opts.Returns = ProcedureSQLReturns{ @@ -505,6 +637,31 @@ func TestProcedures_CreateForSQL(t *testing.T) { opts.ProcedureDefinition = "3.141592654::FLOAT" assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE SECURE PROCEDURE %s (message VARCHAR DEFAULT 'test') COPY GRANTS RETURNS VARCHAR NOT NULL LANGUAGE SQL STRICT COMMENT = 'test comment' EXECUTE AS CALLER AS '3.141592654::FLOAT'`, id.FullyQualifiedName()) }) + + t.Run("all options", func(t *testing.T) { + opts := defaultOpts() + opts.OrReplace = Bool(true) + opts.Secure = Bool(true) + opts.Arguments = []ProcedureArgument{ + { + ArgName: "message", + ArgDataType: dataTypeVarchar, + DefaultValue: String("'test'"), + }, + } + opts.CopyGrants = Bool(true) + opts.Returns = ProcedureSQLReturns{ + ResultDataType: &ProcedureReturnsResultDataType{ + ResultDataType: dataTypeVarchar, + }, + NotNull: Bool(true), + } + opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorStrict) + opts.Comment = String("test comment") + opts.ExecuteAs = ExecuteAsPointer(ExecuteAsCaller) + opts.ProcedureDefinition = "3.141592654::FLOAT" + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE SECURE PROCEDURE %s (message VARCHAR(100) DEFAULT 'test') COPY GRANTS RETURNS VARCHAR(100) NOT NULL LANGUAGE SQL STRICT COMMENT = 'test comment' EXECUTE AS CALLER AS '3.141592654::FLOAT'`, id.FullyQualifiedName()) + }) } func TestProcedures_Drop(t *testing.T) { @@ -859,6 +1016,54 @@ func TestProcedures_CreateAndCallForJava(t *testing.T) { opts.CallArguments = []string{"1", "rnd"} assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE (id NUMBER, name VARCHAR) RETURNS TABLE (country_code VARCHAR) LANGUAGE JAVA RUNTIME_VERSION = '1.8' PACKAGES = ('com.snowflake:snowpark:1.2.0') IMPORTS = ('test_jar.jar') HANDLER = 'TestFunc.echoVarchar' STRICT AS 'return id + name;' , %s (x, y) AS (select m.album_ID, m.album_name, b.band_name from music_albums) CALL %s (1, rnd) INTO :ret`, id.FullyQualifiedName(), cte.FullyQualifiedName(), id.FullyQualifiedName()) }) + + t.Run("all options", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + { + ArgName: "id", + ArgDataType: dataTypeNumber, + }, + { + ArgName: "name", + ArgDataType: dataTypeVarchar, + }, + } + opts.Returns = ProcedureReturns{ + Table: &ProcedureReturnsTable{ + Columns: []ProcedureColumn{ + { + ColumnName: "country_code", + ColumnDataType: dataTypeVarchar, + }, + }, + }, + } + opts.RuntimeVersion = "1.8" + opts.Packages = []ProcedurePackage{ + { + Package: "com.snowflake:snowpark:1.2.0", + }, + } + opts.Imports = []ProcedureImport{ + { + Import: "test_jar.jar", + }, + } + opts.Handler = "TestFunc.echoVarchar" + opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorStrict) + opts.ProcedureDefinition = String("return id + name;") + cte := NewAccountObjectIdentifier("album_info_1976") + opts.WithClause = &ProcedureWithClause{ + CteName: cte, + CteColumns: []string{"x", "y"}, + Statement: "(select m.album_ID, m.album_name, b.band_name from music_albums)", + } + opts.ProcedureName = id + opts.ScriptingVariable = String(":ret") + opts.CallArguments = []string{"1", "rnd"} + assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE (id NUMBER(36, 2), name VARCHAR(100)) RETURNS TABLE (country_code VARCHAR(100)) LANGUAGE JAVA RUNTIME_VERSION = '1.8' PACKAGES = ('com.snowflake:snowpark:1.2.0') IMPORTS = ('test_jar.jar') HANDLER = 'TestFunc.echoVarchar' STRICT AS 'return id + name;' , %s (x, y) AS (select m.album_ID, m.album_name, b.band_name from music_albums) CALL %s (1, rnd) INTO :ret`, id.FullyQualifiedName(), cte.FullyQualifiedName(), id.FullyQualifiedName()) + }) } func TestProcedures_CreateAndCallForScala(t *testing.T) { @@ -983,6 +1188,56 @@ func TestProcedures_CreateAndCallForScala(t *testing.T) { opts.CallArguments = []string{"1", "rnd"} assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE (id NUMBER, name VARCHAR) RETURNS TABLE (country_code VARCHAR) LANGUAGE SCALA RUNTIME_VERSION = '2.12' PACKAGES = ('com.snowflake:snowpark:1.2.0') IMPORTS = ('test_jar.jar') HANDLER = 'TestFunc.echoVarchar' STRICT AS 'return id + name;' , %s (x, y) AS (select m.album_ID, m.album_name, b.band_name from music_albums) CALL %s (1, rnd) INTO :ret`, id.FullyQualifiedName(), cte.FullyQualifiedName(), id.FullyQualifiedName()) }) + + t.Run("all options", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + { + ArgName: "id", + ArgDataType: dataTypeNumber, + }, + { + ArgName: "name", + ArgDataType: dataTypeVarchar, + }, + } + opts.Returns = ProcedureReturns{ + Table: &ProcedureReturnsTable{ + Columns: []ProcedureColumn{ + { + ColumnName: "country_code", + ColumnDataType: dataTypeVarchar, + }, + }, + }, + } + opts.RuntimeVersion = "2.12" + opts.Packages = []ProcedurePackage{ + { + Package: "com.snowflake:snowpark:1.2.0", + }, + } + opts.Imports = []ProcedureImport{ + { + Import: "test_jar.jar", + }, + } + opts.Handler = "TestFunc.echoVarchar" + opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorStrict) + opts.ProcedureDefinition = String("return id + name;") + cte := NewAccountObjectIdentifier("album_info_1976") + opts.WithClauses = []ProcedureWithClause{ + { + CteName: cte, + CteColumns: []string{"x", "y"}, + Statement: "(select m.album_ID, m.album_name, b.band_name from music_albums)", + }, + } + opts.ProcedureName = id + opts.ScriptingVariable = String(":ret") + opts.CallArguments = []string{"1", "rnd"} + assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE (id NUMBER(36, 2), name VARCHAR(100)) RETURNS TABLE (country_code VARCHAR(100)) LANGUAGE SCALA RUNTIME_VERSION = '2.12' PACKAGES = ('com.snowflake:snowpark:1.2.0') IMPORTS = ('test_jar.jar') HANDLER = 'TestFunc.echoVarchar' STRICT AS 'return id + name;' , %s (x, y) AS (select m.album_ID, m.album_name, b.band_name from music_albums) CALL %s (1, rnd) INTO :ret`, id.FullyQualifiedName(), cte.FullyQualifiedName(), id.FullyQualifiedName()) + }) } func TestProcedures_CreateAndCallForPython(t *testing.T) { @@ -1106,6 +1361,55 @@ func TestProcedures_CreateAndCallForPython(t *testing.T) { opts.CallArguments = []string{"1"} assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE (i int DEFAULT 1) RETURNS VARIANT NULL LANGUAGE PYTHON RUNTIME_VERSION = '3.8' PACKAGES = ('numpy', 'pandas') IMPORTS = ('numpy', 'pandas') HANDLER = 'udf' STRICT AS 'import numpy as np' , %s (x, y) AS (select m.album_ID, m.album_name, b.band_name from music_albums) CALL %s (1) INTO :ret`, id.FullyQualifiedName(), cte.FullyQualifiedName(), id.FullyQualifiedName()) }) + + t.Run("all options", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + { + ArgName: "i", + ArgDataType: dataTypeNumber, + DefaultValue: String("1"), + }, + } + opts.Returns = ProcedureReturns{ + ResultDataType: &ProcedureReturnsResultDataType{ + ResultDataType: dataTypeVariant, + Null: Bool(true), + }, + } + opts.RuntimeVersion = "3.8" + opts.Packages = []ProcedurePackage{ + { + Package: "numpy", + }, + { + Package: "pandas", + }, + } + opts.Imports = []ProcedureImport{ + { + Import: "numpy", + }, + { + Import: "pandas", + }, + } + opts.Handler = "udf" + opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorStrict) + opts.ProcedureDefinition = String("import numpy as np") + cte := NewAccountObjectIdentifier("album_info_1976") + opts.WithClauses = []ProcedureWithClause{ + { + CteName: cte, + CteColumns: []string{"x", "y"}, + Statement: "(select m.album_ID, m.album_name, b.band_name from music_albums)", + }, + } + opts.ProcedureName = id + opts.ScriptingVariable = String(":ret") + opts.CallArguments = []string{"1"} + assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE (i NUMBER(36, 2) DEFAULT 1) RETURNS VARIANT NULL LANGUAGE PYTHON RUNTIME_VERSION = '3.8' PACKAGES = ('numpy', 'pandas') IMPORTS = ('numpy', 'pandas') HANDLER = 'udf' STRICT AS 'import numpy as np' , %s (x, y) AS (select m.album_ID, m.album_name, b.band_name from music_albums) CALL %s (1) INTO :ret`, id.FullyQualifiedName(), cte.FullyQualifiedName(), id.FullyQualifiedName()) + }) } func TestProcedures_CreateAndCallForJavaScript(t *testing.T) { @@ -1169,6 +1473,33 @@ func TestProcedures_CreateAndCallForJavaScript(t *testing.T) { opts.CallArguments = []string{"1"} assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE (d DOUBLE DEFAULT 1.0) RETURNS DOUBLE NOT NULL LANGUAGE JAVASCRIPT STRICT AS 'return 1;' , %s (x, y) AS (select m.album_ID, m.album_name, b.band_name from music_albums) CALL %s (1) INTO :ret`, id.FullyQualifiedName(), cte.FullyQualifiedName(), id.FullyQualifiedName()) }) + + t.Run("all options", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + { + ArgName: "d", + ArgDataType: dataTypeFloat, + DefaultValue: String("1.0"), + }, + } + opts.ResultDataType = dataTypeFloat + opts.NotNull = Bool(true) + opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorStrict) + opts.ProcedureDefinition = "return 1;" + cte := NewAccountObjectIdentifier("album_info_1976") + opts.WithClauses = []ProcedureWithClause{ + { + CteName: cte, + CteColumns: []string{"x", "y"}, + Statement: "(select m.album_ID, m.album_name, b.band_name from music_albums)", + }, + } + opts.ProcedureName = id + opts.ScriptingVariable = String(":ret") + opts.CallArguments = []string{"1"} + assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE (d FLOAT DEFAULT 1.0) RETURNS FLOAT NOT NULL LANGUAGE JAVASCRIPT STRICT AS 'return 1;' , %s (x, y) AS (select m.album_ID, m.album_name, b.band_name from music_albums) CALL %s (1) INTO :ret`, id.FullyQualifiedName(), cte.FullyQualifiedName(), id.FullyQualifiedName()) + }) } func TestProcedures_CreateAndCallForSQL(t *testing.T) { @@ -1261,4 +1592,34 @@ func TestProcedures_CreateAndCallForSQL(t *testing.T) { opts.CallArguments = []string{"1"} assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE (message VARCHAR DEFAULT 'test') RETURNS FLOAT LANGUAGE SQL STRICT AS '3.141592654::FLOAT' , %s (x, y) AS (select m.album_ID, m.album_name, b.band_name from music_albums) CALL %s (1) INTO :ret`, id.FullyQualifiedName(), cte.FullyQualifiedName(), id.FullyQualifiedName()) }) + + t.Run("all options", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + { + ArgName: "message", + ArgDataType: dataTypeVarchar, + DefaultValue: String("'test'"), + }, + } + opts.Returns = ProcedureReturns{ + ResultDataType: &ProcedureReturnsResultDataType{ + ResultDataType: dataTypeFloat, + }, + } + opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorStrict) + opts.ProcedureDefinition = "3.141592654::FLOAT" + cte := NewAccountObjectIdentifier("album_info_1976") + opts.WithClauses = []ProcedureWithClause{ + { + CteName: cte, + CteColumns: []string{"x", "y"}, + Statement: "(select m.album_ID, m.album_name, b.band_name from music_albums)", + }, + } + opts.ProcedureName = id + opts.ScriptingVariable = String(":ret") + opts.CallArguments = []string{"1"} + assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE (message VARCHAR(100) DEFAULT 'test') RETURNS FLOAT LANGUAGE SQL STRICT AS '3.141592654::FLOAT' , %s (x, y) AS (select m.album_ID, m.album_name, b.band_name from music_albums) CALL %s (1) INTO :ret`, id.FullyQualifiedName(), cte.FullyQualifiedName(), id.FullyQualifiedName()) + }) } From da1397656d4f1f0b46b938af43e7d37630d43317 Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Wed, 4 Dec 2024 18:17:03 +0100 Subject: [PATCH 20/29] Update unit tests for procedures --- pkg/sdk/procedures_gen_test.go | 38 +++++++++++++++++----------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/pkg/sdk/procedures_gen_test.go b/pkg/sdk/procedures_gen_test.go index ceac49a462..212f922427 100644 --- a/pkg/sdk/procedures_gen_test.go +++ b/pkg/sdk/procedures_gen_test.go @@ -572,7 +572,7 @@ func TestProcedures_CreateForSQL(t *testing.T) { ProcedureDefinition: "3.141592654::FLOAT", Returns: ProcedureSQLReturns{ ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataTypeOld: "VARCHAR", + ResultDataType: dataTypeVarchar, }, }, } @@ -599,7 +599,7 @@ func TestProcedures_CreateForSQL(t *testing.T) { opts := defaultOpts() opts.Returns = ProcedureSQLReturns{ ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataTypeOld: DataTypeFloat, + ResultDataType: dataTypeFloat, }, } opts.ProcedureDefinition = "3.141592654::FLOAT" @@ -928,13 +928,13 @@ func TestProcedures_CreateAndCallForJava(t *testing.T) { Table: &ProcedureReturnsTable{ Columns: []ProcedureColumn{ { - ColumnName: "name", - ColumnDataTypeOld: DataTypeVARCHAR, + ColumnName: "name", + ColumnDataType: dataTypeVarchar, }, }, }, ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataTypeOld: DataTypeFloat, + ResultDataType: dataTypeFloat, }, } assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForJavaProcedureOptions.Returns", "ResultDataType", "Table")) @@ -944,7 +944,7 @@ func TestProcedures_CreateAndCallForJava(t *testing.T) { opts := defaultOpts() opts.Returns = ProcedureReturns{ ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataTypeOld: DataTypeVARCHAR, + ResultDataType: dataTypeVarchar, }, } assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateAndCallForJavaProcedureOptions", "Handler")) @@ -1098,13 +1098,13 @@ func TestProcedures_CreateAndCallForScala(t *testing.T) { Table: &ProcedureReturnsTable{ Columns: []ProcedureColumn{ { - ColumnName: "name", - ColumnDataTypeOld: DataTypeVARCHAR, + ColumnName: "name", + ColumnDataType: dataTypeVarchar, }, }, }, ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataTypeOld: DataTypeFloat, + ResultDataType: dataTypeFloat, }, } assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForScalaProcedureOptions.Returns", "ResultDataType", "Table")) @@ -1114,7 +1114,7 @@ func TestProcedures_CreateAndCallForScala(t *testing.T) { opts := defaultOpts() opts.Returns = ProcedureReturns{ ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataTypeOld: DataTypeVARCHAR, + ResultDataType: dataTypeVarchar, }, } assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateAndCallForScalaProcedureOptions", "Handler")) @@ -1272,13 +1272,13 @@ func TestProcedures_CreateAndCallForPython(t *testing.T) { Table: &ProcedureReturnsTable{ Columns: []ProcedureColumn{ { - ColumnName: "name", - ColumnDataTypeOld: DataTypeVARCHAR, + ColumnName: "name", + ColumnDataType: dataTypeVarchar, }, }, }, ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataTypeOld: DataTypeFloat, + ResultDataType: dataTypeFloat, }, } assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForPythonProcedureOptions.Returns", "ResultDataType", "Table")) @@ -1288,7 +1288,7 @@ func TestProcedures_CreateAndCallForPython(t *testing.T) { opts := defaultOpts() opts.Returns = ProcedureReturns{ ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataTypeOld: DataTypeVARCHAR, + ResultDataType: dataTypeVarchar, }, } assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateAndCallForPythonProcedureOptions", "Handler")) @@ -1440,10 +1440,10 @@ func TestProcedures_CreateAndCallForJavaScript(t *testing.T) { t.Run("no arguments", func(t *testing.T) { opts := defaultOpts() - opts.ResultDataTypeOld = "DOUBLE" + opts.ResultDataType = dataTypeFloat opts.ProcedureDefinition = "return 1;" opts.ProcedureName = id - assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE () RETURNS DOUBLE LANGUAGE JAVASCRIPT AS 'return 1;' CALL %s ()`, id.FullyQualifiedName(), id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE () RETURNS FLOAT LANGUAGE JAVASCRIPT AS 'return 1;' CALL %s ()`, id.FullyQualifiedName(), id.FullyQualifiedName()) }) // TODO [next PR]: remove with old procedure removal for V1 @@ -1534,13 +1534,13 @@ func TestProcedures_CreateAndCallForSQL(t *testing.T) { Table: &ProcedureReturnsTable{ Columns: []ProcedureColumn{ { - ColumnName: "name", - ColumnDataTypeOld: DataTypeVARCHAR, + ColumnName: "name", + ColumnDataType: dataTypeVarchar, }, }, }, ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataTypeOld: DataTypeFloat, + ResultDataType: dataTypeFloat, }, } assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns", "ResultDataType", "Table")) From 11175b21d3b0ffdded24e5a76ef4ae712a42cc09 Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Wed, 4 Dec 2024 18:31:37 +0100 Subject: [PATCH 21/29] Make procedure integration tests pass --- pkg/acceptance/helpers/procedure_client.go | 4 +- pkg/resources/procedure.go | 10 +- .../testint/procedures_integration_test.go | 145 +++++++++--------- 3 files changed, 81 insertions(+), 78 deletions(-) diff --git a/pkg/acceptance/helpers/procedure_client.go b/pkg/acceptance/helpers/procedure_client.go index e9a4375f2d..34aec170f7 100644 --- a/pkg/acceptance/helpers/procedure_client.go +++ b/pkg/acceptance/helpers/procedure_client.go @@ -34,12 +34,12 @@ func (c *ProcedureClient) CreateWithIdentifier(t *testing.T, id sdk.SchemaObject ctx := context.Background() argumentRequests := make([]sdk.ProcedureArgumentRequest, len(id.ArgumentDataTypes())) for i, argumentDataType := range id.ArgumentDataTypes() { - argumentRequests[i] = *sdk.NewProcedureArgumentRequest(c.ids.Alpha(), argumentDataType) + argumentRequests[i] = *sdk.NewProcedureArgumentRequest(c.ids.Alpha(), nil).WithArgDataTypeOld(argumentDataType) } err := c.client().CreateForSQL(ctx, sdk.NewCreateForSQLProcedureRequest( id.SchemaObjectId(), - *sdk.NewProcedureSQLReturnsRequest().WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeInt)), + *sdk.NewProcedureSQLReturnsRequest().WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeInt)), `BEGIN RETURN 1; END`).WithArguments(argumentRequests), ) require.NoError(t, err) diff --git a/pkg/resources/procedure.go b/pkg/resources/procedure.go index c16f19ceef..c1e7a95a5b 100644 --- a/pkg/resources/procedure.go +++ b/pkg/resources/procedure.go @@ -322,7 +322,7 @@ func createJavaScriptProcedure(ctx context.Context, d *schema.ResourceData, meta return diags } procedureDefinition := d.Get("statement").(string) - req := sdk.NewCreateForJavaScriptProcedureRequest(id.SchemaObjectId(), sdk.LegacyDataTypeFrom(returnDataType), procedureDefinition) + req := sdk.NewCreateForJavaScriptProcedureRequest(id.SchemaObjectId(), nil, procedureDefinition).WithResultDataTypeOld(sdk.LegacyDataTypeFrom(returnDataType)) if len(args) > 0 { req.WithArguments(args) } @@ -777,7 +777,7 @@ func parseProcedureReturnsRequest(s string) (*sdk.ProcedureReturnsRequest, diag. } var cr []sdk.ProcedureColumnRequest for _, item := range columns { - cr = append(cr, *sdk.NewProcedureColumnRequest(item.ColumnName, item.ColumnDataTypeOld)) + cr = append(cr, *sdk.NewProcedureColumnRequest(item.ColumnName, nil).WithColumnDataTypeOld(item.ColumnDataTypeOld)) } returns.WithTable(*sdk.NewProcedureReturnsTableRequest().WithColumns(cr)) } else { @@ -785,7 +785,7 @@ func parseProcedureReturnsRequest(s string) (*sdk.ProcedureReturnsRequest, diag. if diags != nil { return nil, diags } - returns.WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(sdk.LegacyDataTypeFrom(returnDataType))) + returns.WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.LegacyDataTypeFrom(returnDataType))) } return returns, nil } @@ -799,7 +799,7 @@ func parseProcedureSQLReturnsRequest(s string) (*sdk.ProcedureSQLReturnsRequest, } var cr []sdk.ProcedureColumnRequest for _, item := range columns { - cr = append(cr, *sdk.NewProcedureColumnRequest(item.ColumnName, item.ColumnDataTypeOld)) + cr = append(cr, *sdk.NewProcedureColumnRequest(item.ColumnName, nil).WithColumnDataTypeOld(item.ColumnDataTypeOld)) } returns.WithTable(*sdk.NewProcedureReturnsTableRequest().WithColumns(cr)) } else { @@ -807,7 +807,7 @@ func parseProcedureSQLReturnsRequest(s string) (*sdk.ProcedureSQLReturnsRequest, if diags != nil { return nil, diags } - returns.WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(sdk.LegacyDataTypeFrom(returnDataType))) + returns.WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.LegacyDataTypeFrom(returnDataType))) } return returns, nil } diff --git a/pkg/sdk/testint/procedures_integration_test.go b/pkg/sdk/testint/procedures_integration_test.go index 309a3db9f9..6ec16c2983 100644 --- a/pkg/sdk/testint/procedures_integration_test.go +++ b/pkg/sdk/testint/procedures_integration_test.go @@ -46,9 +46,9 @@ func TestInt_CreateProcedures(t *testing.T) { } }` - dt := sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeVARCHAR) + dt := sdk.NewProcedureReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeVARCHAR) returns := sdk.NewProcedureReturnsRequest().WithResultDataType(*dt) - argument := sdk.NewProcedureArgumentRequest("input", sdk.DataTypeVARCHAR) + argument := sdk.NewProcedureArgumentRequest("input", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("com.snowflake:snowpark:latest")} request := sdk.NewCreateForJavaProcedureRequest(id.SchemaObjectId(), *returns, "11", packages, "FileReader.execute"). WithOrReplace(true). @@ -77,13 +77,13 @@ func TestInt_CreateProcedures(t *testing.T) { return filteredRows; } }` - column1 := sdk.NewProcedureColumnRequest("id", sdk.DataTypeNumber) - column2 := sdk.NewProcedureColumnRequest("name", sdk.DataTypeVARCHAR) - column3 := sdk.NewProcedureColumnRequest("role", sdk.DataTypeVARCHAR) + column1 := sdk.NewProcedureColumnRequest("id", nil).WithColumnDataTypeOld(sdk.DataTypeNumber) + column2 := sdk.NewProcedureColumnRequest("name", nil).WithColumnDataTypeOld(sdk.DataTypeVARCHAR) + column3 := sdk.NewProcedureColumnRequest("role", nil).WithColumnDataTypeOld(sdk.DataTypeVARCHAR) returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2, *column3}) returns := sdk.NewProcedureReturnsRequest().WithTable(*returnsTable) - arg1 := sdk.NewProcedureArgumentRequest("table_name", sdk.DataTypeVARCHAR) - arg2 := sdk.NewProcedureArgumentRequest("role", sdk.DataTypeVARCHAR) + arg1 := sdk.NewProcedureArgumentRequest("table_name", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) + arg2 := sdk.NewProcedureArgumentRequest("role", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("com.snowflake:snowpark:latest")} request := sdk.NewCreateForJavaProcedureRequest(id.SchemaObjectId(), *returns, "11", packages, "Filter.filterByRole"). WithOrReplace(true). @@ -114,8 +114,9 @@ func TestInt_CreateProcedures(t *testing.T) { catch (err) { return "Failed: " + err; // Return a success/error indicator. }` - argument := sdk.NewProcedureArgumentRequest("FLOAT_PARAM1", sdk.DataTypeFloat) - request := sdk.NewCreateForJavaScriptProcedureRequest(id.SchemaObjectId(), sdk.DataTypeString, definition). + argument := sdk.NewProcedureArgumentRequest("FLOAT_PARAM1", nil).WithArgDataTypeOld(sdk.DataTypeFloat) + request := sdk.NewCreateForJavaScriptProcedureRequest(id.SchemaObjectId(), nil, definition). + WithResultDataTypeOld(sdk.DataTypeString). WithArguments([]sdk.ProcedureArgumentRequest{*argument}). WithNullInputBehavior(*sdk.NullInputBehaviorPointer(sdk.NullInputBehaviorStrict)). WithExecuteAs(*sdk.ExecuteAsPointer(sdk.ExecuteAsCaller)) @@ -134,7 +135,7 @@ func TestInt_CreateProcedures(t *testing.T) { id := testClientHelper().Ids.NewSchemaObjectIdentifierWithArguments(name) definition := `return 3.1415926;` - request := sdk.NewCreateForJavaScriptProcedureRequest(id.SchemaObjectId(), sdk.DataTypeFloat, definition).WithNotNull(true).WithOrReplace(true) + request := sdk.NewCreateForJavaScriptProcedureRequest(id.SchemaObjectId(), nil, definition).WithResultDataTypeOld(sdk.DataTypeFloat).WithNotNull(true).WithOrReplace(true) err := client.Procedures.CreateForJavaScript(ctx, request) require.NoError(t, err) t.Cleanup(cleanupProcedureHandle(id)) @@ -160,9 +161,9 @@ func TestInt_CreateProcedures(t *testing.T) { return new String(input.readAllBytes(), StandardCharsets.UTF_8) } }` - dt := sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeVARCHAR) + dt := sdk.NewProcedureReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeVARCHAR) returns := sdk.NewProcedureReturnsRequest().WithResultDataType(*dt) - argument := sdk.NewProcedureArgumentRequest("input", sdk.DataTypeVARCHAR) + argument := sdk.NewProcedureArgumentRequest("input", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("com.snowflake:snowpark:latest")} request := sdk.NewCreateForScalaProcedureRequest(id.SchemaObjectId(), *returns, "2.12", packages, "FileReader.execute"). WithOrReplace(true). @@ -192,13 +193,13 @@ func TestInt_CreateProcedures(t *testing.T) { return filteredRows } }` - column1 := sdk.NewProcedureColumnRequest("id", sdk.DataTypeNumber) - column2 := sdk.NewProcedureColumnRequest("name", sdk.DataTypeVARCHAR) - column3 := sdk.NewProcedureColumnRequest("role", sdk.DataTypeVARCHAR) + column1 := sdk.NewProcedureColumnRequest("id", nil).WithColumnDataTypeOld(sdk.DataTypeNumber) + column2 := sdk.NewProcedureColumnRequest("name", nil).WithColumnDataTypeOld(sdk.DataTypeVARCHAR) + column3 := sdk.NewProcedureColumnRequest("role", nil).WithColumnDataTypeOld(sdk.DataTypeVARCHAR) returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2, *column3}) returns := sdk.NewProcedureReturnsRequest().WithTable(*returnsTable) - arg1 := sdk.NewProcedureArgumentRequest("table_name", sdk.DataTypeVARCHAR) - arg2 := sdk.NewProcedureArgumentRequest("role", sdk.DataTypeVARCHAR) + arg1 := sdk.NewProcedureArgumentRequest("table_name", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) + arg2 := sdk.NewProcedureArgumentRequest("role", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("com.snowflake:snowpark:latest")} request := sdk.NewCreateForScalaProcedureRequest(id.SchemaObjectId(), *returns, "2.12", packages, "Filter.filterByRole"). WithOrReplace(true). @@ -225,9 +226,9 @@ def joblib_multiprocessing(session, i): result = joblib.Parallel(n_jobs=-1)(joblib.delayed(sqrt)(i ** 2) for i in range(10)) return str(result)` - dt := sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeString) + dt := sdk.NewProcedureReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeString) returns := sdk.NewProcedureReturnsRequest().WithResultDataType(*dt) - argument := sdk.NewProcedureArgumentRequest("i", sdk.DataTypeInt) + argument := sdk.NewProcedureArgumentRequest("i", nil).WithArgDataTypeOld(sdk.DataTypeInt) packages := []sdk.ProcedurePackageRequest{ *sdk.NewProcedurePackageRequest("snowflake-snowpark-python"), *sdk.NewProcedurePackageRequest("joblib"), @@ -255,13 +256,13 @@ from snowflake.snowpark.functions import col def filter_by_role(session, table_name, role): df = session.table(table_name) return df.filter(col("role") == role)` - column1 := sdk.NewProcedureColumnRequest("id", sdk.DataTypeNumber) - column2 := sdk.NewProcedureColumnRequest("name", sdk.DataTypeVARCHAR) - column3 := sdk.NewProcedureColumnRequest("role", sdk.DataTypeVARCHAR) + column1 := sdk.NewProcedureColumnRequest("id", nil).WithColumnDataTypeOld(sdk.DataTypeNumber) + column2 := sdk.NewProcedureColumnRequest("name", nil).WithColumnDataTypeOld(sdk.DataTypeVARCHAR) + column3 := sdk.NewProcedureColumnRequest("role", nil).WithColumnDataTypeOld(sdk.DataTypeVARCHAR) returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2, *column3}) returns := sdk.NewProcedureReturnsRequest().WithTable(*returnsTable) - arg1 := sdk.NewProcedureArgumentRequest("table_name", sdk.DataTypeVARCHAR) - arg2 := sdk.NewProcedureArgumentRequest("role", sdk.DataTypeVARCHAR) + arg1 := sdk.NewProcedureArgumentRequest("table_name", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) + arg2 := sdk.NewProcedureArgumentRequest("role", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("snowflake-snowpark-python")} request := sdk.NewCreateForPythonProcedureRequest(id.SchemaObjectId(), *returns, "3.8", packages, "filter_by_role"). WithOrReplace(true). @@ -286,9 +287,9 @@ def filter_by_role(session, table_name, role): RETURN message; END;` - dt := sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeVARCHAR) + dt := sdk.NewProcedureReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeVARCHAR) returns := sdk.NewProcedureSQLReturnsRequest().WithResultDataType(*dt).WithNotNull(true) - argument := sdk.NewProcedureArgumentRequest("message", sdk.DataTypeVARCHAR) + argument := sdk.NewProcedureArgumentRequest("message", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) request := sdk.NewCreateForSQLProcedureRequest(id.SchemaObjectId(), *returns, definition). WithOrReplace(true). // Suddenly this is erroring out, when it used to not have an problem. Must be an error with the Snowflake API. @@ -318,11 +319,11 @@ def filter_by_role(session, table_name, role): BEGIN RETURN TABLE(res); END;` - column1 := sdk.NewProcedureColumnRequest("id", "INTEGER") - column2 := sdk.NewProcedureColumnRequest("price", "NUMBER(12,2)") + column1 := sdk.NewProcedureColumnRequest("id", nil).WithColumnDataTypeOld("INTEGER") + column2 := sdk.NewProcedureColumnRequest("price", nil).WithColumnDataTypeOld("NUMBER(12,2)") returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2}) returns := sdk.NewProcedureSQLReturnsRequest().WithTable(*returnsTable) - argument := sdk.NewProcedureArgumentRequest("id", sdk.DataTypeVARCHAR) + argument := sdk.NewProcedureArgumentRequest("id", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) request := sdk.NewCreateForSQLProcedureRequest(id.SchemaObjectId(), *returns, definition). WithOrReplace(true). // SNOW-1051627 todo: uncomment once null input behavior working again @@ -382,9 +383,9 @@ func TestInt_OtherProcedureFunctions(t *testing.T) { RETURN message; END;` id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(sdk.DataTypeVARCHAR) - dt := sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeVARCHAR) + dt := sdk.NewProcedureReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeVARCHAR) returns := sdk.NewProcedureSQLReturnsRequest().WithResultDataType(*dt).WithNotNull(true) - argument := sdk.NewProcedureArgumentRequest("message", sdk.DataTypeVARCHAR) + argument := sdk.NewProcedureArgumentRequest("message", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) request := sdk.NewCreateForSQLProcedureRequest(id.SchemaObjectId(), *returns, definition). WithSecure(true). WithOrReplace(true). @@ -520,9 +521,9 @@ func TestInt_OtherProcedureFunctions(t *testing.T) { RETURN message; END;` id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(sdk.DataTypeVARCHAR) - dt := sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeVARCHAR) + dt := sdk.NewProcedureReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeVARCHAR) returns := sdk.NewProcedureSQLReturnsRequest().WithResultDataType(*dt).WithNotNull(true) - argument := sdk.NewProcedureArgumentRequest("message", sdk.DataTypeVARCHAR) + argument := sdk.NewProcedureArgumentRequest("message", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) request := sdk.NewCreateForSQLProcedureRequest(id.SchemaObjectId(), *returns, definition). WithOrReplace(true). WithArguments([]sdk.ProcedureArgumentRequest{*argument}). @@ -575,9 +576,9 @@ func TestInt_CallProcedure(t *testing.T) { RETURN message; END;` id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(sdk.DataTypeVARCHAR) - dt := sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeVARCHAR) + dt := sdk.NewProcedureReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeVARCHAR) returns := sdk.NewProcedureSQLReturnsRequest().WithResultDataType(*dt).WithNotNull(true) - argument := sdk.NewProcedureArgumentRequest("message", sdk.DataTypeVARCHAR) + argument := sdk.NewProcedureArgumentRequest("message", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) request := sdk.NewCreateForSQLProcedureRequest(id.SchemaObjectId(), *returns, definition). WithSecure(true). WithOrReplace(true). @@ -619,13 +620,13 @@ func TestInt_CallProcedure(t *testing.T) { return filteredRows; } }` - column1 := sdk.NewProcedureColumnRequest("id", sdk.DataTypeNumber) - column2 := sdk.NewProcedureColumnRequest("name", sdk.DataTypeVARCHAR) - column3 := sdk.NewProcedureColumnRequest("role", sdk.DataTypeVARCHAR) + column1 := sdk.NewProcedureColumnRequest("id", nil).WithColumnDataTypeOld(sdk.DataTypeNumber) + column2 := sdk.NewProcedureColumnRequest("name", nil).WithColumnDataTypeOld(sdk.DataTypeVARCHAR) + column3 := sdk.NewProcedureColumnRequest("role", nil).WithColumnDataTypeOld(sdk.DataTypeVARCHAR) returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2, *column3}) returns := sdk.NewProcedureReturnsRequest().WithTable(*returnsTable) - arg1 := sdk.NewProcedureArgumentRequest("name", sdk.DataTypeVARCHAR) - arg2 := sdk.NewProcedureArgumentRequest("role", sdk.DataTypeVARCHAR) + arg1 := sdk.NewProcedureArgumentRequest("name", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) + arg2 := sdk.NewProcedureArgumentRequest("role", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("com.snowflake:snowpark:latest")} request := sdk.NewCreateForJavaProcedureRequest(id.SchemaObjectId(), *returns, "11", packages, "Filter.filterByRole"). WithOrReplace(true). @@ -658,8 +659,8 @@ func TestInt_CallProcedure(t *testing.T) { }` returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{}) returns := sdk.NewProcedureReturnsRequest().WithTable(*returnsTable) - arg1 := sdk.NewProcedureArgumentRequest("name", sdk.DataTypeVARCHAR) - arg2 := sdk.NewProcedureArgumentRequest("role", sdk.DataTypeVARCHAR) + arg1 := sdk.NewProcedureArgumentRequest("name", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) + arg2 := sdk.NewProcedureArgumentRequest("role", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("com.snowflake:snowpark:latest")} request := sdk.NewCreateForScalaProcedureRequest(id.SchemaObjectId(), *returns, "2.12", packages, "Filter.filterByRole"). WithOrReplace(true). @@ -690,8 +691,9 @@ func TestInt_CallProcedure(t *testing.T) { catch (err) { return "Failed: " + err; // Return a success/error indicator. }` - arg := sdk.NewProcedureArgumentRequest("FLOAT_PARAM1", sdk.DataTypeFloat) - request := sdk.NewCreateForJavaScriptProcedureRequest(id.SchemaObjectId(), sdk.DataTypeString, definition). + arg := sdk.NewProcedureArgumentRequest("FLOAT_PARAM1", nil).WithArgDataTypeOld(sdk.DataTypeFloat) + request := sdk.NewCreateForJavaScriptProcedureRequest(id.SchemaObjectId(), nil, definition). + WithResultDataTypeOld(sdk.DataTypeString). WithOrReplace(true). WithArguments([]sdk.ProcedureArgumentRequest{*arg}). WithNullInputBehavior(*sdk.NullInputBehaviorPointer(sdk.NullInputBehaviorStrict)). @@ -710,7 +712,7 @@ func TestInt_CallProcedure(t *testing.T) { id := sdk.NewSchemaObjectIdentifierWithArguments(databaseId.Name(), schemaId.Name(), name) definition := `return 3.1415926;` - request := sdk.NewCreateForJavaScriptProcedureRequest(id.SchemaObjectId(), sdk.DataTypeFloat, definition).WithNotNull(true).WithOrReplace(true) + request := sdk.NewCreateForJavaScriptProcedureRequest(id.SchemaObjectId(), nil, definition).WithResultDataTypeOld(sdk.DataTypeFloat).WithNotNull(true).WithOrReplace(true) err := client.Procedures.CreateForJavaScript(ctx, request) require.NoError(t, err) t.Cleanup(cleanupProcedureHandle(id)) @@ -730,8 +732,8 @@ def filter_by_role(session, name, role): return df.filter(col("role") == role)` returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{}) returns := sdk.NewProcedureReturnsRequest().WithTable(*returnsTable) - arg1 := sdk.NewProcedureArgumentRequest("name", sdk.DataTypeVARCHAR) - arg2 := sdk.NewProcedureArgumentRequest("role", sdk.DataTypeVARCHAR) + arg1 := sdk.NewProcedureArgumentRequest("name", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) + arg2 := sdk.NewProcedureArgumentRequest("role", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("snowflake-snowpark-python")} request := sdk.NewCreateForPythonProcedureRequest(id.SchemaObjectId(), *returns, "3.8", packages, "filter_by_role"). WithOrReplace(true). @@ -783,13 +785,13 @@ func TestInt_CreateAndCallProcedures(t *testing.T) { return filteredRows; } }` - column1 := sdk.NewProcedureColumnRequest("id", sdk.DataTypeNumber) - column2 := sdk.NewProcedureColumnRequest("name", sdk.DataTypeVARCHAR) - column3 := sdk.NewProcedureColumnRequest("role", sdk.DataTypeVARCHAR) + column1 := sdk.NewProcedureColumnRequest("id", nil).WithColumnDataTypeOld(sdk.DataTypeNumber) + column2 := sdk.NewProcedureColumnRequest("name", nil).WithColumnDataTypeOld(sdk.DataTypeVARCHAR) + column3 := sdk.NewProcedureColumnRequest("role", nil).WithColumnDataTypeOld(sdk.DataTypeVARCHAR) returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2, *column3}) returns := sdk.NewProcedureReturnsRequest().WithTable(*returnsTable) - arg1 := sdk.NewProcedureArgumentRequest("name", sdk.DataTypeVARCHAR) - arg2 := sdk.NewProcedureArgumentRequest("role", sdk.DataTypeVARCHAR) + arg1 := sdk.NewProcedureArgumentRequest("name", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) + arg2 := sdk.NewProcedureArgumentRequest("role", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("com.snowflake:snowpark:latest")} ca := []string{fmt.Sprintf(`'%s'`, tid.FullyQualifiedName()), "'dev'"} request := sdk.NewCreateAndCallForJavaProcedureRequest(name, *returns, "11", packages, "Filter.filterByRole", name). @@ -816,13 +818,13 @@ func TestInt_CreateAndCallProcedures(t *testing.T) { return filteredRows } }` - column1 := sdk.NewProcedureColumnRequest("id", sdk.DataTypeNumber) - column2 := sdk.NewProcedureColumnRequest("name", sdk.DataTypeVARCHAR) - column3 := sdk.NewProcedureColumnRequest("role", sdk.DataTypeVARCHAR) + column1 := sdk.NewProcedureColumnRequest("id", nil).WithColumnDataTypeOld(sdk.DataTypeNumber) + column2 := sdk.NewProcedureColumnRequest("name", nil).WithColumnDataTypeOld(sdk.DataTypeVARCHAR) + column3 := sdk.NewProcedureColumnRequest("role", nil).WithColumnDataTypeOld(sdk.DataTypeVARCHAR) returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2, *column3}) returns := sdk.NewProcedureReturnsRequest().WithTable(*returnsTable) - arg1 := sdk.NewProcedureArgumentRequest("name", sdk.DataTypeVARCHAR) - arg2 := sdk.NewProcedureArgumentRequest("role", sdk.DataTypeVARCHAR) + arg1 := sdk.NewProcedureArgumentRequest("name", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) + arg2 := sdk.NewProcedureArgumentRequest("role", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("com.snowflake:snowpark:latest")} ca := []string{fmt.Sprintf(`'%s'`, tid.FullyQualifiedName()), "'dev'"} request := sdk.NewCreateAndCallForScalaProcedureRequest(name, *returns, "2.12", packages, "Filter.filterByRole", name). @@ -849,8 +851,9 @@ func TestInt_CreateAndCallProcedures(t *testing.T) { catch (err) { return "Failed: " + err; // Return a success/error indicator. }` - arg := sdk.NewProcedureArgumentRequest("FLOAT_PARAM1", sdk.DataTypeFloat) - request := sdk.NewCreateAndCallForJavaScriptProcedureRequest(name, sdk.DataTypeString, definition, name). + arg := sdk.NewProcedureArgumentRequest("FLOAT_PARAM1", nil).WithArgDataTypeOld(sdk.DataTypeFloat) + request := sdk.NewCreateAndCallForJavaScriptProcedureRequest(name, nil, definition, name). + WithResultDataTypeOld(sdk.DataTypeString). WithArguments([]sdk.ProcedureArgumentRequest{*arg}). WithNullInputBehavior(*sdk.NullInputBehaviorPointer(sdk.NullInputBehaviorStrict)). WithCallArguments([]string{"5.14::FLOAT"}) @@ -864,7 +867,7 @@ func TestInt_CreateAndCallProcedures(t *testing.T) { name := sdk.NewAccountObjectIdentifier("sp_pi") definition := `return 3.1415926;` - request := sdk.NewCreateAndCallForJavaScriptProcedureRequest(name, sdk.DataTypeFloat, definition, name).WithNotNull(true) + request := sdk.NewCreateAndCallForJavaScriptProcedureRequest(name, nil, definition, name).WithResultDataTypeOld(sdk.DataTypeFloat).WithNotNull(true) err := client.Procedures.CreateAndCallForJavaScript(ctx, request) require.NoError(t, err) }) @@ -876,9 +879,9 @@ func TestInt_CreateAndCallProcedures(t *testing.T) { END;` name := testClientHelper().Ids.RandomAccountObjectIdentifier() - dt := sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeVARCHAR) + dt := sdk.NewProcedureReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeVARCHAR) returns := sdk.NewProcedureReturnsRequest().WithResultDataType(*dt) - argument := sdk.NewProcedureArgumentRequest("message", sdk.DataTypeVARCHAR) + argument := sdk.NewProcedureArgumentRequest("message", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) request := sdk.NewCreateAndCallForSQLProcedureRequest(name, *returns, definition, name). WithArguments([]sdk.ProcedureArgumentRequest{*argument}). WithCallArguments([]string{"message => 'hi'"}) @@ -897,8 +900,8 @@ def filter_by_role(session, name, role): return df.filter(col("role") == role)` returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{}) returns := sdk.NewProcedureReturnsRequest().WithTable(*returnsTable) - arg1 := sdk.NewProcedureArgumentRequest("name", sdk.DataTypeVARCHAR) - arg2 := sdk.NewProcedureArgumentRequest("role", sdk.DataTypeVARCHAR) + arg1 := sdk.NewProcedureArgumentRequest("name", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) + arg2 := sdk.NewProcedureArgumentRequest("role", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("snowflake-snowpark-python")} ca := []string{fmt.Sprintf(`'%s'`, tid.FullyQualifiedName()), "'dev'"} request := sdk.NewCreateAndCallForPythonProcedureRequest(name, *returns, "3.8", packages, "filter_by_role", name). @@ -922,13 +925,13 @@ def filter_by_role(session, name, role): return filteredRows; } }` - column1 := sdk.NewProcedureColumnRequest("id", sdk.DataTypeNumber) - column2 := sdk.NewProcedureColumnRequest("name", sdk.DataTypeVARCHAR) - column3 := sdk.NewProcedureColumnRequest("role", sdk.DataTypeVARCHAR) + column1 := sdk.NewProcedureColumnRequest("id", nil).WithColumnDataTypeOld(sdk.DataTypeNumber) + column2 := sdk.NewProcedureColumnRequest("name", nil).WithColumnDataTypeOld(sdk.DataTypeVARCHAR) + column3 := sdk.NewProcedureColumnRequest("role", nil).WithColumnDataTypeOld(sdk.DataTypeVARCHAR) returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2, *column3}) returns := sdk.NewProcedureReturnsRequest().WithTable(*returnsTable) - arg1 := sdk.NewProcedureArgumentRequest("name", sdk.DataTypeVARCHAR) - arg2 := sdk.NewProcedureArgumentRequest("role", sdk.DataTypeVARCHAR) + arg1 := sdk.NewProcedureArgumentRequest("name", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) + arg2 := sdk.NewProcedureArgumentRequest("role", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("com.snowflake:snowpark:latest")} ca := []string{fmt.Sprintf(`'%s'`, tid.FullyQualifiedName()), "'dev'"} @@ -967,9 +970,9 @@ func TestInt_ProceduresShowByID(t *testing.T) { BEGIN RETURN message; END;` - dt := sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeVARCHAR) + dt := sdk.NewProcedureReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeVARCHAR) returns := sdk.NewProcedureSQLReturnsRequest().WithResultDataType(*dt).WithNotNull(true) - argument := sdk.NewProcedureArgumentRequest("message", sdk.DataTypeVARCHAR) + argument := sdk.NewProcedureArgumentRequest("message", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) request := sdk.NewCreateForSQLProcedureRequest(id.SchemaObjectId(), *returns, definition). WithArguments([]sdk.ProcedureArgumentRequest{*argument}). WithExecuteAs(*sdk.ExecuteAsPointer(sdk.ExecuteAsCaller)) From ba97b9c79d6a43c8d9f25db129b190ca66f81332 Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Wed, 4 Dec 2024 18:50:30 +0100 Subject: [PATCH 22/29] Add data types behavior for procedures --- pkg/schemas/procedure_gen.go | 2 +- pkg/sdk/procedures_ext.go | 2 +- pkg/sdk/procedures_gen.go | 2 +- pkg/sdk/procedures_impl_gen.go | 2 +- pkg/sdk/sql_builder_test.go | 3 - pkg/sdk/testint/functions_integration_test.go | 4 +- .../testint/procedures_integration_test.go | 82 ++++++++++++++++++- 7 files changed, 86 insertions(+), 11 deletions(-) diff --git a/pkg/schemas/procedure_gen.go b/pkg/schemas/procedure_gen.go index 2499f043aa..38d5937273 100644 --- a/pkg/schemas/procedure_gen.go +++ b/pkg/schemas/procedure_gen.go @@ -83,7 +83,7 @@ func ProcedureToSchema(procedure *sdk.Procedure) map[string]any { procedureSchema["is_ansi"] = procedure.IsAnsi procedureSchema["min_num_arguments"] = procedure.MinNumArguments procedureSchema["max_num_arguments"] = procedure.MaxNumArguments - procedureSchema["arguments"] = procedure.Arguments + procedureSchema["arguments"] = procedure.ArgumentsOld procedureSchema["arguments_raw"] = procedure.ArgumentsRaw procedureSchema["description"] = procedure.Description procedureSchema["catalog_name"] = procedure.CatalogName diff --git a/pkg/sdk/procedures_ext.go b/pkg/sdk/procedures_ext.go index 055e422501..31307bc2fb 100644 --- a/pkg/sdk/procedures_ext.go +++ b/pkg/sdk/procedures_ext.go @@ -1,5 +1,5 @@ package sdk func (v *Procedure) ID() SchemaObjectIdentifierWithArguments { - return NewSchemaObjectIdentifierWithArguments(v.CatalogName, v.SchemaName, v.Name, v.Arguments...) + return NewSchemaObjectIdentifierWithArguments(v.CatalogName, v.SchemaName, v.Name, v.ArgumentsOld...) } diff --git a/pkg/sdk/procedures_gen.go b/pkg/sdk/procedures_gen.go index 93ecbbc02c..c65e95e94a 100644 --- a/pkg/sdk/procedures_gen.go +++ b/pkg/sdk/procedures_gen.go @@ -234,7 +234,7 @@ type Procedure struct { IsAnsi bool MinNumArguments int MaxNumArguments int - Arguments []DataType + ArgumentsOld []DataType ArgumentsRaw string Description string CatalogName string diff --git a/pkg/sdk/procedures_impl_gen.go b/pkg/sdk/procedures_impl_gen.go index 6f627d9f2e..e63cf1f386 100644 --- a/pkg/sdk/procedures_impl_gen.go +++ b/pkg/sdk/procedures_impl_gen.go @@ -412,7 +412,7 @@ func (r procedureRow) convert() *Procedure { if err != nil { log.Printf("[DEBUG] failed to parse procedure arguments, err = %s", err) } else { - e.Arguments = dataTypes + e.ArgumentsOld = dataTypes } if r.IsSecure.Valid { e.IsSecure = r.IsSecure.String == "Y" diff --git a/pkg/sdk/sql_builder_test.go b/pkg/sdk/sql_builder_test.go index 5f46e364ba..df38f0a69a 100644 --- a/pkg/sdk/sql_builder_test.go +++ b/pkg/sdk/sql_builder_test.go @@ -475,9 +475,6 @@ func TestBuilder_sql(t *testing.T) { }) } -// TODO [this PR]: add optional alternatives to procedures (arguments and return types) -// TODO [this PR]: integration tests for procedures -// TODO [this PR]: integration test to check all data types in a new way + reading from snowflake? func TestBuilder_DataType(t *testing.T) { type dataTypeTestHelper struct { diff --git a/pkg/sdk/testint/functions_integration_test.go b/pkg/sdk/testint/functions_integration_test.go index cbfbfe8514..b45337a5d7 100644 --- a/pkg/sdk/testint/functions_integration_test.go +++ b/pkg/sdk/testint/functions_integration_test.go @@ -555,9 +555,9 @@ func TestInt_FunctionsShowByID(t *testing.T) { require.Equal(t, dataTypes, function.ArgumentsOld) }) - // This test shows behavior of detailed types (e.g. VARCHAR(20) and NUMBER(10, 0)) on Snowflake side. + // This test shows behavior of detailed types (e.g. VARCHAR(20) and NUMBER(10, 0)) on Snowflake side for functions. // For SHOW, data type is generalized both for argument and return type (to e.g. VARCHAR and NUMBER). - // FOR DESCRIBE, data type is generalized for argument and works weirdly for the return type: type is generalized to the canonical one, but we get also the attributes. + // FOR DESCRIBE, data type is generalized for argument and works weirdly for the return type: type is generalized to the canonical one, but we also get the attributes. for _, tc := range []string{ "NUMBER(36, 5)", "NUMBER(36)", diff --git a/pkg/sdk/testint/procedures_integration_test.go b/pkg/sdk/testint/procedures_integration_test.go index 6ec16c2983..4543791c4d 100644 --- a/pkg/sdk/testint/procedures_integration_test.go +++ b/pkg/sdk/testint/procedures_integration_test.go @@ -7,6 +7,7 @@ import ( "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/collections" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -356,7 +357,7 @@ func TestInt_OtherProcedureFunctions(t *testing.T) { assert.Equal(t, false, procedure.IsAnsi) assert.Equal(t, 1, procedure.MinNumArguments) assert.Equal(t, 1, procedure.MaxNumArguments) - assert.NotEmpty(t, procedure.Arguments) + assert.NotEmpty(t, procedure.ArgumentsOld) assert.NotEmpty(t, procedure.ArgumentsRaw) assert.NotEmpty(t, procedure.Description) assert.NotEmpty(t, procedure.CatalogName) @@ -499,7 +500,7 @@ func TestInt_OtherProcedureFunctions(t *testing.T) { require.Equal(t, 0, len(procedures)) }) - t.Run("describe function for SQL", func(t *testing.T) { + t.Run("describe procedure for SQL", func(t *testing.T) { f := createProcedureForSQLHandle(t, true) id := f.ID() @@ -1022,4 +1023,81 @@ func TestInt_ProceduresShowByID(t *testing.T) { require.NoError(t, err) require.Equal(t, *e, *es) }) + + // This test shows behavior of detailed types (e.g. VARCHAR(20) and NUMBER(10, 0)) on Snowflake side for procedures. + // For SHOW, data type is generalized both for argument and return type (to e.g. VARCHAR and NUMBER). + // FOR DESCRIBE, data type is generalized for argument and works weirdly for the return type: type is generalized to the canonical one, but we also get the attributes. + for _, tc := range []string{ + "NUMBER(36, 5)", + "NUMBER(36)", + "NUMBER", + "DECIMAL", + "INTEGER", + "FLOAT", + "DOUBLE", + "VARCHAR", + "VARCHAR(20)", + "CHAR", + "CHAR(10)", + "TEXT", + "BINARY", + "BINARY(1000)", + "VARBINARY", + "BOOLEAN", + "DATE", + "DATETIME", + "TIME", + "TIMESTAMP_LTZ", + "TIMESTAMP_NTZ", + "TIMESTAMP_TZ", + "VARIANT", + "OBJECT", + "ARRAY", + "GEOGRAPHY", + "GEOMETRY", + "VECTOR(INT, 16)", + "VECTOR(FLOAT, 8)", + } { + tc := tc + t.Run(fmt.Sprintf("procedure returns non detailed data types of arguments for %s", tc), func(t *testing.T) { + procName := "add" + argName := "A" + dataType, err := datatypes.ParseDataType(tc) + require.NoError(t, err) + args := []sdk.ProcedureArgumentRequest{ + *sdk.NewProcedureArgumentRequest(argName, dataType), + } + oldDataType := sdk.LegacyDataTypeFrom(dataType) + idWithArguments := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(oldDataType) + + packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("snowflake-snowpark-python")} + definition := fmt.Sprintf("def add(%[1]s): %[1]s", argName) + + err = client.Procedures.CreateForPython(ctx, sdk.NewCreateForPythonProcedureRequest( + idWithArguments.SchemaObjectId(), + *sdk.NewProcedureReturnsRequest().WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(dataType)), + "3.8", + packages, + procName, + ). + WithArguments(args). + WithProcedureDefinition(definition), + ) + require.NoError(t, err) + + procedure, err := client.Procedures.ShowByID(ctx, idWithArguments) + require.NoError(t, err) + assert.Equal(t, []sdk.DataType{oldDataType}, procedure.ArgumentsOld) + assert.Equal(t, fmt.Sprintf("%[1]s(%[2]s) RETURN %[2]s", idWithArguments.Name(), oldDataType), procedure.ArgumentsRaw) + + details, err := client.Procedures.Describe(ctx, idWithArguments) + require.NoError(t, err) + pairs := make(map[string]string) + for _, detail := range details { + pairs[detail.Property] = detail.Value + } + assert.Equal(t, fmt.Sprintf("(%s %s)", argName, oldDataType), pairs["signature"]) + assert.Equal(t, dataType.Canonical(), pairs["returns"]) + }) + } } From 30ecd5dd1bd9a2037eddd85b97a7d3ebef54110c Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Wed, 4 Dec 2024 18:53:35 +0100 Subject: [PATCH 23/29] Add issue numbers to TODOs --- pkg/sdk/functions_gen_test.go | 10 +++++----- pkg/sdk/procedures_gen_test.go | 20 +++++++++---------- pkg/sdk/testint/functions_integration_test.go | 2 +- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/pkg/sdk/functions_gen_test.go b/pkg/sdk/functions_gen_test.go index 865a2911bc..5ad79d64cb 100644 --- a/pkg/sdk/functions_gen_test.go +++ b/pkg/sdk/functions_gen_test.go @@ -53,7 +53,7 @@ func TestFunctions_CreateForJava(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForJavaFunctionOptions", "Handler")) }) - // TODO [next PR]: remove with old function removal for V1 + // TODO [SNOW-1348103]: remove with old function removal for V1 t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.OrReplace = Bool(true) @@ -221,7 +221,7 @@ func TestFunctions_CreateForJavascript(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForJavascriptFunctionOptions", "FunctionDefinition")) }) - // TODO [next PR]: remove with old function removal for V1 + // TODO [SNOW-1348103]: remove with old function removal for V1 t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.OrReplace = Bool(true) @@ -323,7 +323,7 @@ func TestFunctions_CreateForPython(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, NewError("IMPORTS must not be empty when AS is nil")) }) - // TODO [next PR]: remove with old function removal for V1 + // TODO [SNOW-1348103]: remove with old function removal for V1 t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.OrReplace = Bool(true) @@ -478,7 +478,7 @@ func TestFunctions_CreateForScala(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForScalaFunctionOptions", "Handler")) }) - // TODO [next PR]: remove with old function removal for V1 + // TODO [SNOW-1348103]: remove with old function removal for V1 t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.OrReplace = Bool(true) @@ -585,7 +585,7 @@ func TestFunctions_CreateForSQL(t *testing.T) { assertOptsValidAndSQLEquals(t, opts, `CREATE FUNCTION %s () RETURNS FLOAT AS '3.141592654::FLOAT'`, id.FullyQualifiedName()) }) - // TODO [next PR]: remove with old function removal for V1 + // TODO [SNOW-1348103]: remove with old function removal for V1 t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.OrReplace = Bool(true) diff --git a/pkg/sdk/procedures_gen_test.go b/pkg/sdk/procedures_gen_test.go index 212f922427..026bf329a6 100644 --- a/pkg/sdk/procedures_gen_test.go +++ b/pkg/sdk/procedures_gen_test.go @@ -66,7 +66,7 @@ func TestProcedures_CreateForJava(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, NewError("TARGET_PATH must be nil when AS is nil")) }) - // TODO [next PR]: remove with old procedure removal for V1 + // TODO [SNOW-1348106]: remove with old procedure removal for V1 t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.OrReplace = Bool(true) @@ -213,7 +213,7 @@ func TestProcedures_CreateForJavaScript(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) - // TODO [next PR]: remove with old procedure removal for V1 + // TODO [SNOW-1348106]: remove with old procedure removal for V1 t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.OrReplace = Bool(true) @@ -309,7 +309,7 @@ func TestProcedures_CreateForPython(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForPythonProcedureOptions.Returns", "ResultDataType", "Table")) }) - // TODO [next PR]: remove with old procedure removal for V1 + // TODO [SNOW-1348106]: remove with old procedure removal for V1 t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.OrReplace = Bool(true) @@ -485,7 +485,7 @@ func TestProcedures_CreateForScala(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, NewError("TARGET_PATH must be nil when AS is nil")) }) - // TODO [next PR]: remove with old procedure removal for V1 + // TODO [SNOW-1348106]: remove with old procedure removal for V1 t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.OrReplace = Bool(true) @@ -612,7 +612,7 @@ func TestProcedures_CreateForSQL(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLProcedureOptions.Returns", "ResultDataType", "Table")) }) - // TODO [next PR]: remove with old procedure removal for V1 + // TODO [SNOW-1348106]: remove with old procedure removal for V1 t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.OrReplace = Bool(true) @@ -968,7 +968,7 @@ func TestProcedures_CreateAndCallForJava(t *testing.T) { assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE () RETURNS TABLE () LANGUAGE JAVA RUNTIME_VERSION = '1.8' PACKAGES = ('com.snowflake:snowpark:latest') HANDLER = 'TestFunc.echoVarchar' CALL %s ()`, id.FullyQualifiedName(), id.FullyQualifiedName()) }) - // TODO [next PR]: remove with old procedure removal for V1 + // TODO [SNOW-1348106]: remove with old procedure removal for V1 t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.Arguments = []ProcedureArgument{ @@ -1138,7 +1138,7 @@ func TestProcedures_CreateAndCallForScala(t *testing.T) { assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE () RETURNS TABLE () LANGUAGE SCALA RUNTIME_VERSION = '2.12' PACKAGES = ('com.snowflake:snowpark:1.2.0') HANDLER = 'TestFunc.echoVarchar' CALL %s ()`, id.FullyQualifiedName(), id.FullyQualifiedName()) }) - // TODO [next PR]: remove with old procedure removal for V1 + // TODO [SNOW-1348106]: remove with old procedure removal for V1 t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.Arguments = []ProcedureArgument{ @@ -1312,7 +1312,7 @@ func TestProcedures_CreateAndCallForPython(t *testing.T) { assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE () RETURNS TABLE () LANGUAGE PYTHON RUNTIME_VERSION = '3.8' PACKAGES = ('snowflake-snowpark-python') HANDLER = 'udf' CALL %s ()`, id.FullyQualifiedName(), id.FullyQualifiedName()) }) - // TODO [next PR]: remove with old procedure removal for V1 + // TODO [SNOW-1348106]: remove with old procedure removal for V1 t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.Arguments = []ProcedureArgument{ @@ -1446,7 +1446,7 @@ func TestProcedures_CreateAndCallForJavaScript(t *testing.T) { assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE () RETURNS FLOAT LANGUAGE JAVASCRIPT AS 'return 1;' CALL %s ()`, id.FullyQualifiedName(), id.FullyQualifiedName()) }) - // TODO [next PR]: remove with old procedure removal for V1 + // TODO [SNOW-1348106]: remove with old procedure removal for V1 t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.Arguments = []ProcedureArgument{ @@ -1562,7 +1562,7 @@ func TestProcedures_CreateAndCallForSQL(t *testing.T) { assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE () RETURNS TABLE () LANGUAGE SQL AS '3.141592654::FLOAT' CALL %s ()`, id.FullyQualifiedName(), id.FullyQualifiedName()) }) - // TODO [next PR]: remove with old procedure removal for V1 + // TODO [SNOW-1348106]: remove with old procedure removal for V1 t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.Arguments = []ProcedureArgument{ diff --git a/pkg/sdk/testint/functions_integration_test.go b/pkg/sdk/testint/functions_integration_test.go index b45337a5d7..5c19d66af4 100644 --- a/pkg/sdk/testint/functions_integration_test.go +++ b/pkg/sdk/testint/functions_integration_test.go @@ -498,7 +498,7 @@ func TestInt_FunctionsShowByID(t *testing.T) { require.Equal(t, *e, *es) }) - // TODO [next PR]: remove with old function removal for V1 + // TODO [SNOW-1348103]: remove with old function removal for V1 t.Run("function returns non detailed data types of arguments - old data types", func(t *testing.T) { // This test proves that every detailed data types (e.g. VARCHAR(20) and NUMBER(10, 0)) are generalized // on Snowflake side (to e.g. VARCHAR and NUMBER) and that sdk.ToDataType mapping function maps detailed types From 03e2400494cb61df3ee9ede8ab3efd6d45e828e3 Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Wed, 4 Dec 2024 18:56:47 +0100 Subject: [PATCH 24/29] Run make pre-push --- pkg/resources/procedure.go | 6 +++--- pkg/sdk/functions_gen_test.go | 1 - pkg/sdk/procedures_gen_test.go | 1 - pkg/sdk/sql_builder.go | 2 +- pkg/sdk/sql_builder_test.go | 1 - 5 files changed, 4 insertions(+), 7 deletions(-) diff --git a/pkg/resources/procedure.go b/pkg/resources/procedure.go index c1e7a95a5b..f7577833f9 100644 --- a/pkg/resources/procedure.go +++ b/pkg/resources/procedure.go @@ -261,7 +261,7 @@ func createJavaProcedure(ctx context.Context, d *schema.ResourceData, meta inter } procedureDefinition := d.Get("statement").(string) runtimeVersion := d.Get("runtime_version").(string) - var packages []sdk.ProcedurePackageRequest + packages := make([]sdk.ProcedurePackageRequest, 0) for _, item := range d.Get("packages").([]interface{}) { packages = append(packages, *sdk.NewProcedurePackageRequest(item.(string))) } @@ -373,7 +373,7 @@ func createScalaProcedure(ctx context.Context, d *schema.ResourceData, meta inte } procedureDefinition := d.Get("statement").(string) runtimeVersion := d.Get("runtime_version").(string) - var packages []sdk.ProcedurePackageRequest + packages := make([]sdk.ProcedurePackageRequest, 0) for _, item := range d.Get("packages").([]interface{}) { packages = append(packages, *sdk.NewProcedurePackageRequest(item.(string))) } @@ -484,7 +484,7 @@ func createPythonProcedure(ctx context.Context, d *schema.ResourceData, meta int } procedureDefinition := d.Get("statement").(string) runtimeVersion := d.Get("runtime_version").(string) - var packages []sdk.ProcedurePackageRequest + packages := make([]sdk.ProcedurePackageRequest, 0) for _, item := range d.Get("packages").([]interface{}) { packages = append(packages, *sdk.NewProcedurePackageRequest(item.(string))) } diff --git a/pkg/sdk/functions_gen_test.go b/pkg/sdk/functions_gen_test.go index 5ad79d64cb..fe9f855040 100644 --- a/pkg/sdk/functions_gen_test.go +++ b/pkg/sdk/functions_gen_test.go @@ -273,7 +273,6 @@ func TestFunctions_CreateForJavascript(t *testing.T) { opts.FunctionDefinition = "return 1;" assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE TEMPORARY SECURE FUNCTION %s (d FLOAT DEFAULT 1.0) COPY GRANTS RETURNS FLOAT NOT NULL LANGUAGE JAVASCRIPT CALLED ON NULL INPUT IMMUTABLE COMMENT = 'comment' AS 'return 1;'`, id.FullyQualifiedName()) }) - } func TestFunctions_CreateForPython(t *testing.T) { diff --git a/pkg/sdk/procedures_gen_test.go b/pkg/sdk/procedures_gen_test.go index 026bf329a6..2c2a9d8ec4 100644 --- a/pkg/sdk/procedures_gen_test.go +++ b/pkg/sdk/procedures_gen_test.go @@ -255,7 +255,6 @@ func TestProcedures_CreateForJavaScript(t *testing.T) { opts.ProcedureDefinition = "return 1;" assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE SECURE PROCEDURE %s (d FLOAT DEFAULT 1.0) COPY GRANTS RETURNS FLOAT NOT NULL LANGUAGE JAVASCRIPT STRICT COMMENT = 'test comment' EXECUTE AS CALLER AS 'return 1;'`, id.FullyQualifiedName()) }) - } func TestProcedures_CreateForPython(t *testing.T) { diff --git a/pkg/sdk/sql_builder.go b/pkg/sdk/sql_builder.go index b6d9f7469b..ba50ac65f6 100644 --- a/pkg/sdk/sql_builder.go +++ b/pkg/sdk/sql_builder.go @@ -644,7 +644,7 @@ func (v sqlParameterClause) String() string { if v.value == nil { return s } - var value = v.value + value := v.value if dataType, ok := value.(datatypes.DataType); ok { // We check like this and not by `dataType == nil` because for e.g. `var *datatypes.ArrayDataType` return false in a normal nil check if reflect.ValueOf(dataType).IsZero() { diff --git a/pkg/sdk/sql_builder_test.go b/pkg/sdk/sql_builder_test.go index df38f0a69a..48289cdad1 100644 --- a/pkg/sdk/sql_builder_test.go +++ b/pkg/sdk/sql_builder_test.go @@ -476,7 +476,6 @@ func TestBuilder_sql(t *testing.T) { } func TestBuilder_DataType(t *testing.T) { - type dataTypeTestHelper struct { DataType datatypes.DataType `ddl:"parameter,no_quotes,no_equals"` } From 1ff110818283d3668cee0a5a613e507d4e4017d4 Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Thu, 5 Dec 2024 11:16:08 +0100 Subject: [PATCH 25/29] Bump integration tests timeout --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 1f914dc5d9..e8bd91a003 100644 --- a/Makefile +++ b/Makefile @@ -68,7 +68,7 @@ test-acceptance: ## run acceptance tests TF_ACC=1 SF_TF_ACC_TEST_CONFIGURE_CLIENT_ONCE=true TEST_SF_TF_REQUIRE_TEST_OBJECT_SUFFIX=1 go test -run "^TestAcc_" -v -cover -timeout=120m ./... test-integration: ## run SDK integration tests - TEST_SF_TF_REQUIRE_TEST_OBJECT_SUFFIX=1 go test -run "^TestInt_" -v -cover -timeout=45m ./... + TEST_SF_TF_REQUIRE_TEST_OBJECT_SUFFIX=1 go test -run "^TestInt_" -v -cover -timeout=60m ./... test-architecture: ## check architecture constraints between packages go test ./pkg/architests/... -v From 4663f978f3a05569bb45192d4a19d6057f09b7ac Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Thu, 5 Dec 2024 15:43:38 +0100 Subject: [PATCH 26/29] Modify function and procedure definitions with additional validations --- pkg/sdk/functions_def.go | 12 ++++++++---- pkg/sdk/procedures_def.go | 17 +++++++++++------ 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/pkg/sdk/functions_def.go b/pkg/sdk/functions_def.go index 7f1aa0be59..825c1d2551 100644 --- a/pkg/sdk/functions_def.go +++ b/pkg/sdk/functions_def.go @@ -8,19 +8,22 @@ var functionArgument = g.NewQueryStruct("FunctionArgument"). Text("ArgName", g.KeywordOptions().NoQuotes().Required()). PredefinedQueryStructField("ArgDataTypeOld", "DataType", g.KeywordOptions().NoQuotes()). PredefinedQueryStructField("ArgDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()). - PredefinedQueryStructField("DefaultValue", "*string", g.ParameterOptions().NoEquals().SQL("DEFAULT")) + PredefinedQueryStructField("DefaultValue", "*string", g.ParameterOptions().NoEquals().SQL("DEFAULT")). + WithValidation(g.ExactlyOneValueSet, "ArgDataTypeOld", "ArgDataType") var functionColumn = g.NewQueryStruct("FunctionColumn"). Text("ColumnName", g.KeywordOptions().NoQuotes().Required()). PredefinedQueryStructField("ColumnDataTypeOld", "DataType", g.KeywordOptions().NoQuotes()). - PredefinedQueryStructField("ColumnDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()) + PredefinedQueryStructField("ColumnDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()). + WithValidation(g.ExactlyOneValueSet, "ColumnDataTypeOld", "ColumnDataType") var functionReturns = g.NewQueryStruct("FunctionReturns"). OptionalQueryStructField( "ResultDataType", g.NewQueryStruct("FunctionReturnsResultDataType"). PredefinedQueryStructField("ResultDataTypeOld", "DataType", g.KeywordOptions().NoQuotes()). - PredefinedQueryStructField("ResultDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()), + PredefinedQueryStructField("ResultDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()). + WithValidation(g.ExactlyOneValueSet, "ResultDataTypeOld", "ResultDataType"), g.KeywordOptions(), ). OptionalQueryStructField( @@ -201,7 +204,8 @@ var FunctionsDef = g.NewInterface( PredefinedQueryStructField("FunctionDefinition", "*string", g.ParameterOptions().NoEquals().SingleQuotes().SQL("AS")). WithValidation(g.ValidIdentifier, "name"). WithValidation(g.ValidateValueSet, "Handler"). - WithValidation(g.ConflictingFields, "OrReplace", "IfNotExists"), + WithValidation(g.ConflictingFields, "OrReplace", "IfNotExists"). + WithValidation(g.ExactlyOneValueSet, "ResultDataTypeOld", "ResultDataType"), ).CustomOperation( "CreateForSQL", "https://docs.snowflake.com/en/sql-reference/sql/create-function#sql-handler", diff --git a/pkg/sdk/procedures_def.go b/pkg/sdk/procedures_def.go index d30ff67e94..0485b7711b 100644 --- a/pkg/sdk/procedures_def.go +++ b/pkg/sdk/procedures_def.go @@ -8,12 +8,14 @@ var procedureArgument = g.NewQueryStruct("ProcedureArgument"). Text("ArgName", g.KeywordOptions().NoQuotes().Required()). PredefinedQueryStructField("ArgDataTypeOld", "DataType", g.KeywordOptions().NoQuotes()). PredefinedQueryStructField("ArgDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()). - PredefinedQueryStructField("DefaultValue", "*string", g.ParameterOptions().NoEquals().SQL("DEFAULT")) + PredefinedQueryStructField("DefaultValue", "*string", g.ParameterOptions().NoEquals().SQL("DEFAULT")). + WithValidation(g.ExactlyOneValueSet, "ArgDataTypeOld", "ArgDataType") var procedureColumn = g.NewQueryStruct("ProcedureColumn"). Text("ColumnName", g.KeywordOptions().NoQuotes().Required()). PredefinedQueryStructField("ColumnDataTypeOld", "DataType", g.KeywordOptions().NoQuotes()). - PredefinedQueryStructField("ColumnDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()) + PredefinedQueryStructField("ColumnDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()). + WithValidation(g.ExactlyOneValueSet, "ColumnDataTypeOld", "ColumnDataType") var procedureReturns = g.NewQueryStruct("ProcedureReturns"). OptionalQueryStructField( @@ -21,7 +23,8 @@ var procedureReturns = g.NewQueryStruct("ProcedureReturns"). g.NewQueryStruct("ProcedureReturnsResultDataType"). PredefinedQueryStructField("ResultDataTypeOld", "DataType", g.KeywordOptions().NoQuotes()). PredefinedQueryStructField("ResultDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()). - OptionalSQL("NULL").OptionalSQL("NOT NULL"), + OptionalSQL("NULL").OptionalSQL("NOT NULL"). + WithValidation(g.ExactlyOneValueSet, "ResultDataTypeOld", "ResultDataType"), g.KeywordOptions(), ). OptionalQueryStructField( @@ -40,7 +43,8 @@ var procedureSQLReturns = g.NewQueryStruct("ProcedureSQLReturns"). "ResultDataType", g.NewQueryStruct("ProcedureReturnsResultDataType"). PredefinedQueryStructField("ResultDataTypeOld", "DataType", g.KeywordOptions().NoQuotes()). - PredefinedQueryStructField("ResultDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()), + PredefinedQueryStructField("ResultDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()). + WithValidation(g.ExactlyOneValueSet, "ResultDataTypeOld", "ResultDataType"), g.KeywordOptions(), ). OptionalQueryStructField( @@ -140,7 +144,8 @@ var ProceduresDef = g.NewInterface( PredefinedQueryStructField("ExecuteAs", "*ExecuteAs", g.KeywordOptions()). PredefinedQueryStructField("ProcedureDefinition", "string", g.ParameterOptions().NoEquals().SingleQuotes().SQL("AS").Required()). WithValidation(g.ValidateValueSet, "ProcedureDefinition"). - WithValidation(g.ValidIdentifier, "name"), + WithValidation(g.ValidIdentifier, "name"). + WithValidation(g.ExactlyOneValueSet, "ResultDataTypeOld", "ResultDataType"), ).CustomOperation( "CreateForPython", "https://docs.snowflake.com/en/sql-reference/sql/create-procedure#python-handler", @@ -460,7 +465,7 @@ var ProceduresDef = g.NewInterface( PredefinedQueryStructField("CallArguments", "[]string", g.KeywordOptions().MustParentheses()). PredefinedQueryStructField("ScriptingVariable", "*string", g.ParameterOptions().NoEquals().NoQuotes().SQL("INTO")). WithValidation(g.ValidateValueSet, "ProcedureDefinition"). - WithValidation(g.AtLeastOneValueSet, "ResultDataTypeOld", "ResultDataType"). + WithValidation(g.ExactlyOneValueSet, "ResultDataTypeOld", "ResultDataType"). WithValidation(g.ValidIdentifier, "ProcedureName"). WithValidation(g.ValidIdentifier, "Name"), ).CustomOperation( From 80d38bc02f831c5d45982935d5e61a099b158ff0 Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Thu, 5 Dec 2024 16:26:04 +0100 Subject: [PATCH 27/29] Update function validations --- pkg/sdk/functions_gen_test.go | 361 +++++++++++++++++++++++++++ pkg/sdk/functions_validations_gen.go | 103 ++++++++ 2 files changed, 464 insertions(+) diff --git a/pkg/sdk/functions_gen_test.go b/pkg/sdk/functions_gen_test.go index fe9f855040..e4688e1564 100644 --- a/pkg/sdk/functions_gen_test.go +++ b/pkg/sdk/functions_gen_test.go @@ -24,12 +24,93 @@ func TestFunctions_CreateForJava(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []FunctionArgument{ + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavaFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []FunctionArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat, ArgDataType: dataTypeFloat}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavaFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - one valid, one invalid", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []FunctionArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat, ArgDataType: dataTypeFloat}, + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavaFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + t.Run("validation: returns", func(t *testing.T) { opts := defaultOpts() opts.Returns = FunctionReturns{} assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavaFunctionOptions.Returns", "ResultDataType", "Table")) }) + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + ResultDataType: &FunctionReturnsResultDataType{}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + ResultDataType: &FunctionReturnsResultDataType{ + ResultDataTypeOld: DataTypeFloat, + ResultDataType: dataTypeFloat, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + Table: &FunctionReturnsTable{ + Columns: []FunctionColumn{ + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + Table: &FunctionReturnsTable{ + Columns: []FunctionColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat, ColumnDataType: dataTypeFloat}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - one valid, one invalid", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + Table: &FunctionReturnsTable{ + Columns: []FunctionColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat}, + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + t.Run("validation: function definition", func(t *testing.T) { opts := defaultOpts() opts.TargetPath = String("@~/testfunc.jar") @@ -205,6 +286,87 @@ func TestFunctions_CreateForJavascript(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []FunctionArgument{ + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavascriptFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []FunctionArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat, ArgDataType: dataTypeFloat}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavascriptFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - one valid, one invalid", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []FunctionArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat}, + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavascriptFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + ResultDataType: &FunctionReturnsResultDataType{}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + ResultDataType: &FunctionReturnsResultDataType{ + ResultDataTypeOld: DataTypeFloat, + ResultDataType: dataTypeFloat, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + Table: &FunctionReturnsTable{ + Columns: []FunctionColumn{ + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + Table: &FunctionReturnsTable{ + Columns: []FunctionColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat, ColumnDataType: dataTypeFloat}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - one valid, one invalid", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + Table: &FunctionReturnsTable{ + Columns: []FunctionColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat}, + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + t.Run("validation: returns", func(t *testing.T) { opts := defaultOpts() opts.Returns = FunctionReturns{} @@ -295,6 +457,87 @@ func TestFunctions_CreateForPython(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []FunctionArgument{ + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForPythonFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []FunctionArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat, ArgDataType: dataTypeFloat}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForPythonFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - one correct, one incorrect", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []FunctionArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat}, + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForPythonFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + ResultDataType: &FunctionReturnsResultDataType{}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + ResultDataType: &FunctionReturnsResultDataType{ + ResultDataTypeOld: DataTypeFloat, + ResultDataType: dataTypeFloat, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + Table: &FunctionReturnsTable{ + Columns: []FunctionColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat}, + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + Table: &FunctionReturnsTable{ + Columns: []FunctionColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat, ColumnDataType: dataTypeFloat}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - one valid, one invalid", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + Table: &FunctionReturnsTable{ + Columns: []FunctionColumn{ + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + t.Run("validation: returns", func(t *testing.T) { opts := defaultOpts() opts.Returns = FunctionReturns{} @@ -458,6 +701,43 @@ func TestFunctions_CreateForScala(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []FunctionArgument{ + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForScalaFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []FunctionArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat, ArgDataType: dataTypeFloat}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForScalaFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - one valid, one invalid", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []FunctionArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat}, + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForScalaFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.ResultDataTypeOld opts.ResultDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForScalaFunctionOptions", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.ResultDataTypeOld opts.ResultDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.ResultDataTypeOld = DataTypeFloat + opts.ResultDataType = dataTypeFloat + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForScalaFunctionOptions", "ResultDataTypeOld", "ResultDataType")) + }) + t.Run("validation: function definition", func(t *testing.T) { opts := defaultOpts() opts.TargetPath = String("@~/testfunc.jar") @@ -557,6 +837,87 @@ func TestFunctions_CreateForSQL(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []FunctionArgument{ + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []FunctionArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat, ArgDataType: dataTypeFloat}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []FunctionArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat, ArgDataType: dataTypeFloat}, + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + ResultDataType: &FunctionReturnsResultDataType{}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + ResultDataType: &FunctionReturnsResultDataType{ + ResultDataTypeOld: DataTypeFloat, + ResultDataType: dataTypeFloat, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + Table: &FunctionReturnsTable{ + Columns: []FunctionColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat}, + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + Table: &FunctionReturnsTable{ + Columns: []FunctionColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat, ColumnDataType: dataTypeFloat}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - one valid, one invalid", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + Table: &FunctionReturnsTable{ + Columns: []FunctionColumn{ + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + t.Run("validation: returns", func(t *testing.T) { opts := defaultOpts() opts.Returns = FunctionReturns{} diff --git a/pkg/sdk/functions_validations_gen.go b/pkg/sdk/functions_validations_gen.go index 52863802a8..78970158e8 100644 --- a/pkg/sdk/functions_validations_gen.go +++ b/pkg/sdk/functions_validations_gen.go @@ -26,10 +26,33 @@ func (opts *CreateForJavaFunctionOptions) validate() error { if everyValueSet(opts.OrReplace, opts.IfNotExists) { errs = append(errs, errOneOf("CreateForJavaFunctionOptions", "OrReplace", "IfNotExists")) } + if valueSet(opts.Arguments) { + // modified manually + for _, arg := range opts.Arguments { + if !exactlyOneValueSet(arg.ArgDataTypeOld, arg.ArgDataType) { + errs = append(errs, errExactlyOneOf("CreateForJavaFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + } + } + } if valueSet(opts.Returns) { if !exactlyOneValueSet(opts.Returns.ResultDataType, opts.Returns.Table) { errs = append(errs, errExactlyOneOf("CreateForJavaFunctionOptions.Returns", "ResultDataType", "Table")) } + if valueSet(opts.Returns.ResultDataType) { + if !exactlyOneValueSet(opts.Returns.ResultDataType.ResultDataTypeOld, opts.Returns.ResultDataType.ResultDataType) { + errs = append(errs, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + } + } + if valueSet(opts.Returns.Table) { + if valueSet(opts.Returns.Table.Columns) { + // modified manually + for _, col := range opts.Returns.Table.Columns { + if !exactlyOneValueSet(col.ColumnDataTypeOld, col.ColumnDataType) { + errs = append(errs, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + } + } + } + } } // added manually if opts.FunctionDefinition == nil { @@ -57,10 +80,33 @@ func (opts *CreateForJavascriptFunctionOptions) validate() error { if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } + if valueSet(opts.Arguments) { + // modified manually + for _, arg := range opts.Arguments { + if !exactlyOneValueSet(arg.ArgDataTypeOld, arg.ArgDataType) { + errs = append(errs, errExactlyOneOf("CreateForJavascriptFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + } + } + } if valueSet(opts.Returns) { if !exactlyOneValueSet(opts.Returns.ResultDataType, opts.Returns.Table) { errs = append(errs, errExactlyOneOf("CreateForJavascriptFunctionOptions.Returns", "ResultDataType", "Table")) } + if valueSet(opts.Returns.ResultDataType) { + if !exactlyOneValueSet(opts.Returns.ResultDataType.ResultDataTypeOld, opts.Returns.ResultDataType.ResultDataType) { + errs = append(errs, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + } + } + if valueSet(opts.Returns.Table) { + if valueSet(opts.Returns.Table.Columns) { + // modified manually + for _, col := range opts.Returns.Table.Columns { + if !exactlyOneValueSet(col.ColumnDataTypeOld, col.ColumnDataType) { + errs = append(errs, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + } + } + } + } } return JoinErrors(errs...) } @@ -82,10 +128,33 @@ func (opts *CreateForPythonFunctionOptions) validate() error { if everyValueSet(opts.OrReplace, opts.IfNotExists) { errs = append(errs, errOneOf("CreateForPythonFunctionOptions", "OrReplace", "IfNotExists")) } + if valueSet(opts.Arguments) { + // modified manually + for _, arg := range opts.Arguments { + if !exactlyOneValueSet(arg.ArgDataTypeOld, arg.ArgDataType) { + errs = append(errs, errExactlyOneOf("CreateForPythonFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + } + } + } if valueSet(opts.Returns) { if !exactlyOneValueSet(opts.Returns.ResultDataType, opts.Returns.Table) { errs = append(errs, errExactlyOneOf("CreateForPythonFunctionOptions.Returns", "ResultDataType", "Table")) } + if valueSet(opts.Returns.ResultDataType) { + if !exactlyOneValueSet(opts.Returns.ResultDataType.ResultDataTypeOld, opts.Returns.ResultDataType.ResultDataType) { + errs = append(errs, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + } + } + if valueSet(opts.Returns.Table) { + if valueSet(opts.Returns.Table.Columns) { + // modified manually + for _, col := range opts.Returns.Table.Columns { + if !exactlyOneValueSet(col.ColumnDataTypeOld, col.ColumnDataType) { + errs = append(errs, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + } + } + } + } } // added manually if opts.FunctionDefinition == nil { @@ -110,6 +179,17 @@ func (opts *CreateForScalaFunctionOptions) validate() error { if everyValueSet(opts.OrReplace, opts.IfNotExists) { errs = append(errs, errOneOf("CreateForScalaFunctionOptions", "OrReplace", "IfNotExists")) } + if !exactlyOneValueSet(opts.ResultDataTypeOld, opts.ResultDataType) { + errs = append(errs, errExactlyOneOf("CreateForScalaFunctionOptions", "ResultDataTypeOld", "ResultDataType")) + } + if valueSet(opts.Arguments) { + // modified manually + for _, arg := range opts.Arguments { + if !exactlyOneValueSet(arg.ArgDataTypeOld, arg.ArgDataType) { + errs = append(errs, errExactlyOneOf("CreateForScalaFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + } + } + } // added manually if opts.FunctionDefinition == nil { if opts.TargetPath != nil { @@ -136,10 +216,33 @@ func (opts *CreateForSQLFunctionOptions) validate() error { if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } + if valueSet(opts.Arguments) { + // modified manually + for _, arg := range opts.Arguments { + if !exactlyOneValueSet(arg.ArgDataTypeOld, arg.ArgDataType) { + errs = append(errs, errExactlyOneOf("CreateForSQLFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + } + } + } if valueSet(opts.Returns) { if !exactlyOneValueSet(opts.Returns.ResultDataType, opts.Returns.Table) { errs = append(errs, errExactlyOneOf("CreateForSQLFunctionOptions.Returns", "ResultDataType", "Table")) } + if valueSet(opts.Returns.ResultDataType) { + if !exactlyOneValueSet(opts.Returns.ResultDataType.ResultDataTypeOld, opts.Returns.ResultDataType.ResultDataType) { + errs = append(errs, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + } + } + if valueSet(opts.Returns.Table) { + if valueSet(opts.Returns.Table.Columns) { + // modified manually + for _, col := range opts.Returns.Table.Columns { + if !exactlyOneValueSet(col.ColumnDataTypeOld, col.ColumnDataType) { + errs = append(errs, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + } + } + } + } } return JoinErrors(errs...) } From 9ca5692fc9280acc98432f886ba3f4b24370eb3b Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Thu, 5 Dec 2024 16:58:36 +0100 Subject: [PATCH 28/29] Update procedure validations --- pkg/sdk/functions_gen_test.go | 2 +- pkg/sdk/procedures_gen_test.go | 361 ++++++++++++++++++++++++++ pkg/sdk/procedures_validations_gen.go | 207 ++++++++++++++- 3 files changed, 567 insertions(+), 3 deletions(-) diff --git a/pkg/sdk/functions_gen_test.go b/pkg/sdk/functions_gen_test.go index e4688e1564..95c21d9204 100644 --- a/pkg/sdk/functions_gen_test.go +++ b/pkg/sdk/functions_gen_test.go @@ -43,7 +43,7 @@ func TestFunctions_CreateForJava(t *testing.T) { t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - one valid, one invalid", func(t *testing.T) { opts := defaultOpts() opts.Arguments = []FunctionArgument{ - {ArgName: "arg", ArgDataTypeOld: DataTypeFloat, ArgDataType: dataTypeFloat}, + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat}, {ArgName: "arg"}, } assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavaFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) diff --git a/pkg/sdk/procedures_gen_test.go b/pkg/sdk/procedures_gen_test.go index 2c2a9d8ec4..994181b59b 100644 --- a/pkg/sdk/procedures_gen_test.go +++ b/pkg/sdk/procedures_gen_test.go @@ -49,6 +49,87 @@ func TestProcedures_CreateForJava(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForJavaProcedureOptions", "Handler")) }) + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavaProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat, ArgDataType: dataTypeFloat}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavaProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - one valid, one invalid", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat}, + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavaProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureReturns{ + ResultDataType: &ProcedureReturnsResultDataType{}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureReturns{ + ResultDataType: &ProcedureReturnsResultDataType{ + ResultDataTypeOld: DataTypeFloat, + ResultDataType: dataTypeFloat, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureReturns{ + Table: &ProcedureReturnsTable{ + Columns: []ProcedureColumn{ + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureReturns{ + Table: &ProcedureReturnsTable{ + Columns: []ProcedureColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat, ColumnDataType: dataTypeFloat}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - one valid, one invalid", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureReturns{ + Table: &ProcedureReturnsTable{ + Columns: []ProcedureColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat}, + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + t.Run("validation: exactly one field from [opts.Returns.ResultDataType opts.Returns.Table] should be present", func(t *testing.T) { opts := defaultOpts() opts.Returns = ProcedureReturns{} @@ -207,6 +288,43 @@ func TestProcedures_CreateForJavaScript(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForJavaScriptProcedureOptions", "ProcedureDefinition")) }) + t.Run("validation: exactly one field from [opts.ResultDataTypeOld opts.ResultDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavaScriptProcedureOptions", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.ResultDataTypeOld opts.ResultDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.ResultDataTypeOld = DataTypeFloat + opts.ResultDataType = dataTypeFloat + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavaScriptProcedureOptions", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavaScriptProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat, ArgDataType: dataTypeFloat}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavaScriptProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - one correct, one incorrect", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat}, + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavaScriptProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + t.Run("validation: incorrect identifier", func(t *testing.T) { opts := defaultOpts() opts.name = emptySchemaObjectIdentifier @@ -302,6 +420,87 @@ func TestProcedures_CreateForPython(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForPythonProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat, ArgDataType: dataTypeFloat}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForPythonProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - one valid, one invalid", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat}, + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForPythonProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureReturns{ + ResultDataType: &ProcedureReturnsResultDataType{}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureReturns{ + ResultDataType: &ProcedureReturnsResultDataType{ + ResultDataTypeOld: DataTypeFloat, + ResultDataType: dataTypeFloat, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureReturns{ + Table: &ProcedureReturnsTable{ + Columns: []ProcedureColumn{ + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureReturns{ + Table: &ProcedureReturnsTable{ + Columns: []ProcedureColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat, ColumnDataType: dataTypeFloat}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - one valid, one invalid", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureReturns{ + Table: &ProcedureReturnsTable{ + Columns: []ProcedureColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat}, + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + t.Run("validation: exactly one field from [opts.Returns.ResultDataType opts.Returns.Table] should be present", func(t *testing.T) { opts := defaultOpts() opts.Returns = ProcedureReturns{} @@ -467,6 +666,87 @@ func TestProcedures_CreateForScala(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForScalaProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat, ArgDataType: dataTypeFloat}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForScalaProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - one valid, one invalid", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat}, + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForScalaProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureReturns{ + ResultDataType: &ProcedureReturnsResultDataType{}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureReturns{ + ResultDataType: &ProcedureReturnsResultDataType{ + ResultDataTypeOld: DataTypeFloat, + ResultDataType: dataTypeFloat, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureReturns{ + Table: &ProcedureReturnsTable{ + Columns: []ProcedureColumn{ + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureReturns{ + Table: &ProcedureReturnsTable{ + Columns: []ProcedureColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat, ColumnDataType: dataTypeFloat}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - one valid, one invalid", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureReturns{ + Table: &ProcedureReturnsTable{ + Columns: []ProcedureColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat}, + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + t.Run("validation: exactly one field from [opts.Returns.ResultDataType opts.Returns.Table] should be present", func(t *testing.T) { opts := defaultOpts() opts.Returns = ProcedureReturns{} @@ -605,6 +885,87 @@ func TestProcedures_CreateForSQL(t *testing.T) { assertOptsValidAndSQLEquals(t, opts, `CREATE PROCEDURE %s () RETURNS FLOAT LANGUAGE SQL AS '3.141592654::FLOAT'`, id.FullyQualifiedName()) }) + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat, ArgDataType: dataTypeFloat}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - one valid, one invalid", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat}, + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureSQLReturns{ + ResultDataType: &ProcedureReturnsResultDataType{}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureSQLReturns{ + ResultDataType: &ProcedureReturnsResultDataType{ + ResultDataTypeOld: DataTypeFloat, + ResultDataType: dataTypeFloat, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureSQLReturns{ + Table: &ProcedureReturnsTable{ + Columns: []ProcedureColumn{ + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureSQLReturns{ + Table: &ProcedureReturnsTable{ + Columns: []ProcedureColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat, ColumnDataType: dataTypeFloat}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - one valid, one invalid", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureSQLReturns{ + Table: &ProcedureReturnsTable{ + Columns: []ProcedureColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat}, + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + t.Run("validation: exactly one field from [opts.Returns.ResultDataType opts.Returns.Table] should be present", func(t *testing.T) { opts := defaultOpts() opts.Returns = ProcedureSQLReturns{} diff --git a/pkg/sdk/procedures_validations_gen.go b/pkg/sdk/procedures_validations_gen.go index 1630153487..5e7557176f 100644 --- a/pkg/sdk/procedures_validations_gen.go +++ b/pkg/sdk/procedures_validations_gen.go @@ -35,10 +35,33 @@ func (opts *CreateForJavaProcedureOptions) validate() error { if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } + if valueSet(opts.Arguments) { + // modified manually + for _, arg := range opts.Arguments { + if !exactlyOneValueSet(arg.ArgDataTypeOld, arg.ArgDataType) { + errs = append(errs, errExactlyOneOf("CreateForJavaProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + } + } + } if valueSet(opts.Returns) { if !exactlyOneValueSet(opts.Returns.ResultDataType, opts.Returns.Table) { errs = append(errs, errExactlyOneOf("CreateForJavaProcedureOptions.Returns", "ResultDataType", "Table")) } + if valueSet(opts.Returns.ResultDataType) { + if !exactlyOneValueSet(opts.Returns.ResultDataType.ResultDataTypeOld, opts.Returns.ResultDataType.ResultDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + } + } + if valueSet(opts.Returns.Table) { + if valueSet(opts.Returns.Table.Columns) { + // modified manually + for _, col := range opts.Returns.Table.Columns { + if !exactlyOneValueSet(col.ColumnDataTypeOld, col.ColumnDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + } + } + } + } } // added manually if opts.ProcedureDefinition == nil && opts.TargetPath != nil { @@ -58,6 +81,17 @@ func (opts *CreateForJavaScriptProcedureOptions) validate() error { if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } + if !exactlyOneValueSet(opts.ResultDataTypeOld, opts.ResultDataType) { + errs = append(errs, errExactlyOneOf("CreateForJavaScriptProcedureOptions", "ResultDataTypeOld", "ResultDataType")) + } + if valueSet(opts.Arguments) { + // modified manually + for _, arg := range opts.Arguments { + if !exactlyOneValueSet(arg.ArgDataTypeOld, arg.ArgDataType) { + errs = append(errs, errExactlyOneOf("CreateForJavaScriptProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + } + } + } return JoinErrors(errs...) } @@ -78,10 +112,33 @@ func (opts *CreateForPythonProcedureOptions) validate() error { if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } + if valueSet(opts.Arguments) { + // modified manually + for _, arg := range opts.Arguments { + if !exactlyOneValueSet(arg.ArgDataTypeOld, arg.ArgDataType) { + errs = append(errs, errExactlyOneOf("CreateForPythonProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + } + } + } if valueSet(opts.Returns) { if !exactlyOneValueSet(opts.Returns.ResultDataType, opts.Returns.Table) { errs = append(errs, errExactlyOneOf("CreateForPythonProcedureOptions.Returns", "ResultDataType", "Table")) } + if valueSet(opts.Returns.ResultDataType) { + if !exactlyOneValueSet(opts.Returns.ResultDataType.ResultDataTypeOld, opts.Returns.ResultDataType.ResultDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + } + } + if valueSet(opts.Returns.Table) { + if valueSet(opts.Returns.Table.Columns) { + // modified manually + for _, col := range opts.Returns.Table.Columns { + if !exactlyOneValueSet(col.ColumnDataTypeOld, col.ColumnDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + } + } + } + } } return JoinErrors(errs...) } @@ -103,10 +160,33 @@ func (opts *CreateForScalaProcedureOptions) validate() error { if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } + if valueSet(opts.Arguments) { + // modified manually + for _, arg := range opts.Arguments { + if !exactlyOneValueSet(arg.ArgDataTypeOld, arg.ArgDataType) { + errs = append(errs, errExactlyOneOf("CreateForScalaProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + } + } + } if valueSet(opts.Returns) { if !exactlyOneValueSet(opts.Returns.ResultDataType, opts.Returns.Table) { errs = append(errs, errExactlyOneOf("CreateForScalaProcedureOptions.Returns", "ResultDataType", "Table")) } + if valueSet(opts.Returns.ResultDataType) { + if !exactlyOneValueSet(opts.Returns.ResultDataType.ResultDataTypeOld, opts.Returns.ResultDataType.ResultDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + } + } + if valueSet(opts.Returns.Table) { + if valueSet(opts.Returns.Table.Columns) { + // modified manually + for _, col := range opts.Returns.Table.Columns { + if !exactlyOneValueSet(col.ColumnDataTypeOld, col.ColumnDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + } + } + } + } } // added manually if opts.ProcedureDefinition == nil && opts.TargetPath != nil { @@ -126,10 +206,33 @@ func (opts *CreateForSQLProcedureOptions) validate() error { if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } + if valueSet(opts.Arguments) { + // modified manually + for _, arg := range opts.Arguments { + if !exactlyOneValueSet(arg.ArgDataTypeOld, arg.ArgDataType) { + errs = append(errs, errExactlyOneOf("CreateForSQLProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + } + } + } if valueSet(opts.Returns) { if !exactlyOneValueSet(opts.Returns.ResultDataType, opts.Returns.Table) { errs = append(errs, errExactlyOneOf("CreateForSQLProcedureOptions.Returns", "ResultDataType", "Table")) } + if valueSet(opts.Returns.ResultDataType) { + if !exactlyOneValueSet(opts.Returns.ResultDataType.ResultDataTypeOld, opts.Returns.ResultDataType.ResultDataType) { + errs = append(errs, errExactlyOneOf("CreateForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + } + } + if valueSet(opts.Returns.Table) { + if valueSet(opts.Returns.Table.Columns) { + // modified manually + for _, col := range opts.Returns.Table.Columns { + if !exactlyOneValueSet(col.ColumnDataTypeOld, col.ColumnDataType) { + errs = append(errs, errExactlyOneOf("CreateForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + } + } + } + } } return JoinErrors(errs...) } @@ -213,10 +316,33 @@ func (opts *CreateAndCallForJavaProcedureOptions) validate() error { if !ValidObjectIdentifier(opts.Name) { errs = append(errs, ErrInvalidObjectIdentifier) } + if valueSet(opts.Arguments) { + // modified manually + for _, arg := range opts.Arguments { + if !exactlyOneValueSet(arg.ArgDataTypeOld, arg.ArgDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForJavaProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + } + } + } if valueSet(opts.Returns) { if !exactlyOneValueSet(opts.Returns.ResultDataType, opts.Returns.Table) { errs = append(errs, errExactlyOneOf("CreateAndCallForJavaProcedureOptions.Returns", "ResultDataType", "Table")) } + if valueSet(opts.Returns.ResultDataType) { + if !exactlyOneValueSet(opts.Returns.ResultDataType.ResultDataTypeOld, opts.Returns.ResultDataType.ResultDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + } + } + if valueSet(opts.Returns.Table) { + if valueSet(opts.Returns.Table.Columns) { + // modified manually + for _, col := range opts.Returns.Table.Columns { + if !exactlyOneValueSet(col.ColumnDataTypeOld, col.ColumnDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + } + } + } + } } return JoinErrors(errs...) } @@ -242,10 +368,33 @@ func (opts *CreateAndCallForScalaProcedureOptions) validate() error { if !ValidObjectIdentifier(opts.Name) { errs = append(errs, ErrInvalidObjectIdentifier) } + if valueSet(opts.Arguments) { + // modified manually + for _, arg := range opts.Arguments { + if !exactlyOneValueSet(arg.ArgDataTypeOld, arg.ArgDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForScalaProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + } + } + } if valueSet(opts.Returns) { if !exactlyOneValueSet(opts.Returns.ResultDataType, opts.Returns.Table) { errs = append(errs, errExactlyOneOf("CreateAndCallForScalaProcedureOptions.Returns", "ResultDataType", "Table")) } + if valueSet(opts.Returns.ResultDataType) { + if !exactlyOneValueSet(opts.Returns.ResultDataType.ResultDataTypeOld, opts.Returns.ResultDataType.ResultDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + } + } + if valueSet(opts.Returns.Table) { + if valueSet(opts.Returns.Table.Columns) { + // modified manually + for _, col := range opts.Returns.Table.Columns { + if !exactlyOneValueSet(col.ColumnDataTypeOld, col.ColumnDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + } + } + } + } } return JoinErrors(errs...) } @@ -258,8 +407,8 @@ func (opts *CreateAndCallForJavaScriptProcedureOptions) validate() error { if !valueSet(opts.ProcedureDefinition) { errs = append(errs, errNotSet("CreateAndCallForJavaScriptProcedureOptions", "ProcedureDefinition")) } - if !anyValueSet(opts.ResultDataTypeOld, opts.ResultDataType) { - errs = append(errs, errAtLeastOneOf("CreateAndCallForJavaScriptProcedureOptions", "ResultDataTypeOld", "ResultDataType")) + if !exactlyOneValueSet(opts.ResultDataTypeOld, opts.ResultDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForJavaScriptProcedureOptions", "ResultDataTypeOld", "ResultDataType")) } if !ValidObjectIdentifier(opts.ProcedureName) { // altered manually @@ -268,6 +417,14 @@ func (opts *CreateAndCallForJavaScriptProcedureOptions) validate() error { if !ValidObjectIdentifier(opts.Name) { errs = append(errs, ErrInvalidObjectIdentifier) } + if valueSet(opts.Arguments) { + // modified manually + for _, arg := range opts.Arguments { + if !exactlyOneValueSet(arg.ArgDataTypeOld, arg.ArgDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForJavaScriptProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + } + } + } return JoinErrors(errs...) } @@ -292,10 +449,33 @@ func (opts *CreateAndCallForPythonProcedureOptions) validate() error { if !ValidObjectIdentifier(opts.Name) { errs = append(errs, ErrInvalidObjectIdentifier) } + if valueSet(opts.Arguments) { + // modified manually + for _, arg := range opts.Arguments { + if !exactlyOneValueSet(arg.ArgDataTypeOld, arg.ArgDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForPythonProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + } + } + } if valueSet(opts.Returns) { if !exactlyOneValueSet(opts.Returns.ResultDataType, opts.Returns.Table) { errs = append(errs, errExactlyOneOf("CreateAndCallForPythonProcedureOptions.Returns", "ResultDataType", "Table")) } + if valueSet(opts.Returns.ResultDataType) { + if !exactlyOneValueSet(opts.Returns.ResultDataType.ResultDataTypeOld, opts.Returns.ResultDataType.ResultDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + } + } + if valueSet(opts.Returns.Table) { + if valueSet(opts.Returns.Table.Columns) { + // modified manually + for _, col := range opts.Returns.Table.Columns { + if !exactlyOneValueSet(col.ColumnDataTypeOld, col.ColumnDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + } + } + } + } } return JoinErrors(errs...) } @@ -315,10 +495,33 @@ func (opts *CreateAndCallForSQLProcedureOptions) validate() error { if !ValidObjectIdentifier(opts.Name) { errs = append(errs, ErrInvalidObjectIdentifier) } + if valueSet(opts.Arguments) { + // modified manually + for _, arg := range opts.Arguments { + if !exactlyOneValueSet(arg.ArgDataTypeOld, arg.ArgDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + } + } + } if valueSet(opts.Returns) { if !exactlyOneValueSet(opts.Returns.ResultDataType, opts.Returns.Table) { errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns", "ResultDataType", "Table")) } + if valueSet(opts.Returns.ResultDataType) { + if !exactlyOneValueSet(opts.Returns.ResultDataType.ResultDataTypeOld, opts.Returns.ResultDataType.ResultDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + } + } + if valueSet(opts.Returns.Table) { + if valueSet(opts.Returns.Table.Columns) { + // modified manually + for _, col := range opts.Returns.Table.Columns { + if !exactlyOneValueSet(col.ColumnDataTypeOld, col.ColumnDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + } + } + } + } } return JoinErrors(errs...) } From 15c842aac8791a1049baad91c15dce8ce03b0d39 Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Thu, 5 Dec 2024 17:05:05 +0100 Subject: [PATCH 29/29] Add comments to new data type methods --- pkg/sdk/datatypes/data_types.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pkg/sdk/datatypes/data_types.go b/pkg/sdk/datatypes/data_types.go index 1449c91f75..be58f978f2 100644 --- a/pkg/sdk/datatypes/data_types.go +++ b/pkg/sdk/datatypes/data_types.go @@ -14,8 +14,11 @@ import ( // DataType is the common interface that represents all Snowflake datatypes documented in https://docs.snowflake.com/en/sql-reference/intro-summary-data-types. type DataType interface { + // ToSql formats data type explicitly specifying all arguments and using the given type (e.g. CHAR(29) for CHAR(29)). ToSql() string + // ToLegacyDataTypeSql formats data type using its base type without any attributes (e.g. VARCHAR for CHAR(29)). ToLegacyDataTypeSql() string + // Canonical formats the data type between ToSql and ToLegacyDataTypeSql: it uses base type but with arguments (e.g. VARCHAR(29) for CHAR(29)). Canonical() string }