Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support table data type #3274

Merged
merged 3 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,14 @@ func (f *FunctionAssert) HasExactlySecrets(expectedSecrets map[string]sdk.Schema
})
return f
}

func (f *FunctionAssert) HasArgumentsRawContains(substring string) *FunctionAssert {
f.AddAssertion(func(t *testing.T, o *sdk.Function) error {
t.Helper()
if !strings.Contains(o.ArgumentsRaw, substring) {
return fmt.Errorf("expected arguments raw contain: %v, to contain: %v", o.ArgumentsRaw, substring)
}
return nil
})
return f
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,14 @@ func (f *ProcedureAssert) HasExactlyExternalAccessIntegrations(integrations ...s
})
return f
}

func (p *ProcedureAssert) HasArgumentsRawContains(substring string) *ProcedureAssert {
p.AddAssertion(func(t *testing.T, o *sdk.Procedure) error {
t.Helper()
if !strings.Contains(o.ArgumentsRaw, substring) {
return fmt.Errorf("expected arguments raw contain: %v, to contain: %v", o.ArgumentsRaw, substring)
}
return nil
})
return p
}
5 changes: 5 additions & 0 deletions pkg/sdk/datatypes/data_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ func ParseDataType(raw string) (DataType, error) {
if idx := slices.IndexFunc(VectorDataTypeSynonyms, func(s string) bool { return strings.HasPrefix(dataTypeRaw, s) }); idx >= 0 {
return parseVectorDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, VectorDataTypeSynonyms[idx]})
}
if idx := slices.IndexFunc(TableDataTypeSynonyms, func(s string) bool { return strings.HasPrefix(dataTypeRaw, s) }); idx >= 0 {
return parseTableDataTypeRaw(sanitizedDataTypeRaw{strings.TrimSpace(raw), TableDataTypeSynonyms[idx]})
}

return nil, fmt.Errorf("invalid data type: %s", raw)
}
Expand Down Expand Up @@ -118,6 +121,8 @@ func AreTheSame(a DataType, b DataType) bool {
return castSuccessfully(v, b, areNumberDataTypesTheSame)
case *ObjectDataType:
return castSuccessfully(v, b, noArgsDataTypesAreTheSame)
case *TableDataType:
return castSuccessfully(v, b, areTableDataTypesTheSame)
case *TextDataType:
return castSuccessfully(v, b, areTextDataTypesTheSame)
case *TimeDataType:
Expand Down
101 changes: 101 additions & 0 deletions pkg/sdk/datatypes/data_types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"strings"
"testing"

"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/collections"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -1095,6 +1097,98 @@ func Test_ParseDataType_Vector(t *testing.T) {
}
}

func Test_ParseDataType_Table(t *testing.T) {
type column struct {
Name string
Type string
}
type test struct {
input string
expectedColumns []column
}

positiveTestCases := []test{
{input: "TABLE()", expectedColumns: []column{}},
sfc-gh-asawicki marked this conversation as resolved.
Show resolved Hide resolved
{input: "TABLE ()", expectedColumns: []column{}},
sfc-gh-asawicki marked this conversation as resolved.
Show resolved Hide resolved
{input: "TABLE ( )", expectedColumns: []column{}},
{input: "TABLE(arg_name NUMBER)", expectedColumns: []column{{"arg_name", "NUMBER"}}},
{input: "TABLE(arg_name double precision, arg_name_2 NUMBER)", expectedColumns: []column{{"arg_name", "double precision"}, {"arg_name_2", "NUMBER"}}},
{input: "TABLE(arg_name NUMBER(38))", expectedColumns: []column{{"arg_name", "NUMBER(38)"}}},
{input: "TABLE(arg_name NUMBER(38), arg_name_2 VARCHAR)", expectedColumns: []column{{"arg_name", "NUMBER(38)"}, {"arg_name_2", "VARCHAR"}}},
{input: "TABLE(arg_name number, second float, third GEOGRAPHY)", expectedColumns: []column{{"arg_name", "number"}, {"second", "float"}, {"third", "GEOGRAPHY"}}},
{input: "TABLE ( arg_name varchar, second date, third TIME )", expectedColumns: []column{{"arg_name", "varchar"}, {"second", "date"}, {"third", "time"}}},
// TODO: Support types with parameters (for now, only legacy types are supported because Snowflake returns only with this output), e.g. TABLE(ARG NUMBER(38, 0))
// TODO: Support nested tables, e.g. TABLE(ARG NUMBER, NESTED TABLE(A VARCHAR, B GEOMETRY))
sfc-gh-jmichalak marked this conversation as resolved.
Show resolved Hide resolved
// TODO: Support complex argument names (with quotes / spaces / special characters / etc)
}

negativeTestCases := []test{
{input: "TABLE())"},
{input: "TABLE(1, 2)"},
sfc-gh-asawicki marked this conversation as resolved.
Show resolved Hide resolved
{input: "TABLE(INT, INT)"},
{input: "TABLE(a b)"},
{input: "TABLE(1)"},
{input: "TABLE(2, INT)"},
{input: "TABLE"},
{input: "TABLE(INT, 2, 3)"},
{input: "TABLE(INT)"},
{input: "TABLE(x, 2)"},
{input: "TABLE("},
{input: "TABLE)"},
{input: "TA BLE"},
}

for _, tc := range positiveTestCases {
tc := tc
t.Run(tc.input, func(t *testing.T) {
parsed, err := ParseDataType(tc.input)

require.NoError(t, err)
require.IsType(t, &TableDataType{}, parsed)

assert.Equal(t, "TABLE", parsed.(*TableDataType).underlyingType)
assert.Equal(t, len(tc.expectedColumns), len(parsed.(*TableDataType).columns))
sfc-gh-jmichalak marked this conversation as resolved.
Show resolved Hide resolved
for i, column := range tc.expectedColumns {
assert.Equal(t, column.Name, parsed.(*TableDataType).columns[i].name)
parsedType, err := ParseDataType(column.Type)
require.NoError(t, err)
assert.Equal(t, parsedType.ToLegacyDataTypeSql(), parsed.(*TableDataType).columns[i].dataType.ToLegacyDataTypeSql())
}

legacyColumns := strings.Join(collections.Map(tc.expectedColumns, func(col column) string {
parsedType, err := ParseDataType(col.Type)
require.NoError(t, err)
return fmt.Sprintf("%s %s", col.Name, parsedType.ToLegacyDataTypeSql())
}), ", ")
assert.Equal(t, fmt.Sprintf("TABLE(%s)", legacyColumns), parsed.ToLegacyDataTypeSql())

canonicalColumns := strings.Join(collections.Map(tc.expectedColumns, func(col column) string {
parsedType, err := ParseDataType(col.Type)
require.NoError(t, err)
return fmt.Sprintf("%s %s", col.Name, parsedType.Canonical())
}), ", ")
assert.Equal(t, fmt.Sprintf("TABLE(%s)", canonicalColumns), parsed.Canonical())

columns := strings.Join(collections.Map(tc.expectedColumns, func(col column) string {
parsedType, err := ParseDataType(col.Type)
require.NoError(t, err)
return fmt.Sprintf("%s %s", col.Name, parsedType.ToSql())
}), ", ")
assert.Equal(t, fmt.Sprintf("TABLE(%s)", columns), parsed.ToSql())
})
}

for _, tc := range negativeTestCases {
tc := tc
t.Run("negative: "+tc.input, func(t *testing.T) {
parsed, err := ParseDataType(tc.input)

require.Error(t, err)
require.Nil(t, parsed)
})
}
}

func Test_AreTheSame(t *testing.T) {
// empty d1/d2 means nil DataType input
type test struct {
Expand Down Expand Up @@ -1145,6 +1239,13 @@ func Test_AreTheSame(t *testing.T) {
{d1: "TIME", d2: "TIME", expectedOutcome: true},
{d1: "TIME", d2: "TIME(5)", expectedOutcome: false},
{d1: "TIME", d2: fmt.Sprintf("TIME(%d)", DefaultTimePrecision), expectedOutcome: true},
{d1: "TABLE()", d2: "TABLE()", expectedOutcome: true},
sfc-gh-asawicki marked this conversation as resolved.
Show resolved Hide resolved
{d1: "TABLE(A NUMBER)", d2: "TABLE(B NUMBER)", expectedOutcome: false},
{d1: "TABLE(A NUMBER)", d2: "TABLE(a NUMBER)", expectedOutcome: false},
{d1: "TABLE(A NUMBER)", d2: "TABLE(A VARCHAR)", expectedOutcome: false},
{d1: "TABLE(A NUMBER, B VARCHAR)", d2: "TABLE(A NUMBER, B VARCHAR)", expectedOutcome: true},
{d1: "TABLE(A NUMBER, B NUMBER)", d2: "TABLE(A NUMBER, B VARCHAR)", expectedOutcome: false},
{d1: "TABLE()", d2: "TABLE(A NUMBER)", expectedOutcome: false},
}

for _, tc := range testCases {
Expand Down
1 change: 0 additions & 1 deletion pkg/sdk/datatypes/legacy.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ const (
TimestampNtzLegacyDataType = "TIMESTAMP_NTZ"
TimestampTzLegacyDataType = "TIMESTAMP_TZ"
VariantLegacyDataType = "VARIANT"

// TableLegacyDataType was not a value of legacy data type in the old implementation. Left for now for an easier implementation.
sfc-gh-asawicki marked this conversation as resolved.
Show resolved Hide resolved
TableLegacyDataType = "TABLE"
)
81 changes: 76 additions & 5 deletions pkg/sdk/datatypes/table.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
package datatypes

// TableDataType is based on TODO [SNOW-1348103]
import (
"fmt"
"strings"

"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/collections"
"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/logging"
)

// TableDataType is based on https://docs.snowflake.com/en/developer-guide/stored-procedure/stored-procedures-java#returning-tabular-data.
// It does not have synonyms.
// It consists of a list of column name + column type; may be empty.
// TODO [SNOW-1348103]: test and improve
type TableDataType struct {
columns []TableDataTypeColumn
underlyingType string
Expand All @@ -14,6 +21,8 @@ type TableDataTypeColumn struct {
dataType DataType
}

var TableDataTypeSynonyms = []string{"TABLE"}

func (c *TableDataTypeColumn) ColumnName() string {
return c.name
}
Expand All @@ -23,17 +32,79 @@ func (c *TableDataTypeColumn) ColumnType() DataType {
}

func (t *TableDataType) ToSql() string {
return t.underlyingType
columns := strings.Join(collections.Map(t.columns, func(col TableDataTypeColumn) string {
return fmt.Sprintf("%s %s", col.name, col.dataType.ToSql())
}), ", ")
return fmt.Sprintf("%s(%s)", t.underlyingType, columns)
}

func (t *TableDataType) ToLegacyDataTypeSql() string {
return TableLegacyDataType
columns := strings.Join(collections.Map(t.columns, func(col TableDataTypeColumn) string {
return fmt.Sprintf("%s %s", col.name, col.dataType.ToLegacyDataTypeSql())
}), ", ")
return fmt.Sprintf("%s(%s)", TableLegacyDataType, columns)
}

func (t *TableDataType) Canonical() string {
return TableLegacyDataType
columns := strings.Join(collections.Map(t.columns, func(col TableDataTypeColumn) string {
return fmt.Sprintf("%s %s", col.name, col.dataType.Canonical())
}), ", ")
return fmt.Sprintf("%s(%s)", TableLegacyDataType, columns)
}

func (t *TableDataType) Columns() []TableDataTypeColumn {
return t.columns
}

func parseTableDataTypeRaw(raw sanitizedDataTypeRaw) (*TableDataType, error) {
r := strings.TrimSpace(strings.TrimPrefix(raw.raw, raw.matchedByType))
if r == "" || (!strings.HasPrefix(r, "(") || !strings.HasSuffix(r, ")")) {
logging.DebugLogger.Printf(`table %s could not be parsed, use "%s(argName argType, ...)" format`, raw.raw, raw.matchedByType)
return nil, fmt.Errorf(`table %s could not be parsed, use "%s(argName argType, ...)" format`, raw.raw, raw.matchedByType)
}
onlyArgs := strings.TrimSpace(r[1 : len(r)-1])
if onlyArgs == "" {
return &TableDataType{
columns: make([]TableDataTypeColumn, 0),
underlyingType: raw.matchedByType,
}, nil
}
columns, err := collections.MapErr(strings.Split(onlyArgs, ","), func(arg string) (TableDataTypeColumn, error) {
argParts := strings.SplitN(strings.TrimSpace(arg), " ", 2)
if len(argParts) != 2 {
return TableDataTypeColumn{}, fmt.Errorf("could not parse table column: %s, it should contain the following format `<arg_name> <arg_type>`; parser failure may be connected to the complex argument names", arg)
}
argDataType, err := ParseDataType(argParts[1])
if err != nil {
return TableDataTypeColumn{}, err
}
return TableDataTypeColumn{
name: argParts[0],
dataType: argDataType,
}, nil
})
if err != nil {
return nil, err
}
return &TableDataType{
columns: columns,
underlyingType: raw.matchedByType,
}, nil
}

func areTableDataTypesTheSame(a, b *TableDataType) bool {
if len(a.columns) != len(b.columns) {
return false
}

for i := range a.columns {
aColumn := a.columns[i]
bColumn := b.columns[i]

if aColumn.name != bColumn.name || !AreTheSame(aColumn.dataType, bColumn.dataType) {
return false
}
}

return true
}
34 changes: 34 additions & 0 deletions pkg/sdk/testint/functions_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2007,4 +2007,38 @@ func TestInt_Functions(t *testing.T) {
assert.Equal(t, dataType.Canonical(), pairs["returns"])
})
}

t.Run("create function for SQL - return table data type", func(t *testing.T) {
argName := "x"

returnDataType, err := datatypes.ParseDataType(fmt.Sprintf("TABLE(ID %s, PRICE %s, THIRD %s)", datatypes.NumberLegacyDataType, datatypes.FloatLegacyDataType, datatypes.VarcharLegacyDataType))
require.NoError(t, err)

id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(datatypes.VarcharLegacyDataType)

definition := ` SELECT 1, 2.2::float, 'abc';`
dt := sdk.NewFunctionReturnsResultDataTypeRequest(returnDataType)
returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt)
argument := sdk.NewFunctionArgumentRequest(argName, nil).WithArgDataTypeOld(datatypes.VarcharLegacyDataType)
request := sdk.NewCreateForSQLFunctionRequestDefinitionWrapped(id.SchemaObjectId(), *returns, definition).
WithArguments([]sdk.FunctionArgumentRequest{*argument})

err = client.Functions.CreateForSQL(ctx, request)
require.NoError(t, err)
t.Cleanup(testClientHelper().Function.DropFunctionFunc(t, id))

function, err := client.Functions.ShowByID(ctx, id)
require.NoError(t, err)

assertions.AssertThatObject(t, objectassert.FunctionFromObject(t, function).
sfc-gh-asawicki marked this conversation as resolved.
Show resolved Hide resolved
HasCreatedOnNotEmpty().
HasName(id.Name()).
HasSchemaName(id.SchemaName()).
HasArgumentsRawContains(returnDataType.ToLegacyDataTypeSql()),
)

assertions.AssertThatObject(t, objectassert.FunctionDetails(t, id).
HasReturnDataType(returnDataType),
)
})
}
30 changes: 18 additions & 12 deletions pkg/sdk/testint/procedures_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1734,36 +1734,42 @@ def filter_by_role(session, table_name, role):
require.GreaterOrEqual(t, len(procedures), 1)
})

// TODO [SNOW-1348103]: adjust or remove
t.Run("create procedure for SQL: returns table", func(t *testing.T) {
t.Skipf("Skipped for now; left as inspiration for resource rework as part of SNOW-1348103")

name := "find_invoice_by_id"
id := testClientHelper().Ids.NewSchemaObjectIdentifierWithArguments(name, sdk.DataTypeVARCHAR)
id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(sdk.DataTypeVARCHAR)

column1 := sdk.NewProcedureColumnRequest("id", nil).WithColumnDataTypeOld("INTEGER")
column2 := sdk.NewProcedureColumnRequest("price", nil).WithColumnDataTypeOld("double")
column3 := sdk.NewProcedureColumnRequest("third", nil).WithColumnDataTypeOld("Geometry")
returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2, *column3})
expectedReturnDataType, err := datatypes.ParseDataType(fmt.Sprintf("TABLE(id %s, price %s, third %s)", datatypes.NumberLegacyDataType, datatypes.FloatLegacyDataType, datatypes.GeometryLegacyDataType))
require.NoError(t, err)
definition := `
DECLARE
res RESULTSET DEFAULT (SELECT * FROM invoices WHERE id = :id);
BEGIN
RETURN TABLE(res);
END;`
column1 := sdk.NewProcedureColumnRequest("id", nil).WithColumnDataTypeOld("INTEGER")
column2 := sdk.NewProcedureColumnRequest("price", nil).WithColumnDataTypeOld("NUMBER(12,2)")
returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2})
returns := sdk.NewProcedureSQLReturnsRequest().WithTable(*returnsTable)
argument := sdk.NewProcedureArgumentRequest("id", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR)
request := sdk.NewCreateForSQLProcedureRequestDefinitionWrapped(id.SchemaObjectId(), *returns, definition).
WithOrReplace(true).
// SNOW-1051627 todo: uncomment once null input behavior working again
// WithNullInputBehavior(sdk.NullInputBehaviorPointer(sdk.NullInputBehaviorReturnsNullInput)).
WithArguments([]sdk.ProcedureArgumentRequest{*argument})
err := client.Procedures.CreateForSQL(ctx, request)
err = client.Procedures.CreateForSQL(ctx, request)
require.NoError(t, err)
t.Cleanup(testClientHelper().Procedure.DropProcedureFunc(t, id))

procedures, err := client.Procedures.Show(ctx, sdk.NewShowProcedureRequest())
require.NoError(t, err)
require.GreaterOrEqual(t, len(procedures), 1)
assertions.AssertThatObject(t, objectassert.Procedure(t, id).
HasCreatedOnNotEmpty().
HasName(id.Name()).
HasSchemaName(id.SchemaName()).
HasArgumentsRawContains(expectedReturnDataType.ToLegacyDataTypeSql()),
)

assertions.AssertThatObject(t, objectassert.ProcedureDetails(t, id).
HasReturnDataType(expectedReturnDataType),
)
})

t.Run("show parameters", func(t *testing.T) {
Expand Down
Loading