From 1c98a80bb1486a04ead57dd0d8abcf65e00ea86c Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Tue, 12 Dec 2023 16:49:57 +0100 Subject: [PATCH] fix: Fix encode Snowflake ID for object identifiers (#2256) * Add test that is failing for dots * Fix encoding snowflake ID for object identifiers * Fix linter complaints * Add acceptance test * Add issue description * Add tests for nil and pointer * Fix type --- pkg/helpers/helpers.go | 27 ++++- pkg/helpers/helpers_test.go | 114 ++++++++++++++++++ .../TestAcc_ExternalTable_basic/test.tf | 20 +-- pkg/resources/user_acceptance_test.go | 26 ++++ pkg/sdk/identifier_helpers.go | 11 +- pkg/sdk/privileges.go | 12 +- 6 files changed, 188 insertions(+), 22 deletions(-) diff --git a/pkg/helpers/helpers.go b/pkg/helpers/helpers.go index 3ad4342fd1..5893b97f2a 100644 --- a/pkg/helpers/helpers.go +++ b/pkg/helpers/helpers.go @@ -65,10 +65,29 @@ func EncodeSnowflakeID(attributes ...interface{}) string { // is attribute already an object identifier? if len(attributes) == 1 { if id, ok := attributes[0].(sdk.ObjectIdentifier); ok { - // remove quotes and replace dots with pipes - parts := strings.Split(id.FullyQualifiedName(), ".") - for i, part := range parts { - parts[i] = strings.Trim(part, `"`) + if val := reflect.ValueOf(id); val.Kind() == reflect.Ptr && val.IsNil() { + log.Panicf("Nil object identifier received") + } + parts := make([]string, 0) + switch v := id.(type) { + case sdk.AccountObjectIdentifier: + parts = append(parts, v.Name()) + case *sdk.AccountObjectIdentifier: + parts = append(parts, v.Name()) + case sdk.DatabaseObjectIdentifier: + parts = append(parts, v.DatabaseName(), v.Name()) + case *sdk.DatabaseObjectIdentifier: + parts = append(parts, v.DatabaseName(), v.Name()) + case sdk.SchemaObjectIdentifier: + parts = append(parts, v.DatabaseName(), v.SchemaName(), v.Name()) + case *sdk.SchemaObjectIdentifier: + parts = append(parts, v.DatabaseName(), v.SchemaName(), v.Name()) + case sdk.TableColumnIdentifier: + parts = append(parts, v.DatabaseName(), v.SchemaName(), v.TableName(), v.Name()) + case *sdk.TableColumnIdentifier: + parts = append(parts, v.DatabaseName(), v.SchemaName(), v.TableName(), v.Name()) + default: + log.Panicf("Unsupported object identifier: %v", id) } return strings.Join(parts, IDDelimiter) } diff --git a/pkg/helpers/helpers_test.go b/pkg/helpers/helpers_test.go index 1eb536525b..7122c9962e 100644 --- a/pkg/helpers/helpers_test.go +++ b/pkg/helpers/helpers_test.go @@ -1,8 +1,10 @@ package helpers import ( + "fmt" "testing" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/stretchr/testify/require" ) @@ -71,3 +73,115 @@ func TestDecodeSnowflakeParameterID(t *testing.T) { require.Errorf(t, err, "incompatible identifier: %s", id) }) } + +// TODO: add tests for non object identifiers +func TestEncodeSnowflakeID(t *testing.T) { + testCases := map[string]struct { + identifier sdk.ObjectIdentifier + expectedEncodedID string + }{ + "encodes account object identifier": { + identifier: sdk.NewAccountObjectIdentifier("database"), + expectedEncodedID: `database`, + }, + "encodes quoted account object identifier": { + identifier: sdk.NewAccountObjectIdentifier("\"database\""), + expectedEncodedID: `database`, + }, + "encodes account object identifier with a dot": { + identifier: sdk.NewAccountObjectIdentifier("data.base"), + expectedEncodedID: `data.base`, + }, + "encodes pointer to account object identifier": { + identifier: sdk.Pointer(sdk.NewAccountObjectIdentifier("database")), + expectedEncodedID: `database`, + }, + "encodes database object identifier": { + identifier: sdk.NewDatabaseObjectIdentifier("database", "schema"), + expectedEncodedID: `database|schema`, + }, + "encodes quoted database object identifier": { + identifier: sdk.NewDatabaseObjectIdentifier("\"database\"", "\"schema\""), + expectedEncodedID: `database|schema`, + }, + "encodes database object identifier with dots": { + identifier: sdk.NewDatabaseObjectIdentifier("data.base", "sche.ma"), + expectedEncodedID: `data.base|sche.ma`, + }, + "encodes pointer to database object identifier": { + identifier: sdk.Pointer(sdk.NewDatabaseObjectIdentifier("database", "schema")), + expectedEncodedID: `database|schema`, + }, + "encodes schema object identifier": { + identifier: sdk.NewSchemaObjectIdentifier("database", "schema", "table"), + expectedEncodedID: `database|schema|table`, + }, + "encodes quoted schema object identifier": { + identifier: sdk.NewSchemaObjectIdentifier("\"database\"", "\"schema\"", "\"table\""), + expectedEncodedID: `database|schema|table`, + }, + "encodes schema object identifier with dots": { + identifier: sdk.NewSchemaObjectIdentifier("data.base", "sche.ma", "tab.le"), + expectedEncodedID: `data.base|sche.ma|tab.le`, + }, + "encodes pointer to schema object identifier": { + identifier: sdk.Pointer(sdk.NewSchemaObjectIdentifier("database", "schema", "table")), + expectedEncodedID: `database|schema|table`, + }, + "encodes table column identifier": { + identifier: sdk.NewTableColumnIdentifier("database", "schema", "table", "column"), + expectedEncodedID: `database|schema|table|column`, + }, + "encodes quoted table column identifier": { + identifier: sdk.NewTableColumnIdentifier("\"database\"", "\"schema\"", "\"table\"", "\"column\""), + expectedEncodedID: `database|schema|table|column`, + }, + "encodes table column identifier with dots": { + identifier: sdk.NewTableColumnIdentifier("data.base", "sche.ma", "tab.le", "col.umn"), + expectedEncodedID: `data.base|sche.ma|tab.le|col.umn`, + }, + "encodes pointer to table column identifier": { + identifier: sdk.Pointer(sdk.NewTableColumnIdentifier("database", "schema", "table", "column")), + expectedEncodedID: `database|schema|table|column`, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + encodedID := EncodeSnowflakeID(tc.identifier) + require.Equal(t, tc.expectedEncodedID, encodedID) + }) + } + + t.Run("panics for unsupported object identifier", func(t *testing.T) { + id := unsupportedObjectIdentifier{} + require.PanicsWithValue(t, fmt.Sprintf("Unsupported object identifier: %v", id), func() { + EncodeSnowflakeID(id) + }) + }) + + nilTestCases := []any{ + (*sdk.AccountObjectIdentifier)(nil), + (*sdk.DatabaseObjectIdentifier)(nil), + (*sdk.SchemaObjectIdentifier)(nil), + (*sdk.TableColumnIdentifier)(nil), + } + + for i, tt := range nilTestCases { + t.Run(fmt.Sprintf("handle nil pointer to object identifier %d", i), func(t *testing.T) { + require.PanicsWithValue(t, "Nil object identifier received", func() { + EncodeSnowflakeID(tt) + }) + }) + } +} + +type unsupportedObjectIdentifier struct{} + +func (i unsupportedObjectIdentifier) Name() string { + return "name" +} + +func (i unsupportedObjectIdentifier) FullyQualifiedName() string { + return "fully qualified name" +} diff --git a/pkg/resources/testdata/TestAcc_ExternalTable_basic/test.tf b/pkg/resources/testdata/TestAcc_ExternalTable_basic/test.tf index aab230558f..c8efacb28f 100644 --- a/pkg/resources/testdata/TestAcc_ExternalTable_basic/test.tf +++ b/pkg/resources/testdata/TestAcc_ExternalTable_basic/test.tf @@ -1,22 +1,22 @@ resource "snowflake_storage_integration" "i" { - name = var.name + name = var.name storage_allowed_locations = [var.location] - storage_provider = "S3" - storage_aws_role_arn = var.aws_arn + storage_provider = "S3" + storage_aws_role_arn = var.aws_arn } resource "snowflake_stage" "test" { - name = var.name - url = var.location - database = var.database - schema = var.schema + name = var.name + url = var.location + database = var.database + schema = var.schema storage_integration = snowflake_storage_integration.i.name } resource "snowflake_external_table" "test_table" { - name = var.name + name = var.name database = var.database - schema = var.schema + schema = var.schema comment = "Terraform acceptance test" column { name = "column1" @@ -29,5 +29,5 @@ resource "snowflake_external_table" "test_table" { as = "($1:\"CreatedDate\"::timestamp)" } file_format = "TYPE = CSV" - location = "@\"${var.database}\".\"${var.schema}\".\"${snowflake_stage.test.name}\"" + location = "@\"${var.database}\".\"${var.schema}\".\"${snowflake_stage.test.name}\"" } diff --git a/pkg/resources/user_acceptance_test.go b/pkg/resources/user_acceptance_test.go index 47c66abfe4..7300bd876d 100644 --- a/pkg/resources/user_acceptance_test.go +++ b/pkg/resources/user_acceptance_test.go @@ -162,3 +162,29 @@ resource "snowflake_user" "w" { log.Printf("[DEBUG] s2 %s", s) return fmt.Sprintf(s, prefix, prefix) } + +// TestAcc_User_issue2058 proves https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2058 issue. +// The problem was with a dot in user identifier. +// Before the fix it results in panic: interface conversion: sdk.ObjectIdentifier is sdk.DatabaseObjectIdentifier, not sdk.AccountObjectIdentifier error. +func TestAcc_User_issue2058(t *testing.T) { + r := require.New(t) + prefix := "tst-terraform" + strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) + "user.123" + sshkey1, err := testhelpers.Fixture("userkey1") + r.NoError(err) + sshkey2, err := testhelpers.Fixture("userkey2") + r.NoError(err) + + resource.Test(t, resource.TestCase{ + Providers: acc.TestAccProviders(), + PreCheck: func() { acc.TestAccPreCheck(t) }, + CheckDestroy: nil, + Steps: []resource.TestStep{ + { + Config: uConfig(prefix, sshkey1, sshkey2), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr("snowflake_user.w", "name", prefix), + ), + }, + }, + }) +} diff --git a/pkg/sdk/identifier_helpers.go b/pkg/sdk/identifier_helpers.go index 99bc086092..71aa0b9d5b 100644 --- a/pkg/sdk/identifier_helpers.go +++ b/pkg/sdk/identifier_helpers.go @@ -121,7 +121,9 @@ type AccountObjectIdentifier struct { } func NewAccountObjectIdentifier(name string) AccountObjectIdentifier { - return AccountObjectIdentifier{name: name} + return AccountObjectIdentifier{ + name: strings.Trim(name, `"`), + } } func NewAccountObjectIdentifierFromFullyQualifiedName(fullyQualifiedName string) AccountObjectIdentifier { @@ -268,7 +270,12 @@ type TableColumnIdentifier struct { } func NewTableColumnIdentifier(databaseName, schemaName, tableName, columnName string) TableColumnIdentifier { - return TableColumnIdentifier{databaseName: databaseName, schemaName: schemaName, tableName: tableName, columnName: columnName} + return TableColumnIdentifier{ + databaseName: strings.Trim(databaseName, `"`), + schemaName: strings.Trim(schemaName, `"`), + tableName: strings.Trim(tableName, `"`), + columnName: strings.Trim(columnName, `"`), + } } func NewTableColumnIdentifierFromFullyQualifiedName(fullyQualifiedName string) TableColumnIdentifier { diff --git a/pkg/sdk/privileges.go b/pkg/sdk/privileges.go index 47bb95a520..51b24cedf5 100644 --- a/pkg/sdk/privileges.go +++ b/pkg/sdk/privileges.go @@ -186,12 +186,12 @@ const ( // -- For ICEBERG TABLE SchemaObjectPrivilegeApplyBudget SchemaObjectPrivilege = "APPLYBUDGET" - //SchemaObjectPrivilegeDelete SchemaObjectPrivilege = "DELETE" (duplicate) - //SchemaObjectPrivilegeInsert SchemaObjectPrivilege = "INSERT" (duplicate) - //SchemaObjectPrivilegeReferences SchemaObjectPrivilege = "REFERENCES" (duplicate) - //SchemaObjectPrivilegeSelect SchemaObjectPrivilege = "SELECT" (duplicate) - //SchemaObjectPrivilegeTruncate SchemaObjectPrivilege = "Truncate" (duplicate) - //SchemaObjectPrivilegeUpdate SchemaObjectPrivilege = "Update" (duplicate) + // SchemaObjectPrivilegeDelete SchemaObjectPrivilege = "DELETE" (duplicate) + // SchemaObjectPrivilegeInsert SchemaObjectPrivilege = "INSERT" (duplicate) + // SchemaObjectPrivilegeReferences SchemaObjectPrivilege = "REFERENCES" (duplicate) + // SchemaObjectPrivilegeSelect SchemaObjectPrivilege = "SELECT" (duplicate) + // SchemaObjectPrivilegeTruncate SchemaObjectPrivilege = "Truncate" (duplicate) + // SchemaObjectPrivilegeUpdate SchemaObjectPrivilege = "Update" (duplicate) // -- For PIPE // { MONITOR | OPERATE } [ , ... ]