diff --git a/pkg/acceptance/bettertestspoc/assert/objectassert/function_snowflake_ext.go b/pkg/acceptance/bettertestspoc/assert/objectassert/function_snowflake_ext.go index aa8d17a022..8836ff49d5 100644 --- a/pkg/acceptance/bettertestspoc/assert/objectassert/function_snowflake_ext.go +++ b/pkg/acceptance/bettertestspoc/assert/objectassert/function_snowflake_ext.go @@ -65,3 +65,14 @@ func (f *FunctionAssert) HasExactlySecrets(expectedSecrets map[string]sdk.Schema }) return f } + +func (f *FunctionAssert) HasArgumentsRawContains(substring string) *FunctionAssert { + f.AddAssertion(func(t *testing.T, o *sdk.Function) error { + t.Helper() + if !strings.Contains(o.ArgumentsRaw, substring) { + return fmt.Errorf("expected arguments raw contain: %v, to contain: %v", o.ArgumentsRaw, substring) + } + return nil + }) + return f +} diff --git a/pkg/acceptance/bettertestspoc/assert/objectassert/procedure_snowflake_ext.go b/pkg/acceptance/bettertestspoc/assert/objectassert/procedure_snowflake_ext.go index 12d5a384cf..4ce244f856 100644 --- a/pkg/acceptance/bettertestspoc/assert/objectassert/procedure_snowflake_ext.go +++ b/pkg/acceptance/bettertestspoc/assert/objectassert/procedure_snowflake_ext.go @@ -57,3 +57,14 @@ func (f *ProcedureAssert) HasExactlyExternalAccessIntegrations(integrations ...s }) return f } + +func (p *ProcedureAssert) HasArgumentsRawContains(substring string) *ProcedureAssert { + p.AddAssertion(func(t *testing.T, o *sdk.Procedure) error { + t.Helper() + if !strings.Contains(o.ArgumentsRaw, substring) { + return fmt.Errorf("expected arguments raw contain: %v, to contain: %v", o.ArgumentsRaw, substring) + } + return nil + }) + return p +} diff --git a/pkg/sdk/datatypes/data_types.go b/pkg/sdk/datatypes/data_types.go index be58f978f2..2371770a94 100644 --- a/pkg/sdk/datatypes/data_types.go +++ b/pkg/sdk/datatypes/data_types.go @@ -80,6 +80,9 @@ func ParseDataType(raw string) (DataType, error) { if idx := slices.IndexFunc(VectorDataTypeSynonyms, func(s string) bool { return strings.HasPrefix(dataTypeRaw, s) }); idx >= 0 { return parseVectorDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, VectorDataTypeSynonyms[idx]}) } + if idx := slices.IndexFunc(TableDataTypeSynonyms, func(s string) bool { return strings.HasPrefix(dataTypeRaw, s) }); idx >= 0 { + return parseTableDataTypeRaw(sanitizedDataTypeRaw{strings.TrimSpace(raw), TableDataTypeSynonyms[idx]}) + } return nil, fmt.Errorf("invalid data type: %s", raw) } @@ -118,6 +121,8 @@ func AreTheSame(a DataType, b DataType) bool { return castSuccessfully(v, b, areNumberDataTypesTheSame) case *ObjectDataType: return castSuccessfully(v, b, noArgsDataTypesAreTheSame) + case *TableDataType: + return castSuccessfully(v, b, areTableDataTypesTheSame) case *TextDataType: return castSuccessfully(v, b, areTextDataTypesTheSame) case *TimeDataType: diff --git a/pkg/sdk/datatypes/data_types_test.go b/pkg/sdk/datatypes/data_types_test.go index cfb3845ef1..7e6382e63f 100644 --- a/pkg/sdk/datatypes/data_types_test.go +++ b/pkg/sdk/datatypes/data_types_test.go @@ -6,6 +6,8 @@ import ( "strings" "testing" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/collections" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -1095,6 +1097,98 @@ func Test_ParseDataType_Vector(t *testing.T) { } } +func Test_ParseDataType_Table(t *testing.T) { + type column struct { + Name string + Type string + } + type test struct { + input string + expectedColumns []column + } + + positiveTestCases := []test{ + {input: "TABLE()", expectedColumns: []column{}}, + {input: "TABLE ()", expectedColumns: []column{}}, + {input: "TABLE ( )", expectedColumns: []column{}}, + {input: "TABLE(arg_name NUMBER)", expectedColumns: []column{{"arg_name", "NUMBER"}}}, + {input: "TABLE(arg_name double precision, arg_name_2 NUMBER)", expectedColumns: []column{{"arg_name", "double precision"}, {"arg_name_2", "NUMBER"}}}, + {input: "TABLE(arg_name NUMBER(38))", expectedColumns: []column{{"arg_name", "NUMBER(38)"}}}, + {input: "TABLE(arg_name NUMBER(38), arg_name_2 VARCHAR)", expectedColumns: []column{{"arg_name", "NUMBER(38)"}, {"arg_name_2", "VARCHAR"}}}, + {input: "TABLE(arg_name number, second float, third GEOGRAPHY)", expectedColumns: []column{{"arg_name", "number"}, {"second", "float"}, {"third", "GEOGRAPHY"}}}, + {input: "TABLE ( arg_name varchar, second date, third TIME )", expectedColumns: []column{{"arg_name", "varchar"}, {"second", "date"}, {"third", "time"}}}, + // TODO: Support types with parameters (for now, only legacy types are supported because Snowflake returns only with this output), e.g. TABLE(ARG NUMBER(38, 0)) + // TODO: Support nested tables, e.g. TABLE(ARG NUMBER, NESTED TABLE(A VARCHAR, B GEOMETRY)) + // TODO: Support complex argument names (with quotes / spaces / special characters / etc) + } + + negativeTestCases := []test{ + {input: "TABLE())"}, + {input: "TABLE(1, 2)"}, + {input: "TABLE(INT, INT)"}, + {input: "TABLE(a b)"}, + {input: "TABLE(1)"}, + {input: "TABLE(2, INT)"}, + {input: "TABLE"}, + {input: "TABLE(INT, 2, 3)"}, + {input: "TABLE(INT)"}, + {input: "TABLE(x, 2)"}, + {input: "TABLE("}, + {input: "TABLE)"}, + {input: "TA BLE"}, + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &TableDataType{}, parsed) + + assert.Equal(t, "TABLE", parsed.(*TableDataType).underlyingType) + assert.Equal(t, len(tc.expectedColumns), len(parsed.(*TableDataType).columns)) + for i, column := range tc.expectedColumns { + assert.Equal(t, column.Name, parsed.(*TableDataType).columns[i].name) + parsedType, err := ParseDataType(column.Type) + require.NoError(t, err) + assert.Equal(t, parsedType.ToLegacyDataTypeSql(), parsed.(*TableDataType).columns[i].dataType.ToLegacyDataTypeSql()) + } + + legacyColumns := strings.Join(collections.Map(tc.expectedColumns, func(col column) string { + parsedType, err := ParseDataType(col.Type) + require.NoError(t, err) + return fmt.Sprintf("%s %s", col.Name, parsedType.ToLegacyDataTypeSql()) + }), ", ") + assert.Equal(t, fmt.Sprintf("TABLE(%s)", legacyColumns), parsed.ToLegacyDataTypeSql()) + + canonicalColumns := strings.Join(collections.Map(tc.expectedColumns, func(col column) string { + parsedType, err := ParseDataType(col.Type) + require.NoError(t, err) + return fmt.Sprintf("%s %s", col.Name, parsedType.Canonical()) + }), ", ") + assert.Equal(t, fmt.Sprintf("TABLE(%s)", canonicalColumns), parsed.Canonical()) + + columns := strings.Join(collections.Map(tc.expectedColumns, func(col column) string { + parsedType, err := ParseDataType(col.Type) + require.NoError(t, err) + return fmt.Sprintf("%s %s", col.Name, parsedType.ToSql()) + }), ", ") + assert.Equal(t, fmt.Sprintf("TABLE(%s)", columns), parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run("negative: "+tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + func Test_AreTheSame(t *testing.T) { // empty d1/d2 means nil DataType input type test struct { @@ -1145,6 +1239,13 @@ func Test_AreTheSame(t *testing.T) { {d1: "TIME", d2: "TIME", expectedOutcome: true}, {d1: "TIME", d2: "TIME(5)", expectedOutcome: false}, {d1: "TIME", d2: fmt.Sprintf("TIME(%d)", DefaultTimePrecision), expectedOutcome: true}, + {d1: "TABLE()", d2: "TABLE()", expectedOutcome: true}, + {d1: "TABLE(A NUMBER)", d2: "TABLE(B NUMBER)", expectedOutcome: false}, + {d1: "TABLE(A NUMBER)", d2: "TABLE(a NUMBER)", expectedOutcome: false}, + {d1: "TABLE(A NUMBER)", d2: "TABLE(A VARCHAR)", expectedOutcome: false}, + {d1: "TABLE(A NUMBER, B VARCHAR)", d2: "TABLE(A NUMBER, B VARCHAR)", expectedOutcome: true}, + {d1: "TABLE(A NUMBER, B NUMBER)", d2: "TABLE(A NUMBER, B VARCHAR)", expectedOutcome: false}, + {d1: "TABLE()", d2: "TABLE(A NUMBER)", expectedOutcome: false}, } for _, tc := range testCases { diff --git a/pkg/sdk/datatypes/legacy.go b/pkg/sdk/datatypes/legacy.go index c77f286f9c..63f523779e 100644 --- a/pkg/sdk/datatypes/legacy.go +++ b/pkg/sdk/datatypes/legacy.go @@ -16,7 +16,6 @@ const ( TimestampNtzLegacyDataType = "TIMESTAMP_NTZ" TimestampTzLegacyDataType = "TIMESTAMP_TZ" VariantLegacyDataType = "VARIANT" - // TableLegacyDataType was not a value of legacy data type in the old implementation. Left for now for an easier implementation. TableLegacyDataType = "TABLE" ) diff --git a/pkg/sdk/datatypes/table.go b/pkg/sdk/datatypes/table.go index e7c398ec6d..a05298c992 100644 --- a/pkg/sdk/datatypes/table.go +++ b/pkg/sdk/datatypes/table.go @@ -1,9 +1,16 @@ package datatypes -// TableDataType is based on TODO [SNOW-1348103] +import ( + "fmt" + "strings" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/collections" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/logging" +) + +// TableDataType is based on https://docs.snowflake.com/en/developer-guide/stored-procedure/stored-procedures-java#returning-tabular-data. // It does not have synonyms. // It consists of a list of column name + column type; may be empty. -// TODO [SNOW-1348103]: test and improve type TableDataType struct { columns []TableDataTypeColumn underlyingType string @@ -14,6 +21,8 @@ type TableDataTypeColumn struct { dataType DataType } +var TableDataTypeSynonyms = []string{"TABLE"} + func (c *TableDataTypeColumn) ColumnName() string { return c.name } @@ -23,17 +32,79 @@ func (c *TableDataTypeColumn) ColumnType() DataType { } func (t *TableDataType) ToSql() string { - return t.underlyingType + columns := strings.Join(collections.Map(t.columns, func(col TableDataTypeColumn) string { + return fmt.Sprintf("%s %s", col.name, col.dataType.ToSql()) + }), ", ") + return fmt.Sprintf("%s(%s)", t.underlyingType, columns) } func (t *TableDataType) ToLegacyDataTypeSql() string { - return TableLegacyDataType + columns := strings.Join(collections.Map(t.columns, func(col TableDataTypeColumn) string { + return fmt.Sprintf("%s %s", col.name, col.dataType.ToLegacyDataTypeSql()) + }), ", ") + return fmt.Sprintf("%s(%s)", TableLegacyDataType, columns) } func (t *TableDataType) Canonical() string { - return TableLegacyDataType + columns := strings.Join(collections.Map(t.columns, func(col TableDataTypeColumn) string { + return fmt.Sprintf("%s %s", col.name, col.dataType.Canonical()) + }), ", ") + return fmt.Sprintf("%s(%s)", TableLegacyDataType, columns) } func (t *TableDataType) Columns() []TableDataTypeColumn { return t.columns } + +func parseTableDataTypeRaw(raw sanitizedDataTypeRaw) (*TableDataType, error) { + r := strings.TrimSpace(strings.TrimPrefix(raw.raw, raw.matchedByType)) + if r == "" || (!strings.HasPrefix(r, "(") || !strings.HasSuffix(r, ")")) { + logging.DebugLogger.Printf(`table %s could not be parsed, use "%s(argName argType, ...)" format`, raw.raw, raw.matchedByType) + return nil, fmt.Errorf(`table %s could not be parsed, use "%s(argName argType, ...)" format`, raw.raw, raw.matchedByType) + } + onlyArgs := strings.TrimSpace(r[1 : len(r)-1]) + if onlyArgs == "" { + return &TableDataType{ + columns: make([]TableDataTypeColumn, 0), + underlyingType: raw.matchedByType, + }, nil + } + columns, err := collections.MapErr(strings.Split(onlyArgs, ","), func(arg string) (TableDataTypeColumn, error) { + argParts := strings.SplitN(strings.TrimSpace(arg), " ", 2) + if len(argParts) != 2 { + return TableDataTypeColumn{}, fmt.Errorf("could not parse table column: %s, it should contain the following format ` `; parser failure may be connected to the complex argument names", arg) + } + argDataType, err := ParseDataType(argParts[1]) + if err != nil { + return TableDataTypeColumn{}, err + } + return TableDataTypeColumn{ + name: argParts[0], + dataType: argDataType, + }, nil + }) + if err != nil { + return nil, err + } + return &TableDataType{ + columns: columns, + underlyingType: raw.matchedByType, + }, nil +} + +func areTableDataTypesTheSame(a, b *TableDataType) bool { + if len(a.columns) != len(b.columns) { + return false + } + + for i := range a.columns { + aColumn := a.columns[i] + bColumn := b.columns[i] + + if aColumn.name != bColumn.name || !AreTheSame(aColumn.dataType, bColumn.dataType) { + return false + } + } + + return true +} diff --git a/pkg/sdk/testint/functions_integration_test.go b/pkg/sdk/testint/functions_integration_test.go index 8aa1b217fb..b5e9a35fb6 100644 --- a/pkg/sdk/testint/functions_integration_test.go +++ b/pkg/sdk/testint/functions_integration_test.go @@ -2049,4 +2049,38 @@ func TestInt_Functions(t *testing.T) { assert.Equal(t, dataType.Canonical(), pairs["returns"]) }) } + + t.Run("create function for SQL - return table data type", func(t *testing.T) { + argName := "x" + + returnDataType, err := datatypes.ParseDataType(fmt.Sprintf("TABLE(ID %s, PRICE %s, THIRD %s)", datatypes.NumberLegacyDataType, datatypes.FloatLegacyDataType, datatypes.VarcharLegacyDataType)) + require.NoError(t, err) + + id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(datatypes.VarcharLegacyDataType) + + definition := ` SELECT 1, 2.2::float, 'abc';` + dt := sdk.NewFunctionReturnsResultDataTypeRequest(returnDataType) + returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt) + argument := sdk.NewFunctionArgumentRequest(argName, nil).WithArgDataTypeOld(datatypes.VarcharLegacyDataType) + request := sdk.NewCreateForSQLFunctionRequestDefinitionWrapped(id.SchemaObjectId(), *returns, definition). + WithArguments([]sdk.FunctionArgumentRequest{*argument}) + + err = client.Functions.CreateForSQL(ctx, request) + require.NoError(t, err) + t.Cleanup(testClientHelper().Function.DropFunctionFunc(t, id)) + + function, err := client.Functions.ShowByID(ctx, id) + require.NoError(t, err) + + assertions.AssertThatObject(t, objectassert.FunctionFromObject(t, function). + HasCreatedOnNotEmpty(). + HasName(id.Name()). + HasSchemaName(id.SchemaName()). + HasArgumentsRawContains(returnDataType.ToLegacyDataTypeSql()), + ) + + assertions.AssertThatObject(t, objectassert.FunctionDetails(t, id). + HasReturnDataType(returnDataType), + ) + }) } diff --git a/pkg/sdk/testint/procedures_integration_test.go b/pkg/sdk/testint/procedures_integration_test.go index e8b54a9a4d..56f1a20248 100644 --- a/pkg/sdk/testint/procedures_integration_test.go +++ b/pkg/sdk/testint/procedures_integration_test.go @@ -1781,22 +1781,21 @@ def filter_by_role(session, table_name, role): require.GreaterOrEqual(t, len(procedures), 1) }) - // TODO [SNOW-1348103]: adjust or remove t.Run("create procedure for SQL: returns table", func(t *testing.T) { - t.Skipf("Skipped for now; left as inspiration for resource rework as part of SNOW-1348103") - - name := "find_invoice_by_id" - id := testClientHelper().Ids.NewSchemaObjectIdentifierWithArguments(name, sdk.DataTypeVARCHAR) + id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(sdk.DataTypeVARCHAR) + column1 := sdk.NewProcedureColumnRequest("id", nil).WithColumnDataTypeOld("INTEGER") + column2 := sdk.NewProcedureColumnRequest("price", nil).WithColumnDataTypeOld("double") + column3 := sdk.NewProcedureColumnRequest("third", nil).WithColumnDataTypeOld("Geometry") + returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2, *column3}) + expectedReturnDataType, err := datatypes.ParseDataType(fmt.Sprintf("TABLE(id %s, price %s, third %s)", datatypes.NumberLegacyDataType, datatypes.FloatLegacyDataType, datatypes.GeometryLegacyDataType)) + require.NoError(t, err) definition := ` DECLARE res RESULTSET DEFAULT (SELECT * FROM invoices WHERE id = :id); BEGIN RETURN TABLE(res); END;` - column1 := sdk.NewProcedureColumnRequest("id", nil).WithColumnDataTypeOld("INTEGER") - column2 := sdk.NewProcedureColumnRequest("price", nil).WithColumnDataTypeOld("NUMBER(12,2)") - returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2}) returns := sdk.NewProcedureSQLReturnsRequest().WithTable(*returnsTable) argument := sdk.NewProcedureArgumentRequest("id", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) request := sdk.NewCreateForSQLProcedureRequestDefinitionWrapped(id.SchemaObjectId(), *returns, definition). @@ -1804,13 +1803,20 @@ def filter_by_role(session, table_name, role): // SNOW-1051627 todo: uncomment once null input behavior working again // WithNullInputBehavior(sdk.NullInputBehaviorPointer(sdk.NullInputBehaviorReturnsNullInput)). WithArguments([]sdk.ProcedureArgumentRequest{*argument}) - err := client.Procedures.CreateForSQL(ctx, request) + err = client.Procedures.CreateForSQL(ctx, request) require.NoError(t, err) t.Cleanup(testClientHelper().Procedure.DropProcedureFunc(t, id)) - procedures, err := client.Procedures.Show(ctx, sdk.NewShowProcedureRequest()) - require.NoError(t, err) - require.GreaterOrEqual(t, len(procedures), 1) + assertions.AssertThatObject(t, objectassert.Procedure(t, id). + HasCreatedOnNotEmpty(). + HasName(id.Name()). + HasSchemaName(id.SchemaName()). + HasArgumentsRawContains(expectedReturnDataType.ToLegacyDataTypeSql()), + ) + + assertions.AssertThatObject(t, objectassert.ProcedureDetails(t, id). + HasReturnDataType(expectedReturnDataType), + ) }) t.Run("show parameters", func(t *testing.T) {