Skip to content

Commit

Permalink
changes after review
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jcieslak committed Dec 12, 2024
1 parent cf63249 commit 9526e41
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 28 deletions.
35 changes: 23 additions & 12 deletions pkg/sdk/datatypes/data_types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1099,8 +1099,8 @@ func Test_ParseDataType_Vector(t *testing.T) {

func Test_ParseDataType_Table(t *testing.T) {
type column struct {
name string
legacyType string
Name string
Type string
}
type test struct {
input string
Expand All @@ -1110,14 +1110,20 @@ func Test_ParseDataType_Table(t *testing.T) {
positiveTestCases := []test{
{input: "TABLE()", expectedColumns: []column{}},
{input: "TABLE ()", expectedColumns: []column{}},
{input: "TABLE(arg_name NUMBER)", expectedColumns: []column{{"arg_name", NumberLegacyDataType}}},
{input: "TABLE(arg_name number, second float, third GEOGRAPHY)", expectedColumns: []column{{"arg_name", NumberLegacyDataType}, {"second", FloatLegacyDataType}, {"third", GeographyLegacyDataType}}},
{input: "TABLE ( arg_name varchar, second date, third TIME )", expectedColumns: []column{{"arg_name", VarcharLegacyDataType}, {"second", DateLegacyDataType}, {"third", TimeLegacyDataType}}},
{input: "TABLE ( )", expectedColumns: []column{}},
{input: "TABLE(arg_name NUMBER)", expectedColumns: []column{{"arg_name", "NUMBER"}}},
{input: "TABLE(arg_name double precision, arg_name_2 NUMBER)", expectedColumns: []column{{"arg_name", "double precision"}, {"arg_name_2", "NUMBER"}}},
{input: "TABLE(arg_name NUMBER(38))", expectedColumns: []column{{"arg_name", "NUMBER(38)"}}},
{input: "TABLE(arg_name NUMBER(38), arg_name_2 VARCHAR)", expectedColumns: []column{{"arg_name", "NUMBER(38)"}, {"arg_name_2", "VARCHAR"}}},
{input: "TABLE(arg_name number, second float, third GEOGRAPHY)", expectedColumns: []column{{"arg_name", "number"}, {"second", "float"}, {"third", "GEOGRAPHY"}}},
{input: "TABLE ( arg_name varchar, second date, third TIME )", expectedColumns: []column{{"arg_name", "varchar"}, {"second", "date"}, {"third", "time"}}},
// TODO: Support types with parameters (for now, only legacy types are supported because Snowflake returns only with this output), e.g. TABLE(ARG NUMBER(38, 0))
// TODO: Support nested tables, e.g. TABLE(ARG NUMBER, NESTED TABLE(A VARCHAR, B GEOMETRY))
// TODO: Support complex argument names (with quotes / spaces / special characters / etc)
}

negativeTestCases := []test{
{input: "TABLE())"},
{input: "TABLE(1, 2)"},
{input: "TABLE(INT, INT)"},
{input: "TABLE(a b)"},
Expand All @@ -1143,26 +1149,30 @@ func Test_ParseDataType_Table(t *testing.T) {
assert.Equal(t, "TABLE", parsed.(*TableDataType).underlyingType)
assert.Equal(t, len(tc.expectedColumns), len(parsed.(*TableDataType).columns))
for i, column := range tc.expectedColumns {
assert.Equal(t, column.name, parsed.(*TableDataType).columns[i].name)
assert.Equal(t, column.legacyType, parsed.(*TableDataType).columns[i].dataType.ToLegacyDataTypeSql())
assert.Equal(t, column.Name, parsed.(*TableDataType).columns[i].name)
parsedType, err := ParseDataType(column.Type)
require.NoError(t, err)
assert.Equal(t, parsedType.ToLegacyDataTypeSql(), parsed.(*TableDataType).columns[i].dataType.ToLegacyDataTypeSql())
}

legacyColumns := strings.Join(collections.Map(tc.expectedColumns, func(col column) string {
return fmt.Sprintf("%s %s", col.name, col.legacyType)
parsedType, err := ParseDataType(col.Type)
require.NoError(t, err)
return fmt.Sprintf("%s %s", col.Name, parsedType.ToLegacyDataTypeSql())
}), ", ")
assert.Equal(t, fmt.Sprintf("TABLE(%s)", legacyColumns), parsed.ToLegacyDataTypeSql())

canonicalColumns := strings.Join(collections.Map(tc.expectedColumns, func(col column) string {
parsedType, err := ParseDataType(col.legacyType)
parsedType, err := ParseDataType(col.Type)
require.NoError(t, err)
return fmt.Sprintf("%s %s", col.name, parsedType.Canonical())
return fmt.Sprintf("%s %s", col.Name, parsedType.Canonical())
}), ", ")
assert.Equal(t, fmt.Sprintf("TABLE(%s)", canonicalColumns), parsed.Canonical())

columns := strings.Join(collections.Map(tc.expectedColumns, func(col column) string {
parsedType, err := ParseDataType(col.legacyType)
parsedType, err := ParseDataType(col.Type)
require.NoError(t, err)
return fmt.Sprintf("%s %s", col.name, parsedType.ToSql())
return fmt.Sprintf("%s %s", col.Name, parsedType.ToSql())
}), ", ")
assert.Equal(t, fmt.Sprintf("TABLE(%s)", columns), parsed.ToSql())
})
Expand Down Expand Up @@ -1235,6 +1245,7 @@ func Test_AreTheSame(t *testing.T) {
{d1: "TABLE(A NUMBER)", d2: "TABLE(A VARCHAR)", expectedOutcome: false},
{d1: "TABLE(A NUMBER, B VARCHAR)", d2: "TABLE(A NUMBER, B VARCHAR)", expectedOutcome: true},
{d1: "TABLE(A NUMBER, B NUMBER)", d2: "TABLE(A NUMBER, B VARCHAR)", expectedOutcome: false},
{d1: "TABLE()", d2: "TABLE(A NUMBER)", expectedOutcome: false},
}

for _, tc := range testCases {
Expand Down
3 changes: 2 additions & 1 deletion pkg/sdk/datatypes/legacy.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ const (
TimestampNtzLegacyDataType = "TIMESTAMP_NTZ"
TimestampTzLegacyDataType = "TIMESTAMP_TZ"
VariantLegacyDataType = "VARIANT"
TableLegacyDataType = "TABLE"
// TableLegacyDataType was not a value of legacy data type in the old implementation. Left for now for an easier implementation.
TableLegacyDataType = "TABLE"
)
25 changes: 14 additions & 11 deletions pkg/sdk/datatypes/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ import (
"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/logging"
)

// TableDataType does not have synonyms.
// TableDataType is based on https://docs.snowflake.com/en/developer-guide/stored-procedure/stored-procedures-java#returning-tabular-data.
// It does not have synonyms.
// It consists of a list of column name + column type; may be empty.
type TableDataType struct {
columns []TableDataTypeColumn
Expand All @@ -20,6 +21,8 @@ type TableDataTypeColumn struct {
dataType DataType
}

var TableDataTypeSynonyms = []string{"TABLE"}

func (c *TableDataTypeColumn) ColumnName() string {
return c.name
}
Expand All @@ -32,21 +35,21 @@ func (t *TableDataType) ToSql() string {
columns := strings.Join(collections.Map(t.columns, func(col TableDataTypeColumn) string {
return fmt.Sprintf("%s %s", col.name, col.dataType.ToSql())
}), ", ")
return fmt.Sprintf("%s (%s)", TableLegacyDataType, columns)
return fmt.Sprintf("%s(%s)", t.underlyingType, columns)
}

func (t *TableDataType) ToLegacyDataTypeSql() string {
columns := strings.Join(collections.Map(t.columns, func(col TableDataTypeColumn) string {
return fmt.Sprintf("%s %s", col.name, col.dataType.ToLegacyDataTypeSql())
}), ", ")
return fmt.Sprintf("%s (%s)", TableLegacyDataType, columns)
return fmt.Sprintf("%s(%s)", TableLegacyDataType, columns)
}

func (t *TableDataType) Canonical() string {
columns := strings.Join(collections.Map(t.columns, func(col TableDataTypeColumn) string {
return fmt.Sprintf("%s %s", col.name, col.dataType.Canonical())
}), ", ")
return fmt.Sprintf("%s (%s)", TableLegacyDataType, columns)
return fmt.Sprintf("%s(%s)", TableLegacyDataType, columns)
}

func (t *TableDataType) Columns() []TableDataTypeColumn {
Expand All @@ -55,19 +58,19 @@ func (t *TableDataType) Columns() []TableDataTypeColumn {

func parseTableDataTypeRaw(raw sanitizedDataTypeRaw) (*TableDataType, error) {
r := strings.TrimSpace(strings.TrimPrefix(raw.raw, raw.matchedByType))
if r == "()" {
if r == "" || (!strings.HasPrefix(r, "(") || !strings.HasSuffix(r, ")")) {
logging.DebugLogger.Printf(`table %s could not be parsed, use "%s(argName argType, ...)" format`, raw.raw, raw.matchedByType)
return nil, fmt.Errorf(`table %s could not be parsed, use "%s(argName argType, ...)" format`, raw.raw, raw.matchedByType)
}
onlyArgs := strings.TrimSpace(r[1 : len(r)-1])
if onlyArgs == "" {
return &TableDataType{
columns: make([]TableDataTypeColumn, 0),
underlyingType: raw.matchedByType,
}, nil
}
if r == "" || (!strings.HasPrefix(r, "(") || !strings.HasSuffix(r, ")")) {
logging.DebugLogger.Printf(`table %s could not be parsed, use "%s(argName argType, ...)" format`, raw.raw, raw.matchedByType)
return nil, fmt.Errorf(`table %s could not be parsed, use "%s(argName argType, ...)" format`, raw.raw, raw.matchedByType)
}
onlyArgs := r[1 : len(r)-1]
columns, err := collections.MapErr(strings.Split(onlyArgs, ","), func(arg string) (TableDataTypeColumn, error) {
argParts := strings.Split(strings.TrimSpace(arg), " ")
argParts := strings.SplitN(strings.TrimSpace(arg), " ", 2)
if len(argParts) != 2 {
return TableDataTypeColumn{}, fmt.Errorf("could not parse table column: %s, it should contain the following format `<arg_name> <arg_type>`; parser failure may be connected to the complex argument names", arg)
}
Expand Down
1 change: 0 additions & 1 deletion pkg/sdk/datatypes/vector.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ func (t *VectorDataType) Canonical() string {

var (
VectorDataTypeSynonyms = []string{"VECTOR"}
TableDataTypeSynonyms = []string{"TABLE"}
VectorAllowedInnerTypes = []string{"INT", "FLOAT"}
)

Expand Down
8 changes: 6 additions & 2 deletions pkg/sdk/testint/functions_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2016,7 +2016,7 @@ func TestInt_Functions(t *testing.T) {

id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(datatypes.VarcharLegacyDataType)

definition := ` SELECT 1, 2.2::float, 'abc');`
definition := ` SELECT 1, 2.2::float, 'abc';`
dt := sdk.NewFunctionReturnsResultDataTypeRequest(returnDataType)
returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt)
argument := sdk.NewFunctionArgumentRequest(argName, nil).WithArgDataTypeOld(datatypes.VarcharLegacyDataType)
Expand All @@ -2034,7 +2034,11 @@ func TestInt_Functions(t *testing.T) {
HasCreatedOnNotEmpty().
HasName(id.Name()).
HasSchemaName(id.SchemaName()).
HasArgumentsRawContains(fmt.Sprintf(`RETURN %s`, returnDataType.ToLegacyDataTypeSql())),
HasArgumentsRawContains(returnDataType.ToLegacyDataTypeSql()),
)

assertions.AssertThatObject(t, objectassert.FunctionDetails(t, id).
HasReturnDataType(returnDataType),
)
})
}
7 changes: 6 additions & 1 deletion pkg/sdk/testint/procedures_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1764,7 +1764,12 @@ def filter_by_role(session, table_name, role):
HasCreatedOnNotEmpty().
HasName(id.Name()).
HasSchemaName(id.SchemaName()).
HasArgumentsRawContains(fmt.Sprintf(`RETURN %s`, expectedReturnDataType.ToLegacyDataTypeSql())))
HasArgumentsRawContains(expectedReturnDataType.ToLegacyDataTypeSql()),
)

assertions.AssertThatObject(t, objectassert.ProcedureDetails(t, id).
HasReturnDataType(expectedReturnDataType),
)
})

t.Run("show parameters", func(t *testing.T) {
Expand Down

0 comments on commit 9526e41

Please sign in to comment.