Skip to content

Commit

Permalink
feat: support table data type (#3274)
Browse files Browse the repository at this point in the history
### Changes
- support table data types
- add unit tests on parsing and comparing
- add integration tests in function and procedure tests
  • Loading branch information
sfc-gh-jcieslak authored Dec 12, 2024
1 parent 15aa9c2 commit 13401d5
Show file tree
Hide file tree
Showing 8 changed files with 256 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,14 @@ func (f *FunctionAssert) HasExactlySecrets(expectedSecrets map[string]sdk.Schema
})
return f
}

func (f *FunctionAssert) HasArgumentsRawContains(substring string) *FunctionAssert {
f.AddAssertion(func(t *testing.T, o *sdk.Function) error {
t.Helper()
if !strings.Contains(o.ArgumentsRaw, substring) {
return fmt.Errorf("expected arguments raw contain: %v, to contain: %v", o.ArgumentsRaw, substring)
}
return nil
})
return f
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,14 @@ func (f *ProcedureAssert) HasExactlyExternalAccessIntegrations(integrations ...s
})
return f
}

func (p *ProcedureAssert) HasArgumentsRawContains(substring string) *ProcedureAssert {
p.AddAssertion(func(t *testing.T, o *sdk.Procedure) error {
t.Helper()
if !strings.Contains(o.ArgumentsRaw, substring) {
return fmt.Errorf("expected arguments raw contain: %v, to contain: %v", o.ArgumentsRaw, substring)
}
return nil
})
return p
}
5 changes: 5 additions & 0 deletions pkg/sdk/datatypes/data_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ func ParseDataType(raw string) (DataType, error) {
if idx := slices.IndexFunc(VectorDataTypeSynonyms, func(s string) bool { return strings.HasPrefix(dataTypeRaw, s) }); idx >= 0 {
return parseVectorDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, VectorDataTypeSynonyms[idx]})
}
if idx := slices.IndexFunc(TableDataTypeSynonyms, func(s string) bool { return strings.HasPrefix(dataTypeRaw, s) }); idx >= 0 {
return parseTableDataTypeRaw(sanitizedDataTypeRaw{strings.TrimSpace(raw), TableDataTypeSynonyms[idx]})
}

return nil, fmt.Errorf("invalid data type: %s", raw)
}
Expand Down Expand Up @@ -118,6 +121,8 @@ func AreTheSame(a DataType, b DataType) bool {
return castSuccessfully(v, b, areNumberDataTypesTheSame)
case *ObjectDataType:
return castSuccessfully(v, b, noArgsDataTypesAreTheSame)
case *TableDataType:
return castSuccessfully(v, b, areTableDataTypesTheSame)
case *TextDataType:
return castSuccessfully(v, b, areTextDataTypesTheSame)
case *TimeDataType:
Expand Down
101 changes: 101 additions & 0 deletions pkg/sdk/datatypes/data_types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"strings"
"testing"

"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/collections"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -1095,6 +1097,98 @@ func Test_ParseDataType_Vector(t *testing.T) {
}
}

func Test_ParseDataType_Table(t *testing.T) {
type column struct {
Name string
Type string
}
type test struct {
input string
expectedColumns []column
}

positiveTestCases := []test{
{input: "TABLE()", expectedColumns: []column{}},
{input: "TABLE ()", expectedColumns: []column{}},
{input: "TABLE ( )", expectedColumns: []column{}},
{input: "TABLE(arg_name NUMBER)", expectedColumns: []column{{"arg_name", "NUMBER"}}},
{input: "TABLE(arg_name double precision, arg_name_2 NUMBER)", expectedColumns: []column{{"arg_name", "double precision"}, {"arg_name_2", "NUMBER"}}},
{input: "TABLE(arg_name NUMBER(38))", expectedColumns: []column{{"arg_name", "NUMBER(38)"}}},
{input: "TABLE(arg_name NUMBER(38), arg_name_2 VARCHAR)", expectedColumns: []column{{"arg_name", "NUMBER(38)"}, {"arg_name_2", "VARCHAR"}}},
{input: "TABLE(arg_name number, second float, third GEOGRAPHY)", expectedColumns: []column{{"arg_name", "number"}, {"second", "float"}, {"third", "GEOGRAPHY"}}},
{input: "TABLE ( arg_name varchar, second date, third TIME )", expectedColumns: []column{{"arg_name", "varchar"}, {"second", "date"}, {"third", "time"}}},
// TODO: Support types with parameters (for now, only legacy types are supported because Snowflake returns only with this output), e.g. TABLE(ARG NUMBER(38, 0))
// TODO: Support nested tables, e.g. TABLE(ARG NUMBER, NESTED TABLE(A VARCHAR, B GEOMETRY))
// TODO: Support complex argument names (with quotes / spaces / special characters / etc)
}

negativeTestCases := []test{
{input: "TABLE())"},
{input: "TABLE(1, 2)"},
{input: "TABLE(INT, INT)"},
{input: "TABLE(a b)"},
{input: "TABLE(1)"},
{input: "TABLE(2, INT)"},
{input: "TABLE"},
{input: "TABLE(INT, 2, 3)"},
{input: "TABLE(INT)"},
{input: "TABLE(x, 2)"},
{input: "TABLE("},
{input: "TABLE)"},
{input: "TA BLE"},
}

for _, tc := range positiveTestCases {
tc := tc
t.Run(tc.input, func(t *testing.T) {
parsed, err := ParseDataType(tc.input)

require.NoError(t, err)
require.IsType(t, &TableDataType{}, parsed)

assert.Equal(t, "TABLE", parsed.(*TableDataType).underlyingType)
assert.Equal(t, len(tc.expectedColumns), len(parsed.(*TableDataType).columns))
for i, column := range tc.expectedColumns {
assert.Equal(t, column.Name, parsed.(*TableDataType).columns[i].name)
parsedType, err := ParseDataType(column.Type)
require.NoError(t, err)
assert.Equal(t, parsedType.ToLegacyDataTypeSql(), parsed.(*TableDataType).columns[i].dataType.ToLegacyDataTypeSql())
}

legacyColumns := strings.Join(collections.Map(tc.expectedColumns, func(col column) string {
parsedType, err := ParseDataType(col.Type)
require.NoError(t, err)
return fmt.Sprintf("%s %s", col.Name, parsedType.ToLegacyDataTypeSql())
}), ", ")
assert.Equal(t, fmt.Sprintf("TABLE(%s)", legacyColumns), parsed.ToLegacyDataTypeSql())

canonicalColumns := strings.Join(collections.Map(tc.expectedColumns, func(col column) string {
parsedType, err := ParseDataType(col.Type)
require.NoError(t, err)
return fmt.Sprintf("%s %s", col.Name, parsedType.Canonical())
}), ", ")
assert.Equal(t, fmt.Sprintf("TABLE(%s)", canonicalColumns), parsed.Canonical())

columns := strings.Join(collections.Map(tc.expectedColumns, func(col column) string {
parsedType, err := ParseDataType(col.Type)
require.NoError(t, err)
return fmt.Sprintf("%s %s", col.Name, parsedType.ToSql())
}), ", ")
assert.Equal(t, fmt.Sprintf("TABLE(%s)", columns), parsed.ToSql())
})
}

for _, tc := range negativeTestCases {
tc := tc
t.Run("negative: "+tc.input, func(t *testing.T) {
parsed, err := ParseDataType(tc.input)

require.Error(t, err)
require.Nil(t, parsed)
})
}
}

func Test_AreTheSame(t *testing.T) {
// empty d1/d2 means nil DataType input
type test struct {
Expand Down Expand Up @@ -1145,6 +1239,13 @@ func Test_AreTheSame(t *testing.T) {
{d1: "TIME", d2: "TIME", expectedOutcome: true},
{d1: "TIME", d2: "TIME(5)", expectedOutcome: false},
{d1: "TIME", d2: fmt.Sprintf("TIME(%d)", DefaultTimePrecision), expectedOutcome: true},
{d1: "TABLE()", d2: "TABLE()", expectedOutcome: true},
{d1: "TABLE(A NUMBER)", d2: "TABLE(B NUMBER)", expectedOutcome: false},
{d1: "TABLE(A NUMBER)", d2: "TABLE(a NUMBER)", expectedOutcome: false},
{d1: "TABLE(A NUMBER)", d2: "TABLE(A VARCHAR)", expectedOutcome: false},
{d1: "TABLE(A NUMBER, B VARCHAR)", d2: "TABLE(A NUMBER, B VARCHAR)", expectedOutcome: true},
{d1: "TABLE(A NUMBER, B NUMBER)", d2: "TABLE(A NUMBER, B VARCHAR)", expectedOutcome: false},
{d1: "TABLE()", d2: "TABLE(A NUMBER)", expectedOutcome: false},
}

for _, tc := range testCases {
Expand Down
1 change: 0 additions & 1 deletion pkg/sdk/datatypes/legacy.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ const (
TimestampNtzLegacyDataType = "TIMESTAMP_NTZ"
TimestampTzLegacyDataType = "TIMESTAMP_TZ"
VariantLegacyDataType = "VARIANT"

// TableLegacyDataType was not a value of legacy data type in the old implementation. Left for now for an easier implementation.
TableLegacyDataType = "TABLE"
)
81 changes: 76 additions & 5 deletions pkg/sdk/datatypes/table.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
package datatypes

// TableDataType is based on TODO [SNOW-1348103]
import (
"fmt"
"strings"

"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/collections"
"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/logging"
)

// TableDataType is based on https://docs.snowflake.com/en/developer-guide/stored-procedure/stored-procedures-java#returning-tabular-data.
// It does not have synonyms.
// It consists of a list of column name + column type; may be empty.
// TODO [SNOW-1348103]: test and improve
type TableDataType struct {
columns []TableDataTypeColumn
underlyingType string
Expand All @@ -14,6 +21,8 @@ type TableDataTypeColumn struct {
dataType DataType
}

var TableDataTypeSynonyms = []string{"TABLE"}

func (c *TableDataTypeColumn) ColumnName() string {
return c.name
}
Expand All @@ -23,17 +32,79 @@ func (c *TableDataTypeColumn) ColumnType() DataType {
}

func (t *TableDataType) ToSql() string {
return t.underlyingType
columns := strings.Join(collections.Map(t.columns, func(col TableDataTypeColumn) string {
return fmt.Sprintf("%s %s", col.name, col.dataType.ToSql())
}), ", ")
return fmt.Sprintf("%s(%s)", t.underlyingType, columns)
}

func (t *TableDataType) ToLegacyDataTypeSql() string {
return TableLegacyDataType
columns := strings.Join(collections.Map(t.columns, func(col TableDataTypeColumn) string {
return fmt.Sprintf("%s %s", col.name, col.dataType.ToLegacyDataTypeSql())
}), ", ")
return fmt.Sprintf("%s(%s)", TableLegacyDataType, columns)
}

func (t *TableDataType) Canonical() string {
return TableLegacyDataType
columns := strings.Join(collections.Map(t.columns, func(col TableDataTypeColumn) string {
return fmt.Sprintf("%s %s", col.name, col.dataType.Canonical())
}), ", ")
return fmt.Sprintf("%s(%s)", TableLegacyDataType, columns)
}

func (t *TableDataType) Columns() []TableDataTypeColumn {
return t.columns
}

func parseTableDataTypeRaw(raw sanitizedDataTypeRaw) (*TableDataType, error) {
r := strings.TrimSpace(strings.TrimPrefix(raw.raw, raw.matchedByType))
if r == "" || (!strings.HasPrefix(r, "(") || !strings.HasSuffix(r, ")")) {
logging.DebugLogger.Printf(`table %s could not be parsed, use "%s(argName argType, ...)" format`, raw.raw, raw.matchedByType)
return nil, fmt.Errorf(`table %s could not be parsed, use "%s(argName argType, ...)" format`, raw.raw, raw.matchedByType)
}
onlyArgs := strings.TrimSpace(r[1 : len(r)-1])
if onlyArgs == "" {
return &TableDataType{
columns: make([]TableDataTypeColumn, 0),
underlyingType: raw.matchedByType,
}, nil
}
columns, err := collections.MapErr(strings.Split(onlyArgs, ","), func(arg string) (TableDataTypeColumn, error) {
argParts := strings.SplitN(strings.TrimSpace(arg), " ", 2)
if len(argParts) != 2 {
return TableDataTypeColumn{}, fmt.Errorf("could not parse table column: %s, it should contain the following format `<arg_name> <arg_type>`; parser failure may be connected to the complex argument names", arg)
}
argDataType, err := ParseDataType(argParts[1])
if err != nil {
return TableDataTypeColumn{}, err
}
return TableDataTypeColumn{
name: argParts[0],
dataType: argDataType,
}, nil
})
if err != nil {
return nil, err
}
return &TableDataType{
columns: columns,
underlyingType: raw.matchedByType,
}, nil
}

func areTableDataTypesTheSame(a, b *TableDataType) bool {
if len(a.columns) != len(b.columns) {
return false
}

for i := range a.columns {
aColumn := a.columns[i]
bColumn := b.columns[i]

if aColumn.name != bColumn.name || !AreTheSame(aColumn.dataType, bColumn.dataType) {
return false
}
}

return true
}
34 changes: 34 additions & 0 deletions pkg/sdk/testint/functions_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2049,4 +2049,38 @@ func TestInt_Functions(t *testing.T) {
assert.Equal(t, dataType.Canonical(), pairs["returns"])
})
}

t.Run("create function for SQL - return table data type", func(t *testing.T) {
argName := "x"

returnDataType, err := datatypes.ParseDataType(fmt.Sprintf("TABLE(ID %s, PRICE %s, THIRD %s)", datatypes.NumberLegacyDataType, datatypes.FloatLegacyDataType, datatypes.VarcharLegacyDataType))
require.NoError(t, err)

id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(datatypes.VarcharLegacyDataType)

definition := ` SELECT 1, 2.2::float, 'abc';`
dt := sdk.NewFunctionReturnsResultDataTypeRequest(returnDataType)
returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt)
argument := sdk.NewFunctionArgumentRequest(argName, nil).WithArgDataTypeOld(datatypes.VarcharLegacyDataType)
request := sdk.NewCreateForSQLFunctionRequestDefinitionWrapped(id.SchemaObjectId(), *returns, definition).
WithArguments([]sdk.FunctionArgumentRequest{*argument})

err = client.Functions.CreateForSQL(ctx, request)
require.NoError(t, err)
t.Cleanup(testClientHelper().Function.DropFunctionFunc(t, id))

function, err := client.Functions.ShowByID(ctx, id)
require.NoError(t, err)

assertions.AssertThatObject(t, objectassert.FunctionFromObject(t, function).
HasCreatedOnNotEmpty().
HasName(id.Name()).
HasSchemaName(id.SchemaName()).
HasArgumentsRawContains(returnDataType.ToLegacyDataTypeSql()),
)

assertions.AssertThatObject(t, objectassert.FunctionDetails(t, id).
HasReturnDataType(returnDataType),
)
})
}
30 changes: 18 additions & 12 deletions pkg/sdk/testint/procedures_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1781,36 +1781,42 @@ def filter_by_role(session, table_name, role):
require.GreaterOrEqual(t, len(procedures), 1)
})

// TODO [SNOW-1348103]: adjust or remove
t.Run("create procedure for SQL: returns table", func(t *testing.T) {
t.Skipf("Skipped for now; left as inspiration for resource rework as part of SNOW-1348103")

name := "find_invoice_by_id"
id := testClientHelper().Ids.NewSchemaObjectIdentifierWithArguments(name, sdk.DataTypeVARCHAR)
id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(sdk.DataTypeVARCHAR)

column1 := sdk.NewProcedureColumnRequest("id", nil).WithColumnDataTypeOld("INTEGER")
column2 := sdk.NewProcedureColumnRequest("price", nil).WithColumnDataTypeOld("double")
column3 := sdk.NewProcedureColumnRequest("third", nil).WithColumnDataTypeOld("Geometry")
returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2, *column3})
expectedReturnDataType, err := datatypes.ParseDataType(fmt.Sprintf("TABLE(id %s, price %s, third %s)", datatypes.NumberLegacyDataType, datatypes.FloatLegacyDataType, datatypes.GeometryLegacyDataType))
require.NoError(t, err)
definition := `
DECLARE
res RESULTSET DEFAULT (SELECT * FROM invoices WHERE id = :id);
BEGIN
RETURN TABLE(res);
END;`
column1 := sdk.NewProcedureColumnRequest("id", nil).WithColumnDataTypeOld("INTEGER")
column2 := sdk.NewProcedureColumnRequest("price", nil).WithColumnDataTypeOld("NUMBER(12,2)")
returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2})
returns := sdk.NewProcedureSQLReturnsRequest().WithTable(*returnsTable)
argument := sdk.NewProcedureArgumentRequest("id", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR)
request := sdk.NewCreateForSQLProcedureRequestDefinitionWrapped(id.SchemaObjectId(), *returns, definition).
WithOrReplace(true).
// SNOW-1051627 todo: uncomment once null input behavior working again
// WithNullInputBehavior(sdk.NullInputBehaviorPointer(sdk.NullInputBehaviorReturnsNullInput)).
WithArguments([]sdk.ProcedureArgumentRequest{*argument})
err := client.Procedures.CreateForSQL(ctx, request)
err = client.Procedures.CreateForSQL(ctx, request)
require.NoError(t, err)
t.Cleanup(testClientHelper().Procedure.DropProcedureFunc(t, id))

procedures, err := client.Procedures.Show(ctx, sdk.NewShowProcedureRequest())
require.NoError(t, err)
require.GreaterOrEqual(t, len(procedures), 1)
assertions.AssertThatObject(t, objectassert.Procedure(t, id).
HasCreatedOnNotEmpty().
HasName(id.Name()).
HasSchemaName(id.SchemaName()).
HasArgumentsRawContains(expectedReturnDataType.ToLegacyDataTypeSql()),
)

assertions.AssertThatObject(t, objectassert.ProcedureDetails(t, id).
HasReturnDataType(expectedReturnDataType),
)
})

t.Run("show parameters", func(t *testing.T) {
Expand Down

0 comments on commit 13401d5

Please sign in to comment.