diff --git a/go.mod b/go.mod index a74fad2aac..769fbf3c0e 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/buger/jsonparser v1.1.1 github.com/google/uuid v1.4.0 github.com/gookit/color v1.5.4 + github.com/hashicorp/go-cty v1.4.1-0.20200414143053-d3edf31b6320 github.com/hashicorp/go-uuid v1.0.3 github.com/hashicorp/terraform-plugin-framework v1.4.2 github.com/hashicorp/terraform-plugin-framework-validators v0.12.0 @@ -78,7 +79,6 @@ require ( github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-checkpoint v0.5.0 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect - github.com/hashicorp/go-cty v1.4.1-0.20200414143053-d3edf31b6320 // indirect github.com/hashicorp/go-hclog v1.5.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/hashicorp/go-plugin v1.5.2 // indirect diff --git a/pkg/resources/grant_privileges_to_role.go b/pkg/resources/grant_privileges_to_role.go index 44b533c2e7..8de01aa7ae 100644 --- a/pkg/resources/grant_privileges_to_role.go +++ b/pkg/resources/grant_privileges_to_role.go @@ -69,9 +69,10 @@ var grantPrivilegesToRoleSchema = map[string]*schema.Schema{ }, true), }, "object_name": { - Type: schema.TypeString, - Required: true, - Description: "The fully qualified name of the object on which privileges will be granted.", + Type: schema.TypeString, + Required: true, + Description: "The fully qualified name of the object on which privileges will be granted.", + ValidateDiagFunc: IsValidIdentifier[sdk.AccountObjectIdentifier](), }, }, }, @@ -86,11 +87,12 @@ var grantPrivilegesToRoleSchema = map[string]*schema.Schema{ Elem: &schema.Resource{ Schema: map[string]*schema.Schema{ "schema_name": { - Type: schema.TypeString, - Optional: true, - Description: "The fully qualified name of the schema.", - ConflictsWith: []string{"on_schema.0.all_schemas_in_database", "on_schema.0.future_schemas_in_database"}, - ForceNew: true, + Type: schema.TypeString, + Optional: true, + Description: "The fully qualified name of the schema.", + ConflictsWith: []string{"on_schema.0.all_schemas_in_database", "on_schema.0.future_schemas_in_database"}, + ValidateDiagFunc: IsValidIdentifier[sdk.DatabaseObjectIdentifier](), + ForceNew: true, }, "all_schemas_in_database": { Type: schema.TypeString, @@ -151,12 +153,13 @@ var grantPrivilegesToRoleSchema = map[string]*schema.Schema{ }, true), }, "object_name": { - Type: schema.TypeString, - Optional: true, - Description: "The fully qualified name of the object on which privileges will be granted.", - RequiredWith: []string{"on_schema_object.0.object_type"}, - ConflictsWith: []string{"on_schema_object.0.all", "on_schema_object.0.future"}, - ForceNew: true, + Type: schema.TypeString, + Optional: true, + Description: "The fully qualified name of the object on which privileges will be granted.", + RequiredWith: []string{"on_schema_object.0.object_type"}, + ConflictsWith: []string{"on_schema_object.0.all", "on_schema_object.0.future"}, + ValidateDiagFunc: IsValidIdentifier[sdk.SchemaObjectIdentifier](), + ForceNew: true, }, "all": { Type: schema.TypeList, @@ -197,18 +200,20 @@ var grantPrivilegesToRoleSchema = map[string]*schema.Schema{ }, true), }, "in_database": { - Type: schema.TypeString, - Optional: true, - Description: "The fully qualified name of the database.", - ConflictsWith: []string{"on_schema_object.0.all.in_schema"}, - ForceNew: true, + Type: schema.TypeString, + Optional: true, + Description: "The fully qualified name of the database.", + ConflictsWith: []string{"on_schema_object.0.all.in_schema"}, + ValidateDiagFunc: IsValidIdentifier[sdk.AccountObjectIdentifier](), + ForceNew: true, }, "in_schema": { - Type: schema.TypeString, - Optional: true, - Description: "The fully qualified name of the schema.", - ConflictsWith: []string{"on_schema_object.0.all.in_database"}, - ForceNew: true, + Type: schema.TypeString, + Optional: true, + Description: "The fully qualified name of the schema.", + ConflictsWith: []string{"on_schema_object.0.all.in_database"}, + ValidateDiagFunc: IsValidIdentifier[sdk.DatabaseObjectIdentifier](), + ForceNew: true, }, }, }, @@ -252,18 +257,20 @@ var grantPrivilegesToRoleSchema = map[string]*schema.Schema{ }, true), }, "in_database": { - Type: schema.TypeString, - Optional: true, - Description: "The fully qualified name of the database.", - ConflictsWith: []string{"on_schema_object.0.all.in_schema"}, - ForceNew: true, + Type: schema.TypeString, + Optional: true, + Description: "The fully qualified name of the database.", + ConflictsWith: []string{"on_schema_object.0.all.in_schema"}, + ValidateDiagFunc: IsValidIdentifier[sdk.AccountObjectIdentifier](), + ForceNew: true, }, "in_schema": { - Type: schema.TypeString, - Optional: true, - Description: "The fully qualified name of the schema.", - ConflictsWith: []string{"on_schema_object.0.all.in_database"}, - ForceNew: true, + Type: schema.TypeString, + Optional: true, + Description: "The fully qualified name of the schema.", + ConflictsWith: []string{"on_schema_object.0.all.in_database"}, + ValidateDiagFunc: IsValidIdentifier[sdk.DatabaseObjectIdentifier](), + ForceNew: true, }, }, }, diff --git a/pkg/resources/grant_privileges_to_role_acceptance_test.go b/pkg/resources/grant_privileges_to_role_acceptance_test.go index 2d6b20be8a..9cd8227a02 100644 --- a/pkg/resources/grant_privileges_to_role_acceptance_test.go +++ b/pkg/resources/grant_privileges_to_role_acceptance_test.go @@ -2,6 +2,7 @@ package resources_test import ( "fmt" + "regexp" "strings" "testing" @@ -899,3 +900,34 @@ resource "snowflake_grant_privileges_to_role" "grant" { }, }) } + +func TestAcc_GrantPrivilegesToRole_ValidatedIdentifiers(t *testing.T) { + resource.Test(t, resource.TestCase{ + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: nil, + Steps: []resource.TestStep{ + { + Config: fmt.Sprintf(` +resource "snowflake_role" "role" { + name = "TEST_ROLE_123" +} + +resource "snowflake_grant_privileges_to_role" "test_invalidation" { + role_name = snowflake_role.role.name + privileges = ["SELECT"] + on_schema_object { + future { + object_type_plural = "ICEBERG TABLES" + in_schema = "%s" + } + } +}`, acc.TestSchemaName), + ExpectError: regexp.MustCompile(".*Expected DatabaseObjectIdentifier identifier type.*"), + }, + }, + }) +} diff --git a/pkg/resources/helpers.go b/pkg/resources/helpers.go index 0cb19d4e55..08890b354f 100644 --- a/pkg/resources/helpers.go +++ b/pkg/resources/helpers.go @@ -125,20 +125,3 @@ func GetPropertyAsPointer[T any](d *schema.ResourceData, property string) *T { } return &typedValue } - -func IsDataType() schema.SchemaValidateFunc { //nolint:staticcheck - return func(value any, key string) (warnings []string, errors []error) { - stringValue, ok := value.(string) - if !ok { - errors = append(errors, fmt.Errorf("expected type of %s to be string, got %T", key, value)) - return warnings, errors - } - - _, err := sdk.ToDataType(stringValue) - if err != nil { - errors = append(errors, fmt.Errorf("expected %s to be one of %T values, got %s", key, sdk.DataTypeString, stringValue)) - } - - return warnings, errors - } -} diff --git a/pkg/resources/helpers_test.go b/pkg/resources/helpers_test.go index 87a8812274..78c07827b6 100644 --- a/pkg/resources/helpers_test.go +++ b/pkg/resources/helpers_test.go @@ -470,45 +470,3 @@ func tagGrant(t *testing.T, id string, params map[string]interface{}) *schema.Re d.SetId(id) return d } - -func TestIsDataType(t *testing.T) { - isDataType := resources.IsDataType() - key := "tag" - - testCases := []struct { - Name string - Value any - Error string - }{ - { - Name: "validation: correct DataType value", - Value: "NUMBER", - }, - { - Name: "validation: correct DataType value in lowercase", - Value: "number", - }, - { - Name: "validation: incorrect DataType value", - Value: "invalid data type", - Error: "expected tag to be one of", - }, - { - Name: "validation: incorrect value type", - Value: 123, - Error: "expected type of tag to be string", - }, - } - - for _, tt := range testCases { - t.Run(tt.Name, func(t *testing.T) { - _, errors := isDataType(tt.Value, key) - if tt.Error != "" { - assert.Len(t, errors, 1) - assert.ErrorContains(t, errors[0], tt.Error) - } else { - assert.Len(t, errors, 0) - } - }) - } -} diff --git a/pkg/resources/validators.go b/pkg/resources/validators.go new file mode 100644 index 0000000000..a517a7e0ab --- /dev/null +++ b/pkg/resources/validators.go @@ -0,0 +1,110 @@ +package resources + +import ( + "fmt" + "reflect" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/hashicorp/go-cty/cty" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" +) + +func IsDataType() schema.SchemaValidateFunc { //nolint:staticcheck + return func(value any, key string) (warnings []string, errors []error) { + stringValue, ok := value.(string) + if !ok { + errors = append(errors, fmt.Errorf("expected type of %s to be string, got %T", key, value)) + return warnings, errors + } + + _, err := sdk.ToDataType(stringValue) + if err != nil { + errors = append(errors, fmt.Errorf("expected %s to be one of %T values, got %s", key, sdk.DataTypeString, stringValue)) + } + + return warnings, errors + } +} + +// IsValidIdentifier is a validator that can be used for validating identifiers passed in resources and data sources. +// +// Typically, we expect passed identifiers to be a variation of sdk.ObjectIdentifier. +// For now, we're expecting implementations of sdk.ObjectIdentifier, because we won't support sdk.ExternalObjectIdentifiers. +// The reason behind it is that the functions that parse identifiers are not able to differentiate between +// sdk.ExternalObjectIdentifiers and sdk.DatabaseObjectIdentifier or sdk.SchemaObjectIdentifier. +// That's because sdk.ExternalObjectIdentifiers has varying parts count (2 or 3). +// +// To use this function, pass it as a validation function on identifier field with generic parameter set to the desired sdk.ObjectIdentifier implementation. +func IsValidIdentifier[T sdk.AccountObjectIdentifier | sdk.DatabaseObjectIdentifier | sdk.SchemaObjectIdentifier | sdk.TableColumnIdentifier]() schema.SchemaValidateDiagFunc { + return func(value any, path cty.Path) diag.Diagnostics { + var diags diag.Diagnostics + + if _, ok := value.(string); !ok { + diags = append(diags, diag.Diagnostic{ + Severity: diag.Error, + Summary: "Invalid schema identifier type", + Detail: fmt.Sprintf("Expected schema string type, but got: %T. This is a provider error please file a report: https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/new/choose", value), + AttributePath: path, + }) + return diags + } + + stringValue := value.(string) + id, err := helpers.DecodeSnowflakeParameterID(stringValue) + if err != nil { + diags = append(diags, diag.Diagnostic{ + Severity: diag.Error, + Summary: "Unable to parse the identifier", + Detail: fmt.Sprintf( + "Unable to parse the identifier: %s. Make sure you are using the correct form of the fully qualified name for this field: %s.\nOriginal Error: %s", + stringValue, + getExpectedIdentifierRepresentationFromGeneric[T](), + err.Error(), + ), + AttributePath: path, + }) + return diags + } + + if _, ok := id.(T); !ok { + diags = append(diags, diag.Diagnostic{ + Severity: diag.Error, + Summary: "Invalid identifier type", + Detail: fmt.Sprintf( + "Expected %s identifier type, but got: %T. The correct form of the fully qualified name for this field is: %s, but was %s", + reflect.TypeOf(new(T)).Elem().Name(), + id, + getExpectedIdentifierRepresentationFromGeneric[T](), + getExpectedIdentifierRepresentationFromParam(id), + ), + AttributePath: path, + }) + } + + return diags + } +} + +func getExpectedIdentifierRepresentationFromGeneric[T sdk.AccountObjectIdentifier | sdk.DatabaseObjectIdentifier | sdk.SchemaObjectIdentifier | sdk.TableColumnIdentifier]() string { + return getExpectedIdentifierForm(new(T)) +} + +func getExpectedIdentifierRepresentationFromParam(id sdk.ObjectIdentifier) string { + return getExpectedIdentifierForm(id) +} + +func getExpectedIdentifierForm(id any) string { + switch id.(type) { + case sdk.AccountObjectIdentifier, *sdk.AccountObjectIdentifier: + return "" + case sdk.DatabaseObjectIdentifier, *sdk.DatabaseObjectIdentifier: + return "." + case sdk.SchemaObjectIdentifier, *sdk.SchemaObjectIdentifier: + return ".." + case sdk.TableColumnIdentifier, *sdk.TableColumnIdentifier: + return "..." + } + return "" +} diff --git a/pkg/resources/validators_test.go b/pkg/resources/validators_test.go new file mode 100644 index 0000000000..3aa248d1bd --- /dev/null +++ b/pkg/resources/validators_test.go @@ -0,0 +1,214 @@ +package resources + +import ( + "testing" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/hashicorp/go-cty/cty" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" + "github.com/stretchr/testify/assert" +) + +func TestIsDataType(t *testing.T) { + isDataType := IsDataType() + key := "tag" + + testCases := []struct { + Name string + Value any + Error string + }{ + { + Name: "validation: correct DataType value", + Value: "NUMBER", + }, + { + Name: "validation: correct DataType value in lowercase", + Value: "number", + }, + { + Name: "validation: incorrect DataType value", + Value: "invalid data type", + Error: "expected tag to be one of", + }, + { + Name: "validation: incorrect value type", + Value: 123, + Error: "expected type of tag to be string", + }, + } + + for _, tt := range testCases { + t.Run(tt.Name, func(t *testing.T) { + _, errors := isDataType(tt.Value, key) + if tt.Error != "" { + assert.Len(t, errors, 1) + assert.ErrorContains(t, errors[0], tt.Error) + } else { + assert.Len(t, errors, 0) + } + }) + } +} + +func TestIsValidIdentifier(t *testing.T) { + accountObjectIdentifierCheck := IsValidIdentifier[sdk.AccountObjectIdentifier]() + databaseObjectIdentifierCheck := IsValidIdentifier[sdk.DatabaseObjectIdentifier]() + schemaObjectIdentifierCheck := IsValidIdentifier[sdk.SchemaObjectIdentifier]() + tableColumnIdentifierCheck := IsValidIdentifier[sdk.TableColumnIdentifier]() + + testCases := []struct { + Name string + Value any + Error string + CheckingFn schema.SchemaValidateDiagFunc + }{ + { + Name: "validation: invalid value type", + Value: 123, + Error: "Expected schema string type, but got: int", + CheckingFn: accountObjectIdentifierCheck, + }, + { + Name: "validation: invalid identifier representation", + Value: "", + Error: "Unable to parse the identifier: ", + CheckingFn: accountObjectIdentifierCheck, + }, + { + Name: "validation: incorrect form for account object identifier", + Value: "a.b", + Error: ", but was .", + CheckingFn: accountObjectIdentifierCheck, + }, + { + Name: "validation: incorrect form for database object identifier", + Value: "a.b.c", + Error: "., but was ..", + CheckingFn: databaseObjectIdentifierCheck, + }, + { + Name: "validation: incorrect form for schema object identifier", + Value: "a.b.c.d", + Error: ".., but was ...", + CheckingFn: schemaObjectIdentifierCheck, + }, + { + Name: "validation: incorrect form for table column identifier", + Value: "a", + Error: "..., but was ", + CheckingFn: tableColumnIdentifierCheck, + }, + { + Name: "correct form for account object identifier", + Value: "a", + CheckingFn: accountObjectIdentifierCheck, + }, + { + Name: "correct form for database object identifier", + Value: "a.b", + CheckingFn: databaseObjectIdentifierCheck, + }, + { + Name: "correct form for schema object identifier", + Value: "a.b.c", + CheckingFn: schemaObjectIdentifierCheck, + }, + { + Name: "correct form for table column identifier", + Value: "a.b.c.d", + CheckingFn: tableColumnIdentifierCheck, + }, + } + + for _, tt := range testCases { + t.Run(tt.Name, func(t *testing.T) { + diag := tt.CheckingFn(tt.Value, cty.IndexStringPath("path")) + if tt.Error != "" { + assert.Len(t, diag, 1) + assert.Contains(t, diag[0].Detail, tt.Error) + } else { + assert.Len(t, diag, 0) + } + }) + } +} + +func TestGetExpectedIdentifierFormGeneric(t *testing.T) { + testCases := []struct { + Name string + Expected string + Actual string + }{ + { + Name: "correct account object identifier from generic parameter", + Expected: "", + Actual: getExpectedIdentifierRepresentationFromGeneric[sdk.AccountObjectIdentifier](), + }, + { + Name: "correct database object identifier from generic parameter", + Expected: ".", + Actual: getExpectedIdentifierRepresentationFromGeneric[sdk.DatabaseObjectIdentifier](), + }, + { + Name: "correct schema object identifier from generic parameter", + Expected: "..", + Actual: getExpectedIdentifierRepresentationFromGeneric[sdk.SchemaObjectIdentifier](), + }, + { + Name: "correct table column identifier from generic parameter", + Expected: "...", + Actual: getExpectedIdentifierRepresentationFromGeneric[sdk.TableColumnIdentifier](), + }, + } + + for _, tt := range testCases { + t.Run(tt.Name, func(t *testing.T) { + assert.Equal(t, tt.Expected, tt.Actual) + }) + } +} + +func TestGetExpectedIdentifierFormParam(t *testing.T) { + testCases := []struct { + Name string + Expected string + Identifier sdk.ObjectIdentifier + IdentifierPointer sdk.ObjectIdentifier + }{ + { + Name: "correct account object identifier from function argument", + Expected: "", + Identifier: sdk.AccountObjectIdentifier{}, + IdentifierPointer: &sdk.AccountObjectIdentifier{}, + }, + { + Name: "correct database object identifier from function argument", + Expected: ".", + Identifier: sdk.DatabaseObjectIdentifier{}, + IdentifierPointer: &sdk.DatabaseObjectIdentifier{}, + }, + { + Name: "correct schema object identifier from function argument", + Expected: "..", + Identifier: sdk.SchemaObjectIdentifier{}, + IdentifierPointer: &sdk.SchemaObjectIdentifier{}, + }, + { + Name: "correct table column identifier from function argument", + Expected: "...", + Identifier: sdk.TableColumnIdentifier{}, + IdentifierPointer: &sdk.TableColumnIdentifier{}, + }, + } + + for _, tt := range testCases { + t.Run(tt.Name+" - non-pointer", func(t *testing.T) { + assert.Equal(t, tt.Expected, getExpectedIdentifierRepresentationFromParam(tt.Identifier)) + }) + + t.Run(tt.Name+" - pointer", func(t *testing.T) { + assert.Equal(t, tt.Expected, getExpectedIdentifierRepresentationFromParam(tt.IdentifierPointer)) + }) + } +}