From 9526e410ea2341b992b949469c9258caacc4c947 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Cie=C5=9Blak?= Date: Thu, 12 Dec 2024 15:32:52 +0100 Subject: [PATCH] changes after review --- pkg/sdk/datatypes/data_types_test.go | 35 ++++++++++++------- pkg/sdk/datatypes/legacy.go | 3 +- pkg/sdk/datatypes/table.go | 25 +++++++------ pkg/sdk/datatypes/vector.go | 1 - pkg/sdk/testint/functions_integration_test.go | 8 +++-- .../testint/procedures_integration_test.go | 7 +++- 6 files changed, 51 insertions(+), 28 deletions(-) diff --git a/pkg/sdk/datatypes/data_types_test.go b/pkg/sdk/datatypes/data_types_test.go index 33004c6997..7e6382e63f 100644 --- a/pkg/sdk/datatypes/data_types_test.go +++ b/pkg/sdk/datatypes/data_types_test.go @@ -1099,8 +1099,8 @@ func Test_ParseDataType_Vector(t *testing.T) { func Test_ParseDataType_Table(t *testing.T) { type column struct { - name string - legacyType string + Name string + Type string } type test struct { input string @@ -1110,14 +1110,20 @@ func Test_ParseDataType_Table(t *testing.T) { positiveTestCases := []test{ {input: "TABLE()", expectedColumns: []column{}}, {input: "TABLE ()", expectedColumns: []column{}}, - {input: "TABLE(arg_name NUMBER)", expectedColumns: []column{{"arg_name", NumberLegacyDataType}}}, - {input: "TABLE(arg_name number, second float, third GEOGRAPHY)", expectedColumns: []column{{"arg_name", NumberLegacyDataType}, {"second", FloatLegacyDataType}, {"third", GeographyLegacyDataType}}}, - {input: "TABLE ( arg_name varchar, second date, third TIME )", expectedColumns: []column{{"arg_name", VarcharLegacyDataType}, {"second", DateLegacyDataType}, {"third", TimeLegacyDataType}}}, + {input: "TABLE ( )", expectedColumns: []column{}}, + {input: "TABLE(arg_name NUMBER)", expectedColumns: []column{{"arg_name", "NUMBER"}}}, + {input: "TABLE(arg_name double precision, arg_name_2 NUMBER)", expectedColumns: []column{{"arg_name", "double precision"}, {"arg_name_2", "NUMBER"}}}, + {input: "TABLE(arg_name NUMBER(38))", expectedColumns: []column{{"arg_name", "NUMBER(38)"}}}, + {input: "TABLE(arg_name NUMBER(38), arg_name_2 VARCHAR)", expectedColumns: []column{{"arg_name", "NUMBER(38)"}, {"arg_name_2", "VARCHAR"}}}, + {input: "TABLE(arg_name number, second float, third GEOGRAPHY)", expectedColumns: []column{{"arg_name", "number"}, {"second", "float"}, {"third", "GEOGRAPHY"}}}, + {input: "TABLE ( arg_name varchar, second date, third TIME )", expectedColumns: []column{{"arg_name", "varchar"}, {"second", "date"}, {"third", "time"}}}, // TODO: Support types with parameters (for now, only legacy types are supported because Snowflake returns only with this output), e.g. TABLE(ARG NUMBER(38, 0)) // TODO: Support nested tables, e.g. TABLE(ARG NUMBER, NESTED TABLE(A VARCHAR, B GEOMETRY)) + // TODO: Support complex argument names (with quotes / spaces / special characters / etc) } negativeTestCases := []test{ + {input: "TABLE())"}, {input: "TABLE(1, 2)"}, {input: "TABLE(INT, INT)"}, {input: "TABLE(a b)"}, @@ -1143,26 +1149,30 @@ func Test_ParseDataType_Table(t *testing.T) { assert.Equal(t, "TABLE", parsed.(*TableDataType).underlyingType) assert.Equal(t, len(tc.expectedColumns), len(parsed.(*TableDataType).columns)) for i, column := range tc.expectedColumns { - assert.Equal(t, column.name, parsed.(*TableDataType).columns[i].name) - assert.Equal(t, column.legacyType, parsed.(*TableDataType).columns[i].dataType.ToLegacyDataTypeSql()) + assert.Equal(t, column.Name, parsed.(*TableDataType).columns[i].name) + parsedType, err := ParseDataType(column.Type) + require.NoError(t, err) + assert.Equal(t, parsedType.ToLegacyDataTypeSql(), parsed.(*TableDataType).columns[i].dataType.ToLegacyDataTypeSql()) } legacyColumns := strings.Join(collections.Map(tc.expectedColumns, func(col column) string { - return fmt.Sprintf("%s %s", col.name, col.legacyType) + parsedType, err := ParseDataType(col.Type) + require.NoError(t, err) + return fmt.Sprintf("%s %s", col.Name, parsedType.ToLegacyDataTypeSql()) }), ", ") assert.Equal(t, fmt.Sprintf("TABLE(%s)", legacyColumns), parsed.ToLegacyDataTypeSql()) canonicalColumns := strings.Join(collections.Map(tc.expectedColumns, func(col column) string { - parsedType, err := ParseDataType(col.legacyType) + parsedType, err := ParseDataType(col.Type) require.NoError(t, err) - return fmt.Sprintf("%s %s", col.name, parsedType.Canonical()) + return fmt.Sprintf("%s %s", col.Name, parsedType.Canonical()) }), ", ") assert.Equal(t, fmt.Sprintf("TABLE(%s)", canonicalColumns), parsed.Canonical()) columns := strings.Join(collections.Map(tc.expectedColumns, func(col column) string { - parsedType, err := ParseDataType(col.legacyType) + parsedType, err := ParseDataType(col.Type) require.NoError(t, err) - return fmt.Sprintf("%s %s", col.name, parsedType.ToSql()) + return fmt.Sprintf("%s %s", col.Name, parsedType.ToSql()) }), ", ") assert.Equal(t, fmt.Sprintf("TABLE(%s)", columns), parsed.ToSql()) }) @@ -1235,6 +1245,7 @@ func Test_AreTheSame(t *testing.T) { {d1: "TABLE(A NUMBER)", d2: "TABLE(A VARCHAR)", expectedOutcome: false}, {d1: "TABLE(A NUMBER, B VARCHAR)", d2: "TABLE(A NUMBER, B VARCHAR)", expectedOutcome: true}, {d1: "TABLE(A NUMBER, B NUMBER)", d2: "TABLE(A NUMBER, B VARCHAR)", expectedOutcome: false}, + {d1: "TABLE()", d2: "TABLE(A NUMBER)", expectedOutcome: false}, } for _, tc := range testCases { diff --git a/pkg/sdk/datatypes/legacy.go b/pkg/sdk/datatypes/legacy.go index b8bd63040f..63f523779e 100644 --- a/pkg/sdk/datatypes/legacy.go +++ b/pkg/sdk/datatypes/legacy.go @@ -16,5 +16,6 @@ const ( TimestampNtzLegacyDataType = "TIMESTAMP_NTZ" TimestampTzLegacyDataType = "TIMESTAMP_TZ" VariantLegacyDataType = "VARIANT" - TableLegacyDataType = "TABLE" + // TableLegacyDataType was not a value of legacy data type in the old implementation. Left for now for an easier implementation. + TableLegacyDataType = "TABLE" ) diff --git a/pkg/sdk/datatypes/table.go b/pkg/sdk/datatypes/table.go index dbe87307fe..a05298c992 100644 --- a/pkg/sdk/datatypes/table.go +++ b/pkg/sdk/datatypes/table.go @@ -8,7 +8,8 @@ import ( "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/logging" ) -// TableDataType does not have synonyms. +// TableDataType is based on https://docs.snowflake.com/en/developer-guide/stored-procedure/stored-procedures-java#returning-tabular-data. +// It does not have synonyms. // It consists of a list of column name + column type; may be empty. type TableDataType struct { columns []TableDataTypeColumn @@ -20,6 +21,8 @@ type TableDataTypeColumn struct { dataType DataType } +var TableDataTypeSynonyms = []string{"TABLE"} + func (c *TableDataTypeColumn) ColumnName() string { return c.name } @@ -32,21 +35,21 @@ func (t *TableDataType) ToSql() string { columns := strings.Join(collections.Map(t.columns, func(col TableDataTypeColumn) string { return fmt.Sprintf("%s %s", col.name, col.dataType.ToSql()) }), ", ") - return fmt.Sprintf("%s (%s)", TableLegacyDataType, columns) + return fmt.Sprintf("%s(%s)", t.underlyingType, columns) } func (t *TableDataType) ToLegacyDataTypeSql() string { columns := strings.Join(collections.Map(t.columns, func(col TableDataTypeColumn) string { return fmt.Sprintf("%s %s", col.name, col.dataType.ToLegacyDataTypeSql()) }), ", ") - return fmt.Sprintf("%s (%s)", TableLegacyDataType, columns) + return fmt.Sprintf("%s(%s)", TableLegacyDataType, columns) } func (t *TableDataType) Canonical() string { columns := strings.Join(collections.Map(t.columns, func(col TableDataTypeColumn) string { return fmt.Sprintf("%s %s", col.name, col.dataType.Canonical()) }), ", ") - return fmt.Sprintf("%s (%s)", TableLegacyDataType, columns) + return fmt.Sprintf("%s(%s)", TableLegacyDataType, columns) } func (t *TableDataType) Columns() []TableDataTypeColumn { @@ -55,19 +58,19 @@ func (t *TableDataType) Columns() []TableDataTypeColumn { func parseTableDataTypeRaw(raw sanitizedDataTypeRaw) (*TableDataType, error) { r := strings.TrimSpace(strings.TrimPrefix(raw.raw, raw.matchedByType)) - if r == "()" { + if r == "" || (!strings.HasPrefix(r, "(") || !strings.HasSuffix(r, ")")) { + logging.DebugLogger.Printf(`table %s could not be parsed, use "%s(argName argType, ...)" format`, raw.raw, raw.matchedByType) + return nil, fmt.Errorf(`table %s could not be parsed, use "%s(argName argType, ...)" format`, raw.raw, raw.matchedByType) + } + onlyArgs := strings.TrimSpace(r[1 : len(r)-1]) + if onlyArgs == "" { return &TableDataType{ columns: make([]TableDataTypeColumn, 0), underlyingType: raw.matchedByType, }, nil } - if r == "" || (!strings.HasPrefix(r, "(") || !strings.HasSuffix(r, ")")) { - logging.DebugLogger.Printf(`table %s could not be parsed, use "%s(argName argType, ...)" format`, raw.raw, raw.matchedByType) - return nil, fmt.Errorf(`table %s could not be parsed, use "%s(argName argType, ...)" format`, raw.raw, raw.matchedByType) - } - onlyArgs := r[1 : len(r)-1] columns, err := collections.MapErr(strings.Split(onlyArgs, ","), func(arg string) (TableDataTypeColumn, error) { - argParts := strings.Split(strings.TrimSpace(arg), " ") + argParts := strings.SplitN(strings.TrimSpace(arg), " ", 2) if len(argParts) != 2 { return TableDataTypeColumn{}, fmt.Errorf("could not parse table column: %s, it should contain the following format ` `; parser failure may be connected to the complex argument names", arg) } diff --git a/pkg/sdk/datatypes/vector.go b/pkg/sdk/datatypes/vector.go index d4fa9e9050..035249af64 100644 --- a/pkg/sdk/datatypes/vector.go +++ b/pkg/sdk/datatypes/vector.go @@ -32,7 +32,6 @@ func (t *VectorDataType) Canonical() string { var ( VectorDataTypeSynonyms = []string{"VECTOR"} - TableDataTypeSynonyms = []string{"TABLE"} VectorAllowedInnerTypes = []string{"INT", "FLOAT"} ) diff --git a/pkg/sdk/testint/functions_integration_test.go b/pkg/sdk/testint/functions_integration_test.go index d7a0e386b4..349c17de83 100644 --- a/pkg/sdk/testint/functions_integration_test.go +++ b/pkg/sdk/testint/functions_integration_test.go @@ -2016,7 +2016,7 @@ func TestInt_Functions(t *testing.T) { id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(datatypes.VarcharLegacyDataType) - definition := ` SELECT 1, 2.2::float, 'abc');` + definition := ` SELECT 1, 2.2::float, 'abc';` dt := sdk.NewFunctionReturnsResultDataTypeRequest(returnDataType) returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt) argument := sdk.NewFunctionArgumentRequest(argName, nil).WithArgDataTypeOld(datatypes.VarcharLegacyDataType) @@ -2034,7 +2034,11 @@ func TestInt_Functions(t *testing.T) { HasCreatedOnNotEmpty(). HasName(id.Name()). HasSchemaName(id.SchemaName()). - HasArgumentsRawContains(fmt.Sprintf(`RETURN %s`, returnDataType.ToLegacyDataTypeSql())), + HasArgumentsRawContains(returnDataType.ToLegacyDataTypeSql()), + ) + + assertions.AssertThatObject(t, objectassert.FunctionDetails(t, id). + HasReturnDataType(returnDataType), ) }) } diff --git a/pkg/sdk/testint/procedures_integration_test.go b/pkg/sdk/testint/procedures_integration_test.go index dd357e11d8..3b4e4f041d 100644 --- a/pkg/sdk/testint/procedures_integration_test.go +++ b/pkg/sdk/testint/procedures_integration_test.go @@ -1764,7 +1764,12 @@ def filter_by_role(session, table_name, role): HasCreatedOnNotEmpty(). HasName(id.Name()). HasSchemaName(id.SchemaName()). - HasArgumentsRawContains(fmt.Sprintf(`RETURN %s`, expectedReturnDataType.ToLegacyDataTypeSql()))) + HasArgumentsRawContains(expectedReturnDataType.ToLegacyDataTypeSql()), + ) + + assertions.AssertThatObject(t, objectassert.ProcedureDetails(t, id). + HasReturnDataType(expectedReturnDataType), + ) }) t.Run("show parameters", func(t *testing.T) {