diff --git a/pkg/resources/helpers.go b/pkg/resources/helpers.go index f539f8499c6..e25c90b9b4b 100644 --- a/pkg/resources/helpers.go +++ b/pkg/resources/helpers.go @@ -2,6 +2,10 @@ package resources import ( "fmt" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" + "github.com/hashicorp/go-cty/cty" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "reflect" "strings" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" @@ -142,3 +146,77 @@ func IsDataType() schema.SchemaValidateFunc { return warnings, errors } } + +// IsValidIdentifier is a validator that can be used for validating identifiers passed in resources and data sources. +// Typically, we expect passed identifiers to be a variation of sdk.ObjectIdentifier. To use this function, pass it as +// a validation function on identifier field with generic parameter set to the desired sdk.ObjectIdentifier implementation. +func IsValidIdentifier[T sdk.ObjectIdentifier]() schema.SchemaValidateDiagFunc { + return func(value any, path cty.Path) diag.Diagnostics { + var diags diag.Diagnostics + + // For now, we won't support sdk.ExternalObjectIdentifiers. The reason behind it is that the functions that parse identifiers are not + // able to differentiate between sdk.ExternalObjectIdentifiers and sdk.DatabaseObjectIdentifier or sdk.SchemaObjectIdentifier, + // because sdk.ExternalObjectIdentifiers has varying parts count (2 or 3). + if _, ok := any(sdk.ExternalObjectIdentifier{}).(T); ok { + diags = append(diags, diag.Diagnostic{ + Severity: diag.Error, + Summary: "Invalid schema identifier type", + Detail: "Identifier validation is not available for sdk.ExternalObjectIdentifier type. This is a provider error please file a report: https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/new/choose", + AttributePath: path, + }) + return diags + } + + if stringValue, ok := value.(string); ok { + id, err := helpers.DecodeSnowflakeParameterID(stringValue) + if err != nil { + diags = append(diags, diag.Diagnostic{ + Severity: diag.Error, + Summary: "Unable to parse the identifier", + Detail: fmt.Sprintf( + "Unable to parse the identifier: %s. Make sure you are using the correct form of the fully qualified name for this field: %s", + stringValue, + getExpectedIdentifierForm[T](nil), + ), + AttributePath: path, + }) + return diags + } + + if _, ok := id.(T); !ok { + diags = append(diags, diag.Diagnostic{ + Severity: diag.Error, + Summary: "Invalid identifier type", + Detail: fmt.Sprintf( + "Expected %s identifier type, but got: %T. The correct form of the fully qualified name for this field is: %s, but was %s", + reflect.TypeOf(new(T)).Elem().Name(), + id, + getExpectedIdentifierForm[T](nil), + getExpectedIdentifierForm(&id), + ), + AttributePath: path, + }) + } + } else { + diags = append(diags, diag.Diagnostic{ + Severity: diag.Error, + Summary: "Invalid schema identifier type", + Detail: fmt.Sprintf("Expected schema string type, but got: %T. This is a provider error please file a report: https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/new/choose", value), + AttributePath: path, + }) + } + + return diags + } +} + +// getExpectedIdentifierForm will choose the type either from the objectIdentifier parameter if it's present. If it's not, +// then it will create a new identifier based on the generic type parameter T, then it will return the proper structure +// we are expecting for the given sdk.ObjectIdentifier type. +func getExpectedIdentifierForm[T sdk.ObjectIdentifier](objectIdentifier *T) string { + if objectIdentifier != nil { + return (*objectIdentifier).Representation() + } + id := new(T) + return sdk.GetIdentifierRepresentation(*id) +} diff --git a/pkg/resources/helpers_test.go b/pkg/resources/helpers_test.go index 87a8812274f..15e5c0a5363 100644 --- a/pkg/resources/helpers_test.go +++ b/pkg/resources/helpers_test.go @@ -1,6 +1,8 @@ package resources_test import ( + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/hashicorp/go-cty/cty" "testing" "github.com/stretchr/testify/assert" @@ -512,3 +514,95 @@ func TestIsDataType(t *testing.T) { }) } } + +func TestIsValidIdentifier(t *testing.T) { + accountObjectIdentifierCheck := resources.IsValidIdentifier[sdk.AccountObjectIdentifier]() + databaseObjectIdentifierCheck := resources.IsValidIdentifier[sdk.DatabaseObjectIdentifier]() + schemaObjectIdentifierCheck := resources.IsValidIdentifier[sdk.SchemaObjectIdentifier]() + externalObjectIdentifierCheck := resources.IsValidIdentifier[sdk.ExternalObjectIdentifier]() + tableColumnIdentifierCheck := resources.IsValidIdentifier[sdk.TableColumnIdentifier]() + + testCases := []struct { + Name string + Value any + Error string + CheckingFn schema.SchemaValidateDiagFunc + }{ + { + Name: "validation: invalid value type", + Value: 123, + Error: "Expected schema string type, but got: int", + CheckingFn: accountObjectIdentifierCheck, + }, + { + Name: "validation: invalid identifier representation", + Value: "", + Error: "Unable to parse the identifier: ", + CheckingFn: accountObjectIdentifierCheck, + }, + // Invalid form for different checkers (tests getExpectedIdentifierForm function) + { + Name: "validation: incorrect form for account object identifier", + Value: "a.b", + Error: ", but was .", + CheckingFn: accountObjectIdentifierCheck, + }, + { + Name: "validation: incorrect form for database object identifier", + Value: "a.b.c", + Error: "., but was ..", + CheckingFn: databaseObjectIdentifierCheck, + }, + { + Name: "validation: incorrect form for schema object identifier", + Value: "a.b.c.d", + Error: ".., but was ...", + CheckingFn: schemaObjectIdentifierCheck, + }, + { + Name: "validation: incorrect form for table column identifier", + Value: "a", + Error: "..., but was ", + CheckingFn: tableColumnIdentifierCheck, + }, + { + Name: "validation: external object identifier is validated", + Value: "a", + Error: "Identifier validation is not available for sdk.ExternalObjectIdentifier type.", + CheckingFn: externalObjectIdentifierCheck, + }, + // Valid form for different checkers (tests getExpectedIdentifierForm function) + { + Name: "correct form for account object identifier", + Value: "a", + CheckingFn: accountObjectIdentifierCheck, + }, + { + Name: "correct form for database object identifier", + Value: "a.b", + CheckingFn: databaseObjectIdentifierCheck, + }, + { + Name: "correct form for schema object identifier", + Value: "a.b.c", + CheckingFn: schemaObjectIdentifierCheck, + }, + { + Name: "correct form for table column identifier", + Value: "a.b.c.d", + CheckingFn: tableColumnIdentifierCheck, + }, + } + + for _, tt := range testCases { + t.Run(tt.Name, func(t *testing.T) { + diag := tt.CheckingFn(tt.Value, cty.IndexStringPath("path")) + if tt.Error != "" { + assert.Len(t, diag, 1) + assert.Contains(t, diag[0].Detail, tt.Error) + } else { + assert.Len(t, diag, 0) + } + }) + } +} diff --git a/pkg/sdk/identifier_helpers.go b/pkg/sdk/identifier_helpers.go index 71aa0b9d5b5..01645259b3c 100644 --- a/pkg/sdk/identifier_helpers.go +++ b/pkg/sdk/identifier_helpers.go @@ -12,6 +12,7 @@ type Identifier interface { type ObjectIdentifier interface { Identifier FullyQualifiedName() string + Representation() string } func NewObjectIdentifierFromFullyQualifiedName(fullyQualifiedName string) ObjectIdentifier { @@ -80,6 +81,10 @@ func (i ExternalObjectIdentifier) FullyQualifiedName() string { return fmt.Sprintf(`%v.%v`, i.accountIdentifier.Name(), i.objectIdentifier.FullyQualifiedName()) } +func (i ExternalObjectIdentifier) Representation() string { + return ". or .." +} + type AccountIdentifier struct { organizationName string accountName string @@ -142,6 +147,10 @@ func (i AccountObjectIdentifier) FullyQualifiedName() string { return fmt.Sprintf(`"%v"`, i.name) } +func (i AccountObjectIdentifier) Representation() string { + return "" +} + type DatabaseObjectIdentifier struct { databaseName string name string @@ -177,6 +186,10 @@ func (i DatabaseObjectIdentifier) FullyQualifiedName() string { return fmt.Sprintf(`"%v"."%v"`, i.databaseName, i.name) } +func (i DatabaseObjectIdentifier) Representation() string { + return "." +} + type SchemaObjectIdentifier struct { databaseName string schemaName string @@ -262,6 +275,10 @@ func (i SchemaObjectIdentifier) FullyQualifiedName() string { return fmt.Sprintf(`"%v"."%v"."%v"(%v)`, i.databaseName, i.schemaName, i.name, strings.Join(args, ", ")) } +func (i SchemaObjectIdentifier) Representation() string { + return ".." +} + type TableColumnIdentifier struct { databaseName string schemaName string @@ -310,3 +327,11 @@ func (i TableColumnIdentifier) FullyQualifiedName() string { } return fmt.Sprintf(`"%v"."%v"."%v"."%v"`, i.databaseName, i.schemaName, i.tableName, i.columnName) } + +func (i TableColumnIdentifier) Representation() string { + return "..." +} + +func GetIdentifierRepresentation(identifier ObjectIdentifier) string { + return identifier.Representation() +}