diff --git a/pkg/internal/tracking/query.go b/pkg/internal/tracking/query.go index e49421b1a9..6a829bf9b3 100644 --- a/pkg/internal/tracking/query.go +++ b/pkg/internal/tracking/query.go @@ -6,6 +6,11 @@ import ( "strings" ) +func TrimMetadata(sql string) string { + queryParts := strings.Split(sql, fmt.Sprintf(" --%s", MetadataPrefix)) + return queryParts[0] +} + func AppendMetadata(sql string, metadata Metadata) (string, error) { bytes, err := json.Marshal(metadata) if err != nil { diff --git a/pkg/internal/tracking/query_test.go b/pkg/internal/tracking/query_test.go index 6d46162186..0261d77684 100644 --- a/pkg/internal/tracking/query_test.go +++ b/pkg/internal/tracking/query_test.go @@ -5,10 +5,47 @@ import ( "fmt" "testing" + "github.com/stretchr/testify/assert" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" "github.com/stretchr/testify/require" ) +func TestTrimMetadata(t *testing.T) { + testCases := []struct { + Input string + Expected string + }{ + { + Input: "select 1", + Expected: "select 1", + }, + { + Input: "select 1; --some comment", + Expected: "select 1; --some comment", + }, + { + Input: fmt.Sprintf("select 1; --%s", MetadataPrefix), + Expected: "select 1;", + }, + { + Input: fmt.Sprintf("select 1; --%s ", MetadataPrefix), + Expected: "select 1;", + }, + { + Input: fmt.Sprintf("select 1; --%s some text after", MetadataPrefix), + Expected: "select 1;", + }, + } + + for _, tc := range testCases { + t.Run("TrimMetadata: "+tc.Input, func(t *testing.T) { + trimmedInput := TrimMetadata(tc.Input) + assert.Equal(t, tc.Expected, trimmedInput) + }) + } +} + func TestAppendMetadata(t *testing.T) { metadata := NewMetadata("123", resources.Account, CreateOperation) sql := "SELECT 1" diff --git a/pkg/provider/resources/resources.go b/pkg/provider/resources/resources.go index dc4de69296..6991cbabe2 100644 --- a/pkg/provider/resources/resources.go +++ b/pkg/provider/resources/resources.go @@ -4,6 +4,9 @@ type resource string const ( Account resource = "snowflake_account" + AccountAuthenticationPolicyAttachment resource = "snowflake_account_authentication_policy_attachment" + AccountParameter resource = "snowflake_account_parameter" + AccountPasswordPolicyAttachment resource = "snowflake_account_password_policy_attachment" AccountRole resource = "snowflake_account_role" Alert resource = "snowflake_alert" ApiAuthenticationIntegrationWithAuthorizationCodeGrant resource = "snowflake_api_authentication_integration_with_authorization_code_grant" @@ -36,10 +39,13 @@ const ( MaskingPolicy resource = "snowflake_masking_policy" MaterializedView resource = "snowflake_materialized_view" NetworkPolicy resource = "snowflake_network_policy" + NetworkPolicyAttachment resource = "snowflake_network_policy_attachment" NetworkRule resource = "snowflake_network_rule" NotificationIntegration resource = "snowflake_notification_integration" + OauthIntegration resource = "snowflake_oauth_integration" OauthIntegrationForCustomClients resource = "snowflake_oauth_integration_for_custom_clients" OauthIntegrationForPartnerApplications resource = "snowflake_oauth_integration_for_partner_applications" + ObjectParameter resource = "snowflake_object_parameter" PasswordPolicy resource = "snowflake_password_policy" Pipe resource = "snowflake_pipe" PrimaryConnection resource = "snowflake_primary_connection" @@ -47,6 +53,7 @@ const ( ResourceMonitor resource = "snowflake_resource_monitor" Role resource = "snowflake_role" RowAccessPolicy resource = "snowflake_row_access_policy" + SamlSecurityIntegration resource = "snowflake_saml_integration" Saml2SecurityIntegration resource = "snowflake_saml2_integration" Schema resource = "snowflake_schema" ScimSecurityIntegration resource = "snowflake_scim_integration" @@ -56,6 +63,7 @@ const ( SecretWithBasicAuthentication resource = "snowflake_secret_with_basic_authentication" SecretWithClientCredentials resource = "snowflake_secret_with_client_credentials" SecretWithGenericString resource = "snowflake_secret_with_generic_string" + SessionParameter resource = "snowflake_session_parameter" Sequence resource = "snowflake_sequence" ServiceUser resource = "snowflake_service_user" Share resource = "snowflake_share" @@ -69,12 +77,17 @@ const ( StreamOnView resource = "snowflake_stream_on_view" Streamlit resource = "snowflake_streamlit" Table resource = "snowflake_table" + TableColumnMaskingPolicyApplication resource = "snowflake_table_column_masking_policy_application" + TableConstraint resource = "snowflake_table_constraint" Tag resource = "snowflake_tag" TagAssociation resource = "snowflake_tag_association" TagMaskingPolicyAssociation resource = "snowflake_tag_masking_policy_association" Task resource = "snowflake_task" UnsafeExecute resource = "snowflake_unsafe_execute" User resource = "snowflake_user" + UserAuthenticationPolicyAttachment resource = "snowflake_user_authentication_policy_attachment" + UserPasswordPolicyAttachment resource = "snowflake_user_password_policy_attachment" + UserPublicKeys resource = "snowflake_user_public_keys" View resource = "snowflake_view" Warehouse resource = "snowflake_warehouse" ) diff --git a/pkg/resources/account.go b/pkg/resources/account.go index dd58b80c5c..8b687c7cd8 100644 --- a/pkg/resources/account.go +++ b/pkg/resources/account.go @@ -7,6 +7,8 @@ import ( "strings" "time" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" @@ -212,11 +214,11 @@ var accountSchema = map[string]*schema.Schema{ func Account() *schema.Resource { return &schema.Resource{ - Description: "The account resource allows you to create and manage Snowflake accounts.", - Create: CreateAccount, - Read: ReadAccount, - Update: UpdateAccount, - Delete: DeleteAccount, + Description: "The account resource allows you to create and manage Snowflake accounts.", + CreateContext: TrackingCreateWrapper(resources.Account, CreateAccount), + ReadContext: TrackingReadWrapper(resources.Account, ReadAccount), + UpdateContext: TrackingUpdateWrapper(resources.Account, UpdateAccount), + DeleteContext: TrackingDeleteWrapper(resources.Account, DeleteAccount), CustomizeDiff: TrackingCustomDiffWrapper(resources.Account, customdiff.All( ComputedIfAnyAttributeChanged(accountSchema, FullyQualifiedNameAttributeName, "name"), @@ -230,9 +232,8 @@ func Account() *schema.Resource { } // CreateAccount implements schema.CreateFunc. -func CreateAccount(d *schema.ResourceData, meta interface{}) error { +func CreateAccount(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() name := d.Get("name").(string) objectIdentifier := sdk.NewAccountObjectIdentifier(name) @@ -267,7 +268,7 @@ func CreateAccount(d *schema.ResourceData, meta interface{}) error { // For organizations that have accounts in multiple region groups, returns . so we need to split on "." currentRegion, err := client.ContextFunctions.CurrentRegion(ctx) if err != nil { - return err + return diag.FromErr(err) } regionParts := strings.Split(currentRegion, ".") if len(regionParts) == 2 { @@ -280,7 +281,7 @@ func CreateAccount(d *schema.ResourceData, meta interface{}) error { // For organizations that have accounts in multiple region groups, returns . so we need to split on "." currentRegion, err := client.ContextFunctions.CurrentRegion(ctx) if err != nil { - return err + return diag.FromErr(err) } regionParts := strings.Split(currentRegion, ".") if len(regionParts) == 2 { @@ -295,7 +296,7 @@ func CreateAccount(d *schema.ResourceData, meta interface{}) error { err := client.Accounts.Create(ctx, objectIdentifier, createOptions) if err != nil { - return err + return diag.FromErr(err) } var account *sdk.Account @@ -308,17 +309,16 @@ func CreateAccount(d *schema.ResourceData, meta interface{}) error { return nil, true }) if err != nil { - return err + return diag.FromErr(err) } d.SetId(helpers.EncodeSnowflakeID(account.AccountLocator)) - return ReadAccount(d, meta) + return ReadAccount(ctx, d, meta) } // ReadAccount implements schema.ReadFunc. -func ReadAccount(d *schema.ResourceData, meta interface{}) error { +func ReadAccount(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() id := helpers.DecodeSnowflakeID(d.Id()).(sdk.AccountObjectIdentifier) @@ -333,42 +333,42 @@ func ReadAccount(d *schema.ResourceData, meta interface{}) error { return nil, true }) if err != nil { - return err + return diag.FromErr(err) } if err := d.Set(FullyQualifiedNameAttributeName, id.FullyQualifiedName()); err != nil { - return err + return diag.FromErr(err) } if err = d.Set("name", acc.AccountName); err != nil { - return fmt.Errorf("error setting name: %w", err) + return diag.FromErr(fmt.Errorf("error setting name: %w", err)) } if err = d.Set("edition", acc.Edition); err != nil { - return fmt.Errorf("error setting edition: %w", err) + return diag.FromErr(fmt.Errorf("error setting edition: %w", err)) } if err = d.Set("region_group", acc.RegionGroup); err != nil { - return fmt.Errorf("error setting region_group: %w", err) + return diag.FromErr(fmt.Errorf("error setting region_group: %w", err)) } if err = d.Set("region", acc.SnowflakeRegion); err != nil { - return fmt.Errorf("error setting region: %w", err) + return diag.FromErr(fmt.Errorf("error setting region: %w", err)) } if err = d.Set("comment", acc.Comment); err != nil { - return fmt.Errorf("error setting comment: %w", err) + return diag.FromErr(fmt.Errorf("error setting comment: %w", err)) } if err = d.Set("is_org_admin", acc.IsOrgAdmin); err != nil { - return fmt.Errorf("error setting is_org_admin: %w", err) + return diag.FromErr(fmt.Errorf("error setting is_org_admin: %w", err)) } return nil } // UpdateAccount implements schema.UpdateFunc. -func UpdateAccount(d *schema.ResourceData, meta interface{}) error { +func UpdateAccount(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { /* todo: comments may eventually work again for accounts, so this can be uncommented when that happens client := meta.(*provider.Context).Client @@ -394,12 +394,11 @@ func UpdateAccount(d *schema.ResourceData, meta interface{}) error { } // DeleteAccount implements schema.DeleteFunc. -func DeleteAccount(d *schema.ResourceData, meta interface{}) error { +func DeleteAccount(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() gracePeriodInDays := d.Get("grace_period_in_days").(int) err := client.Accounts.Drop(ctx, helpers.DecodeSnowflakeID(d.Id()).(sdk.AccountObjectIdentifier), gracePeriodInDays, &sdk.DropAccountOptions{ IfExists: sdk.Bool(true), }) - return err + return diag.FromErr(err) } diff --git a/pkg/resources/account_authentication_policy_attachment.go b/pkg/resources/account_authentication_policy_attachment.go index 628cee492c..bb61f3d215 100644 --- a/pkg/resources/account_authentication_policy_attachment.go +++ b/pkg/resources/account_authentication_policy_attachment.go @@ -4,6 +4,9 @@ import ( "context" "fmt" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" @@ -26,9 +29,9 @@ func AccountAuthenticationPolicyAttachment() *schema.Resource { return &schema.Resource{ Description: "Specifies the authentication policy to use for the current account. To set the authentication policy of a different account, use a provider alias.", - Create: CreateAccountAuthenticationPolicyAttachment, - Read: ReadAccountAuthenticationPolicyAttachment, - Delete: DeleteAccountAuthenticationPolicyAttachment, + CreateContext: TrackingCreateWrapper(resources.AccountAuthenticationPolicyAttachment, CreateAccountAuthenticationPolicyAttachment), + ReadContext: TrackingReadWrapper(resources.AccountAuthenticationPolicyAttachment, ReadAccountAuthenticationPolicyAttachment), + DeleteContext: TrackingDeleteWrapper(resources.AccountAuthenticationPolicyAttachment, DeleteAccountAuthenticationPolicyAttachment), Schema: accountAuthenticationPolicyAttachmentSchema, Importer: &schema.ResourceImporter{ @@ -38,13 +41,12 @@ func AccountAuthenticationPolicyAttachment() *schema.Resource { } // CreateAccountAuthenticationPolicyAttachment implements schema.CreateFunc. -func CreateAccountAuthenticationPolicyAttachment(d *schema.ResourceData, meta interface{}) error { +func CreateAccountAuthenticationPolicyAttachment(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() authenticationPolicy, ok := sdk.NewObjectIdentifierFromFullyQualifiedName(d.Get("authentication_policy").(string)).(sdk.SchemaObjectIdentifier) if !ok { - return fmt.Errorf("authentication_policy %s is not a valid authentication policy qualified name, expected format: `\"db\".\"schema\".\"policy\"`", d.Get("authentication_policy")) + return diag.FromErr(fmt.Errorf("authentication_policy %s is not a valid authentication policy qualified name, expected format: `\"db\".\"schema\".\"policy\"`", d.Get("authentication_policy"))) } err := client.Accounts.Alter(ctx, &sdk.AlterAccountOptions{ @@ -53,27 +55,26 @@ func CreateAccountAuthenticationPolicyAttachment(d *schema.ResourceData, meta in }, }) if err != nil { - return err + return diag.FromErr(err) } d.SetId(helpers.EncodeSnowflakeID(authenticationPolicy)) - return ReadAccountAuthenticationPolicyAttachment(d, meta) + return ReadAccountAuthenticationPolicyAttachment(ctx, d, meta) } -func ReadAccountAuthenticationPolicyAttachment(d *schema.ResourceData, meta interface{}) error { +func ReadAccountAuthenticationPolicyAttachment(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { authenticationPolicy := helpers.DecodeSnowflakeID(d.Id()) if err := d.Set("authentication_policy", authenticationPolicy.FullyQualifiedName()); err != nil { - return err + return diag.FromErr(err) } return nil } // DeleteAccountAuthenticationPolicyAttachment implements schema.DeleteFunc. -func DeleteAccountAuthenticationPolicyAttachment(d *schema.ResourceData, meta interface{}) error { +func DeleteAccountAuthenticationPolicyAttachment(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() err := client.Accounts.Alter(ctx, &sdk.AlterAccountOptions{ Unset: &sdk.AccountUnset{ @@ -81,7 +82,7 @@ func DeleteAccountAuthenticationPolicyAttachment(d *schema.ResourceData, meta in }, }) if err != nil { - return err + return diag.FromErr(err) } return nil diff --git a/pkg/resources/account_parameter.go b/pkg/resources/account_parameter.go index ccff34a089..95cea28579 100644 --- a/pkg/resources/account_parameter.go +++ b/pkg/resources/account_parameter.go @@ -4,6 +4,9 @@ import ( "context" "fmt" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" @@ -26,10 +29,10 @@ var accountParameterSchema = map[string]*schema.Schema{ func AccountParameter() *schema.Resource { return &schema.Resource{ - Create: CreateAccountParameter, - Read: ReadAccountParameter, - Update: UpdateAccountParameter, - Delete: DeleteAccountParameter, + CreateContext: TrackingCreateWrapper(resources.AccountParameter, CreateAccountParameter), + ReadContext: TrackingReadWrapper(resources.AccountParameter, ReadAccountParameter), + UpdateContext: TrackingUpdateWrapper(resources.AccountParameter, UpdateAccountParameter), + DeleteContext: TrackingDeleteWrapper(resources.AccountParameter, DeleteAccountParameter), Schema: accountParameterSchema, Importer: &schema.ResourceImporter{ @@ -39,59 +42,56 @@ func AccountParameter() *schema.Resource { } // CreateAccountParameter implements schema.CreateFunc. -func CreateAccountParameter(d *schema.ResourceData, meta interface{}) error { +func CreateAccountParameter(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client key := d.Get("key").(string) value := d.Get("value").(string) - ctx := context.Background() parameter := sdk.AccountParameter(key) err := client.Parameters.SetAccountParameter(ctx, parameter, value) if err != nil { - return err + return diag.FromErr(err) } d.SetId(key) - return ReadAccountParameter(d, meta) + return ReadAccountParameter(ctx, d, meta) } // ReadAccountParameter implements schema.ReadFunc. -func ReadAccountParameter(d *schema.ResourceData, meta interface{}) error { +func ReadAccountParameter(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() parameterName := d.Id() parameter, err := client.Parameters.ShowAccountParameter(ctx, sdk.AccountParameter(parameterName)) if err != nil { - return fmt.Errorf("error reading account parameter err = %w", err) + return diag.FromErr(fmt.Errorf("error reading account parameter err = %w", err)) } err = d.Set("value", parameter.Value) if err != nil { - return fmt.Errorf("error setting account parameter err = %w", err) + return diag.FromErr(fmt.Errorf("error setting account parameter err = %w", err)) } err = d.Set("key", parameter.Key) if err != nil { - return fmt.Errorf("error setting account parameter err = %w", err) + return diag.FromErr(fmt.Errorf("error setting account parameter err = %w", err)) } return nil } // UpdateAccountParameter implements schema.UpdateFunc. -func UpdateAccountParameter(d *schema.ResourceData, meta interface{}) error { - return CreateAccountParameter(d, meta) +func UpdateAccountParameter(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { + return CreateAccountParameter(ctx, d, meta) } // DeleteAccountParameter implements schema.DeleteFunc. -func DeleteAccountParameter(d *schema.ResourceData, meta interface{}) error { +func DeleteAccountParameter(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client key := d.Get("key").(string) - ctx := context.Background() parameter := sdk.AccountParameter(key) defaultParameter, err := client.Parameters.ShowAccountParameter(ctx, sdk.AccountParameter(key)) if err != nil { - return err + return diag.FromErr(err) } defaultValue := defaultParameter.Default err = client.Parameters.SetAccountParameter(ctx, parameter, defaultValue) if err != nil { - return fmt.Errorf("error resetting account parameter err = %w", err) + return diag.FromErr(fmt.Errorf("error resetting account parameter err = %w", err)) } d.SetId("") diff --git a/pkg/resources/account_password_policy_attachment.go b/pkg/resources/account_password_policy_attachment.go index 245b2d33c2..03375b0c75 100644 --- a/pkg/resources/account_password_policy_attachment.go +++ b/pkg/resources/account_password_policy_attachment.go @@ -4,6 +4,9 @@ import ( "context" "fmt" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" @@ -26,9 +29,9 @@ func AccountPasswordPolicyAttachment() *schema.Resource { return &schema.Resource{ Description: "Specifies the password policy to use for the current account. To set the password policy of a different account, use a provider alias.", - Create: CreateAccountPasswordPolicyAttachment, - Read: ReadAccountPasswordPolicyAttachment, - Delete: DeleteAccountPasswordPolicyAttachment, + CreateContext: TrackingCreateWrapper(resources.AccountPasswordPolicyAttachment, CreateAccountPasswordPolicyAttachment), + ReadContext: TrackingReadWrapper(resources.AccountPasswordPolicyAttachment, ReadAccountPasswordPolicyAttachment), + DeleteContext: TrackingDeleteWrapper(resources.AccountPasswordPolicyAttachment, DeleteAccountPasswordPolicyAttachment), Schema: accountPasswordPolicyAttachmentSchema, Importer: &schema.ResourceImporter{ @@ -38,13 +41,12 @@ func AccountPasswordPolicyAttachment() *schema.Resource { } // CreateAccountPasswordPolicyAttachment implements schema.CreateFunc. -func CreateAccountPasswordPolicyAttachment(d *schema.ResourceData, meta interface{}) error { +func CreateAccountPasswordPolicyAttachment(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() passwordPolicy, ok := sdk.NewObjectIdentifierFromFullyQualifiedName(d.Get("password_policy").(string)).(sdk.SchemaObjectIdentifier) if !ok { - return fmt.Errorf("password_policy %s is not a valid password policy qualified name, expected format: `\"db\".\"schema\".\"policy\"`", d.Get("password_policy")) + return diag.FromErr(fmt.Errorf("password_policy %s is not a valid password policy qualified name, expected format: `\"db\".\"schema\".\"policy\"`", d.Get("password_policy"))) } err := client.Accounts.Alter(ctx, &sdk.AlterAccountOptions{ @@ -53,27 +55,26 @@ func CreateAccountPasswordPolicyAttachment(d *schema.ResourceData, meta interfac }, }) if err != nil { - return err + return diag.FromErr(err) } d.SetId(helpers.EncodeSnowflakeID(passwordPolicy)) - return ReadAccountPasswordPolicyAttachment(d, meta) + return ReadAccountPasswordPolicyAttachment(ctx, d, meta) } -func ReadAccountPasswordPolicyAttachment(d *schema.ResourceData, meta interface{}) error { +func ReadAccountPasswordPolicyAttachment(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { passwordPolicy := helpers.DecodeSnowflakeID(d.Id()) if err := d.Set("password_policy", passwordPolicy.FullyQualifiedName()); err != nil { - return err + return diag.FromErr(err) } return nil } // DeleteAccountPasswordPolicyAttachment implements schema.DeleteFunc. -func DeleteAccountPasswordPolicyAttachment(d *schema.ResourceData, meta interface{}) error { +func DeleteAccountPasswordPolicyAttachment(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() err := client.Accounts.Alter(ctx, &sdk.AlterAccountOptions{ Unset: &sdk.AccountUnset{ @@ -81,7 +82,7 @@ func DeleteAccountPasswordPolicyAttachment(d *schema.ResourceData, meta interfac }, }) if err != nil { - return err + return diag.FromErr(err) } return nil diff --git a/pkg/resources/alert.go b/pkg/resources/alert.go index 2c16b7c8c7..6ca90697db 100644 --- a/pkg/resources/alert.go +++ b/pkg/resources/alert.go @@ -119,7 +119,7 @@ func Alert() *schema.Resource { Schema: alertSchema, Importer: &schema.ResourceImporter{ - StateContext: TrackingImportWrapper(resources.Alert, schema.ImportStatePassthroughContext), + StateContext: schema.ImportStatePassthroughContext, }, } } diff --git a/pkg/resources/api_integration.go b/pkg/resources/api_integration.go index c6d8ffc8d7..fee0a6312c 100644 --- a/pkg/resources/api_integration.go +++ b/pkg/resources/api_integration.go @@ -6,6 +6,9 @@ import ( "log" "strings" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/schemas" @@ -121,10 +124,10 @@ var apiIntegrationSchema = map[string]*schema.Schema{ // APIIntegration returns a pointer to the resource representing an api integration. func APIIntegration() *schema.Resource { return &schema.Resource{ - Create: CreateAPIIntegration, - Read: ReadAPIIntegration, - Update: UpdateAPIIntegration, - Delete: DeleteAPIIntegration, + CreateContext: TrackingCreateWrapper(resources.ApiIntegration, CreateAPIIntegration), + ReadContext: TrackingReadWrapper(resources.ApiIntegration, ReadAPIIntegration), + UpdateContext: TrackingUpdateWrapper(resources.ApiIntegration, UpdateAPIIntegration), + DeleteContext: TrackingDeleteWrapper(resources.ApiIntegration, DeleteAPIIntegration), Schema: apiIntegrationSchema, Importer: &schema.ResourceImporter{ @@ -142,9 +145,8 @@ func toApiIntegrationEndpointPrefix(paths []string) []sdk.ApiIntegrationEndpoint } // CreateAPIIntegration implements schema.CreateFunc. -func CreateAPIIntegration(d *schema.ResourceData, meta interface{}) error { +func CreateAPIIntegration(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() name := d.Get("name").(string) id := sdk.NewAccountObjectIdentifier(name) @@ -169,7 +171,7 @@ func CreateAPIIntegration(d *schema.ResourceData, meta interface{}) error { case "aws_api_gateway", "aws_private_api_gateway", "aws_gov_api_gateway", "aws_gov_private_api_gateway": roleArn, ok := d.GetOk("api_aws_role_arn") if !ok { - return fmt.Errorf("if you use AWS api provider you must specify an api_aws_role_arn") + return diag.FromErr(fmt.Errorf("if you use AWS api provider you must specify an api_aws_role_arn")) } awsParams := sdk.NewAwsApiParamsRequest(sdk.ApiIntegrationAwsApiProviderType(apiProvider), roleArn.(string)) if v, ok := d.GetOk("api_key"); ok { @@ -179,11 +181,11 @@ func CreateAPIIntegration(d *schema.ResourceData, meta interface{}) error { case "azure_api_management": tenantId, ok := d.GetOk("azure_tenant_id") if !ok { - return fmt.Errorf("if you use the Azure api provider you must specify an azure_tenant_id") + return diag.FromErr(fmt.Errorf("if you use the Azure api provider you must specify an azure_tenant_id")) } applicationId, ok := d.GetOk("azure_ad_application_id") if !ok { - return fmt.Errorf("if you use the Azure api provider you must specify an azure_ad_application_id") + return diag.FromErr(fmt.Errorf("if you use the Azure api provider you must specify an azure_ad_application_id")) } azureParams := sdk.NewAzureApiParamsRequest(tenantId.(string), applicationId.(string)) if v, ok := d.GetOk("api_key"); ok { @@ -193,66 +195,65 @@ func CreateAPIIntegration(d *schema.ResourceData, meta interface{}) error { case "google_api_gateway": audience, ok := d.GetOk("google_audience") if !ok { - return fmt.Errorf("if you use GCP api provider you must specify a google_audience") + return diag.FromErr(fmt.Errorf("if you use GCP api provider you must specify a google_audience")) } googleParams := sdk.NewGoogleApiParamsRequest(audience.(string)) createRequest.WithGoogleApiProviderParams(googleParams) default: - return fmt.Errorf("unexpected provider %v", apiProvider) + return diag.FromErr(fmt.Errorf("unexpected provider %v", apiProvider)) } err := client.ApiIntegrations.Create(ctx, createRequest) if err != nil { - return fmt.Errorf("error creating api integration: %w", err) + return diag.FromErr(fmt.Errorf("error creating api integration: %w", err)) } d.SetId(helpers.EncodeSnowflakeID(id)) - return ReadAPIIntegration(d, meta) + return ReadAPIIntegration(ctx, d, meta) } // ReadAPIIntegration implements schema.ReadFunc. -func ReadAPIIntegration(d *schema.ResourceData, meta interface{}) error { +func ReadAPIIntegration(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() id := helpers.DecodeSnowflakeID(d.Id()).(sdk.AccountObjectIdentifier) integration, err := client.ApiIntegrations.ShowByID(ctx, id) if err != nil { log.Printf("[DEBUG] api integration (%s) not found", d.Id()) d.SetId("") - return err + return diag.FromErr(err) } // Note: category must be API or something is broken if c := integration.Category; c != "API" { - return fmt.Errorf("expected %v to be an api integration, got %v", id, c) + return diag.FromErr(fmt.Errorf("expected %v to be an api integration, got %v", id, c)) } if err := d.Set(FullyQualifiedNameAttributeName, id.FullyQualifiedName()); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("name", integration.Name); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("comment", integration.Comment); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("created_on", integration.CreatedOn.String()); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("enabled", integration.Enabled); err != nil { - return err + return diag.FromErr(err) } // Some properties come from the DESCRIBE INTEGRATION call integrationProperties, err := client.ApiIntegrations.Describe(ctx, id) if err != nil { - return fmt.Errorf("could not describe api integration: %w", err) + return diag.FromErr(fmt.Errorf("could not describe api integration: %w", err)) } for _, property := range integrationProperties { @@ -263,66 +264,65 @@ func ReadAPIIntegration(d *schema.ResourceData, meta interface{}) error { // We set this using the SHOW INTEGRATION call so let's ignore it here case "API_ALLOWED_PREFIXES": if err := d.Set("api_allowed_prefixes", strings.Split(value, ",")); err != nil { - return err + return diag.FromErr(err) } case "API_BLOCKED_PREFIXES": if val := value; val != "" { if err := d.Set("api_blocked_prefixes", strings.Split(val, ",")); err != nil { - return err + return diag.FromErr(err) } } case "API_AWS_IAM_USER_ARN": if err := d.Set("api_aws_iam_user_arn", value); err != nil { - return err + return diag.FromErr(err) } case "API_AWS_ROLE_ARN": if err := d.Set("api_aws_role_arn", value); err != nil { - return err + return diag.FromErr(err) } case "API_AWS_EXTERNAL_ID": if err := d.Set("api_aws_external_id", value); err != nil { - return err + return diag.FromErr(err) } case "AZURE_CONSENT_URL": if err := d.Set("azure_consent_url", value); err != nil { - return err + return diag.FromErr(err) } case "AZURE_MULTI_TENANT_APP_NAME": if err := d.Set("azure_multi_tenant_app_name", value); err != nil { - return err + return diag.FromErr(err) } case "AZURE_TENANT_ID": if err := d.Set("azure_tenant_id", value); err != nil { - return err + return diag.FromErr(err) } case "AZURE_AD_APPLICATION_ID": if err := d.Set("azure_ad_application_id", value); err != nil { - return err + return diag.FromErr(err) } case "GOOGLE_AUDIENCE": if err := d.Set("google_audience", value); err != nil { - return err + return diag.FromErr(err) } case "API_GCP_SERVICE_ACCOUNT": if err := d.Set("api_gcp_service_account", value); err != nil { - return err + return diag.FromErr(err) } case "API_PROVIDER": if err := d.Set("api_provider", strings.ToLower(value)); err != nil { - return err + return diag.FromErr(err) } default: log.Printf("[WARN] unexpected api integration property %v returned from Snowflake", name) } } - return err + return diag.FromErr(err) } // UpdateAPIIntegration implements schema.UpdateFunc. -func UpdateAPIIntegration(d *schema.ResourceData, meta interface{}) error { +func UpdateAPIIntegration(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() id := helpers.DecodeSnowflakeID(d.Id()).(sdk.AccountObjectIdentifier) var runSetStatement bool @@ -348,7 +348,7 @@ func UpdateAPIIntegration(d *schema.ResourceData, meta interface{}) error { if len(v) == 0 { err := client.ApiIntegrations.Alter(ctx, sdk.NewAlterApiIntegrationRequest(id).WithUnset(sdk.NewApiIntegrationUnsetRequest().WithApiBlockedPrefixes(sdk.Bool(true)))) if err != nil { - return fmt.Errorf("error unsetting api_blocked_prefixes: %w", err) + return diag.FromErr(fmt.Errorf("error unsetting api_blocked_prefixes: %w", err)) } } else { runSetStatement = true @@ -392,28 +392,27 @@ func UpdateAPIIntegration(d *schema.ResourceData, meta interface{}) error { setRequest.WithGoogleParams(googleParams) } default: - return fmt.Errorf("unexpected provider %v", apiProvider) + return diag.FromErr(fmt.Errorf("unexpected provider %v", apiProvider)) } if runSetStatement { err := client.ApiIntegrations.Alter(ctx, sdk.NewAlterApiIntegrationRequest(id).WithSet(setRequest)) if err != nil { - return fmt.Errorf("error updating api integration: %w", err) + return diag.FromErr(fmt.Errorf("error updating api integration: %w", err)) } } - return ReadAPIIntegration(d, meta) + return ReadAPIIntegration(ctx, d, meta) } // DeleteAPIIntegration implements schema.DeleteFunc. -func DeleteAPIIntegration(d *schema.ResourceData, meta interface{}) error { +func DeleteAPIIntegration(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() id := helpers.DecodeSnowflakeID(d.Id()).(sdk.AccountObjectIdentifier) err := client.ApiIntegrations.Drop(ctx, sdk.NewDropApiIntegrationRequest(id)) if err != nil { - return err + return diag.FromErr(err) } d.SetId("") diff --git a/pkg/resources/database_old.go b/pkg/resources/database_old.go index 15d4fca440..6c6b6c3641 100644 --- a/pkg/resources/database_old.go +++ b/pkg/resources/database_old.go @@ -7,6 +7,8 @@ import ( "slices" "strconv" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" @@ -90,10 +92,10 @@ var databaseOldSchema = map[string]*schema.Schema{ // Database returns a pointer to the resource representing a database. func DatabaseOld() *schema.Resource { return &schema.Resource{ - Create: CreateDatabaseOld, - Read: ReadDatabaseOld, - Delete: DeleteDatabaseOld, - Update: UpdateDatabaseOld, + CreateContext: TrackingCreateWrapper(resources.DatabaseOld, CreateDatabaseOld), + ReadContext: TrackingReadWrapper(resources.DatabaseOld, ReadDatabaseOld), + DeleteContext: TrackingDeleteWrapper(resources.DatabaseOld, DeleteDatabaseOld), + UpdateContext: TrackingUpdateWrapper(resources.DatabaseOld, UpdateDatabaseOld), DeprecationMessage: "This resource is deprecated and will be removed in a future major version release. Please use snowflake_database or snowflake_shared_database or snowflake_secondary_database instead.", Schema: databaseOldSchema, @@ -104,9 +106,8 @@ func DatabaseOld() *schema.Resource { } // CreateDatabase implements schema.CreateFunc. -func CreateDatabaseOld(d *schema.ResourceData, meta interface{}) error { +func CreateDatabaseOld(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() name := d.Get("name").(string) id := sdk.NewAccountObjectIdentifier(name) @@ -121,10 +122,10 @@ func CreateDatabaseOld(d *schema.ResourceData, meta interface{}) error { } err := client.Databases.CreateShared(ctx, id, shareID, opts) if err != nil { - return fmt.Errorf("error creating database %v: %w", name, err) + return diag.FromErr(fmt.Errorf("error creating database %v: %w", name, err)) } d.SetId(name) - return ReadDatabaseOld(d, meta) + return ReadDatabaseOld(ctx, d, meta) } // Is it a Secondary Database? if primaryName, ok := d.GetOk("from_replica"); ok { @@ -135,11 +136,11 @@ func CreateDatabaseOld(d *schema.ResourceData, meta interface{}) error { } err := client.Databases.CreateSecondary(ctx, id, primaryID, opts) if err != nil { - return fmt.Errorf("error creating database %v: %w", name, err) + return diag.FromErr(fmt.Errorf("error creating database %v: %w", name, err)) } d.SetId(name) // todo: add failover_configuration block - return ReadDatabaseOld(d, meta) + return ReadDatabaseOld(ctx, d, meta) } // Otherwise it is a Standard Database @@ -164,7 +165,7 @@ func CreateDatabaseOld(d *schema.ResourceData, meta interface{}) error { err := client.Databases.Create(ctx, id, &opts) if err != nil { - return fmt.Errorf("error creating database %v: %w", name, err) + return diag.FromErr(fmt.Errorf("error creating database %v: %w", name, err)) } d.SetId(name) @@ -185,16 +186,15 @@ func CreateDatabaseOld(d *schema.ResourceData, meta interface{}) error { } err := client.Databases.AlterReplication(ctx, id, opts) if err != nil { - return fmt.Errorf("error enabling replication for database %v: %w", name, err) + return diag.FromErr(fmt.Errorf("error enabling replication for database %v: %w", name, err)) } } - return ReadDatabaseOld(d, meta) + return ReadDatabaseOld(ctx, d, meta) } -func ReadDatabaseOld(d *schema.ResourceData, meta interface{}) error { +func ReadDatabaseOld(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() id := helpers.DecodeSnowflakeID(d.Id()).(sdk.AccountObjectIdentifier) database, err := client.Databases.ShowByID(ctx, id) @@ -205,35 +205,34 @@ func ReadDatabaseOld(d *schema.ResourceData, meta interface{}) error { } if err := d.Set("comment", database.Comment); err != nil { - return err + return diag.FromErr(err) } dataRetention, err := client.Parameters.ShowAccountParameter(ctx, sdk.AccountParameterDataRetentionTimeInDays) if err != nil { - return err + return diag.FromErr(err) } paramDataRetention, err := strconv.Atoi(dataRetention.Value) if err != nil { - return err + return diag.FromErr(err) } if dataRetentionDays := d.Get("data_retention_time_in_days"); dataRetentionDays.(int) != IntDefault || database.RetentionTime != paramDataRetention { if err := d.Set("data_retention_time_in_days", database.RetentionTime); err != nil { - return err + return diag.FromErr(err) } } if err := d.Set("is_transient", database.Transient); err != nil { - return err + return diag.FromErr(err) } return nil } -func UpdateDatabaseOld(d *schema.ResourceData, meta interface{}) error { +func UpdateDatabaseOld(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { id := helpers.DecodeSnowflakeID(d.Id()).(sdk.AccountObjectIdentifier) client := meta.(*provider.Context).Client - ctx := context.Background() if d.HasChange("name") { newName := d.Get("name").(string) @@ -243,7 +242,7 @@ func UpdateDatabaseOld(d *schema.ResourceData, meta interface{}) error { } err := client.Databases.Alter(ctx, id, opts) if err != nil { - return fmt.Errorf("error updating database name on %v err = %w", d.Id(), err) + return diag.FromErr(fmt.Errorf("error updating database name on %v err = %w", d.Id(), err)) } d.SetId(helpers.EncodeSnowflakeID(newId)) id = newId @@ -261,7 +260,7 @@ func UpdateDatabaseOld(d *schema.ResourceData, meta interface{}) error { } err := client.Databases.Alter(ctx, id, opts) if err != nil { - return fmt.Errorf("error updating database comment on %v err = %w", d.Id(), err) + return diag.FromErr(fmt.Errorf("error updating database comment on %v err = %w", d.Id(), err)) } } @@ -273,7 +272,7 @@ func UpdateDatabaseOld(d *schema.ResourceData, meta interface{}) error { }, }) if err != nil { - return fmt.Errorf("error when setting database data retention time on %v err = %w", d.Id(), err) + return diag.FromErr(fmt.Errorf("error when setting database data retention time on %v err = %w", d.Id(), err)) } } else { err := client.Databases.Alter(ctx, id, &sdk.AlterDatabaseOptions{ @@ -282,7 +281,7 @@ func UpdateDatabaseOld(d *schema.ResourceData, meta interface{}) error { }, }) if err != nil { - return fmt.Errorf("error when usetting database data retention time on %v err = %w", d.Id(), err) + return diag.FromErr(fmt.Errorf("error when usetting database data retention time on %v err = %w", d.Id(), err)) } } } @@ -335,7 +334,7 @@ func UpdateDatabaseOld(d *schema.ResourceData, meta interface{}) error { } err := client.Databases.AlterReplication(ctx, id, opts) if err != nil { - return fmt.Errorf("error enabling replication configuration on %v err = %w", d.Id(), err) + return diag.FromErr(fmt.Errorf("error enabling replication configuration on %v err = %w", d.Id(), err)) } } @@ -347,23 +346,22 @@ func UpdateDatabaseOld(d *schema.ResourceData, meta interface{}) error { } err := client.Databases.AlterReplication(ctx, id, opts) if err != nil { - return fmt.Errorf("error disabling replication configuration on %v err = %w", d.Id(), err) + return diag.FromErr(fmt.Errorf("error disabling replication configuration on %v err = %w", d.Id(), err)) } } } - return ReadDatabaseOld(d, meta) + return ReadDatabaseOld(ctx, d, meta) } -func DeleteDatabaseOld(d *schema.ResourceData, meta interface{}) error { +func DeleteDatabaseOld(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() id := helpers.DecodeSnowflakeID(d.Id()).(sdk.AccountObjectIdentifier) err := client.Databases.Drop(ctx, id, &sdk.DropDatabaseOptions{ IfExists: sdk.Bool(true), }) if err != nil { - return err + return diag.FromErr(err) } d.SetId("") return nil diff --git a/pkg/resources/dynamic_table.go b/pkg/resources/dynamic_table.go index e8b64b64e4..72446d5b91 100644 --- a/pkg/resources/dynamic_table.go +++ b/pkg/resources/dynamic_table.go @@ -7,6 +7,9 @@ import ( "strings" "time" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/schemas" @@ -163,10 +166,10 @@ var dynamicTableSchema = map[string]*schema.Schema{ // DynamicTable returns a pointer to the resource representing a dynamic table. func DynamicTable() *schema.Resource { return &schema.Resource{ - Create: CreateDynamicTable, - Read: ReadDynamicTable, - Update: UpdateDynamicTable, - Delete: DeleteDynamicTable, + CreateContext: TrackingCreateWrapper(resources.DynamicTable, CreateDynamicTable), + ReadContext: TrackingReadWrapper(resources.DynamicTable, ReadDynamicTable), + UpdateContext: TrackingUpdateWrapper(resources.DynamicTable, UpdateDynamicTable), + DeleteContext: TrackingDeleteWrapper(resources.DynamicTable, DeleteDynamicTable), Schema: dynamicTableSchema, Importer: &schema.ResourceImporter{ @@ -176,93 +179,93 @@ func DynamicTable() *schema.Resource { } // ReadDynamicTable implements schema.ReadFunc. -func ReadDynamicTable(d *schema.ResourceData, meta interface{}) error { +func ReadDynamicTable(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client id := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) - dynamicTable, err := client.DynamicTables.ShowByID(context.Background(), id) + dynamicTable, err := client.DynamicTables.ShowByID(ctx, id) if err != nil { log.Printf("[DEBUG] dynamic table (%s) not found", d.Id()) d.SetId("") return nil } if err := d.Set(FullyQualifiedNameAttributeName, id.FullyQualifiedName()); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("name", dynamicTable.Name); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("database", dynamicTable.DatabaseName); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("schema", dynamicTable.SchemaName); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("warehouse", dynamicTable.Warehouse); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("comment", dynamicTable.Comment); err != nil { - return err + return diag.FromErr(err) } tl := map[string]interface{}{} if dynamicTable.TargetLag == "DOWNSTREAM" { tl["downstream"] = true if err := d.Set("target_lag", []interface{}{tl}); err != nil { - return err + return diag.FromErr(err) } } else { tl["maximum_duration"] = dynamicTable.TargetLag if err := d.Set("target_lag", []interface{}{tl}); err != nil { - return err + return diag.FromErr(err) } } if strings.Contains(dynamicTable.Text, "OR REPLACE") { if err := d.Set("or_replace", true); err != nil { - return err + return diag.FromErr(err) } } else { if err := d.Set("or_replace", false); err != nil { - return err + return diag.FromErr(err) } } if strings.Contains(dynamicTable.Text, "initialize = 'ON_CREATE'") { if err := d.Set("initialize", "ON_CREATE"); err != nil { - return err + return diag.FromErr(err) } } else if strings.Contains(dynamicTable.Text, "initialize = 'ON_SCHEDULE'") { if err := d.Set("initialize", "ON_SCHEDULE"); err != nil { - return err + return diag.FromErr(err) } } m := refreshModePattern.FindStringSubmatch(dynamicTable.Text) if len(m) > 1 { if err := d.Set("refresh_mode", m[1]); err != nil { - return err + return diag.FromErr(err) } } if err := d.Set("created_on", dynamicTable.CreatedOn.Format(time.RFC3339)); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("cluster_by", dynamicTable.ClusterBy); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("rows", dynamicTable.Rows); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("bytes", dynamicTable.Bytes); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("owner", dynamicTable.Owner); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("refresh_mode_reason", dynamicTable.RefreshModeReason); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("automatic_clustering", dynamicTable.AutomaticClustering); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("scheduling_state", string(dynamicTable.SchedulingState)); err != nil { - return err + return diag.FromErr(err) } /* guides on time formatting @@ -271,25 +274,25 @@ func ReadDynamicTable(d *schema.ResourceData, meta interface{}) error { note: format may depend on what the account parameter for TIMESTAMP_OUTPUT_FORMAT is set to. Perhaps we should return this as a string rather than a time.Time? */ if err := d.Set("last_suspended_on", dynamicTable.LastSuspendedOn.Format("2006-01-02T16:04:05.000 -0700")); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("is_clone", dynamicTable.IsClone); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("is_replica", dynamicTable.IsReplica); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("data_timestamp", dynamicTable.DataTimestamp.Format("2006-01-02T16:04:05.000 -0700")); err != nil { - return err + return diag.FromErr(err) } extractor := snowflake.NewViewSelectStatementExtractor(dynamicTable.Text) query, err := extractor.ExtractDynamicTable() if err != nil { - return err + return diag.FromErr(err) } if err := d.Set("query", query); err != nil { - return err + return diag.FromErr(err) } return nil @@ -309,7 +312,7 @@ func parseTargetLag(v interface{}) sdk.TargetLag { } // CreateDynamicTable implements schema.CreateFunc. -func CreateDynamicTable(d *schema.ResourceData, meta interface{}) error { +func CreateDynamicTable(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client databaseName := d.Get("database").(string) @@ -334,18 +337,17 @@ func CreateDynamicTable(d *schema.ResourceData, meta interface{}) error { if v, ok := d.GetOk("initialize"); ok { request.WithInitialize(sdk.DynamicTableInitialize(v.(string))) } - if err := client.DynamicTables.Create(context.Background(), request); err != nil { - return err + if err := client.DynamicTables.Create(ctx, request); err != nil { + return diag.FromErr(err) } d.SetId(helpers.EncodeSnowflakeID(id)) - return ReadDynamicTable(d, meta) + return ReadDynamicTable(ctx, d, meta) } // UpdateDynamicTable implements schema.UpdateFunc. -func UpdateDynamicTable(d *schema.ResourceData, meta interface{}) error { +func UpdateDynamicTable(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() id := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) request := sdk.NewAlterDynamicTableRequest(id) @@ -366,7 +368,7 @@ func UpdateDynamicTable(d *schema.ResourceData, meta interface{}) error { if runSet { request.WithSet(set) if err := client.DynamicTables.Alter(ctx, request); err != nil { - return err + return diag.FromErr(err) } } @@ -377,19 +379,19 @@ func UpdateDynamicTable(d *schema.ResourceData, meta interface{}) error { Value: sdk.String(d.Get("comment").(string)), }) if err != nil { - return err + return diag.FromErr(err) } } - return ReadDynamicTable(d, meta) + return ReadDynamicTable(ctx, d, meta) } // DeleteDynamicTable implements schema.DeleteFunc. -func DeleteDynamicTable(d *schema.ResourceData, meta interface{}) error { +func DeleteDynamicTable(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client id := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) - if err := client.DynamicTables.Drop(context.Background(), sdk.NewDropDynamicTableRequest(id)); err != nil { - return err + if err := client.DynamicTables.Drop(ctx, sdk.NewDropDynamicTableRequest(id)); err != nil { + return diag.FromErr(err) } d.SetId("") diff --git a/pkg/resources/email_notification_integration.go b/pkg/resources/email_notification_integration.go index 586adfc72d..bd0a38c64e 100644 --- a/pkg/resources/email_notification_integration.go +++ b/pkg/resources/email_notification_integration.go @@ -6,6 +6,9 @@ import ( "log" "strings" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/schemas" @@ -42,10 +45,10 @@ var emailNotificationIntegrationSchema = map[string]*schema.Schema{ // EmailNotificationIntegration returns a pointer to the resource representing a notification integration. func EmailNotificationIntegration() *schema.Resource { return &schema.Resource{ - Create: CreateEmailNotificationIntegration, - Read: ReadEmailNotificationIntegration, - Update: UpdateEmailNotificationIntegration, - Delete: DeleteEmailNotificationIntegration, + CreateContext: TrackingCreateWrapper(resources.EmailNotificationIntegration, CreateEmailNotificationIntegration), + ReadContext: TrackingReadWrapper(resources.EmailNotificationIntegration, ReadEmailNotificationIntegration), + UpdateContext: TrackingUpdateWrapper(resources.EmailNotificationIntegration, UpdateEmailNotificationIntegration), + DeleteContext: TrackingDeleteWrapper(resources.EmailNotificationIntegration, DeleteEmailNotificationIntegration), Schema: emailNotificationIntegrationSchema, Importer: &schema.ResourceImporter{ @@ -63,9 +66,8 @@ func toAllowedRecipients(emails []string) []sdk.NotificationIntegrationAllowedRe } // CreateEmailNotificationIntegration implements schema.CreateFunc. -func CreateEmailNotificationIntegration(d *schema.ResourceData, meta interface{}) error { +func CreateEmailNotificationIntegration(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() name := d.Get("name").(string) id := sdk.NewAccountObjectIdentifier(name) @@ -85,47 +87,46 @@ func CreateEmailNotificationIntegration(d *schema.ResourceData, meta interface{} err := client.NotificationIntegrations.Create(ctx, createRequest) if err != nil { - return fmt.Errorf("error creating notification integration: %w", err) + return diag.FromErr(fmt.Errorf("error creating notification integration: %w", err)) } d.SetId(helpers.EncodeSnowflakeID(id)) - return ReadEmailNotificationIntegration(d, meta) + return ReadEmailNotificationIntegration(ctx, d, meta) } // ReadEmailNotificationIntegration implements schema.ReadFunc. -func ReadEmailNotificationIntegration(d *schema.ResourceData, meta interface{}) error { +func ReadEmailNotificationIntegration(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() id := helpers.DecodeSnowflakeID(d.Id()).(sdk.AccountObjectIdentifier) integration, err := client.NotificationIntegrations.ShowByID(ctx, id) if err != nil { log.Printf("[DEBUG] notification integration (%s) not found", d.Id()) d.SetId("") - return err + return diag.FromErr(err) } if err := d.Set(FullyQualifiedNameAttributeName, id.FullyQualifiedName()); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("name", integration.Name); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("enabled", integration.Enabled); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("comment", integration.Comment); err != nil { - return err + return diag.FromErr(err) } // Some properties come from the DESCRIBE INTEGRATION call integrationProperties, err := client.NotificationIntegrations.Describe(ctx, id) if err != nil { - return fmt.Errorf("could not describe notification integration: %w", err) + return diag.FromErr(fmt.Errorf("could not describe notification integration: %w", err)) } for _, property := range integrationProperties { name := property.Name @@ -135,11 +136,11 @@ func ReadEmailNotificationIntegration(d *schema.ResourceData, meta interface{}) case "ALLOWED_RECIPIENTS": if value == "" { if err := d.Set("allowed_recipients", make([]string, 0)); err != nil { - return err + return diag.FromErr(err) } } else { if err := d.Set("allowed_recipients", strings.Split(value, ",")); err != nil { - return err + return diag.FromErr(err) } } default: @@ -147,13 +148,12 @@ func ReadEmailNotificationIntegration(d *schema.ResourceData, meta interface{}) } } - return err + return diag.FromErr(err) } // UpdateEmailNotificationIntegration implements schema.UpdateFunc. -func UpdateEmailNotificationIntegration(d *schema.ResourceData, meta interface{}) error { +func UpdateEmailNotificationIntegration(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() id := helpers.DecodeSnowflakeID(d.Id()).(sdk.AccountObjectIdentifier) var runSetStatement bool @@ -190,29 +190,28 @@ func UpdateEmailNotificationIntegration(d *schema.ResourceData, meta interface{} if runSetStatement { err := client.NotificationIntegrations.Alter(ctx, sdk.NewAlterNotificationIntegrationRequest(id).WithSet(setRequest)) if err != nil { - return fmt.Errorf("error updating notification integration: %w", err) + return diag.FromErr(fmt.Errorf("error updating notification integration: %w", err)) } } if runUnsetStatement { err := client.NotificationIntegrations.Alter(ctx, sdk.NewAlterNotificationIntegrationRequest(id).WithUnsetEmailParams(unsetRequest)) if err != nil { - return fmt.Errorf("error updating notification integration: %w", err) + return diag.FromErr(fmt.Errorf("error updating notification integration: %w", err)) } } - return ReadEmailNotificationIntegration(d, meta) + return ReadEmailNotificationIntegration(ctx, d, meta) } // DeleteEmailNotificationIntegration implements schema.DeleteFunc. -func DeleteEmailNotificationIntegration(d *schema.ResourceData, meta interface{}) error { +func DeleteEmailNotificationIntegration(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() id := helpers.DecodeSnowflakeID(d.Id()).(sdk.AccountObjectIdentifier) err := client.NotificationIntegrations.Drop(ctx, sdk.NewDropNotificationIntegrationRequest(id)) if err != nil { - return err + return diag.FromErr(err) } d.SetId("") diff --git a/pkg/resources/external_table.go b/pkg/resources/external_table.go index 94cd921cf1..833b8c4dd8 100644 --- a/pkg/resources/external_table.go +++ b/pkg/resources/external_table.go @@ -5,6 +5,9 @@ import ( "fmt" "log" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/schemas" @@ -140,10 +143,10 @@ var externalTableSchema = map[string]*schema.Schema{ func ExternalTable() *schema.Resource { return &schema.Resource{ - Create: CreateExternalTable, - Read: ReadExternalTable, - Update: UpdateExternalTable, - Delete: DeleteExternalTable, + CreateContext: TrackingCreateWrapper(resources.ExternalTable, CreateExternalTable), + ReadContext: TrackingReadWrapper(resources.ExternalTable, ReadExternalTable), + UpdateContext: TrackingUpdateWrapper(resources.ExternalTable, UpdateExternalTable), + DeleteContext: TrackingDeleteWrapper(resources.ExternalTable, DeleteExternalTable), Schema: externalTableSchema, Importer: &schema.ResourceImporter{ @@ -153,9 +156,8 @@ func ExternalTable() *schema.Resource { } // CreateExternalTable implements schema.CreateFunc. -func CreateExternalTable(d *schema.ResourceData, meta any) error { +func CreateExternalTable(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() database := d.Get("database").(string) schema := d.Get("schema").(string) @@ -214,7 +216,7 @@ func CreateExternalTable(d *schema.ResourceData, meta any) error { } err := client.ExternalTables.CreateDeltaLake(ctx, req) if err != nil { - return err + return diag.FromErr(err) } default: req := sdk.NewCreateExternalTableRequest(id, location). @@ -236,47 +238,45 @@ func CreateExternalTable(d *schema.ResourceData, meta any) error { } err := client.ExternalTables.Create(ctx, req) if err != nil { - return err + return diag.FromErr(err) } } d.SetId(helpers.EncodeSnowflakeID(id)) - return ReadExternalTable(d, meta) + return ReadExternalTable(ctx, d, meta) } // ReadExternalTable implements schema.ReadFunc. -func ReadExternalTable(d *schema.ResourceData, meta any) error { +func ReadExternalTable(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() id := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) externalTable, err := client.ExternalTables.ShowByID(ctx, id) if err != nil { log.Printf("[DEBUG] external table (%s) not found", d.Id()) d.SetId("") - return err + return diag.FromErr(err) } if err := d.Set(FullyQualifiedNameAttributeName, id.FullyQualifiedName()); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("name", externalTable.Name); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("owner", externalTable.Owner); err != nil { - return err + return diag.FromErr(err) } return nil } // UpdateExternalTable implements schema.UpdateFunc. -func UpdateExternalTable(d *schema.ResourceData, meta any) error { +func UpdateExternalTable(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() id := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) if d.HasChange("tag") { @@ -285,7 +285,7 @@ func UpdateExternalTable(d *schema.ResourceData, meta any) error { if len(unsetTags) > 0 { err := client.ExternalTables.Alter(ctx, sdk.NewAlterExternalTableRequest(id).WithUnsetTag(unsetTags)) if err != nil { - return fmt.Errorf("error setting tags on %v, err = %w", d.Id(), err) + return diag.FromErr(fmt.Errorf("error setting tags on %v, err = %w", d.Id(), err)) } } @@ -296,23 +296,22 @@ func UpdateExternalTable(d *schema.ResourceData, meta any) error { } err := client.ExternalTables.Alter(ctx, sdk.NewAlterExternalTableRequest(id).WithSetTag(tagAssociationRequests)) if err != nil { - return fmt.Errorf("error setting tags on %v, err = %w", d.Id(), err) + return diag.FromErr(fmt.Errorf("error setting tags on %v, err = %w", d.Id(), err)) } } } - return ReadExternalTable(d, meta) + return ReadExternalTable(ctx, d, meta) } // DeleteExternalTable implements schema.DeleteFunc. -func DeleteExternalTable(d *schema.ResourceData, meta any) error { +func DeleteExternalTable(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() id := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) err := client.ExternalTables.Drop(ctx, sdk.NewDropExternalTableRequest(id)) if err != nil { - return err + return diag.FromErr(err) } d.SetId("") diff --git a/pkg/resources/failover_group.go b/pkg/resources/failover_group.go index e873436b58..4a8ab0e3fe 100644 --- a/pkg/resources/failover_group.go +++ b/pkg/resources/failover_group.go @@ -8,6 +8,9 @@ import ( "strconv" "strings" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/schemas" @@ -138,10 +141,10 @@ var failoverGroupSchema = map[string]*schema.Schema{ // FailoverGroup returns a pointer to the resource representing a failover group. func FailoverGroup() *schema.Resource { return &schema.Resource{ - Create: CreateFailoverGroup, - Read: ReadFailoverGroup, - Update: UpdateFailoverGroup, - Delete: DeleteFailoverGroup, + CreateContext: TrackingCreateWrapper(resources.FailoverGroup, CreateFailoverGroup), + ReadContext: TrackingReadWrapper(resources.FailoverGroup, ReadFailoverGroup), + UpdateContext: TrackingUpdateWrapper(resources.FailoverGroup, UpdateFailoverGroup), + DeleteContext: TrackingDeleteWrapper(resources.FailoverGroup, DeleteFailoverGroup), Schema: failoverGroupSchema, Importer: &schema.ResourceImporter{ @@ -151,9 +154,8 @@ func FailoverGroup() *schema.Resource { } // CreateFailoverGroup implements schema.CreateFunc. -func CreateFailoverGroup(d *schema.ResourceData, meta interface{}) error { +func CreateFailoverGroup(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() // getting required attributes name := d.Get("name").(string) id := sdk.NewAccountObjectIdentifier(name) @@ -168,15 +170,15 @@ func CreateFailoverGroup(d *schema.ResourceData, meta interface{}) error { primaryFailoverGroupID := sdk.NewExternalObjectIdentifier(sdk.NewAccountIdentifier(organizationName, sourceAccountName), sdk.NewAccountObjectIdentifier(sourceFailoverGroupName)) err := client.FailoverGroups.CreateSecondaryReplicationGroup(ctx, id, primaryFailoverGroupID, nil) if err != nil { - return err + return diag.FromErr(err) } d.SetId(name) - return ReadFailoverGroup(d, meta) + return ReadFailoverGroup(ctx, d, meta) } // these two are required attributes if from_replica is not set if _, ok := d.GetOk("object_types"); !ok { - return errors.New("object_types is required when not creating from a replica") + return diag.FromErr(errors.New("object_types is required when not creating from a replica")) } objectTypesList := expandStringList(d.Get("object_types").(*schema.Set).List()) objectTypes := make([]sdk.PluralObjectType, len(objectTypesList)) @@ -185,7 +187,7 @@ func CreateFailoverGroup(d *schema.ResourceData, meta interface{}) error { } if _, ok := d.GetOk("allowed_accounts"); !ok { - return errors.New("allowed_accounts is required when not creating from a replica") + return diag.FromErr(errors.New("allowed_accounts is required when not creating from a replica")) } aaList := expandStringList(d.Get("allowed_accounts").(*schema.Set).List()) allowedAccounts := make([]sdk.AccountIdentifier, len(aaList)) @@ -193,7 +195,7 @@ func CreateFailoverGroup(d *schema.ResourceData, meta interface{}) error { // validation since we cannot do that in the ValidateFunc parts := strings.Split(v, ".") if len(parts) != 2 { - return fmt.Errorf("allowed_account %s cannot be an account locator and must be of the format .", allowedAccounts[i]) + return diag.FromErr(fmt.Errorf("allowed_account %s cannot be an account locator and must be of the format .", allowedAccounts[i])) } organizationName := parts[0] accountName := parts[1] @@ -257,29 +259,28 @@ func CreateFailoverGroup(d *schema.ResourceData, meta interface{}) error { err := client.FailoverGroups.Create(ctx, id, objectTypes, allowedAccounts, &opts) if err != nil { - return err + return diag.FromErr(err) } d.SetId(name) - return ReadFailoverGroup(d, meta) + return ReadFailoverGroup(ctx, d, meta) } // ReadFailoverGroup implements schema.ReadFunc. -func ReadFailoverGroup(d *schema.ResourceData, meta interface{}) error { +func ReadFailoverGroup(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() id := helpers.DecodeSnowflakeID(d.Id()).(sdk.AccountObjectIdentifier) failoverGroup, err := client.FailoverGroups.ShowByID(ctx, id) if err != nil { - return err + return diag.FromErr(err) } if err := d.Set(FullyQualifiedNameAttributeName, id.FullyQualifiedName()); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("name", failoverGroup.Name); err != nil { - return err + return diag.FromErr(err) } // if the failover group is created from a replica, then we do not want to get the other values if _, ok := d.GetOk("from_replica"); ok { @@ -291,7 +292,7 @@ func ReadFailoverGroup(d *schema.ResourceData, meta interface{}) error { if strings.Contains(replicationSchedule, "MINUTE") { interval, err := strconv.Atoi(strings.TrimSuffix(replicationSchedule, " MINUTE")) if err != nil { - return err + return diag.FromErr(err) } err = d.Set("replication_schedule", []interface{}{ map[string]interface{}{ @@ -299,7 +300,7 @@ func ReadFailoverGroup(d *schema.ResourceData, meta interface{}) error { }, }) if err != nil { - return err + return diag.FromErr(err) } } else { repScheduleParts := strings.Split(replicationSchedule, " ") @@ -316,7 +317,7 @@ func ReadFailoverGroup(d *schema.ResourceData, meta interface{}) error { }, }) if err != nil { - return err + return diag.FromErr(err) } } } @@ -327,7 +328,7 @@ func ReadFailoverGroup(d *schema.ResourceData, meta interface{}) error { } objectTypesSet := schema.NewSet(schema.HashString, objectTypes) if err := d.Set("object_types", objectTypesSet); err != nil { - return err + return diag.FromErr(err) } // integration types @@ -338,7 +339,7 @@ func ReadFailoverGroup(d *schema.ResourceData, meta interface{}) error { allowedIntegrationsTypesSet := schema.NewSet(schema.HashString, allowedIntegrationTypes) if err := d.Set("allowed_integration_types", allowedIntegrationsTypesSet); err != nil { - return err + return diag.FromErr(err) } // allowed accounts @@ -348,13 +349,13 @@ func ReadFailoverGroup(d *schema.ResourceData, meta interface{}) error { } allowedAccountsSet := schema.NewSet(schema.HashString, allowedAccounts) if err := d.Set("allowed_accounts", allowedAccountsSet); err != nil { - return err + return diag.FromErr(err) } // allowed databases databases, err := client.FailoverGroups.ShowDatabases(ctx, id) if err != nil { - return err + return diag.FromErr(err) } allowedDatabases := make([]interface{}, len(databases)) for i, database := range databases { @@ -363,18 +364,18 @@ func ReadFailoverGroup(d *schema.ResourceData, meta interface{}) error { allowedDatabasesSet := schema.NewSet(schema.HashString, allowedDatabases) if len(allowedDatabases) > 0 { if err := d.Set("allowed_databases", allowedDatabasesSet); err != nil { - return err + return diag.FromErr(err) } } else { if err := d.Set("allowed_databases", nil); err != nil { - return err + return diag.FromErr(err) } } // allowed shares shares, err := client.FailoverGroups.ShowShares(ctx, id) if err != nil { - return err + return diag.FromErr(err) } allowedShares := make([]interface{}, len(shares)) for i, share := range shares { @@ -383,11 +384,11 @@ func ReadFailoverGroup(d *schema.ResourceData, meta interface{}) error { allowedSharesSet := schema.NewSet(schema.HashString, allowedShares) if len(allowedShares) > 0 { if err := d.Set("allowed_shares", allowedSharesSet); err != nil { - return err + return diag.FromErr(err) } } else { if err := d.Set("allowed_shares", nil); err != nil { - return err + return diag.FromErr(err) } } @@ -395,9 +396,8 @@ func ReadFailoverGroup(d *schema.ResourceData, meta interface{}) error { } // UpdateFailoverGroup implements schema.UpdateFunc. -func UpdateFailoverGroup(d *schema.ResourceData, meta interface{}) error { +func UpdateFailoverGroup(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() id := helpers.DecodeSnowflakeID(d.Id()).(sdk.AccountObjectIdentifier) // alter failover group set ... @@ -437,7 +437,7 @@ func UpdateFailoverGroup(d *schema.ResourceData, meta interface{}) error { } if runSet { if err := client.FailoverGroups.AlterSource(ctx, id, opts); err != nil { - return err + return diag.FromErr(err) } } @@ -467,7 +467,7 @@ func UpdateFailoverGroup(d *schema.ResourceData, meta interface{}) error { }, }) if err != nil { - return err + return diag.FromErr(err) } } else { err := client.FailoverGroups.AlterSource(ctx, id, &sdk.AlterSourceFailoverGroupOptions{ @@ -476,7 +476,7 @@ func UpdateFailoverGroup(d *schema.ResourceData, meta interface{}) error { }, }) if err != nil { - return err + return diag.FromErr(err) } } } @@ -507,7 +507,7 @@ func UpdateFailoverGroup(d *schema.ResourceData, meta interface{}) error { }, } if err := client.FailoverGroups.AlterSource(ctx, id, opts); err != nil { - return fmt.Errorf("error removing allowed databases for failover group %v err = %w", id.Name(), err) + return diag.FromErr(fmt.Errorf("error removing allowed databases for failover group %v err = %w", id.Name(), err)) } } @@ -525,7 +525,7 @@ func UpdateFailoverGroup(d *schema.ResourceData, meta interface{}) error { }, } if err := client.FailoverGroups.AlterSource(ctx, id, opts); err != nil { - return fmt.Errorf("error removing allowed databases for failover group %v err = %w", id.Name(), err) + return diag.FromErr(fmt.Errorf("error removing allowed databases for failover group %v err = %w", id.Name(), err)) } } } @@ -556,7 +556,7 @@ func UpdateFailoverGroup(d *schema.ResourceData, meta interface{}) error { }, } if err := client.FailoverGroups.AlterSource(ctx, id, opts); err != nil { - return fmt.Errorf("error removing allowed shares for failover group %v err = %w", id.Name(), err) + return diag.FromErr(fmt.Errorf("error removing allowed shares for failover group %v err = %w", id.Name(), err)) } } @@ -574,7 +574,7 @@ func UpdateFailoverGroup(d *schema.ResourceData, meta interface{}) error { }, } if err := client.FailoverGroups.AlterSource(ctx, id, opts); err != nil { - return fmt.Errorf("error removing allowed shares for failover group %v err = %w", id.Name(), err) + return diag.FromErr(fmt.Errorf("error removing allowed shares for failover group %v err = %w", id.Name(), err)) } } } @@ -613,7 +613,7 @@ func UpdateFailoverGroup(d *schema.ResourceData, meta interface{}) error { }, } if err := client.FailoverGroups.AlterSource(ctx, id, opts); err != nil { - return fmt.Errorf("error removing allowed accounts for failover group %v err = %w", id.Name(), err) + return diag.FromErr(fmt.Errorf("error removing allowed accounts for failover group %v err = %w", id.Name(), err)) } } @@ -631,22 +631,21 @@ func UpdateFailoverGroup(d *schema.ResourceData, meta interface{}) error { }, } if err := client.FailoverGroups.AlterSource(ctx, id, opts); err != nil { - return fmt.Errorf("error removing allowed accounts for failover group %v err = %w", id.Name(), err) + return diag.FromErr(fmt.Errorf("error removing allowed accounts for failover group %v err = %w", id.Name(), err)) } } } - return ReadFailoverGroup(d, meta) + return ReadFailoverGroup(ctx, d, meta) } // DeleteFailoverGroup implements schema.DeleteFunc. -func DeleteFailoverGroup(d *schema.ResourceData, meta interface{}) error { +func DeleteFailoverGroup(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client id := helpers.DecodeSnowflakeID(d.Id()).(sdk.AccountObjectIdentifier) - ctx := context.Background() err := client.FailoverGroups.Drop(ctx, id, &sdk.DropFailoverGroupOptions{IfExists: sdk.Bool(true)}) if err != nil { - return fmt.Errorf("error deleting failover group %v err = %w", id.Name(), err) + return diag.FromErr(fmt.Errorf("error deleting failover group %v err = %w", id.Name(), err)) } d.SetId("") diff --git a/pkg/resources/file_format.go b/pkg/resources/file_format.go index 561212e487..8ff7cec53f 100644 --- a/pkg/resources/file_format.go +++ b/pkg/resources/file_format.go @@ -7,6 +7,8 @@ import ( "fmt" "strings" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" @@ -315,10 +317,10 @@ func (ffi *fileFormatID) String() (string, error) { // FileFormat returns a pointer to the resource representing a file format. func FileFormat() *schema.Resource { return &schema.Resource{ - Create: CreateFileFormat, - Read: ReadFileFormat, - Update: UpdateFileFormat, - Delete: DeleteFileFormat, + CreateContext: TrackingCreateWrapper(resources.FileFormat, CreateFileFormat), + ReadContext: TrackingReadWrapper(resources.FileFormat, ReadFileFormat), + UpdateContext: TrackingUpdateWrapper(resources.FileFormat, UpdateFileFormat), + DeleteContext: TrackingDeleteWrapper(resources.FileFormat, DeleteFileFormat), CustomizeDiff: TrackingCustomDiffWrapper(resources.FileFormat, customdiff.All( ComputedIfAnyAttributeChanged(fileFormatSchema, FullyQualifiedNameAttributeName, "name"), @@ -332,9 +334,8 @@ func FileFormat() *schema.Resource { } // CreateFileFormat implements schema.CreateFunc. -func CreateFileFormat(d *schema.ResourceData, meta interface{}) error { +func CreateFileFormat(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() dbName := d.Get("database").(string) schemaName := d.Get("schema").(string) @@ -520,7 +521,7 @@ func CreateFileFormat(d *schema.ResourceData, meta interface{}) error { err := client.FileFormats.Create(ctx, id, &opts) if err != nil { - return err + return diag.FromErr(err) } fileFormatID := &fileFormatID{ @@ -530,252 +531,250 @@ func CreateFileFormat(d *schema.ResourceData, meta interface{}) error { } dataIDInput, err := fileFormatID.String() if err != nil { - return err + return diag.FromErr(err) } d.SetId(dataIDInput) - return ReadFileFormat(d, meta) + return ReadFileFormat(ctx, d, meta) } // ReadFileFormat implements schema.ReadFunc. -func ReadFileFormat(d *schema.ResourceData, meta interface{}) error { +func ReadFileFormat(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() fileFormatID, err := fileFormatIDFromString(d.Id()) if err != nil { - return err + return diag.FromErr(err) } id := sdk.NewSchemaObjectIdentifier(fileFormatID.DatabaseName, fileFormatID.SchemaName, fileFormatID.FileFormatName) fileFormat, err := client.FileFormats.ShowByID(ctx, id) if err != nil { - return fmt.Errorf("cannot read file format: %w", err) + return diag.FromErr(fmt.Errorf("cannot read file format: %w", err)) } if err := d.Set(FullyQualifiedNameAttributeName, id.FullyQualifiedName()); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("name", fileFormat.Name.Name()); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("database", fileFormat.Name.DatabaseName()); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("schema", fileFormat.Name.SchemaName()); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("format_type", fileFormat.Type); err != nil { - return err + return diag.FromErr(err) } switch fileFormat.Type { case sdk.FileFormatTypeCSV: if err := d.Set("compression", fileFormat.Options.CSVCompression); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("record_delimiter", fileFormat.Options.CSVRecordDelimiter); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("field_delimiter", fileFormat.Options.CSVFieldDelimiter); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("file_extension", fileFormat.Options.CSVFileExtension); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("parse_header", fileFormat.Options.CSVParseHeader); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("skip_header", fileFormat.Options.CSVSkipHeader); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("skip_blank_lines", fileFormat.Options.CSVSkipBlankLines); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("date_format", fileFormat.Options.CSVDateFormat); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("time_format", fileFormat.Options.CSVTimeFormat); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("timestamp_format", fileFormat.Options.CSVTimestampFormat); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("binary_format", fileFormat.Options.CSVBinaryFormat); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("escape", fileFormat.Options.CSVEscape); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("escape_unenclosed_field", fileFormat.Options.CSVEscapeUnenclosedField); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("trim_space", fileFormat.Options.CSVTrimSpace); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("field_optionally_enclosed_by", fileFormat.Options.CSVFieldOptionallyEnclosedBy); err != nil { - return err + return diag.FromErr(err) } nullIf := []string{} for _, s := range *fileFormat.Options.CSVNullIf { nullIf = append(nullIf, s.S) } if err := d.Set("null_if", nullIf); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("error_on_column_count_mismatch", fileFormat.Options.CSVErrorOnColumnCountMismatch); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("replace_invalid_characters", fileFormat.Options.CSVReplaceInvalidCharacters); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("empty_field_as_null", fileFormat.Options.CSVEmptyFieldAsNull); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("skip_byte_order_mark", fileFormat.Options.CSVSkipByteOrderMark); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("encoding", fileFormat.Options.CSVEncoding); err != nil { - return err + return diag.FromErr(err) } case sdk.FileFormatTypeJSON: if err := d.Set("compression", fileFormat.Options.JSONCompression); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("date_format", fileFormat.Options.JSONDateFormat); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("time_format", fileFormat.Options.JSONTimeFormat); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("timestamp_format", fileFormat.Options.JSONTimestampFormat); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("binary_format", fileFormat.Options.JSONBinaryFormat); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("trim_space", fileFormat.Options.JSONTrimSpace); err != nil { - return err + return diag.FromErr(err) } nullIf := []string{} for _, s := range fileFormat.Options.JSONNullIf { nullIf = append(nullIf, s.S) } if err := d.Set("null_if", nullIf); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("file_extension", fileFormat.Options.JSONFileExtension); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("enable_octal", fileFormat.Options.JSONEnableOctal); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("allow_duplicate", fileFormat.Options.JSONAllowDuplicate); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("strip_outer_array", fileFormat.Options.JSONStripOuterArray); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("strip_null_values", fileFormat.Options.JSONStripNullValues); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("replace_invalid_characters", fileFormat.Options.JSONReplaceInvalidCharacters); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("ignore_utf8_errors", fileFormat.Options.JSONIgnoreUTF8Errors); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("skip_byte_order_mark", fileFormat.Options.JSONSkipByteOrderMark); err != nil { - return err + return diag.FromErr(err) } case sdk.FileFormatTypeAvro: if err := d.Set("compression", fileFormat.Options.AvroCompression); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("trim_space", fileFormat.Options.AvroTrimSpace); err != nil { - return err + return diag.FromErr(err) } nullIf := []string{} for _, s := range *fileFormat.Options.AvroNullIf { nullIf = append(nullIf, s.S) } if err := d.Set("null_if", nullIf); err != nil { - return err + return diag.FromErr(err) } case sdk.FileFormatTypeORC: if err := d.Set("trim_space", fileFormat.Options.ORCTrimSpace); err != nil { - return err + return diag.FromErr(err) } nullIf := []string{} for _, s := range *fileFormat.Options.ORCNullIf { nullIf = append(nullIf, s.S) } if err := d.Set("null_if", nullIf); err != nil { - return err + return diag.FromErr(err) } case sdk.FileFormatTypeParquet: if err := d.Set("compression", fileFormat.Options.ParquetCompression); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("binary_as_text", fileFormat.Options.ParquetBinaryAsText); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("trim_space", fileFormat.Options.ParquetTrimSpace); err != nil { - return err + return diag.FromErr(err) } nullIf := []string{} for _, s := range *fileFormat.Options.ParquetNullIf { nullIf = append(nullIf, s.S) } if err := d.Set("null_if", nullIf); err != nil { - return err + return diag.FromErr(err) } case sdk.FileFormatTypeXML: if err := d.Set("compression", fileFormat.Options.XMLCompression); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("ignore_utf8_errors", fileFormat.Options.XMLIgnoreUTF8Errors); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("preserve_space", fileFormat.Options.XMLPreserveSpace); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("strip_outer_element", fileFormat.Options.XMLStripOuterElement); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("disable_snowflake_data", fileFormat.Options.XMLDisableSnowflakeData); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("disable_auto_convert", fileFormat.Options.XMLDisableAutoConvert); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("skip_byte_order_mark", fileFormat.Options.XMLSkipByteOrderMark); err != nil { - return err + return diag.FromErr(err) } // Terraform doesn't like it when computed fields aren't set. if err := d.Set("null_if", []string{}); err != nil { - return err + return diag.FromErr(err) } } if err := d.Set("comment", fileFormat.Comment); err != nil { - return err + return diag.FromErr(err) } return nil } // UpdateFileFormat implements schema.UpdateFunc. -func UpdateFileFormat(d *schema.ResourceData, meta interface{}) error { +func UpdateFileFormat(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() fileFormatID, err := fileFormatIDFromString(d.Id()) if err != nil { - return err + return diag.FromErr(err) } id := sdk.NewSchemaObjectIdentifier(fileFormatID.DatabaseName, fileFormatID.SchemaName, fileFormatID.FileFormatName) @@ -788,7 +787,7 @@ func UpdateFileFormat(d *schema.ResourceData, meta interface{}) error { }, }) if err != nil { - return fmt.Errorf("error renaming file format: %w", err) + return diag.FromErr(fmt.Errorf("error renaming file format: %w", err)) } d.SetId(helpers.EncodeSnowflakeID(newId)) @@ -1116,27 +1115,26 @@ func UpdateFileFormat(d *schema.ResourceData, meta interface{}) error { if runSet { err = client.FileFormats.Alter(ctx, id, &opts) if err != nil { - return err + return diag.FromErr(err) } } - return ReadFileFormat(d, meta) + return ReadFileFormat(ctx, d, meta) } // DeleteFileFormat implements schema.DeleteFunc. -func DeleteFileFormat(d *schema.ResourceData, meta interface{}) error { +func DeleteFileFormat(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() fileFormatID, err := fileFormatIDFromString(d.Id()) if err != nil { - return err + return diag.FromErr(err) } id := sdk.NewSchemaObjectIdentifier(fileFormatID.DatabaseName, fileFormatID.SchemaName, fileFormatID.FileFormatName) err = client.FileFormats.Drop(ctx, id, nil) if err != nil { - return fmt.Errorf("error while deleting file format: %w", err) + return diag.FromErr(fmt.Errorf("error while deleting file format: %w", err)) } d.SetId("") diff --git a/pkg/resources/grant_account_role.go b/pkg/resources/grant_account_role.go index 78e6b35f2c..ce02dc6f9c 100644 --- a/pkg/resources/grant_account_role.go +++ b/pkg/resources/grant_account_role.go @@ -6,6 +6,8 @@ import ( "log" "strings" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" @@ -49,10 +51,10 @@ var grantAccountRoleSchema = map[string]*schema.Schema{ func GrantAccountRole() *schema.Resource { return &schema.Resource{ - Create: CreateGrantAccountRole, - Read: ReadGrantAccountRole, - Delete: DeleteGrantAccountRole, - Schema: grantAccountRoleSchema, + CreateContext: TrackingCreateWrapper(resources.GrantAccountRole, CreateGrantAccountRole), + ReadContext: TrackingReadWrapper(resources.GrantAccountRole, ReadGrantAccountRole), + DeleteContext: TrackingDeleteWrapper(resources.GrantAccountRole, DeleteGrantAccountRole), + Schema: grantAccountRoleSchema, Importer: &schema.ResourceImporter{ StateContext: TrackingImportWrapper(resources.GrantAccountRole, func(ctx context.Context, d *schema.ResourceData, m interface{}) ([]*schema.ResourceData, error) { parts := strings.Split(d.Id(), helpers.IDDelimiter) @@ -82,9 +84,8 @@ func GrantAccountRole() *schema.Resource { } // CreateGrantAccountRole implements schema.CreateFunc. -func CreateGrantAccountRole(d *schema.ResourceData, meta interface{}) error { +func CreateGrantAccountRole(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() roleName := d.Get("role_name").(string) roleIdentifier := sdk.NewAccountObjectIdentifierFromFullyQualifiedName(roleName) // format of snowflakeResourceID is || @@ -96,7 +97,7 @@ func CreateGrantAccountRole(d *schema.ResourceData, meta interface{}) error { Role: &parentRoleIdentifier, }) if err := client.Roles.Grant(ctx, req); err != nil { - return err + return diag.FromErr(err) } } else if userName, ok := d.GetOk("user_name"); ok && userName.(string) != "" { userIdentifier := sdk.NewAccountObjectIdentifierFromFullyQualifiedName(userName.(string)) @@ -105,26 +106,25 @@ func CreateGrantAccountRole(d *schema.ResourceData, meta interface{}) error { User: &userIdentifier, }) if err := client.Roles.Grant(ctx, req); err != nil { - return err + return diag.FromErr(err) } } else { - return fmt.Errorf("invalid role grant specified: %v", d) + return diag.FromErr(fmt.Errorf("invalid role grant specified: %v", d)) } d.SetId(snowflakeResourceID) - return ReadGrantAccountRole(d, meta) + return ReadGrantAccountRole(ctx, d, meta) } -func ReadGrantAccountRole(d *schema.ResourceData, meta interface{}) error { +func ReadGrantAccountRole(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client parts := strings.Split(d.Id(), helpers.IDDelimiter) if len(parts) != 3 { - return fmt.Errorf("invalid ID specified: %v, expected ||", d.Id()) + return diag.FromErr(fmt.Errorf("invalid ID specified: %v, expected ||", d.Id())) } roleName := parts[0] roleIdentifier := sdk.NewAccountObjectIdentifierFromFullyQualifiedName(roleName) objectType := parts[1] targetIdentifier := parts[2] - ctx := context.Background() grants, err := client.Grants.Show(ctx, &sdk.ShowGrantOptions{ Of: &sdk.ShowGrantsOf{ Role: roleIdentifier, @@ -153,28 +153,27 @@ func ReadGrantAccountRole(d *schema.ResourceData, meta interface{}) error { return nil } -func DeleteGrantAccountRole(d *schema.ResourceData, meta interface{}) error { +func DeleteGrantAccountRole(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client parts := strings.Split(d.Id(), helpers.IDDelimiter) if len(parts) != 3 { - return fmt.Errorf("invalid ID specified: %v, expected ||", d.Id()) + return diag.FromErr(fmt.Errorf("invalid ID specified: %v, expected ||", d.Id())) } id := sdk.NewAccountObjectIdentifierFromFullyQualifiedName(parts[0]) objectType := parts[1] granteeName := parts[2] - ctx := context.Background() granteeIdentifier := sdk.NewAccountObjectIdentifierFromFullyQualifiedName(granteeName) switch objectType { case "ROLE": if err := client.Roles.Revoke(ctx, sdk.NewRevokeRoleRequest(id, sdk.RevokeRole{Role: &granteeIdentifier})); err != nil { - return err + return diag.FromErr(err) } case "USER": if err := client.Roles.Revoke(ctx, sdk.NewRevokeRoleRequest(id, sdk.RevokeRole{User: &granteeIdentifier})); err != nil { - return err + return diag.FromErr(err) } default: - return fmt.Errorf("invalid object type specified: %v, expected ROLE or USER", objectType) + return diag.FromErr(fmt.Errorf("invalid object type specified: %v, expected ROLE or USER", objectType)) } d.SetId("") return nil diff --git a/pkg/resources/grant_database_role.go b/pkg/resources/grant_database_role.go index 1c845200af..e67946fc5d 100644 --- a/pkg/resources/grant_database_role.go +++ b/pkg/resources/grant_database_role.go @@ -5,6 +5,8 @@ import ( "fmt" "log" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" @@ -66,10 +68,10 @@ var grantDatabaseRoleSchema = map[string]*schema.Schema{ func GrantDatabaseRole() *schema.Resource { return &schema.Resource{ - Create: CreateGrantDatabaseRole, - Read: ReadGrantDatabaseRole, - Delete: DeleteGrantDatabaseRole, - Schema: grantDatabaseRoleSchema, + CreateContext: TrackingCreateWrapper(resources.GrantDatabaseRole, CreateGrantDatabaseRole), + ReadContext: TrackingReadWrapper(resources.GrantDatabaseRole, ReadGrantDatabaseRole), + DeleteContext: TrackingDeleteWrapper(resources.GrantDatabaseRole, DeleteGrantDatabaseRole), + Schema: grantDatabaseRoleSchema, Importer: &schema.ResourceImporter{ StateContext: TrackingImportWrapper(resources.GrantDatabaseRole, func(ctx context.Context, d *schema.ResourceData, m interface{}) ([]*schema.ResourceData, error) { parts := helpers.ParseResourceIdentifier(d.Id()) @@ -121,63 +123,61 @@ func GrantDatabaseRole() *schema.Resource { } // CreateGrantDatabaseRole implements schema.CreateFunc. -func CreateGrantDatabaseRole(d *schema.ResourceData, meta interface{}) error { +func CreateGrantDatabaseRole(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() databaseRoleName := d.Get("database_role_name").(string) databaseRoleIdentifier, err := sdk.ParseDatabaseObjectIdentifier(databaseRoleName) if err != nil { - return err + return diag.FromErr(err) } // format of snowflakeResourceID is || var snowflakeResourceID string if parentRoleName, ok := d.GetOk("parent_role_name"); ok && parentRoleName.(string) != "" { parentRoleIdentifier, err := sdk.ParseAccountObjectIdentifier(parentRoleName.(string)) if err != nil { - return err + return diag.FromErr(err) } snowflakeResourceID = helpers.EncodeResourceIdentifier(databaseRoleIdentifier.FullyQualifiedName(), sdk.ObjectTypeRole.String(), parentRoleIdentifier.FullyQualifiedName()) req := sdk.NewGrantDatabaseRoleRequest(databaseRoleIdentifier).WithAccountRole(parentRoleIdentifier) if err := client.DatabaseRoles.Grant(ctx, req); err != nil { - return err + return diag.FromErr(err) } } else if parentDatabaseRoleName, ok := d.GetOk("parent_database_role_name"); ok && parentDatabaseRoleName.(string) != "" { parentRoleIdentifier, err := sdk.ParseDatabaseObjectIdentifier(parentDatabaseRoleName.(string)) if err != nil { - return err + return diag.FromErr(err) } snowflakeResourceID = helpers.EncodeResourceIdentifier(databaseRoleIdentifier.FullyQualifiedName(), sdk.ObjectTypeDatabaseRole.String(), parentRoleIdentifier.FullyQualifiedName()) req := sdk.NewGrantDatabaseRoleRequest(databaseRoleIdentifier).WithDatabaseRole(parentRoleIdentifier) if err := client.DatabaseRoles.Grant(ctx, req); err != nil { - return err + return diag.FromErr(err) } } else if shareName, ok := d.GetOk("share_name"); ok && shareName.(string) != "" { shareIdentifier, err := sdk.ParseAccountObjectIdentifier(shareName.(string)) if err != nil { - return err + return diag.FromErr(err) } snowflakeResourceID = helpers.EncodeResourceIdentifier(databaseRoleIdentifier.FullyQualifiedName(), sdk.ObjectTypeShare.String(), shareIdentifier.FullyQualifiedName()) req := sdk.NewGrantDatabaseRoleToShareRequest(databaseRoleIdentifier, shareIdentifier) if err := client.DatabaseRoles.GrantToShare(ctx, req); err != nil { - return err + return diag.FromErr(err) } } d.SetId(snowflakeResourceID) - return ReadGrantDatabaseRole(d, meta) + return ReadGrantDatabaseRole(ctx, d, meta) } // ReadGrantDatabaseRole implements schema.ReadFunc. -func ReadGrantDatabaseRole(d *schema.ResourceData, meta interface{}) error { +func ReadGrantDatabaseRole(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client parts := helpers.ParseResourceIdentifier(d.Id()) databaseRoleName := parts[0] databaseRoleIdentifier, err := sdk.ParseDatabaseObjectIdentifier(databaseRoleName) if err != nil { - return err + return diag.FromErr(err) } objectType := parts[1] targetIdentifier := parts[2] - ctx := context.Background() grants, err := client.Grants.Show(ctx, &sdk.ShowGrantOptions{ Of: &sdk.ShowGrantsOf{ DatabaseRole: databaseRoleIdentifier, @@ -206,41 +206,40 @@ func ReadGrantDatabaseRole(d *schema.ResourceData, meta interface{}) error { } // DeleteGrantDatabaseRole implements schema.DeleteFunc. -func DeleteGrantDatabaseRole(d *schema.ResourceData, meta interface{}) error { +func DeleteGrantDatabaseRole(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client parts := helpers.ParseResourceIdentifier(d.Id()) id, err := sdk.ParseDatabaseObjectIdentifier(parts[0]) if err != nil { - return err + return diag.FromErr(err) } objectType := parts[1] granteeName := parts[2] - ctx := context.Background() switch objectType { case "ROLE": accountRoleId, err := sdk.ParseAccountObjectIdentifier(granteeName) if err != nil { - return err + return diag.FromErr(err) } if err := client.DatabaseRoles.Revoke(ctx, sdk.NewRevokeDatabaseRoleRequest(id).WithAccountRole(accountRoleId)); err != nil { - return err + return diag.FromErr(err) } case "DATABASE ROLE": databaseRoleId, err := sdk.ParseDatabaseObjectIdentifier(granteeName) if err != nil { - return err + return diag.FromErr(err) } if err := client.DatabaseRoles.Revoke(ctx, sdk.NewRevokeDatabaseRoleRequest(id).WithDatabaseRole(databaseRoleId)); err != nil { - return err + return diag.FromErr(err) } case "SHARE": sharedId, err := sdk.ParseAccountObjectIdentifier(granteeName) if err != nil { - return err + return diag.FromErr(err) } if err := client.DatabaseRoles.RevokeFromShare(ctx, sdk.NewRevokeDatabaseRoleFromShareRequest(id, sharedId)); err != nil { - return err + return diag.FromErr(err) } } d.SetId("") diff --git a/pkg/resources/managed_account.go b/pkg/resources/managed_account.go index 9a01a6cf67..510d7af633 100644 --- a/pkg/resources/managed_account.go +++ b/pkg/resources/managed_account.go @@ -6,6 +6,9 @@ import ( "log" "time" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/util" @@ -84,9 +87,9 @@ var managedAccountSchema = map[string]*schema.Schema{ // ManagedAccount returns a pointer to the resource representing a managed account. func ManagedAccount() *schema.Resource { return &schema.Resource{ - Create: CreateManagedAccount, - Read: ReadManagedAccount, - Delete: DeleteManagedAccount, + CreateContext: TrackingCreateWrapper(resources.ManagedAccount, CreateManagedAccount), + ReadContext: TrackingReadWrapper(resources.ManagedAccount, ReadManagedAccount), + DeleteContext: TrackingDeleteWrapper(resources.ManagedAccount, DeleteManagedAccount), Schema: managedAccountSchema, Importer: &schema.ResourceImporter{ @@ -96,9 +99,8 @@ func ManagedAccount() *schema.Resource { } // CreateManagedAccount implements schema.CreateFunc. -func CreateManagedAccount(d *schema.ResourceData, meta interface{}) error { +func CreateManagedAccount(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() name := d.Get("name").(string) id := sdk.NewAccountObjectIdentifier(name) @@ -115,19 +117,18 @@ func CreateManagedAccount(d *schema.ResourceData, meta interface{}) error { err := client.ManagedAccounts.Create(ctx, createRequest) if err != nil { - return err + return diag.FromErr(err) } d.SetId(helpers.EncodeSnowflakeID(id)) - return ReadManagedAccount(d, meta) + return ReadManagedAccount(ctx, d, meta) } // ReadManagedAccount implements schema.ReadFunc. -func ReadManagedAccount(d *schema.ResourceData, meta interface{}) error { +func ReadManagedAccount(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() id := helpers.DecodeSnowflakeID(d.Id()).(sdk.AccountObjectIdentifier) // We have to wait during the first read, since the locator takes some time to appear. @@ -144,60 +145,59 @@ func ReadManagedAccount(d *schema.ResourceData, meta interface{}) error { return nil, true }) if err != nil { - return err + return diag.FromErr(err) } if err := d.Set(FullyQualifiedNameAttributeName, id.FullyQualifiedName()); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("name", managedAccount.Name); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("cloud", managedAccount.Cloud); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("region", managedAccount.Region); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("locator", managedAccount.Locator); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("created_on", managedAccount.CreatedOn); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("url", managedAccount.URL); err != nil { - return err + return diag.FromErr(err) } if managedAccount.IsReader { if err := d.Set("type", "READER"); err != nil { - return err + return diag.FromErr(err) } } else { - return fmt.Errorf("unable to determine the account type") + return diag.FromErr(fmt.Errorf("unable to determine the account type")) } if err := d.Set("comment", managedAccount.Comment); err != nil { - return err + return diag.FromErr(err) } return nil } // DeleteManagedAccount implements schema.DeleteFunc. -func DeleteManagedAccount(d *schema.ResourceData, meta interface{}) error { +func DeleteManagedAccount(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() objectIdentifier := helpers.DecodeSnowflakeID(d.Id()).(sdk.AccountObjectIdentifier) err := client.ManagedAccounts.Drop(ctx, sdk.NewDropManagedAccountRequest(objectIdentifier)) if err != nil { - return err + return diag.FromErr(err) } d.SetId("") diff --git a/pkg/resources/materialized_view.go b/pkg/resources/materialized_view.go index 5151daf88a..2dacd668e6 100644 --- a/pkg/resources/materialized_view.go +++ b/pkg/resources/materialized_view.go @@ -6,6 +6,8 @@ import ( "log" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/customdiff" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/schemas" @@ -13,7 +15,6 @@ import ( "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake" - "github.com/hashicorp/terraform-plugin-sdk/v2/helper/customdiff" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" ) @@ -72,10 +73,10 @@ var materializedViewSchema = map[string]*schema.Schema{ // MaterializedView returns a pointer to the resource representing a view. func MaterializedView() *schema.Resource { return &schema.Resource{ - Create: CreateMaterializedView, - Read: ReadMaterializedView, - Update: UpdateMaterializedView, - Delete: DeleteMaterializedView, + CreateContext: TrackingCreateWrapper(resources.MaterializedView, CreateMaterializedView), + ReadContext: TrackingReadWrapper(resources.MaterializedView, ReadMaterializedView), + UpdateContext: TrackingUpdateWrapper(resources.MaterializedView, UpdateMaterializedView), + DeleteContext: TrackingDeleteWrapper(resources.MaterializedView, DeleteMaterializedView), CustomizeDiff: TrackingCustomDiffWrapper(resources.MaterializedView, customdiff.All( ComputedIfAnyAttributeChanged(materializedViewSchema, FullyQualifiedNameAttributeName, "name"), @@ -89,9 +90,8 @@ func MaterializedView() *schema.Resource { } // CreateMaterializedView implements schema.CreateFunc. -func CreateMaterializedView(d *schema.ResourceData, meta interface{}) error { +func CreateMaterializedView(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() databaseName := d.Get("database").(string) schemaName := d.Get("schema").(string) @@ -117,12 +117,12 @@ func CreateMaterializedView(d *schema.ResourceData, meta interface{}) error { // TODO [SNOW-1348355]: this was the old implementation, it's left for now, we will address this with resources rework discussions err := client.Sessions.UseWarehouse(ctx, sdk.NewAccountObjectIdentifier(warehouseName)) if err != nil { - return fmt.Errorf("error setting warehouse %s while creating materialized view %v err = %w", warehouseName, name, err) + return diag.FromErr(fmt.Errorf("error setting warehouse %s while creating materialized view %v err = %w", warehouseName, name, err)) } err = client.MaterializedViews.Create(ctx, createRequest) if err != nil { - return fmt.Errorf("error creating materialized view %v err = %w", name, err) + return diag.FromErr(fmt.Errorf("error creating materialized view %v err = %w", name, err)) } // TODO [SNOW-1348355]: we have to set tags after creation because existing materialized view extractor is not aware of TAG during CREATE @@ -130,19 +130,19 @@ func CreateMaterializedView(d *schema.ResourceData, meta interface{}) error { if _, ok := d.GetOk("tag"); ok { err := client.Views.Alter(ctx, sdk.NewAlterViewRequest(id).WithSetTags(getPropertyTags(d, "tag"))) if err != nil { - return fmt.Errorf("error setting tags on materialized view %v, err = %w", id, err) + return diag.FromErr(fmt.Errorf("error setting tags on materialized view %v, err = %w", id, err)) } } d.SetId(helpers.EncodeSnowflakeID(id)) - return ReadMaterializedView(d, meta) + return ReadMaterializedView(ctx, d, meta) } // ReadMaterializedView implements schema.ReadFunc. -func ReadMaterializedView(d *schema.ResourceData, meta interface{}) error { +func ReadMaterializedView(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() + id := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) materializedView, err := client.MaterializedViews.ShowByID(ctx, id) @@ -152,46 +152,46 @@ func ReadMaterializedView(d *schema.ResourceData, meta interface{}) error { return nil } if err := d.Set(FullyQualifiedNameAttributeName, id.FullyQualifiedName()); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("name", materializedView.Name); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("is_secure", materializedView.IsSecure); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("comment", materializedView.Comment); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("schema", materializedView.SchemaName); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("database", materializedView.DatabaseName); err != nil { - return err + return diag.FromErr(err) } // Want to only capture the SELECT part of the query because before that is the CREATE part of the view. extractor := snowflake.NewViewSelectStatementExtractor(materializedView.Text) substringOfQuery, err := extractor.ExtractMaterializedView() if err != nil { - return err + return diag.FromErr(err) } if err := d.Set("statement", substringOfQuery); err != nil { - return err + return diag.FromErr(err) } return nil } // UpdateMaterializedView implements schema.UpdateFunc. -func UpdateMaterializedView(d *schema.ResourceData, meta interface{}) error { +func UpdateMaterializedView(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() + id := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) if d.HasChange("name") { @@ -199,7 +199,7 @@ func UpdateMaterializedView(d *schema.ResourceData, meta interface{}) error { err := client.MaterializedViews.Alter(ctx, sdk.NewAlterMaterializedViewRequest(id).WithRenameTo(&newId)) if err != nil { - return fmt.Errorf("error renaming materialized view %v err = %w", d.Id(), err) + return diag.FromErr(fmt.Errorf("error renaming materialized view %v err = %w", d.Id(), err)) } d.SetId(helpers.EncodeSnowflakeID(newId)) @@ -234,14 +234,14 @@ func UpdateMaterializedView(d *schema.ResourceData, meta interface{}) error { if runSetStatement { err := client.MaterializedViews.Alter(ctx, sdk.NewAlterMaterializedViewRequest(id).WithSet(setRequest)) if err != nil { - return fmt.Errorf("error updating materialized view: %w", err) + return diag.FromErr(fmt.Errorf("error updating materialized view: %w", err)) } } if runUnsetStatement { err := client.MaterializedViews.Alter(ctx, sdk.NewAlterMaterializedViewRequest(id).WithUnset(unsetRequest)) if err != nil { - return fmt.Errorf("error updating materialized view: %w", err) + return diag.FromErr(fmt.Errorf("error updating materialized view: %w", err)) } } @@ -252,7 +252,7 @@ func UpdateMaterializedView(d *schema.ResourceData, meta interface{}) error { // TODO [SNOW-1022645]: view is used on purpose here; change after we have an agreement on situations like this in the SDK err := client.Views.Alter(ctx, sdk.NewAlterViewRequest(id).WithUnsetTags(unsetTags)) if err != nil { - return fmt.Errorf("error unsetting tags on %v, err = %w", d.Id(), err) + return diag.FromErr(fmt.Errorf("error unsetting tags on %v, err = %w", d.Id(), err)) } } @@ -260,23 +260,23 @@ func UpdateMaterializedView(d *schema.ResourceData, meta interface{}) error { // TODO [SNOW-1022645]: view is used on purpose here; change after we have an agreement on situations like this in the SDK err := client.Views.Alter(ctx, sdk.NewAlterViewRequest(id).WithSetTags(setTags)) if err != nil { - return fmt.Errorf("error setting tags on %v, err = %w", d.Id(), err) + return diag.FromErr(fmt.Errorf("error setting tags on %v, err = %w", d.Id(), err)) } } } - return ReadMaterializedView(d, meta) + return ReadMaterializedView(ctx, d, meta) } // DeleteMaterializedView implements schema.DeleteFunc. -func DeleteMaterializedView(d *schema.ResourceData, meta interface{}) error { +func DeleteMaterializedView(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() + id := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) err := client.MaterializedViews.Drop(ctx, sdk.NewDropMaterializedViewRequest(id)) if err != nil { - return err + return diag.FromErr(err) } d.SetId("") diff --git a/pkg/resources/network_policy_attachment.go b/pkg/resources/network_policy_attachment.go index e1d64590b1..c68ab02f85 100644 --- a/pkg/resources/network_policy_attachment.go +++ b/pkg/resources/network_policy_attachment.go @@ -6,6 +6,9 @@ import ( "log" "strings" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" @@ -36,10 +39,10 @@ var networkPolicyAttachmentSchema = map[string]*schema.Schema{ // NetworkPolicyAttachment returns a pointer to the resource representing a network policy attachment. func NetworkPolicyAttachment() *schema.Resource { return &schema.Resource{ - Create: CreateNetworkPolicyAttachment, - Read: ReadNetworkPolicyAttachment, - Update: UpdateNetworkPolicyAttachment, - Delete: DeleteNetworkPolicyAttachment, + CreateContext: TrackingCreateWrapper(resources.NetworkPolicyAttachment, CreateNetworkPolicyAttachment), + ReadContext: TrackingReadWrapper(resources.NetworkPolicyAttachment, ReadNetworkPolicyAttachment), + UpdateContext: TrackingUpdateWrapper(resources.NetworkPolicyAttachment, UpdateNetworkPolicyAttachment), + DeleteContext: TrackingDeleteWrapper(resources.NetworkPolicyAttachment, DeleteNetworkPolicyAttachment), Schema: networkPolicyAttachmentSchema, Importer: &schema.ResourceImporter{ @@ -49,40 +52,40 @@ func NetworkPolicyAttachment() *schema.Resource { } // CreateNetworkPolicyAttachment implements schema.CreateFunc. -func CreateNetworkPolicyAttachment(d *schema.ResourceData, meta interface{}) error { +func CreateNetworkPolicyAttachment(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { policyName := d.Get("network_policy_name").(string) d.SetId(policyName + "_attachment") if d.Get("set_for_account").(bool) { - if err := setOnAccount(d, meta); err != nil { - return fmt.Errorf("error creating attachment for network policy %v err = %w", policyName, err) + if err := setOnAccount(ctx, d, meta); err != nil { + return diag.FromErr(fmt.Errorf("error creating attachment for network policy %v err = %w", policyName, err)) } } if u, ok := d.GetOk("users"); ok { users := expandStringList(u.(*schema.Set).List()) - if err := ensureUserAlterPrivileges(users, meta); err != nil { - return err + if err := ensureUserAlterPrivileges(ctx, users, meta); err != nil { + return diag.FromErr(err) } - if err := setOnUsers(users, d, meta); err != nil { - return fmt.Errorf("error creating attachment for network policy %v err = %w", policyName, err) + if err := setOnUsers(ctx, users, d, meta); err != nil { + return diag.FromErr(fmt.Errorf("error creating attachment for network policy %v err = %w", policyName, err)) } } - return ReadNetworkPolicyAttachment(d, meta) + return ReadNetworkPolicyAttachment(ctx, d, meta) } // ReadNetworkPolicyAttachment implements schema.ReadFunc. -func ReadNetworkPolicyAttachment(d *schema.ResourceData, meta interface{}) error { +func ReadNetworkPolicyAttachment(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() + policyName := strings.Replace(d.Id(), "_attachment", "", 1) var currentUsers []string if err := d.Set("network_policy_name", policyName); err != nil { - return err + return diag.FromErr(err) } if u, ok := d.GetOk("users"); ok { @@ -100,7 +103,7 @@ func ReadNetworkPolicyAttachment(d *schema.ResourceData, meta interface{}) error } if err := d.Set("users", currentUsers); err != nil { - return err + return diag.FromErr(err) } } @@ -117,22 +120,22 @@ func ReadNetworkPolicyAttachment(d *schema.ResourceData, meta interface{}) error } if err := d.Set("set_for_account", isSetOnAccount); err != nil { - return err + return diag.FromErr(err) } return nil } // UpdateNetworkPolicyAttachment implements schema.UpdateFunc. -func UpdateNetworkPolicyAttachment(d *schema.ResourceData, meta interface{}) error { +func UpdateNetworkPolicyAttachment(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { if d.HasChange("set_for_account") { oldAcctFlag, newAcctFlag := d.GetChange("set_for_account") if newAcctFlag.(bool) { - if err := setOnAccount(d, meta); err != nil { - return err + if err := setOnAccount(ctx, d, meta); err != nil { + return diag.FromErr(err) } } else if !newAcctFlag.(bool) && oldAcctFlag == true { - if err := unsetOnAccount(d, meta); err != nil { - return err + if err := unsetOnAccount(ctx, d, meta); err != nil { + return diag.FromErr(err) } } } @@ -145,50 +148,50 @@ func UpdateNetworkPolicyAttachment(d *schema.ResourceData, meta interface{}) err removedUsers := expandStringList(oldUsersSet.Difference(newUsersSet).List()) addedUsers := expandStringList(newUsersSet.Difference(oldUsersSet).List()) - if err := ensureUserAlterPrivileges(removedUsers, meta); err != nil { - return err + if err := ensureUserAlterPrivileges(ctx, removedUsers, meta); err != nil { + return diag.FromErr(err) } - if err := ensureUserAlterPrivileges(addedUsers, meta); err != nil { - return err + if err := ensureUserAlterPrivileges(ctx, addedUsers, meta); err != nil { + return diag.FromErr(err) } for _, user := range removedUsers { - if err := unsetOnUser(user, d, meta); err != nil { - return err + if err := unsetOnUser(ctx, user, d, meta); err != nil { + return diag.FromErr(err) } } for _, user := range addedUsers { - if err := setOnUser(user, d, meta); err != nil { - return err + if err := setOnUser(ctx, user, d, meta); err != nil { + return diag.FromErr(err) } } } - return ReadNetworkPolicyAttachment(d, meta) + return ReadNetworkPolicyAttachment(ctx, d, meta) } // DeleteNetworkPolicyAttachment implements schema.DeleteFunc. -func DeleteNetworkPolicyAttachment(d *schema.ResourceData, meta interface{}) error { +func DeleteNetworkPolicyAttachment(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { policyName := d.Get("network_policy_name").(string) d.SetId(policyName + "_attachment") if d.Get("set_for_account").(bool) { - if err := unsetOnAccount(d, meta); err != nil { - return fmt.Errorf("error deleting attachment for network policy %v err = %w", policyName, err) + if err := unsetOnAccount(ctx, d, meta); err != nil { + return diag.FromErr(fmt.Errorf("error deleting attachment for network policy %v err = %w", policyName, err)) } } if u, ok := d.GetOk("users"); ok { users := expandStringList(u.(*schema.Set).List()) - if err := ensureUserAlterPrivileges(users, meta); err != nil { - return err + if err := ensureUserAlterPrivileges(ctx, users, meta); err != nil { + return diag.FromErr(err) } - if err := unsetOnUsers(users, d, meta); err != nil { - return fmt.Errorf("error deleting attachment for network policy %v err = %w", policyName, err) + if err := unsetOnUsers(ctx, users, d, meta); err != nil { + return diag.FromErr(fmt.Errorf("error deleting attachment for network policy %v err = %w", policyName, err)) } } @@ -197,9 +200,9 @@ func DeleteNetworkPolicyAttachment(d *schema.ResourceData, meta interface{}) err // setOnAccount sets the network policy globally for the Snowflake account // Note: the ip address of the session executing this SQL must be allowed by the network policy being set. -func setOnAccount(d *schema.ResourceData, meta interface{}) error { +func setOnAccount(ctx context.Context, d *schema.ResourceData, meta any) error { client := meta.(*provider.Context).Client - ctx := context.Background() + policyName := d.Get("network_policy_name").(string) err := client.Accounts.Alter(ctx, &sdk.AlterAccountOptions{Set: &sdk.AccountSet{Parameters: &sdk.AccountLevelParameters{ObjectParameters: &sdk.ObjectParameters{NetworkPolicy: sdk.String(policyName)}}}}) @@ -211,9 +214,9 @@ func setOnAccount(d *schema.ResourceData, meta interface{}) error { } // setOnAccount unsets the network policy globally for the Snowflake account. -func unsetOnAccount(d *schema.ResourceData, meta interface{}) error { +func unsetOnAccount(ctx context.Context, d *schema.ResourceData, meta any) error { client := meta.(*provider.Context).Client - ctx := context.Background() + policyName := d.Get("network_policy_name").(string) err := client.Accounts.Alter(ctx, &sdk.AlterAccountOptions{Unset: &sdk.AccountUnset{Parameters: &sdk.AccountLevelParametersUnset{ObjectParameters: &sdk.ObjectParametersUnset{NetworkPolicy: sdk.Bool(true)}}}}) @@ -225,10 +228,10 @@ func unsetOnAccount(d *schema.ResourceData, meta interface{}) error { } // setOnUsers sets the network policy for list of users. -func setOnUsers(users []string, data *schema.ResourceData, meta interface{}) error { +func setOnUsers(ctx context.Context, users []string, data *schema.ResourceData, meta interface{}) error { policyName := data.Get("network_policy_name").(string) for _, user := range users { - if err := setOnUser(user, data, meta); err != nil { + if err := setOnUser(ctx, user, data, meta); err != nil { return fmt.Errorf("error setting network policy %v on user %v err = %w", policyName, user, err) } } @@ -237,9 +240,9 @@ func setOnUsers(users []string, data *schema.ResourceData, meta interface{}) err } // setOnUser sets the network policy for a given user. -func setOnUser(user string, data *schema.ResourceData, meta interface{}) error { +func setOnUser(ctx context.Context, user string, data *schema.ResourceData, meta interface{}) error { client := meta.(*provider.Context).Client - ctx := context.Background() + policyName := data.Get("network_policy_name").(string) err := client.Users.Alter(ctx, sdk.NewAccountObjectIdentifier(user), &sdk.AlterUserOptions{Set: &sdk.UserSet{ObjectParameters: &sdk.UserObjectParameters{NetworkPolicy: sdk.Pointer(sdk.NewAccountObjectIdentifier(policyName))}}}) @@ -251,10 +254,10 @@ func setOnUser(user string, data *schema.ResourceData, meta interface{}) error { } // unsetOnUsers unsets the network policy for list of users. -func unsetOnUsers(users []string, data *schema.ResourceData, meta interface{}) error { +func unsetOnUsers(ctx context.Context, users []string, data *schema.ResourceData, meta interface{}) error { policyName := data.Get("network_policy_name").(string) for _, user := range users { - if err := unsetOnUser(user, data, meta); err != nil { + if err := unsetOnUser(ctx, user, data, meta); err != nil { return fmt.Errorf("error unsetting network policy %v on user %v err = %w", policyName, user, err) } } @@ -263,9 +266,9 @@ func unsetOnUsers(users []string, data *schema.ResourceData, meta interface{}) e } // unsetOnUser sets the network policy for a given user. -func unsetOnUser(user string, data *schema.ResourceData, meta interface{}) error { +func unsetOnUser(ctx context.Context, user string, data *schema.ResourceData, meta interface{}) error { client := meta.(*provider.Context).Client - ctx := context.Background() + policyName := data.Get("network_policy_name").(string) err := client.Users.Alter(ctx, sdk.NewAccountObjectIdentifier(user), &sdk.AlterUserOptions{Unset: &sdk.UserUnset{ObjectParameters: &sdk.UserObjectParametersUnset{NetworkPolicy: sdk.Bool(true)}}}) @@ -277,9 +280,8 @@ func unsetOnUser(user string, data *schema.ResourceData, meta interface{}) error } // ensureUserAlterPrivileges ensures the executing Snowflake user can alter each user in the set of users. -func ensureUserAlterPrivileges(users []string, meta interface{}) error { +func ensureUserAlterPrivileges(ctx context.Context, users []string, meta interface{}) error { client := meta.(*provider.Context).Client - ctx := context.Background() for _, user := range users { _, err := client.Users.Describe(ctx, sdk.NewAccountObjectIdentifier(user)) diff --git a/pkg/resources/notification_integration.go b/pkg/resources/notification_integration.go index a69e474f33..109a46e0ba 100644 --- a/pkg/resources/notification_integration.go +++ b/pkg/resources/notification_integration.go @@ -6,6 +6,9 @@ import ( "log" "strings" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/schemas" @@ -151,10 +154,10 @@ var notificationIntegrationSchema = map[string]*schema.Schema{ // NotificationIntegration returns a pointer to the resource representing a notification integration. func NotificationIntegration() *schema.Resource { return &schema.Resource{ - Create: CreateNotificationIntegration, - Read: ReadNotificationIntegration, - Update: UpdateNotificationIntegration, - Delete: DeleteNotificationIntegration, + CreateContext: TrackingCreateWrapper(resources.NotificationIntegration, CreateNotificationIntegration), + ReadContext: TrackingReadWrapper(resources.NotificationIntegration, ReadNotificationIntegration), + UpdateContext: TrackingUpdateWrapper(resources.NotificationIntegration, UpdateNotificationIntegration), + DeleteContext: TrackingDeleteWrapper(resources.NotificationIntegration, DeleteNotificationIntegration), Schema: notificationIntegrationSchema, Importer: &schema.ResourceImporter{ @@ -164,9 +167,8 @@ func NotificationIntegration() *schema.Resource { } // CreateNotificationIntegration implements schema.CreateFunc. -func CreateNotificationIntegration(d *schema.ResourceData, meta interface{}) error { +func CreateNotificationIntegration(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() name := d.Get("name").(string) id := sdk.NewAccountObjectIdentifier(name) @@ -183,11 +185,11 @@ func CreateNotificationIntegration(d *schema.ResourceData, meta interface{}) err case "AWS_SNS": topic, ok := d.GetOk("aws_sns_topic_arn") if !ok { - return fmt.Errorf("if you use AWS_SNS provider you must specify an aws_sns_topic_arn") + return diag.FromErr(fmt.Errorf("if you use AWS_SNS provider you must specify an aws_sns_topic_arn")) } role, ok := d.GetOk("aws_sns_role_arn") if !ok { - return fmt.Errorf("if you use AWS_SNS provider you must specify an aws_sns_role_arn") + return diag.FromErr(fmt.Errorf("if you use AWS_SNS provider you must specify an aws_sns_role_arn")) } createRequest.WithPushNotificationParams( sdk.NewPushNotificationParamsRequest().WithAmazonPushParams(sdk.NewAmazonPushParamsRequest(topic.(string), role.(string))), @@ -206,63 +208,63 @@ func CreateNotificationIntegration(d *schema.ResourceData, meta interface{}) err case "AZURE_STORAGE_QUEUE": uri, ok := d.GetOk("azure_storage_queue_primary_uri") if !ok { - return fmt.Errorf("if you use AZURE_STORAGE_QUEUE provider you must specify an azure_storage_queue_primary_uri") + return diag.FromErr(fmt.Errorf("if you use AZURE_STORAGE_QUEUE provider you must specify an azure_storage_queue_primary_uri")) } tenantId, ok := d.GetOk("azure_tenant_id") if !ok { - return fmt.Errorf("if you use AZURE_STORAGE_QUEUE provider you must specify an azure_tenant_id") + return diag.FromErr(fmt.Errorf("if you use AZURE_STORAGE_QUEUE provider you must specify an azure_tenant_id")) } createRequest.WithAutomatedDataLoadsParams( sdk.NewAutomatedDataLoadsParamsRequest().WithAzureAutoParams(sdk.NewAzureAutoParamsRequest(uri.(string), tenantId.(string))), ) default: - return fmt.Errorf("unexpected provider %v", notificationProvider) + return diag.FromErr(fmt.Errorf("unexpected provider %v", notificationProvider)) } err := client.NotificationIntegrations.Create(ctx, createRequest) if err != nil { - return fmt.Errorf("error creating notification integration: %w", err) + return diag.FromErr(fmt.Errorf("error creating notification integration: %w", err)) } d.SetId(helpers.EncodeSnowflakeID(id)) - return ReadNotificationIntegration(d, meta) + return ReadNotificationIntegration(ctx, d, meta) } // ReadNotificationIntegration implements schema.ReadFunc. -func ReadNotificationIntegration(d *schema.ResourceData, meta interface{}) error { +func ReadNotificationIntegration(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() + id := helpers.DecodeSnowflakeID(d.Id()).(sdk.AccountObjectIdentifier) integration, err := client.NotificationIntegrations.ShowByID(ctx, id) if err != nil { log.Printf("[DEBUG] notification integration (%s) not found", d.Id()) d.SetId("") - return err + return diag.FromErr(err) } // Note: category must be NOTIFICATION or something is broken if c := integration.Category; c != "NOTIFICATION" { - return fmt.Errorf("expected %v to be a NOTIFICATION integration, got %v", id, c) + return diag.FromErr(fmt.Errorf("expected %v to be a NOTIFICATION integration, got %v", id, c)) } if err := d.Set(FullyQualifiedNameAttributeName, id.FullyQualifiedName()); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("name", integration.Name); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("comment", integration.Comment); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("created_on", integration.CreatedOn.String()); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("enabled", integration.Enabled); err != nil { - return err + return diag.FromErr(err) } // Snowflake returns "QUEUE - AZURE_STORAGE_QUEUE" instead of simple "QUEUE" as a type @@ -270,13 +272,13 @@ func ReadNotificationIntegration(d *schema.ResourceData, meta interface{}) error typeParts := strings.Split(integration.NotificationType, "-") parsedType := strings.TrimSpace(typeParts[0]) if err := d.Set("type", parsedType); err != nil { - return err + return diag.FromErr(err) } // Some properties come from the DESCRIBE INTEGRATION call integrationProperties, err := client.NotificationIntegrations.Describe(ctx, id) if err != nil { - return fmt.Errorf("could not describe notification integration: %w", err) + return diag.FromErr(fmt.Errorf("could not describe notification integration: %w", err)) } for _, property := range integrationProperties { name := property.Name @@ -286,72 +288,72 @@ func ReadNotificationIntegration(d *schema.ResourceData, meta interface{}) error // We set this using the SHOW INTEGRATION call so let's ignore it here case "DIRECTION": if err := d.Set("direction", value); err != nil { - return err + return diag.FromErr(err) } case "NOTIFICATION_PROVIDER": if err := d.Set("notification_provider", value); err != nil { - return err + return diag.FromErr(err) } case "AZURE_STORAGE_QUEUE_PRIMARY_URI": if err := d.Set("azure_storage_queue_primary_uri", value); err != nil { - return err + return diag.FromErr(err) } // NOTIFICATION_PROVIDER is not returned for azure automated data load, so we set it manually in such a case if err := d.Set("notification_provider", "AZURE_STORAGE_QUEUE"); err != nil { - return err + return diag.FromErr(err) } case "AZURE_TENANT_ID": if err := d.Set("azure_tenant_id", value); err != nil { - return err + return diag.FromErr(err) } case "AWS_SNS_TOPIC_ARN": if err := d.Set("aws_sns_topic_arn", value); err != nil { - return err + return diag.FromErr(err) } case "AWS_SNS_ROLE_ARN": if err := d.Set("aws_sns_role_arn", value); err != nil { - return err + return diag.FromErr(err) } case "SF_AWS_EXTERNAL_ID": if err := d.Set("aws_sns_external_id", value); err != nil { - return err + return diag.FromErr(err) } case "SF_AWS_IAM_USER_ARN": if err := d.Set("aws_sns_iam_user_arn", value); err != nil { - return err + return diag.FromErr(err) } case "GCP_PUBSUB_SUBSCRIPTION_NAME": if err := d.Set("gcp_pubsub_subscription_name", value); err != nil { - return err + return diag.FromErr(err) } // NOTIFICATION_PROVIDER is not returned for gcp, so we set it manually in such a case if err := d.Set("notification_provider", "GCP_PUBSUB"); err != nil { - return err + return diag.FromErr(err) } case "GCP_PUBSUB_TOPIC_NAME": if err := d.Set("gcp_pubsub_topic_name", value); err != nil { - return err + return diag.FromErr(err) } // NOTIFICATION_PROVIDER is not returned for gcp, so we set it manually in such a case if err := d.Set("notification_provider", "GCP_PUBSUB"); err != nil { - return err + return diag.FromErr(err) } case "GCP_PUBSUB_SERVICE_ACCOUNT": if err := d.Set("gcp_pubsub_service_account", value); err != nil { - return err + return diag.FromErr(err) } default: log.Printf("[WARN] unexpected property %v returned from Snowflake", name) } } - return err + return diag.FromErr(err) } // UpdateNotificationIntegration implements schema.UpdateFunc. -func UpdateNotificationIntegration(d *schema.ResourceData, meta interface{}) error { +func UpdateNotificationIntegration(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() + id := helpers.DecodeSnowflakeID(d.Id()).(sdk.AccountObjectIdentifier) var runSetStatement bool @@ -379,28 +381,28 @@ func UpdateNotificationIntegration(d *schema.ResourceData, meta interface{}) err case "AZURE_STORAGE_QUEUE": log.Printf("[WARN] all AZURE_STORAGE_QUEUE properties should recreate the resource") default: - return fmt.Errorf("unexpected provider %v", notificationProvider) + return diag.FromErr(fmt.Errorf("unexpected provider %v", notificationProvider)) } if runSetStatement { err := client.NotificationIntegrations.Alter(ctx, sdk.NewAlterNotificationIntegrationRequest(id).WithSet(setRequest)) if err != nil { - return fmt.Errorf("error updating notification integration: %w", err) + return diag.FromErr(fmt.Errorf("error updating notification integration: %w", err)) } } - return ReadNotificationIntegration(d, meta) + return ReadNotificationIntegration(ctx, d, meta) } // DeleteNotificationIntegration implements schema.DeleteFunc. -func DeleteNotificationIntegration(d *schema.ResourceData, meta interface{}) error { +func DeleteNotificationIntegration(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() + id := helpers.DecodeSnowflakeID(d.Id()).(sdk.AccountObjectIdentifier) err := client.NotificationIntegrations.Drop(ctx, sdk.NewDropNotificationIntegrationRequest(id)) if err != nil { - return err + return diag.FromErr(err) } d.SetId("") diff --git a/pkg/resources/oauth_integration.go b/pkg/resources/oauth_integration.go index 2d7e4152ae..22b812ed91 100644 --- a/pkg/resources/oauth_integration.go +++ b/pkg/resources/oauth_integration.go @@ -1,11 +1,15 @@ package resources import ( + "context" "fmt" "log" "strconv" "strings" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake" @@ -86,10 +90,10 @@ var oauthIntegrationSchema = map[string]*schema.Schema{ // OAuthIntegration returns a pointer to the resource representing an OAuth integration. func OAuthIntegration() *schema.Resource { return &schema.Resource{ - Create: CreateOAuthIntegration, - Read: ReadOAuthIntegration, - Update: UpdateOAuthIntegration, - Delete: DeleteOAuthIntegration, + CreateContext: TrackingCreateWrapper(resources.OauthIntegration, CreateOAuthIntegration), + ReadContext: TrackingReadWrapper(resources.OauthIntegration, ReadOAuthIntegration), + UpdateContext: TrackingUpdateWrapper(resources.OauthIntegration, UpdateOAuthIntegration), + DeleteContext: TrackingDeleteWrapper(resources.OauthIntegration, DeleteOAuthIntegration), DeprecationMessage: "This resource is deprecated and will be removed in a future major version release. Please use snowflake_oauth_integration_for_custom_clients or snowflake_oauth_integration_for_partner_applications instead.", Schema: oauthIntegrationSchema, @@ -100,7 +104,7 @@ func OAuthIntegration() *schema.Resource { } // CreateOAuthIntegration implements schema.CreateFunc. -func CreateOAuthIntegration(d *schema.ResourceData, meta interface{}) error { +func CreateOAuthIntegration(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client db := client.GetConn().DB name := d.Get("name").(string) @@ -137,16 +141,16 @@ func CreateOAuthIntegration(d *schema.ResourceData, meta interface{}) error { } if err := snowflake.Exec(db, stmt.Statement()); err != nil { - return fmt.Errorf("error creating security integration err = %w", err) + return diag.FromErr(fmt.Errorf("error creating security integration err = %w", err)) } d.SetId(name) - return ReadOAuthIntegration(d, meta) + return ReadOAuthIntegration(ctx, d, meta) } // ReadOAuthIntegration implements schema.ReadFunc. -func ReadOAuthIntegration(d *schema.ResourceData, meta interface{}) error { +func ReadOAuthIntegration(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client db := client.GetConn().DB id := d.Id() @@ -158,32 +162,32 @@ func ReadOAuthIntegration(d *schema.ResourceData, meta interface{}) error { s, err := snowflake.ScanOAuthIntegration(row) if err != nil { - return fmt.Errorf("could not show security integration err = %w", err) + return diag.FromErr(fmt.Errorf("could not show security integration err = %w", err)) } // Note: category must be Security or something is broken if c := s.Category.String; c != "SECURITY" { - return fmt.Errorf("expected %v to be an Security integration, got %v err = %w", id, c, err) + return diag.FromErr(fmt.Errorf("expected %v to be an Security integration, got %v err = %w", id, c, err)) } if err := d.Set("oauth_client", strings.TrimPrefix(s.IntegrationType.String, "OAUTH - ")); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("name", s.Name.String); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("enabled", s.Enabled.Bool); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("comment", s.Comment.String); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("created_on", s.CreatedOn.String); err != nil { - return err + return diag.FromErr(err) } // Some properties come from the DESCRIBE INTEGRATION call @@ -193,12 +197,12 @@ func ReadOAuthIntegration(d *schema.ResourceData, meta interface{}) error { stmt = snowflake.NewOAuthIntegrationBuilder(id).Describe() rows, err := db.Query(stmt) if err != nil { - return fmt.Errorf("could not describe security integration err = %w", err) + return diag.FromErr(fmt.Errorf("could not describe security integration err = %w", err)) } defer rows.Close() for rows.Next() { if err := rows.Scan(&k, &pType, &v, &unused); err != nil { - return fmt.Errorf("unable to parse security integration rows err = %w", err) + return diag.FromErr(fmt.Errorf("unable to parse security integration rows err = %w", err)) } switch k { case "ENABLED": @@ -208,22 +212,22 @@ func ReadOAuthIntegration(d *schema.ResourceData, meta interface{}) error { case "OAUTH_ISSUE_REFRESH_TOKENS": b, err := strconv.ParseBool(v.(string)) if err != nil { - return fmt.Errorf("returned OAuth issue refresh tokens that is not boolean err = %w", err) + return diag.FromErr(fmt.Errorf("returned OAuth issue refresh tokens that is not boolean err = %w", err)) } if err := d.Set("oauth_issue_refresh_tokens", b); err != nil { - return fmt.Errorf("unable to set OAuth issue refresh tokens for security integration err = %w", err) + return diag.FromErr(fmt.Errorf("unable to set OAuth issue refresh tokens for security integration err = %w", err)) } case "OAUTH_REFRESH_TOKEN_VALIDITY": i, err := strconv.Atoi(v.(string)) if err != nil { - return fmt.Errorf("returned OAuth refresh token validity that is not integer err = %w", err) + return diag.FromErr(fmt.Errorf("returned OAuth refresh token validity that is not integer err = %w", err)) } if err := d.Set("oauth_refresh_token_validity", i); err != nil { - return fmt.Errorf("unable to set OAuth refresh token validity for security integration err = %w", err) + return diag.FromErr(fmt.Errorf("unable to set OAuth refresh token validity for security integration err = %w", err)) } case "OAUTH_USE_SECONDARY_ROLES": if err := d.Set("oauth_use_secondary_roles", v.(string)); err != nil { - return fmt.Errorf("unable to set OAuth use secondary roles for security integration err = %w", err) + return diag.FromErr(fmt.Errorf("unable to set OAuth use secondary roles for security integration err = %w", err)) } case "BLOCKED_ROLES_LIST": blockedRolesAll := strings.Split(v.(string), ",") @@ -238,18 +242,18 @@ func ReadOAuthIntegration(d *schema.ResourceData, meta interface{}) error { } if err := d.Set("blocked_roles_list", blockedRolesCustom); err != nil { - return fmt.Errorf("unable to set blocked roles list for security integration err = %w", err) + return diag.FromErr(fmt.Errorf("unable to set blocked roles list for security integration err = %w", err)) } case "OAUTH_REDIRECT_URI": if err := d.Set("oauth_redirect_uri", v.(string)); err != nil { - return fmt.Errorf("unable to set OAuth redirect URI for security integration err = %w", err) + return diag.FromErr(fmt.Errorf("unable to set OAuth redirect URI for security integration err = %w", err)) } case "OAUTH_CLIENT_TYPE": isTableau := strings.HasSuffix(s.IntegrationType.String, "TABLEAU_DESKTOP") || strings.HasSuffix(s.IntegrationType.String, "TABLEAU_SERVER") if !isTableau { if err = d.Set("oauth_client_type", v.(string)); err != nil { - return fmt.Errorf("unable to set OAuth client type for security integration err = %w", err) + return diag.FromErr(fmt.Errorf("unable to set OAuth client type for security integration err = %w", err)) } } case "OAUTH_ENFORCE_PKCE": @@ -270,11 +274,11 @@ func ReadOAuthIntegration(d *schema.ResourceData, meta interface{}) error { } } - return err + return diag.FromErr(err) } // UpdateOAuthIntegration implements schema.UpdateFunc. -func UpdateOAuthIntegration(d *schema.ResourceData, meta interface{}) error { +func UpdateOAuthIntegration(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client db := client.GetConn().DB id := d.Id() @@ -330,14 +334,14 @@ func UpdateOAuthIntegration(d *schema.ResourceData, meta interface{}) error { if runSetStatement { if err := snowflake.Exec(db, stmt.Statement()); err != nil { - return fmt.Errorf("error updating security integration err = %w", err) + return diag.FromErr(fmt.Errorf("error updating security integration err = %w", err)) } } - return ReadOAuthIntegration(d, meta) + return ReadOAuthIntegration(ctx, d, meta) } // DeleteOAuthIntegration implements schema.DeleteFunc. -func DeleteOAuthIntegration(d *schema.ResourceData, meta interface{}) error { - return DeleteResource("", snowflake.NewOAuthIntegrationBuilder)(d, meta) +func DeleteOAuthIntegration(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { + return diag.FromErr(DeleteResource("", snowflake.NewOAuthIntegrationBuilder)(d, meta)) } diff --git a/pkg/resources/oauth_integration_test.go b/pkg/resources/oauth_integration_test.go index c8e2d9a126..7a353dd8f2 100644 --- a/pkg/resources/oauth_integration_test.go +++ b/pkg/resources/oauth_integration_test.go @@ -1,6 +1,7 @@ package resources_test import ( + "context" "database/sql" "testing" @@ -37,10 +38,10 @@ func TestOAuthIntegrationCreate(t *testing.T) { ).WillReturnResult(sqlmock.NewResult(1, 1)) expectReadOAuthIntegration(mock) - err := resources.CreateOAuthIntegration(d, &internalprovider.Context{ + diags := resources.CreateOAuthIntegration(context.Background(), d, &internalprovider.Context{ Client: sdk.NewClientFromDB(db), }) - r.NoError(err) + r.Empty(diags) }) } @@ -52,10 +53,10 @@ func TestOAuthIntegrationRead(t *testing.T) { WithMockDb(t, func(db *sql.DB, mock sqlmock.Sqlmock) { expectReadOAuthIntegration(mock) - err := resources.ReadOAuthIntegration(d, &internalprovider.Context{ + diags := resources.ReadOAuthIntegration(context.Background(), d, &internalprovider.Context{ Client: sdk.NewClientFromDB(db), }) - r.NoError(err) + r.Empty(diags) }) } @@ -66,10 +67,10 @@ func TestOAuthIntegrationDelete(t *testing.T) { WithMockDb(t, func(db *sql.DB, mock sqlmock.Sqlmock) { mock.ExpectExec(`DROP SECURITY INTEGRATION "drop_it"`).WillReturnResult(sqlmock.NewResult(1, 1)) - err := resources.DeleteOAuthIntegration(d, &internalprovider.Context{ + diags := resources.DeleteOAuthIntegration(context.Background(), d, &internalprovider.Context{ Client: sdk.NewClientFromDB(db), }) - r.NoError(err) + r.Empty(diags) }) } diff --git a/pkg/resources/object_parameter.go b/pkg/resources/object_parameter.go index ab5dbc0ecc..7d5b01ae08 100644 --- a/pkg/resources/object_parameter.go +++ b/pkg/resources/object_parameter.go @@ -5,6 +5,9 @@ import ( "fmt" "strings" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" @@ -69,10 +72,10 @@ var objectParameterSchema = map[string]*schema.Schema{ func ObjectParameter() *schema.Resource { return &schema.Resource{ - Create: CreateObjectParameter, - Read: ReadObjectParameter, - Update: UpdateObjectParameter, - Delete: DeleteObjectParameter, + CreateContext: TrackingCreateWrapper(resources.ObjectParameter, CreateObjectParameter), + ReadContext: TrackingReadWrapper(resources.ObjectParameter, ReadObjectParameter), + UpdateContext: TrackingUpdateWrapper(resources.ObjectParameter, UpdateObjectParameter), + DeleteContext: TrackingDeleteWrapper(resources.ObjectParameter, DeleteObjectParameter), Schema: objectParameterSchema, Importer: &schema.ResourceImporter{ @@ -82,11 +85,11 @@ func ObjectParameter() *schema.Resource { } // CreateObjectParameter implements schema.CreateFunc. -func CreateObjectParameter(d *schema.ResourceData, meta interface{}) error { +func CreateObjectParameter(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client key := d.Get("key").(string) value := d.Get("value").(string) - ctx := context.Background() + parameter := sdk.ObjectParameter(key) o := sdk.Object{} @@ -102,12 +105,12 @@ func CreateObjectParameter(d *schema.ResourceData, meta interface{}) error { if onAccount { err := client.Parameters.SetObjectParameterOnAccount(ctx, parameter, value) if err != nil { - return fmt.Errorf("error creating object parameter err = %w", err) + return diag.FromErr(fmt.Errorf("error creating object parameter err = %w", err)) } } else { err := client.Parameters.SetObjectParameterOnObject(ctx, o, parameter, value) if err != nil { - return fmt.Errorf("error setting object parameter err = %w", err) + return diag.FromErr(fmt.Errorf("error setting object parameter err = %w", err)) } } @@ -127,26 +130,26 @@ func CreateObjectParameter(d *schema.ResourceData, meta interface{}) error { p, err = client.Parameters.ShowObjectParameter(ctx, sdk.ObjectParameter(key), o) } if err != nil { - return fmt.Errorf("error reading object parameter err = %w", err) + return diag.FromErr(fmt.Errorf("error reading object parameter err = %w", err)) } err = d.Set("value", p.Value) if err != nil { - return err + return diag.FromErr(err) } - return ReadObjectParameter(d, meta) + return ReadObjectParameter(ctx, d, meta) } // ReadObjectParameter implements schema.ReadFunc. -func ReadObjectParameter(d *schema.ResourceData, meta interface{}) error { +func ReadObjectParameter(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() + id := d.Id() parts := strings.Split(id, "|") if len(parts) != 3 { parts = strings.Split(id, "❄️") // for backwards compatibility } if len(parts) != 3 { - return fmt.Errorf("unexpected format of ID (%v), expected key|object_type|object_identifier", id) + return diag.FromErr(fmt.Errorf("unexpected format of ID (%v), expected key|object_type|object_identifier", id)) } key := parts[0] var p *sdk.Parameter @@ -163,35 +166,35 @@ func ReadObjectParameter(d *schema.ResourceData, meta interface{}) error { }) } if err != nil { - return fmt.Errorf("error reading object parameter err = %w", err) + return diag.FromErr(fmt.Errorf("error reading object parameter err = %w", err)) } if err := d.Set("value", p.Value); err != nil { - return err + return diag.FromErr(err) } return nil } // UpdateObjectParameter implements schema.UpdateFunc. -func UpdateObjectParameter(d *schema.ResourceData, meta interface{}) error { - return CreateObjectParameter(d, meta) +func UpdateObjectParameter(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { + return CreateObjectParameter(ctx, d, meta) } // DeleteObjectParameter implements schema.DeleteFunc. -func DeleteObjectParameter(d *schema.ResourceData, meta interface{}) error { +func DeleteObjectParameter(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() + key := d.Get("key").(string) onAccount := d.Get("on_account").(bool) if onAccount { defaultParameter, err := client.Parameters.ShowAccountParameter(ctx, sdk.AccountParameter(key)) if err != nil { - return err + return diag.FromErr(err) } defaultValue := defaultParameter.Default err = client.Parameters.SetAccountParameter(ctx, sdk.AccountParameter(key), defaultValue) if err != nil { - return fmt.Errorf("error resetting account parameter err = %w", err) + return diag.FromErr(fmt.Errorf("error resetting account parameter err = %w", err)) } } else { v := d.Get("object_identifier") @@ -205,12 +208,12 @@ func DeleteObjectParameter(d *schema.ResourceData, meta interface{}) error { objectParameter := sdk.ObjectParameter(key) defaultParameter, err := client.Parameters.ShowObjectParameter(ctx, objectParameter, o) if err != nil { - return err + return diag.FromErr(err) } defaultValue := defaultParameter.Default err = client.Parameters.SetObjectParameterOnObject(ctx, o, objectParameter, defaultValue) if err != nil { - return fmt.Errorf("error resetting object parameter err = %w", err) + return diag.FromErr(fmt.Errorf("error resetting object parameter err = %w", err)) } } d.SetId("") diff --git a/pkg/resources/password_policy.go b/pkg/resources/password_policy.go index 0fba86bf06..da3cf85038 100644 --- a/pkg/resources/password_policy.go +++ b/pkg/resources/password_policy.go @@ -4,13 +4,14 @@ import ( "context" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/customdiff" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/schemas" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" - "github.com/hashicorp/terraform-plugin-sdk/v2/helper/customdiff" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" ) @@ -140,11 +141,11 @@ var passwordPolicySchema = map[string]*schema.Schema{ func PasswordPolicy() *schema.Resource { return &schema.Resource{ - Description: "A password policy specifies the requirements that must be met to create and reset a password to authenticate to Snowflake.", - Create: CreatePasswordPolicy, - Read: ReadPasswordPolicy, - Update: UpdatePasswordPolicy, - Delete: DeletePasswordPolicy, + Description: "A password policy specifies the requirements that must be met to create and reset a password to authenticate to Snowflake.", + CreateContext: TrackingCreateWrapper(resources.PasswordPolicy, CreatePasswordPolicy), + ReadContext: TrackingReadWrapper(resources.PasswordPolicy, ReadPasswordPolicy), + UpdateContext: TrackingUpdateWrapper(resources.PasswordPolicy, UpdatePasswordPolicy), + DeleteContext: TrackingDeleteWrapper(resources.PasswordPolicy, DeletePasswordPolicy), CustomizeDiff: TrackingCustomDiffWrapper(resources.PasswordPolicy, customdiff.All( ComputedIfAnyAttributeChanged(passwordPolicySchema, FullyQualifiedNameAttributeName, "name"), @@ -158,9 +159,9 @@ func PasswordPolicy() *schema.Resource { } // CreatePasswordPolicy implements schema.CreateFunc. -func CreatePasswordPolicy(d *schema.ResourceData, meta interface{}) error { +func CreatePasswordPolicy(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() + name := d.Get("name").(string) database := d.Get("database").(string) schema := d.Get("schema").(string) @@ -188,84 +189,83 @@ func CreatePasswordPolicy(d *schema.ResourceData, meta interface{}) error { err := client.PasswordPolicies.Create(ctx, objectIdentifier, createOptions) if err != nil { - return err + return diag.FromErr(err) } d.SetId(helpers.EncodeSnowflakeID(objectIdentifier)) - return ReadPasswordPolicy(d, meta) + return ReadPasswordPolicy(ctx, d, meta) } // ReadPasswordPolicy implements schema.ReadFunc. -func ReadPasswordPolicy(d *schema.ResourceData, meta interface{}) error { +func ReadPasswordPolicy(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() + id := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) passwordPolicy, err := client.PasswordPolicies.ShowByID(ctx, id) if err != nil { - return err + return diag.FromErr(err) } if err := d.Set(FullyQualifiedNameAttributeName, id.FullyQualifiedName()); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("database", passwordPolicy.DatabaseName); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("schema", passwordPolicy.SchemaName); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("name", passwordPolicy.Name); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("comment", passwordPolicy.Comment); err != nil { - return err + return diag.FromErr(err) } passwordPolicyDetails, err := client.PasswordPolicies.Describe(ctx, id) if err != nil { - return err + return diag.FromErr(err) } if err := setFromIntProperty(d, "min_length", passwordPolicyDetails.PasswordMinLength); err != nil { - return err + return diag.FromErr(err) } if err := setFromIntProperty(d, "max_length", passwordPolicyDetails.PasswordMaxLength); err != nil { - return err + return diag.FromErr(err) } if err := setFromIntProperty(d, "min_upper_case_chars", passwordPolicyDetails.PasswordMinUpperCaseChars); err != nil { - return err + return diag.FromErr(err) } if err := setFromIntProperty(d, "min_lower_case_chars", passwordPolicyDetails.PasswordMinLowerCaseChars); err != nil { - return err + return diag.FromErr(err) } if err := setFromIntProperty(d, "min_numeric_chars", passwordPolicyDetails.PasswordMinNumericChars); err != nil { - return err + return diag.FromErr(err) } if err := setFromIntProperty(d, "min_special_chars", passwordPolicyDetails.PasswordMinSpecialChars); err != nil { - return err + return diag.FromErr(err) } if err := setFromIntProperty(d, "min_age_days", passwordPolicyDetails.PasswordMinAgeDays); err != nil { - return err + return diag.FromErr(err) } if err := setFromIntProperty(d, "max_age_days", passwordPolicyDetails.PasswordMaxAgeDays); err != nil { - return err + return diag.FromErr(err) } if err := setFromIntProperty(d, "max_retries", passwordPolicyDetails.PasswordMaxRetries); err != nil { - return err + return diag.FromErr(err) } if err := setFromIntProperty(d, "lockout_time_mins", passwordPolicyDetails.PasswordLockoutTimeMins); err != nil { - return err + return diag.FromErr(err) } if err := setFromIntProperty(d, "history", passwordPolicyDetails.PasswordHistory); err != nil { - return err + return diag.FromErr(err) } return nil } // UpdatePasswordPolicy implements schema.UpdateFunc. -func UpdatePasswordPolicy(d *schema.ResourceData, meta interface{}) error { +func UpdatePasswordPolicy(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() objectIdentifier := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) @@ -276,7 +276,7 @@ func UpdatePasswordPolicy(d *schema.ResourceData, meta interface{}) error { NewName: &newId, }) if err != nil { - return err + return diag.FromErr(err) } d.SetId(helpers.EncodeSnowflakeID(newId)) @@ -291,7 +291,7 @@ func UpdatePasswordPolicy(d *schema.ResourceData, meta interface{}) error { } err := client.PasswordPolicies.Alter(ctx, objectIdentifier, alterOptions) if err != nil { - return err + return diag.FromErr(err) } } if d.HasChange("max_length") { @@ -302,7 +302,7 @@ func UpdatePasswordPolicy(d *schema.ResourceData, meta interface{}) error { } err := client.PasswordPolicies.Alter(ctx, objectIdentifier, alterOptions) if err != nil { - return err + return diag.FromErr(err) } } if d.HasChange("min_upper_case_chars") { @@ -313,7 +313,7 @@ func UpdatePasswordPolicy(d *schema.ResourceData, meta interface{}) error { } err := client.PasswordPolicies.Alter(ctx, objectIdentifier, alterOptions) if err != nil { - return err + return diag.FromErr(err) } } if d.HasChange("min_lower_case_chars") { @@ -324,7 +324,7 @@ func UpdatePasswordPolicy(d *schema.ResourceData, meta interface{}) error { } err := client.PasswordPolicies.Alter(ctx, objectIdentifier, alterOptions) if err != nil { - return err + return diag.FromErr(err) } } @@ -336,7 +336,7 @@ func UpdatePasswordPolicy(d *schema.ResourceData, meta interface{}) error { } err := client.PasswordPolicies.Alter(ctx, objectIdentifier, alterOptions) if err != nil { - return err + return diag.FromErr(err) } } @@ -348,7 +348,7 @@ func UpdatePasswordPolicy(d *schema.ResourceData, meta interface{}) error { } err := client.PasswordPolicies.Alter(ctx, objectIdentifier, alterOptions) if err != nil { - return err + return diag.FromErr(err) } } @@ -360,7 +360,7 @@ func UpdatePasswordPolicy(d *schema.ResourceData, meta interface{}) error { } err := client.PasswordPolicies.Alter(ctx, objectIdentifier, alterOptions) if err != nil { - return err + return diag.FromErr(err) } } @@ -372,7 +372,7 @@ func UpdatePasswordPolicy(d *schema.ResourceData, meta interface{}) error { } err := client.PasswordPolicies.Alter(ctx, objectIdentifier, alterOptions) if err != nil { - return err + return diag.FromErr(err) } } @@ -384,7 +384,7 @@ func UpdatePasswordPolicy(d *schema.ResourceData, meta interface{}) error { } err := client.PasswordPolicies.Alter(ctx, objectIdentifier, alterOptions) if err != nil { - return err + return diag.FromErr(err) } } @@ -396,7 +396,7 @@ func UpdatePasswordPolicy(d *schema.ResourceData, meta interface{}) error { } err := client.PasswordPolicies.Alter(ctx, objectIdentifier, alterOptions) if err != nil { - return err + return diag.FromErr(err) } } @@ -408,7 +408,7 @@ func UpdatePasswordPolicy(d *schema.ResourceData, meta interface{}) error { } err := client.PasswordPolicies.Alter(ctx, objectIdentifier, alterOptions) if err != nil { - return err + return diag.FromErr(err) } } @@ -425,21 +425,21 @@ func UpdatePasswordPolicy(d *schema.ResourceData, meta interface{}) error { } err := client.PasswordPolicies.Alter(ctx, objectIdentifier, alterOptions) if err != nil { - return err + return diag.FromErr(err) } } - return ReadPasswordPolicy(d, meta) + return ReadPasswordPolicy(ctx, d, meta) } // DeletePasswordPolicy implements schema.DeleteFunc. -func DeletePasswordPolicy(d *schema.ResourceData, meta interface{}) error { +func DeletePasswordPolicy(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() + objectIdentifier := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) err := client.PasswordPolicies.Drop(ctx, objectIdentifier, nil) if err != nil { - return err + return diag.FromErr(err) } return nil diff --git a/pkg/resources/pipe.go b/pkg/resources/pipe.go index 208d08e13a..c15140621f 100644 --- a/pkg/resources/pipe.go +++ b/pkg/resources/pipe.go @@ -6,6 +6,9 @@ import ( "log" "strings" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/schemas" @@ -83,10 +86,10 @@ var pipeSchema = map[string]*schema.Schema{ func Pipe() *schema.Resource { return &schema.Resource{ - Create: CreatePipe, - Read: ReadPipe, - Update: UpdatePipe, - Delete: DeletePipe, + CreateContext: TrackingCreateWrapper(resources.Pipe, CreatePipe), + ReadContext: TrackingReadWrapper(resources.Pipe, ReadPipe), + UpdateContext: TrackingUpdateWrapper(resources.Pipe, UpdatePipe), + DeleteContext: TrackingDeleteWrapper(resources.Pipe, DeletePipe), Schema: pipeSchema, Importer: &schema.ResourceImporter{ @@ -105,14 +108,13 @@ func pipeCopyStatementDiffSuppress(_, o, n string, _ *schema.ResourceData) bool } // CreatePipe implements schema.CreateFunc. -func CreatePipe(d *schema.ResourceData, meta interface{}) error { +func CreatePipe(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client databaseName := d.Get("database").(string) schemaName := d.Get("schema").(string) name := d.Get("name").(string) - ctx := context.Background() objectIdentifier := sdk.NewSchemaObjectIdentifier(databaseName, schemaName, name) opts := &sdk.CreatePipeOptions{} @@ -142,20 +144,19 @@ func CreatePipe(d *schema.ResourceData, meta interface{}) error { err := client.Pipes.Create(ctx, objectIdentifier, copyStatement, opts) if err != nil { - return err + return diag.FromErr(err) } d.SetId(helpers.EncodeSnowflakeID(objectIdentifier)) - return ReadPipe(d, meta) + return ReadPipe(ctx, d, meta) } // ReadPipe implements schema.ReadFunc. -func ReadPipe(d *schema.ResourceData, meta interface{}) error { +func ReadPipe(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client id := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) - ctx := context.Background() pipe, err := client.Pipes.ShowByID(ctx, id) if err != nil { // If not found, mark resource to be removed from state file during apply or refresh @@ -165,59 +166,58 @@ func ReadPipe(d *schema.ResourceData, meta interface{}) error { } if err := d.Set(FullyQualifiedNameAttributeName, id.FullyQualifiedName()); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("name", pipe.Name); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("database", pipe.DatabaseName); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("schema", pipe.SchemaName); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("copy_statement", pipe.Definition); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("owner", pipe.Owner); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("comment", pipe.Comment); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("notification_channel", pipe.NotificationChannel); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("auto_ingest", pipe.NotificationChannel != ""); err != nil { - return err + return diag.FromErr(err) } if strings.Contains(pipe.NotificationChannel, "arn:aws:sns:") { if err := d.Set("aws_sns_topic_arn", pipe.NotificationChannel); err != nil { - return err + return diag.FromErr(err) } } if err := d.Set("error_integration", pipe.ErrorIntegration); err != nil { - return err + return diag.FromErr(err) } return nil } // UpdatePipe implements schema.UpdateFunc. -func UpdatePipe(d *schema.ResourceData, meta interface{}) error { +func UpdatePipe(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client objectIdentifier := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) - ctx := context.Background() pipeSet := &sdk.PipeSet{} pipeUnset := &sdk.PipeUnset{} @@ -248,7 +248,7 @@ func UpdatePipe(d *schema.ResourceData, meta interface{}) error { options := &sdk.AlterPipeOptions{Set: pipeSet} err := client.Pipes.Alter(ctx, objectIdentifier, options) if err != nil { - return fmt.Errorf("error updating pipe %v: %w", objectIdentifier.Name(), err) + return diag.FromErr(fmt.Errorf("error updating pipe %v: %w", objectIdentifier.Name(), err)) } } @@ -256,22 +256,22 @@ func UpdatePipe(d *schema.ResourceData, meta interface{}) error { options := &sdk.AlterPipeOptions{Unset: pipeUnset} err := client.Pipes.Alter(ctx, objectIdentifier, options) if err != nil { - return fmt.Errorf("error updating pipe %v: %w", objectIdentifier.Name(), err) + return diag.FromErr(fmt.Errorf("error updating pipe %v: %w", objectIdentifier.Name(), err)) } } - return ReadPipe(d, meta) + return ReadPipe(ctx, d, meta) } // DeletePipe implements schema.DeleteFunc. -func DeletePipe(d *schema.ResourceData, meta interface{}) error { +func DeletePipe(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() + objectIdentifier := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) err := client.Pipes.Drop(ctx, objectIdentifier, &sdk.DropPipeOptions{IfExists: sdk.Bool(true)}) if err != nil { - return err + return diag.FromErr(err) } d.SetId("") diff --git a/pkg/resources/saml_integration.go b/pkg/resources/saml_integration.go index 8062a9e40d..3cd6a6a18d 100644 --- a/pkg/resources/saml_integration.go +++ b/pkg/resources/saml_integration.go @@ -1,11 +1,15 @@ package resources import ( + "context" "fmt" "log" "strconv" "strings" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake" @@ -141,10 +145,10 @@ var samlIntegrationSchema = map[string]*schema.Schema{ // SAMLIntegration returns a pointer to the resource representing a SAML2 security integration. func SAMLIntegration() *schema.Resource { return &schema.Resource{ - Create: CreateSAMLIntegration, - Read: ReadSAMLIntegration, - Update: UpdateSAMLIntegration, - Delete: DeleteSAMLIntegration, + CreateContext: TrackingCreateWrapper(resources.SamlSecurityIntegration, CreateSAMLIntegration), + ReadContext: TrackingReadWrapper(resources.SamlSecurityIntegration, ReadSAMLIntegration), + UpdateContext: TrackingUpdateWrapper(resources.SamlSecurityIntegration, UpdateSAMLIntegration), + DeleteContext: TrackingDeleteWrapper(resources.SamlSecurityIntegration, DeleteSAMLIntegration), DeprecationMessage: "This resource is deprecated and will be removed in a future major version release. Please use snowflake_saml2_integration instead.", Schema: samlIntegrationSchema, @@ -155,7 +159,7 @@ func SAMLIntegration() *schema.Resource { } // CreateSAMLIntegration implements schema.CreateFunc. -func CreateSAMLIntegration(d *schema.ResourceData, meta interface{}) error { +func CreateSAMLIntegration(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client db := client.GetConn().DB name := d.Get("name").(string) @@ -212,16 +216,16 @@ func CreateSAMLIntegration(d *schema.ResourceData, meta interface{}) error { err := snowflake.Exec(db, stmt.Statement()) if err != nil { - return fmt.Errorf("error creating security integration err = %w", err) + return diag.FromErr(fmt.Errorf("error creating security integration err = %w", err)) } d.SetId(name) - return ReadSAMLIntegration(d, meta) + return ReadSAMLIntegration(ctx, d, meta) } // ReadSAMLIntegration implements schema.ReadFunc. -func ReadSAMLIntegration(d *schema.ResourceData, meta interface{}) error { +func ReadSAMLIntegration(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client db := client.GetConn().DB id := d.Id() @@ -233,29 +237,29 @@ func ReadSAMLIntegration(d *schema.ResourceData, meta interface{}) error { s, err := snowflake.ScanSamlIntegration(row) if err != nil { - return fmt.Errorf("could not show security integration err = %w", err) + return diag.FromErr(fmt.Errorf("could not show security integration err = %w", err)) } // Note: category must be Security or something is broken if c := s.Category.String; c != "SECURITY" { - return fmt.Errorf("expected %v to be an Security integration, got %v", id, c) + return diag.FromErr(fmt.Errorf("expected %v to be an Security integration, got %v", id, c)) } // Note: type must be SAML2 or something is broken if c := s.IntegrationType.String; c != "SAML2" { - return fmt.Errorf("expected %v to be a SAML2 integration type, got %v", id, c) + return diag.FromErr(fmt.Errorf("expected %v to be a SAML2 integration type, got %v", id, c)) } if err := d.Set("name", s.Name.String); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("created_on", s.CreatedOn.String); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("enabled", s.Enabled.Bool); err != nil { - return err + return diag.FromErr(err) } // Some properties come from the DESCRIBE INTEGRATION call @@ -265,35 +269,35 @@ func ReadSAMLIntegration(d *schema.ResourceData, meta interface{}) error { stmt = snowflake.NewSamlIntegrationBuilder(id).Describe() rows, err := db.Query(stmt) if err != nil { - return fmt.Errorf("could not describe security integration err = %w", err) + return diag.FromErr(fmt.Errorf("could not describe security integration err = %w", err)) } defer rows.Close() for rows.Next() { if err := rows.Scan(&k, &pType, &v, &unused); err != nil { - return fmt.Errorf("unable to parse security integration rows err = %w", err) + return diag.FromErr(fmt.Errorf("unable to parse security integration rows err = %w", err)) } switch k { case "ENABLED": // set using the SHOW INTEGRATION, ignoring here case "SAML2_ISSUER": if err := d.Set("saml2_issuer", v.(string)); err != nil { - return fmt.Errorf("unable to set saml2_issuer for security integration err = %w", err) + return diag.FromErr(fmt.Errorf("unable to set saml2_issuer for security integration err = %w", err)) } case "SAML2_SSO_URL": if err := d.Set("saml2_sso_url", v.(string)); err != nil { - return fmt.Errorf("unable to set saml2_sso_url for security integration err = %w", err) + return diag.FromErr(fmt.Errorf("unable to set saml2_sso_url for security integration err = %w", err)) } case "SAML2_PROVIDER": if err := d.Set("saml2_provider", v.(string)); err != nil { - return fmt.Errorf("unable to set saml2_provider for security integration err = %w", err) + return diag.FromErr(fmt.Errorf("unable to set saml2_provider for security integration err = %w", err)) } case "SAML2_X509_CERT": if err := d.Set("saml2_x509_cert", v.(string)); err != nil { - return fmt.Errorf("unable to set saml2_x509_cert for security integration err = %w", err) + return diag.FromErr(fmt.Errorf("unable to set saml2_x509_cert for security integration err = %w", err)) } case "SAML2_SP_INITIATED_LOGIN_PAGE_LABEL": if err := d.Set("saml2_sp_initiated_login_page_label", v.(string)); err != nil { - return fmt.Errorf("unable to set saml2_sp_initiated_login_page_label for security integration") + return diag.FromErr(fmt.Errorf("unable to set saml2_sp_initiated_login_page_label for security integration")) } case "SAML2_ENABLE_SP_INITIATED": var b bool @@ -303,17 +307,17 @@ func ReadSAMLIntegration(d *schema.ResourceData, meta interface{}) error { case string: b, err = strconv.ParseBool(v.(string)) if err != nil { - return fmt.Errorf("returned saml2_force_authn that is not boolean err = %w", err) + return diag.FromErr(fmt.Errorf("returned saml2_force_authn that is not boolean err = %w", err)) } default: - return fmt.Errorf("returned saml2_force_authn that is not boolean") + return diag.FromErr(fmt.Errorf("returned saml2_force_authn that is not boolean")) } if err := d.Set("saml2_enable_sp_initiated", b); err != nil { - return fmt.Errorf("unable to set saml2_enable_sp_initiated for security integration err = %w", err) + return diag.FromErr(fmt.Errorf("unable to set saml2_enable_sp_initiated for security integration err = %w", err)) } case "SAML2_SNOWFLAKE_X509_CERT": if err := d.Set("saml2_snowflake_x509_cert", v.(string)); err != nil { - return fmt.Errorf("unable to set saml2_snowflake_x509_cert for security integration err = %w", err) + return diag.FromErr(fmt.Errorf("unable to set saml2_snowflake_x509_cert for security integration err = %w", err)) } case "SAML2_SIGN_REQUEST": var b bool @@ -323,21 +327,21 @@ func ReadSAMLIntegration(d *schema.ResourceData, meta interface{}) error { case string: b, err = strconv.ParseBool(v.(string)) if err != nil { - return fmt.Errorf("returned saml2_force_authn that is not boolean err = %w", err) + return diag.FromErr(fmt.Errorf("returned saml2_force_authn that is not boolean err = %w", err)) } default: - return fmt.Errorf("returned saml2_force_authn that is not boolean err = %w", err) + return diag.FromErr(fmt.Errorf("returned saml2_force_authn that is not boolean err = %w", err)) } if err := d.Set("saml2_sign_request", b); err != nil { - return fmt.Errorf("unable to set saml2_sign_request for security integration err = %w", err) + return diag.FromErr(fmt.Errorf("unable to set saml2_sign_request for security integration err = %w", err)) } case "SAML2_REQUESTED_NAMEID_FORMAT": if err := d.Set("saml2_requested_nameid_format", v.(string)); err != nil { - return fmt.Errorf("unable to set saml2_requested_nameid_format for security integration err = %w", err) + return diag.FromErr(fmt.Errorf("unable to set saml2_requested_nameid_format for security integration err = %w", err)) } case "SAML2_POST_LOGOUT_REDIRECT_URL": if err := d.Set("saml2_post_logout_redirect_url", v.(string)); err != nil { - return fmt.Errorf("unable to set saml2_post_logout_redirect_url for security integration err = %w", err) + return diag.FromErr(fmt.Errorf("unable to set saml2_post_logout_redirect_url for security integration err = %w", err)) } case "SAML2_FORCE_AUTHN": var b bool @@ -347,33 +351,33 @@ func ReadSAMLIntegration(d *schema.ResourceData, meta interface{}) error { case string: b, err = strconv.ParseBool(v.(string)) if err != nil { - return fmt.Errorf("returned saml2_force_authn that is not boolean err = %w", err) + return diag.FromErr(fmt.Errorf("returned saml2_force_authn that is not boolean err = %w", err)) } default: - return fmt.Errorf("returned saml2_force_authn that is not boolean err = %w", err) + return diag.FromErr(fmt.Errorf("returned saml2_force_authn that is not boolean err = %w", err)) } if err := d.Set("saml2_force_authn", b); err != nil { - return fmt.Errorf("unable to set saml2_force_authn for security integration err = %w", err) + return diag.FromErr(fmt.Errorf("unable to set saml2_force_authn for security integration err = %w", err)) } case "SAML2_SNOWFLAKE_ISSUER_URL": if err := d.Set("saml2_snowflake_issuer_url", v.(string)); err != nil { - return fmt.Errorf("unable to set saml2_snowflake_issuer_url for security integration err = %w", err) + return diag.FromErr(fmt.Errorf("unable to set saml2_snowflake_issuer_url for security integration err = %w", err)) } case "SAML2_SNOWFLAKE_ACS_URL": if err := d.Set("saml2_snowflake_acs_url", v.(string)); err != nil { - return fmt.Errorf("unable to set saml2_snowflake_acs_url for security integration err = %w", err) + return diag.FromErr(fmt.Errorf("unable to set saml2_snowflake_acs_url for security integration err = %w", err)) } case "SAML2_SNOWFLAKE_METADATA": if err := d.Set("saml2_snowflake_metadata", v.(string)); err != nil { - return fmt.Errorf("unable to set saml2_snowflake_metadata for security integration err = %w", err) + return diag.FromErr(fmt.Errorf("unable to set saml2_snowflake_metadata for security integration err = %w", err)) } case "SAML2_DIGEST_METHODS_USED": if err := d.Set("saml2_digest_methods_used", v.(string)); err != nil { - return fmt.Errorf("unable to set saml2_digest_methods_used for security integration err = %w", err) + return diag.FromErr(fmt.Errorf("unable to set saml2_digest_methods_used for security integration err = %w", err)) } case "SAML2_SIGNATURE_METHODS_USED": if err := d.Set("saml2_signature_methods_used", v.(string)); err != nil { - return fmt.Errorf("unable to set saml2_signature_methods_used for security integration err = %w", err) + return diag.FromErr(fmt.Errorf("unable to set saml2_signature_methods_used for security integration err = %w", err)) } case "COMMENT": // COMMENT cannot be set according to snowflake docs, so ignoring @@ -382,11 +386,11 @@ func ReadSAMLIntegration(d *schema.ResourceData, meta interface{}) error { } } - return err + return diag.FromErr(err) } // UpdateSAMLIntegration implements schema.UpdateFunc. -func UpdateSAMLIntegration(d *schema.ResourceData, meta interface{}) error { +func UpdateSAMLIntegration(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client db := client.GetConn().DB id := d.Id() @@ -467,14 +471,14 @@ func UpdateSAMLIntegration(d *schema.ResourceData, meta interface{}) error { if runSetStatement { if err := snowflake.Exec(db, stmt.Statement()); err != nil { - return fmt.Errorf("error updating security integration err = %w", err) + return diag.FromErr(fmt.Errorf("error updating security integration err = %w", err)) } } - return ReadSAMLIntegration(d, meta) + return ReadSAMLIntegration(ctx, d, meta) } // DeleteSAMLIntegration implements schema.DeleteFunc. -func DeleteSAMLIntegration(d *schema.ResourceData, meta interface{}) error { - return DeleteResource("", snowflake.NewSamlIntegrationBuilder)(d, meta) +func DeleteSAMLIntegration(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { + return diag.FromErr(DeleteResource("", snowflake.NewSamlIntegrationBuilder)(d, meta)) } diff --git a/pkg/resources/saml_integration_test.go b/pkg/resources/saml_integration_test.go index d326c12212..36327ec029 100644 --- a/pkg/resources/saml_integration_test.go +++ b/pkg/resources/saml_integration_test.go @@ -1,6 +1,7 @@ package resources_test import ( + "context" "database/sql" "testing" @@ -41,10 +42,10 @@ func TestSAMLIntegrationCreate(t *testing.T) { ).WillReturnResult(sqlmock.NewResult(1, 1)) expectReadSAMLIntegration(mock) - err := resources.CreateSAMLIntegration(d, &internalprovider.Context{ + diags := resources.CreateSAMLIntegration(context.Background(), d, &internalprovider.Context{ Client: sdk.NewClientFromDB(db), }) - r.NoError(err) + r.Empty(diags) }) } @@ -56,10 +57,10 @@ func TestSAMLIntegrationRead(t *testing.T) { WithMockDb(t, func(db *sql.DB, mock sqlmock.Sqlmock) { expectReadSAMLIntegration(mock) - err := resources.ReadSAMLIntegration(d, &internalprovider.Context{ + diags := resources.ReadSAMLIntegration(context.Background(), d, &internalprovider.Context{ Client: sdk.NewClientFromDB(db), }) - r.NoError(err) + r.Empty(diags) }) } @@ -70,10 +71,10 @@ func TestSAMLIntegrationDelete(t *testing.T) { WithMockDb(t, func(db *sql.DB, mock sqlmock.Sqlmock) { mock.ExpectExec(`DROP SECURITY INTEGRATION "drop_it"`).WillReturnResult(sqlmock.NewResult(1, 1)) - err := resources.DeleteSAMLIntegration(d, &internalprovider.Context{ + diags := resources.DeleteSAMLIntegration(context.Background(), d, &internalprovider.Context{ Client: sdk.NewClientFromDB(db), }) - r.NoError(err) + r.Empty(diags) }) } diff --git a/pkg/resources/sequence.go b/pkg/resources/sequence.go index 207b78d909..449dc56694 100644 --- a/pkg/resources/sequence.go +++ b/pkg/resources/sequence.go @@ -3,6 +3,9 @@ package resources import ( "context" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/schemas" @@ -64,10 +67,10 @@ var sequenceSchema = map[string]*schema.Schema{ func Sequence() *schema.Resource { return &schema.Resource{ - Create: CreateSequence, - Read: ReadSequence, - Delete: DeleteSequence, - Update: UpdateSequence, + CreateContext: TrackingCreateWrapper(resources.Sequence, CreateSequence), + ReadContext: TrackingReadWrapper(resources.Sequence, ReadSequence), + DeleteContext: TrackingDeleteWrapper(resources.Sequence, DeleteSequence), + UpdateContext: TrackingUpdateWrapper(resources.Sequence, UpdateSequence), Schema: sequenceSchema, Importer: &schema.ResourceImporter{ @@ -76,9 +79,9 @@ func Sequence() *schema.Resource { } } -func CreateSequence(d *schema.ResourceData, meta interface{}) error { +func CreateSequence(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() + database := d.Get("database").(string) schema := d.Get("schema").(string) name := d.Get("name").(string) @@ -97,72 +100,72 @@ func CreateSequence(d *schema.ResourceData, meta interface{}) error { } err := client.Sequences.Create(ctx, req) if err != nil { - return err + return diag.FromErr(err) } d.SetId(helpers.EncodeSnowflakeID(database, schema, name)) - return ReadSequence(d, meta) + return ReadSequence(ctx, d, meta) } -func ReadSequence(d *schema.ResourceData, meta interface{}) error { +func ReadSequence(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() + id := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) seq, err := client.Sequences.ShowByID(ctx, id) if err != nil { - return err + return diag.FromErr(err) } if err := d.Set("name", seq.Name); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("schema", seq.SchemaName); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("database", seq.DatabaseName); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("comment", seq.Comment); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("increment", seq.Interval); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("next_value", seq.NextValue); err != nil { - return err + return diag.FromErr(err) } if seq.Ordered { if err := d.Set("ordering", "ORDER"); err != nil { - return err + return diag.FromErr(err) } } else { if err := d.Set("ordering", "NOORDER"); err != nil { - return err + return diag.FromErr(err) } } if err := d.Set(FullyQualifiedNameAttributeName, id.FullyQualifiedName()); err != nil { - return err + return diag.FromErr(err) } return nil } -func UpdateSequence(d *schema.ResourceData, meta interface{}) error { +func UpdateSequence(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() + id := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) if d.HasChange("comment") { req := sdk.NewAlterSequenceRequest(id) req.WithSet(sdk.NewSequenceSetRequest().WithComment(sdk.String(d.Get("comment").(string)))) if err := client.Sequences.Alter(ctx, req); err != nil { - return err + return diag.FromErr(err) } } @@ -170,7 +173,7 @@ func UpdateSequence(d *schema.ResourceData, meta interface{}) error { req := sdk.NewAlterSequenceRequest(id) req.WithSetIncrement(sdk.Int(d.Get("increment").(int))) if err := client.Sequences.Alter(ctx, req); err != nil { - return err + return diag.FromErr(err) } } @@ -178,20 +181,20 @@ func UpdateSequence(d *schema.ResourceData, meta interface{}) error { req := sdk.NewAlterSequenceRequest(id) req.WithSet(sdk.NewSequenceSetRequest().WithValuesBehavior(sdk.ValuesBehaviorPointer(sdk.ValuesBehavior(d.Get("ordering").(string))))) if err := client.Sequences.Alter(ctx, req); err != nil { - return err + return diag.FromErr(err) } } - return ReadSequence(d, meta) + return ReadSequence(ctx, d, meta) } -func DeleteSequence(d *schema.ResourceData, meta interface{}) error { +func DeleteSequence(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() + id := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) err := client.Sequences.Drop(ctx, sdk.NewDropSequenceRequest(id).WithIfExists(sdk.Bool(true))) if err != nil { - return err + return diag.FromErr(err) } d.SetId("") return nil diff --git a/pkg/resources/session_parameter.go b/pkg/resources/session_parameter.go index 9c4d13e09d..1efbc5a40e 100644 --- a/pkg/resources/session_parameter.go +++ b/pkg/resources/session_parameter.go @@ -4,6 +4,9 @@ import ( "context" "fmt" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" @@ -37,10 +40,10 @@ var sessionParameterSchema = map[string]*schema.Schema{ func SessionParameter() *schema.Resource { return &schema.Resource{ - Create: CreateSessionParameter, - Read: ReadSessionParameter, - Update: UpdateSessionParameter, - Delete: DeleteSessionParameter, + CreateContext: TrackingCreateWrapper(resources.SessionParameter, CreateSessionParameter), + ReadContext: TrackingReadWrapper(resources.SessionParameter, ReadSessionParameter), + UpdateContext: TrackingUpdateWrapper(resources.SessionParameter, UpdateSessionParameter), + DeleteContext: TrackingDeleteWrapper(resources.SessionParameter, DeleteSessionParameter), Schema: sessionParameterSchema, Importer: &schema.ResourceImporter{ @@ -50,11 +53,11 @@ func SessionParameter() *schema.Resource { } // CreateSessionParameter implements schema.CreateFunc. -func CreateSessionParameter(d *schema.ResourceData, meta interface{}) error { +func CreateSessionParameter(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client key := d.Get("key").(string) value := d.Get("value").(string) - ctx := context.Background() + onAccount := d.Get("on_account").(bool) user := d.Get("user").(string) parameter := sdk.SessionParameter(key) @@ -63,28 +66,28 @@ func CreateSessionParameter(d *schema.ResourceData, meta interface{}) error { if onAccount { err := client.Parameters.SetSessionParameterOnAccount(ctx, parameter, value) if err != nil { - return err + return diag.FromErr(err) } } else { if user == "" { - return fmt.Errorf("user is required if on_account is false") + return diag.FromErr(fmt.Errorf("user is required if on_account is false")) } userId := sdk.NewAccountObjectIdentifier(user) err = client.Parameters.SetSessionParameterOnUser(ctx, userId, parameter, value) if err != nil { - return fmt.Errorf("error creating session parameter err = %w", err) + return diag.FromErr(fmt.Errorf("error creating session parameter err = %w", err)) } } d.SetId(key) - return ReadSessionParameter(d, meta) + return ReadSessionParameter(ctx, d, meta) } // ReadSessionParameter implements schema.ReadFunc. -func ReadSessionParameter(d *schema.ResourceData, meta interface{}) error { +func ReadSessionParameter(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() + parameter := d.Id() onAccount := d.Get("on_account").(bool) @@ -98,25 +101,24 @@ func ReadSessionParameter(d *schema.ResourceData, meta interface{}) error { p, err = client.Parameters.ShowUserParameter(ctx, sdk.UserParameter(parameter), userId) } if err != nil { - return fmt.Errorf("error reading session parameter err = %w", err) + return diag.FromErr(fmt.Errorf("error reading session parameter err = %w", err)) } err = d.Set("value", p.Value) if err != nil { - return fmt.Errorf("error setting session parameter err = %w", err) + return diag.FromErr(fmt.Errorf("error setting session parameter err = %w", err)) } return nil } // UpdateSessionParameter implements schema.UpdateFunc. -func UpdateSessionParameter(d *schema.ResourceData, meta interface{}) error { - return CreateSessionParameter(d, meta) +func UpdateSessionParameter(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { + return CreateSessionParameter(ctx, d, meta) } // DeleteSessionParameter implements schema.DeleteFunc. -func DeleteSessionParameter(d *schema.ResourceData, meta interface{}) error { +func DeleteSessionParameter(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client key := d.Get("key").(string) - ctx := context.Background() onAccount := d.Get("on_account").(bool) parameter := sdk.SessionParameter(key) @@ -124,27 +126,27 @@ func DeleteSessionParameter(d *schema.ResourceData, meta interface{}) error { if onAccount { defaultParameter, err := client.Parameters.ShowAccountParameter(ctx, sdk.AccountParameter(key)) if err != nil { - return err + return diag.FromErr(err) } defaultValue := defaultParameter.Default err = client.Parameters.SetSessionParameterOnAccount(ctx, parameter, defaultValue) if err != nil { - return fmt.Errorf("error creating session parameter err = %w", err) + return diag.FromErr(fmt.Errorf("error creating session parameter err = %w", err)) } } else { user := d.Get("user").(string) if user == "" { - return fmt.Errorf("user is required if on_account is false") + return diag.FromErr(fmt.Errorf("user is required if on_account is false")) } userId := sdk.NewAccountObjectIdentifier(user) defaultParameter, err := client.Parameters.ShowSessionParameter(ctx, sdk.SessionParameter(key)) if err != nil { - return err + return diag.FromErr(err) } defaultValue := defaultParameter.Default err = client.Parameters.SetSessionParameterOnUser(ctx, userId, parameter, defaultValue) if err != nil { - return fmt.Errorf("error deleting session parameter err = %w", err) + return diag.FromErr(fmt.Errorf("error deleting session parameter err = %w", err)) } } diff --git a/pkg/resources/share.go b/pkg/resources/share.go index 8c4d5ebe33..334a931f30 100644 --- a/pkg/resources/share.go +++ b/pkg/resources/share.go @@ -8,6 +8,7 @@ import ( "time" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" @@ -47,10 +48,10 @@ var shareSchema = map[string]*schema.Schema{ // Share returns a pointer to the resource representing a share. func Share() *schema.Resource { return &schema.Resource{ - Create: CreateShare, - Read: ReadShare, - Update: UpdateShare, - Delete: DeleteShare, + CreateContext: TrackingCreateWrapper(resources.Share, CreateShare), + ReadContext: TrackingReadWrapper(resources.Share, ReadShare), + UpdateContext: TrackingUpdateWrapper(resources.Share, UpdateShare), + DeleteContext: TrackingDeleteWrapper(resources.Share, DeleteShare), Schema: shareSchema, Importer: &schema.ResourceImporter{ @@ -60,10 +61,10 @@ func Share() *schema.Resource { } // CreateShare implements schema.CreateFunc. -func CreateShare(d *schema.ResourceData, meta interface{}) error { +func CreateShare(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client name := d.Get("name").(string) - ctx := context.Background() + comment := d.Get("comment").(string) id := sdk.NewAccountObjectIdentifier(name) var opts sdk.CreateShareOptions @@ -73,7 +74,7 @@ func CreateShare(d *schema.ResourceData, meta interface{}) error { } } if err := client.Shares.Create(ctx, id, &opts); err != nil { - return fmt.Errorf("error creating share (%v) err = %w", d.Id(), err) + return diag.FromErr(fmt.Errorf("error creating share (%v) err = %w", d.Id(), err)) } d.SetId(name) @@ -89,10 +90,10 @@ func CreateShare(d *schema.ResourceData, meta interface{}) error { } err := setShareAccounts(ctx, client, shareID, accountIdentifiers) if err != nil { - return err + return diag.FromErr(err) } } - return ReadShare(d, meta) + return ReadShare(ctx, d, meta) } func setShareAccounts(ctx context.Context, client *sdk.Client, shareID sdk.AccountObjectIdentifier, accounts []sdk.AccountIdentifier) error { @@ -158,20 +159,19 @@ func setShareAccounts(ctx context.Context, client *sdk.Client, shareID sdk.Accou } // ReadShare implements schema.ReadFunc. -func ReadShare(d *schema.ResourceData, meta interface{}) error { +func ReadShare(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client id := helpers.DecodeSnowflakeID(d.Id()).(sdk.AccountObjectIdentifier) - ctx := context.Background() share, err := client.Shares.ShowByID(ctx, id) if err != nil { - return fmt.Errorf("error reading share (%v) err = %w", d.Id(), err) + return diag.FromErr(fmt.Errorf("error reading share (%v) err = %w", d.Id(), err)) } if err := d.Set(FullyQualifiedNameAttributeName, id.FullyQualifiedName()); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("comment", share.Comment); err != nil { - return err + return diag.FromErr(err) } accounts := make([]string, len(share.To)) for i, accountIdentifier := range share.To { @@ -186,10 +186,10 @@ func ReadShare(d *schema.ResourceData, meta interface{}) error { accounts = reorderStringList(currentAccounts, accounts) } if err := d.Set("accounts", accounts); err != nil { - return err + return diag.FromErr(err) } - return err + return diag.FromErr(err) } func accountIdentifiersFromSlice(accounts []string) []sdk.AccountIdentifier { @@ -204,10 +204,10 @@ func accountIdentifiersFromSlice(accounts []string) []sdk.AccountIdentifier { } // UpdateShare implements schema.UpdateFunc. -func UpdateShare(d *schema.ResourceData, meta interface{}) error { +func UpdateShare(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { id := helpers.DecodeSnowflakeID(d.Id()).(sdk.AccountObjectIdentifier) client := meta.(*provider.Context).Client - ctx := context.Background() + if d.HasChange("accounts") { o, n := d.GetChange("accounts") oldAccounts := expandStringList(o.([]interface{})) @@ -220,13 +220,13 @@ func UpdateShare(d *schema.ResourceData, meta interface{}) error { }, }) if err != nil { - return fmt.Errorf("error removing accounts from share (%v) err = %w", d.Id(), err) + return diag.FromErr(fmt.Errorf("error removing accounts from share (%v) err = %w", d.Id(), err)) } } else { accountIdentifiers := accountIdentifiersFromSlice(newAccounts) err := setShareAccounts(ctx, client, id, accountIdentifiers) if err != nil { - return err + return diag.FromErr(err) } } } @@ -238,21 +238,21 @@ func UpdateShare(d *schema.ResourceData, meta interface{}) error { }, }) if err != nil { - return fmt.Errorf("error updating share (%v) comment err = %w", d.Id(), err) + return diag.FromErr(fmt.Errorf("error updating share (%v) comment err = %w", d.Id(), err)) } } - return ReadShare(d, meta) + return ReadShare(ctx, d, meta) } // DeleteShare implements schema.DeleteFunc. -func DeleteShare(d *schema.ResourceData, meta interface{}) error { +func DeleteShare(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { id := helpers.DecodeSnowflakeID(d.Id()).(sdk.AccountObjectIdentifier) client := meta.(*provider.Context).Client - ctx := context.Background() + err := client.Shares.Drop(ctx, id, &sdk.DropShareOptions{IfExists: sdk.Bool(true)}) if err != nil { - return fmt.Errorf("error deleting share (%v) err = %w", d.Id(), err) + return diag.FromErr(fmt.Errorf("error deleting share (%v) err = %w", d.Id(), err)) } return nil } diff --git a/pkg/resources/stream.go b/pkg/resources/stream.go index 1957dfedc4..219161a107 100644 --- a/pkg/resources/stream.go +++ b/pkg/resources/stream.go @@ -7,6 +7,7 @@ import ( "strings" providerresources "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" @@ -100,10 +101,10 @@ var streamSchema = map[string]*schema.Schema{ func Stream() *schema.Resource { return &schema.Resource{ - Create: CreateStream, - Read: ReadStream, - Update: UpdateStream, - Delete: DeleteStream, + CreateContext: TrackingCreateWrapper(providerresources.Stream, CreateStream), + ReadContext: TrackingReadWrapper(providerresources.Stream, ReadStream), + UpdateContext: TrackingUpdateWrapper(providerresources.Stream, UpdateStream), + DeleteContext: TrackingDeleteWrapper(providerresources.Stream, DeleteStream), DeprecationMessage: deprecatedResourceDescription( string(providerresources.StreamOnDirectoryTable), string(providerresources.StreamOnExternalTable), @@ -119,7 +120,7 @@ func Stream() *schema.Resource { } // CreateStream implements schema.CreateFunc. -func CreateStream(d *schema.ResourceData, meta interface{}) error { +func CreateStream(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client databaseName := d.Get("database").(string) schemaName := d.Get("schema").(string) @@ -129,8 +130,6 @@ func CreateStream(d *schema.ResourceData, meta interface{}) error { showInitialRows := d.Get("show_initial_rows").(bool) id := sdk.NewSchemaObjectIdentifier(databaseName, schemaName, name) - ctx := context.Background() - onTable, onTableSet := d.GetOk("on_table") onView, onViewSet := d.GetOk("on_view") onStage, onStageSet := d.GetOk("on_stage") @@ -139,13 +138,13 @@ func CreateStream(d *schema.ResourceData, meta interface{}) error { case onTableSet: tableObjectIdentifier, err := helpers.DecodeSnowflakeParameterID(onTable.(string)) if err != nil { - return err + return diag.FromErr(err) } tableId := tableObjectIdentifier.(sdk.SchemaObjectIdentifier) table, err := client.Tables.ShowByID(ctx, tableId) if err != nil { - return err + return diag.FromErr(err) } if table.IsExternal { @@ -158,7 +157,7 @@ func CreateStream(d *schema.ResourceData, meta interface{}) error { } err := client.Streams.CreateOnExternalTable(ctx, req) if err != nil { - return fmt.Errorf("error creating stream %v err = %w", name, err) + return diag.FromErr(fmt.Errorf("error creating stream %v err = %w", name, err)) } } else { req := sdk.NewCreateOnTableStreamRequest(id, tableId) @@ -173,19 +172,19 @@ func CreateStream(d *schema.ResourceData, meta interface{}) error { } err := client.Streams.CreateOnTable(ctx, req) if err != nil { - return fmt.Errorf("error creating stream %v err = %w", name, err) + return diag.FromErr(fmt.Errorf("error creating stream %v err = %w", name, err)) } } case onViewSet: viewObjectIdentifier, err := helpers.DecodeSnowflakeParameterID(onView.(string)) viewId := viewObjectIdentifier.(sdk.SchemaObjectIdentifier) if err != nil { - return err + return diag.FromErr(err) } _, err = client.Views.ShowByID(ctx, viewId) if err != nil { - return err + return diag.FromErr(err) } req := sdk.NewCreateOnViewStreamRequest(id, viewId) @@ -200,20 +199,20 @@ func CreateStream(d *schema.ResourceData, meta interface{}) error { } err = client.Streams.CreateOnView(ctx, req) if err != nil { - return fmt.Errorf("error creating stream %v err = %w", name, err) + return diag.FromErr(fmt.Errorf("error creating stream %v err = %w", name, err)) } case onStageSet: stageObjectIdentifier, err := helpers.DecodeSnowflakeParameterID(onStage.(string)) stageId := stageObjectIdentifier.(sdk.SchemaObjectIdentifier) if err != nil { - return err + return diag.FromErr(err) } stageProperties, err := client.Stages.Describe(ctx, stageId) if err != nil { - return err + return diag.FromErr(err) } if findStagePropertyValueByName(stageProperties, "ENABLE") != "true" { - return fmt.Errorf("directory must be enabled on stage") + return diag.FromErr(fmt.Errorf("directory must be enabled on stage")) } req := sdk.NewCreateOnDirectoryTableStreamRequest(id, stageId) if v, ok := d.GetOk("comment"); ok { @@ -221,19 +220,19 @@ func CreateStream(d *schema.ResourceData, meta interface{}) error { } err = client.Streams.CreateOnDirectoryTable(ctx, req) if err != nil { - return fmt.Errorf("error creating stream %v err = %w", name, err) + return diag.FromErr(fmt.Errorf("error creating stream %v err = %w", name, err)) } } d.SetId(helpers.EncodeSnowflakeID(id)) - return ReadStream(d, meta) + return ReadStream(ctx, d, meta) } // ReadStream implements schema.ReadFunc. -func ReadStream(d *schema.ResourceData, meta interface{}) error { +func ReadStream(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() + id := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) stream, err := client.Streams.ShowByID(ctx, id) if err != nil { @@ -242,56 +241,56 @@ func ReadStream(d *schema.ResourceData, meta interface{}) error { return nil } if err := d.Set(FullyQualifiedNameAttributeName, id.FullyQualifiedName()); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("name", stream.Name); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("database", stream.DatabaseName); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("schema", stream.SchemaName); err != nil { - return err + return diag.FromErr(err) } switch *stream.SourceType { case sdk.StreamSourceTypeStage: if err := d.Set("on_stage", *stream.TableName); err != nil { - return err + return diag.FromErr(err) } case sdk.StreamSourceTypeView: if err := d.Set("on_view", *stream.TableName); err != nil { - return err + return diag.FromErr(err) } default: if err := d.Set("on_table", *stream.TableName); err != nil { - return err + return diag.FromErr(err) } } if err := d.Set("append_only", *stream.Mode == "APPEND_ONLY"); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("insert_only", *stream.Mode == "INSERT_ONLY"); err != nil { - return err + return diag.FromErr(err) } // TODO: SHOW STREAMS doesn't return that value right now (I'm not sure if it ever did), but probably we can assume // the customers got 'false' every time and hardcode it (it's only on create thing, so it's not necessary // to track its value after creation). if err := d.Set("show_initial_rows", false); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("comment", *stream.Comment); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("owner", *stream.Owner); err != nil { - return err + return diag.FromErr(err) } return nil } // UpdateStream implements schema.UpdateFunc. -func UpdateStream(d *schema.ResourceData, meta interface{}) error { +func UpdateStream(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() + id := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) if d.HasChange("comment") { @@ -299,28 +298,28 @@ func UpdateStream(d *schema.ResourceData, meta interface{}) error { if comment == "" { err := client.Streams.Alter(ctx, sdk.NewAlterStreamRequest(id).WithUnsetComment(true)) if err != nil { - return fmt.Errorf("error unsetting stream comment on %v", d.Id()) + return diag.FromErr(fmt.Errorf("error unsetting stream comment on %v", d.Id())) } } else { err := client.Streams.Alter(ctx, sdk.NewAlterStreamRequest(id).WithSetComment(comment)) if err != nil { - return fmt.Errorf("error setting stream comment on %v", d.Id()) + return diag.FromErr(fmt.Errorf("error setting stream comment on %v", d.Id())) } } } - return ReadStream(d, meta) + return ReadStream(ctx, d, meta) } // DeleteStream implements schema.DeleteFunc. -func DeleteStream(d *schema.ResourceData, meta interface{}) error { +func DeleteStream(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() + streamId := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) err := client.Streams.Drop(ctx, sdk.NewDropStreamRequest(streamId)) if err != nil { - return fmt.Errorf("error deleting stream %v err = %w", d.Id(), err) + return diag.FromErr(fmt.Errorf("error deleting stream %v err = %w", d.Id(), err)) } d.SetId("") diff --git a/pkg/resources/table.go b/pkg/resources/table.go index ce39f90765..017ad799d8 100644 --- a/pkg/resources/table.go +++ b/pkg/resources/table.go @@ -8,6 +8,8 @@ import ( "strings" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/customdiff" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/schemas" @@ -15,7 +17,6 @@ import ( "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake" - "github.com/hashicorp/terraform-plugin-sdk/v2/helper/customdiff" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" ) @@ -208,10 +209,10 @@ var tableSchema = map[string]*schema.Schema{ func Table() *schema.Resource { return &schema.Resource{ - Create: CreateTable, - Read: ReadTable, - Update: UpdateTable, - Delete: DeleteTable, + CreateContext: TrackingCreateWrapper(resources.Table, CreateTable), + ReadContext: TrackingReadWrapper(resources.Table, ReadTable), + UpdateContext: TrackingUpdateWrapper(resources.Table, UpdateTable), + DeleteContext: TrackingDeleteWrapper(resources.Table, DeleteTable), CustomizeDiff: TrackingCustomDiffWrapper(resources.Table, customdiff.All( ComputedIfAnyAttributeChanged(tableSchema, FullyQualifiedNameAttributeName, "name"), @@ -568,9 +569,8 @@ func toColumnIdentityConfig(td sdk.TableColumnDetails) map[string]any { } // CreateTable implements schema.CreateFunc. -func CreateTable(d *schema.ResourceData, meta interface{}) error { +func CreateTable(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() databaseName := d.Get("database").(string) schemaName := d.Get("schema").(string) @@ -622,18 +622,18 @@ func CreateTable(d *schema.ResourceData, meta interface{}) error { err := client.Tables.Create(ctx, createRequest) if err != nil { - return fmt.Errorf("error creating table %v err = %w", name, err) + return diag.FromErr(fmt.Errorf("error creating table %v err = %w", name, err)) } d.SetId(helpers.EncodeSnowflakeID(id)) - return ReadTable(d, meta) + return ReadTable(ctx, d, meta) } // ReadTable implements schema.ReadFunc. -func ReadTable(d *schema.ResourceData, meta interface{}) error { +func ReadTable(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() + id := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) table, err := client.Tables.ShowByID(ctx, id) @@ -650,7 +650,7 @@ func ReadTable(d *schema.ResourceData, meta interface{}) error { return nil } if err := d.Set(FullyQualifiedNameAttributeName, id.FullyQualifiedName()); err != nil { - return err + return diag.FromErr(err) } var schemaRetentionTime int64 // "retention_time" may sometimes be empty string instead of an integer @@ -662,13 +662,13 @@ func ReadTable(d *schema.ResourceData, meta interface{}) error { schemaRetentionTime, err = strconv.ParseInt(rt, 10, 64) if err != nil { - return err + return diag.FromErr(err) } } tableDescription, err := client.Tables.DescribeColumns(ctx, sdk.NewDescribeTableColumnsRequest(id)) if err != nil { - return err + return diag.FromErr(err) } // Set the relevant data in the state @@ -688,16 +688,16 @@ func ReadTable(d *schema.ResourceData, meta interface{}) error { for key, val := range toSet { if err := d.Set(key, val); err != nil { // lintignore:R001 - return err + return diag.FromErr(err) } } return nil } // UpdateTable implements schema.UpdateFunc. -func UpdateTable(d *schema.ResourceData, meta interface{}) error { +func UpdateTable(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() + id := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) if d.HasChange("name") { @@ -705,7 +705,7 @@ func UpdateTable(d *schema.ResourceData, meta interface{}) error { err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithNewName(&newId)) if err != nil { - return fmt.Errorf("error renaming table %v err = %w", d.Id(), err) + return diag.FromErr(fmt.Errorf("error renaming table %v err = %w", d.Id(), err)) } d.SetId(helpers.EncodeSnowflakeID(newId)) @@ -747,14 +747,14 @@ func UpdateTable(d *schema.ResourceData, meta interface{}) error { if runSetStatement { err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithSet(setRequest)) if err != nil { - return fmt.Errorf("error updating table: %w", err) + return diag.FromErr(fmt.Errorf("error updating table: %w", err)) } } if runUnsetStatement { err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithUnset(unsetRequest)) if err != nil { - return fmt.Errorf("error updating table: %w", err) + return diag.FromErr(fmt.Errorf("error updating table: %w", err)) } } @@ -764,12 +764,12 @@ func UpdateTable(d *schema.ResourceData, meta interface{}) error { if len(cb) != 0 { err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithClusteringAction(sdk.NewTableClusteringActionRequest().WithClusterBy(cb))) if err != nil { - return fmt.Errorf("error updating table: %w", err) + return diag.FromErr(fmt.Errorf("error updating table: %w", err)) } } else { err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithClusteringAction(sdk.NewTableClusteringActionRequest().WithDropClusteringKey(sdk.Bool(true)))) if err != nil { - return fmt.Errorf("error updating table: %w", err) + return diag.FromErr(fmt.Errorf("error updating table: %w", err)) } } } @@ -785,7 +785,7 @@ func UpdateTable(d *schema.ResourceData, meta interface{}) error { } err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithColumnAction(sdk.NewTableColumnActionRequest().WithDropColumns(snowflake.QuoteStringList(removedColumnNames)))) if err != nil { - return fmt.Errorf("error updating table: %w", err) + return diag.FromErr(fmt.Errorf("error updating table: %w", err)) } } @@ -795,7 +795,7 @@ func UpdateTable(d *schema.ResourceData, meta interface{}) error { if cA._default != nil { if cA._default._type() != "constant" { - return fmt.Errorf("failed to add column %v => Only adding a column as a constant is supported by Snowflake", cA.name) + return diag.FromErr(fmt.Errorf("failed to add column %v => Only adding a column as a constant is supported by Snowflake", cA.name)) } var expression string if sdk.IsStringType(cA.dataType) { @@ -824,7 +824,7 @@ func UpdateTable(d *schema.ResourceData, meta interface{}) error { err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithColumnAction(sdk.NewTableColumnActionRequest().WithAdd(addRequest))) if err != nil { - return fmt.Errorf("error adding column: %w", err) + return diag.FromErr(fmt.Errorf("error adding column: %w", err)) } } for _, cA := range changed { @@ -835,7 +835,7 @@ func UpdateTable(d *schema.ResourceData, meta interface{}) error { } err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithColumnAction(sdk.NewTableColumnActionRequest().WithAlter([]sdk.TableColumnAlterActionRequest{*sdk.NewTableColumnAlterActionRequest(fmt.Sprintf("\"%s\"", cA.newColumn.name)).WithType(sdk.Pointer(sdk.DataType(cA.newColumn.dataType))).WithCollate(newCollation)}))) if err != nil { - return fmt.Errorf("error changing property on %v: err %w", d.Id(), err) + return diag.FromErr(fmt.Errorf("error changing property on %v: err %w", d.Id(), err)) } } if cA.changedNullConstraint { @@ -847,13 +847,13 @@ func UpdateTable(d *schema.ResourceData, meta interface{}) error { } err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithColumnAction(sdk.NewTableColumnActionRequest().WithAlter([]sdk.TableColumnAlterActionRequest{*sdk.NewTableColumnAlterActionRequest(fmt.Sprintf("\"%s\"", cA.newColumn.name)).WithNotNullConstraint(nullabilityRequest)}))) if err != nil { - return fmt.Errorf("error changing property on %v: err %w", d.Id(), err) + return diag.FromErr(fmt.Errorf("error changing property on %v: err %w", d.Id(), err)) } } if cA.dropedDefault { err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithColumnAction(sdk.NewTableColumnActionRequest().WithAlter([]sdk.TableColumnAlterActionRequest{*sdk.NewTableColumnAlterActionRequest(fmt.Sprintf("\"%s\"", cA.newColumn.name)).WithDropDefault(sdk.Bool(true))}))) if err != nil { - return fmt.Errorf("error changing property on %v: err %w", d.Id(), err) + return diag.FromErr(fmt.Errorf("error changing property on %v: err %w", d.Id(), err)) } } if cA.changedComment { @@ -866,7 +866,7 @@ func UpdateTable(d *schema.ResourceData, meta interface{}) error { err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithColumnAction(sdk.NewTableColumnActionRequest().WithAlter([]sdk.TableColumnAlterActionRequest{*columnAlterActionRequest}))) if err != nil { - return fmt.Errorf("error changing property on %v: err %w", d.Id(), err) + return diag.FromErr(fmt.Errorf("error changing property on %v: err %w", d.Id(), err)) } } if cA.changedMaskingPolicy { @@ -878,7 +878,7 @@ func UpdateTable(d *schema.ResourceData, meta interface{}) error { } err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithColumnAction(columnAction)) if err != nil { - return fmt.Errorf("error changing property on %v: err %w", d.Id(), err) + return diag.FromErr(fmt.Errorf("error changing property on %v: err %w", d.Id(), err)) } } } @@ -897,7 +897,7 @@ func UpdateTable(d *schema.ResourceData, meta interface{}) error { WithDrop(sdk.NewTableConstraintDropActionRequest().WithPrimaryKey(sdk.Bool(true))), )) if err != nil { - return fmt.Errorf("error updating table: %w", err) + return diag.FromErr(fmt.Errorf("error updating table: %w", err)) } } @@ -910,7 +910,7 @@ func UpdateTable(d *schema.ResourceData, meta interface{}) error { sdk.NewTableConstraintActionRequest().WithAdd(constraint), )) if err != nil { - return fmt.Errorf("error updating table: %w", err) + return diag.FromErr(fmt.Errorf("error updating table: %w", err)) } } } @@ -921,7 +921,7 @@ func UpdateTable(d *schema.ResourceData, meta interface{}) error { if len(unsetTags) > 0 { err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithUnsetTags(unsetTags)) if err != nil { - return fmt.Errorf("error setting tags on %v, err = %w", d.Id(), err) + return diag.FromErr(fmt.Errorf("error setting tags on %v, err = %w", d.Id(), err)) } } @@ -932,23 +932,23 @@ func UpdateTable(d *schema.ResourceData, meta interface{}) error { } err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithSetTags(tagAssociationRequests)) if err != nil { - return fmt.Errorf("error setting tags on %v, err = %w", d.Id(), err) + return diag.FromErr(fmt.Errorf("error setting tags on %v, err = %w", d.Id(), err)) } } } - return ReadTable(d, meta) + return ReadTable(ctx, d, meta) } // DeleteTable implements schema.DeleteFunc. -func DeleteTable(d *schema.ResourceData, meta interface{}) error { +func DeleteTable(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() + id := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) err := client.Tables.Drop(ctx, sdk.NewDropTableRequest(id)) if err != nil { - return err + return diag.FromErr(err) } d.SetId("") diff --git a/pkg/resources/table_column_masking_policy_application.go b/pkg/resources/table_column_masking_policy_application.go index 71b722f2ad..e48d4447ad 100644 --- a/pkg/resources/table_column_masking_policy_application.go +++ b/pkg/resources/table_column_masking_policy_application.go @@ -1,8 +1,12 @@ package resources import ( + "context" "fmt" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake" @@ -32,10 +36,10 @@ var tableColumnMaskingPolicyApplicationSchema = map[string]*schema.Schema{ func TableColumnMaskingPolicyApplication() *schema.Resource { return &schema.Resource{ - Description: "Applies a masking policy to a table column.", - Create: CreateTableColumnMaskingPolicyApplication, - Read: ReadTableColumnMaskingPolicyApplication, - Delete: DeleteTableColumnMaskingPolicyApplication, + Description: "Applies a masking policy to a table column.", + CreateContext: TrackingCreateWrapper(resources.TableColumnMaskingPolicyApplication, CreateTableColumnMaskingPolicyApplication), + ReadContext: TrackingReadWrapper(resources.TableColumnMaskingPolicyApplication, ReadTableColumnMaskingPolicyApplication), + DeleteContext: TrackingDeleteWrapper(resources.TableColumnMaskingPolicyApplication, DeleteTableColumnMaskingPolicyApplication), Schema: tableColumnMaskingPolicyApplicationSchema, Importer: &schema.ResourceImporter{ @@ -45,7 +49,7 @@ func TableColumnMaskingPolicyApplication() *schema.Resource { } // CreateTableColumnMaskingPolicyApplication implements schema.CreateFunc. -func CreateTableColumnMaskingPolicyApplication(d *schema.ResourceData, meta interface{}) error { +func CreateTableColumnMaskingPolicyApplication(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { manager := snowflake.NewTableColumnMaskingPolicyApplicationManager() input := &snowflake.TableColumnMaskingPolicyApplicationCreateInput{ @@ -62,25 +66,25 @@ func CreateTableColumnMaskingPolicyApplication(d *schema.ResourceData, meta inte db := client.GetConn().DB _, err := db.Exec(stmt) if err != nil { - return fmt.Errorf("error applying masking policy: %w", err) + return diag.FromErr(fmt.Errorf("error applying masking policy: %w", err)) } d.SetId(TableColumnMaskingPolicyApplicationID(&input.TableColumnMaskingPolicyApplication)) - return ReadTableColumnMaskingPolicyApplication(d, meta) + return ReadTableColumnMaskingPolicyApplication(ctx, d, meta) } // ReadTableColumnMaskingPolicyApplication implements schema.ReadFunc. -func ReadTableColumnMaskingPolicyApplication(d *schema.ResourceData, meta interface{}) error { +func ReadTableColumnMaskingPolicyApplication(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { manager := snowflake.NewTableColumnMaskingPolicyApplicationManager() table, column := TableColumnMaskingPolicyApplicationIdentifier(d.Id()) if err := d.Set("table", table.QualifiedName()); err != nil { - return fmt.Errorf("error setting table: %w", err) + return diag.FromErr(fmt.Errorf("error setting table: %w", err)) } if err := d.Set("column", column); err != nil { - return fmt.Errorf("error setting column: %w", err) + return diag.FromErr(fmt.Errorf("error setting column: %w", err)) } input := &snowflake.TableColumnMaskingPolicyApplicationReadInput{ @@ -94,24 +98,24 @@ func ReadTableColumnMaskingPolicyApplication(d *schema.ResourceData, meta interf db := client.GetConn().DB rows, err := db.Query(stmt) if err != nil { - return fmt.Errorf("error querying password policy: %w", err) + return diag.FromErr(fmt.Errorf("error querying password policy: %w", err)) } defer rows.Close() maskingPolicy, err := manager.Parse(rows, column) if err != nil { - return fmt.Errorf("failed to parse result of describe: %w", err) + return diag.FromErr(fmt.Errorf("failed to parse result of describe: %w", err)) } if err = d.Set("masking_policy", maskingPolicy); err != nil { - return fmt.Errorf("error setting masking_policy: %w", err) + return diag.FromErr(fmt.Errorf("error setting masking_policy: %w", err)) } return nil } // DeleteTableColumnMaskingPolicyApplication implements schema.DeleteFunc. -func DeleteTableColumnMaskingPolicyApplication(d *schema.ResourceData, meta interface{}) error { +func DeleteTableColumnMaskingPolicyApplication(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { manager := snowflake.NewTableColumnMaskingPolicyApplicationManager() input := &snowflake.TableColumnMaskingPolicyApplicationDeleteInput{ @@ -127,7 +131,7 @@ func DeleteTableColumnMaskingPolicyApplication(d *schema.ResourceData, meta inte db := client.GetConn().DB _, err := db.Exec(stmt) if err != nil { - return fmt.Errorf("error executing drop statement: %w", err) + return diag.FromErr(fmt.Errorf("error executing drop statement: %w", err)) } return nil diff --git a/pkg/resources/table_constraint.go b/pkg/resources/table_constraint.go index 415322c9fd..92e2f6a3f7 100644 --- a/pkg/resources/table_constraint.go +++ b/pkg/resources/table_constraint.go @@ -5,6 +5,9 @@ import ( "fmt" "strings" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" @@ -179,10 +182,10 @@ var tableConstraintSchema = map[string]*schema.Schema{ func TableConstraint() *schema.Resource { return &schema.Resource{ - Create: CreateTableConstraint, - Read: ReadTableConstraint, - Update: UpdateTableConstraint, - Delete: DeleteTableConstraint, + CreateContext: TrackingCreateWrapper(resources.TableConstraint, CreateTableConstraint), + ReadContext: TrackingReadWrapper(resources.TableConstraint, ReadTableConstraint), + UpdateContext: TrackingUpdateWrapper(resources.TableConstraint, UpdateTableConstraint), + DeleteContext: TrackingDeleteWrapper(resources.TableConstraint, DeleteTableConstraint), Schema: tableConstraintSchema, Importer: &schema.ResourceImporter{ @@ -229,9 +232,8 @@ func getTableIdentifier(s string) (*sdk.SchemaObjectIdentifier, error) { } // CreateTableConstraint implements schema.CreateFunc. -func CreateTableConstraint(d *schema.ResourceData, meta interface{}) error { +func CreateTableConstraint(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() name := d.Get("name").(string) cType := d.Get("type").(string) @@ -239,12 +241,12 @@ func CreateTableConstraint(d *schema.ResourceData, meta interface{}) error { tableIdentifier, err := getTableIdentifier(tableID) if err != nil { - return err + return diag.FromErr(err) } constraintType, err := sdk.ToColumnConstraintType(cType) if err != nil { - return err + return diag.FromErr(err) } constraintRequest := sdk.NewOutOfLineConstraintRequest(constraintType).WithName(&name) @@ -290,11 +292,11 @@ func CreateTableConstraint(d *schema.ResourceData, meta interface{}) error { fkTableID := references["table_id"].(string) fkId, err := helpers.DecodeSnowflakeParameterID(fkTableID) if err != nil { - return fmt.Errorf("table id is incorrect: %s, err: %w", fkTableID, err) + return diag.FromErr(fmt.Errorf("table id is incorrect: %s, err: %w", fkTableID, err)) } referencedTableIdentifier, ok := fkId.(sdk.SchemaObjectIdentifier) if !ok { - return fmt.Errorf("table id is incorrect: %s", fkId) + return diag.FromErr(fmt.Errorf("table id is incorrect: %s", fkId)) } cols := references["columns"].([]interface{}) @@ -306,17 +308,17 @@ func CreateTableConstraint(d *schema.ResourceData, meta interface{}) error { matchType, err := sdk.ToMatchType(foreignKeyProperties["match"].(string)) if err != nil { - return err + return diag.FromErr(err) } foreignKeyRequest.WithMatch(&matchType) onUpdate, err := sdk.ToForeignKeyAction(foreignKeyProperties["on_update"].(string)) if err != nil { - return err + return diag.FromErr(err) } onDelete, err := sdk.ToForeignKeyAction(foreignKeyProperties["on_delete"].(string)) if err != nil { - return err + return diag.FromErr(err) } foreignKeyRequest.WithOn(sdk.NewForeignKeyOnAction(). WithOnDelete(&onDelete). @@ -328,7 +330,7 @@ func CreateTableConstraint(d *schema.ResourceData, meta interface{}) error { alterStatement := sdk.NewAlterTableRequest(*tableIdentifier).WithConstraintAction(sdk.NewTableConstraintActionRequest().WithAdd(constraintRequest)) err = client.Tables.Alter(ctx, alterStatement) if err != nil { - return fmt.Errorf("error creating table constraint %v err = %w", name, err) + return diag.FromErr(fmt.Errorf("error creating table constraint %v err = %w", name, err)) } tc := tableConstraintID{ @@ -338,11 +340,11 @@ func CreateTableConstraint(d *schema.ResourceData, meta interface{}) error { } d.SetId(tc.String()) - return ReadTableConstraint(d, meta) + return ReadTableConstraint(ctx, d, meta) } // ReadTableConstraint implements schema.ReadFunc. -func ReadTableConstraint(_ *schema.ResourceData, _ interface{}) error { +func ReadTableConstraint(ctx context.Context, _ *schema.ResourceData, _ interface{}) diag.Diagnostics { // TODO(issue-2683): Implement read operation // commenting this out since it requires an active warehouse to be set which may not be intuitive. // also it takes a while for the database to reflect changes. Would likely need to add a validation @@ -356,24 +358,24 @@ func ReadTableConstraint(_ *schema.ResourceData, _ interface{}) error { // just need to check to make sure it exists _, err := snowflake.ShowTableConstraint(tc.name, databaseName, schemaName, tableName, db) if err != nil { - return fmt.Errorf(fmt.Sprintf("error reading table constraint %v", tc.String())) + return diag.FromErr(fmt.Errorf(fmt.Sprintf("error reading table constraint %v", tc.String())) }*/ return nil } // UpdateTableConstraint implements schema.UpdateFunc. -func UpdateTableConstraint(d *schema.ResourceData, meta interface{}) error { +func UpdateTableConstraint(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { /* TODO(issue-2683): Update isn't be possible with non-existing Read operation. The Update logic is ready to be uncommented once the Read operation is ready. client := meta.(*provider.Context).Client - ctx := context.Background() + tc := tableConstraintID{} tc.parse(d.Id()) tableIdentifier, err := getTableIdentifier(tc.tableID) if err != nil { - return err + return diag.FromErr(err) } if d.HasChange("name") { @@ -383,7 +385,7 @@ func UpdateTableConstraint(d *schema.ResourceData, meta interface{}) error { err = client.Tables.Alter(ctx, alterStatement) if err != nil { - return fmt.Errorf("error renaming table constraint %s err = %w", tc.name, err) + return diag.FromErr(fmt.Errorf("error renaming table constraint %s err = %w", tc.name, err) } tc.name = newName @@ -391,20 +393,19 @@ func UpdateTableConstraint(d *schema.ResourceData, meta interface{}) error { } */ - return ReadTableConstraint(d, meta) + return ReadTableConstraint(ctx, d, meta) } // DeleteTableConstraint implements schema.DeleteFunc. -func DeleteTableConstraint(d *schema.ResourceData, meta interface{}) error { +func DeleteTableConstraint(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() tc := tableConstraintID{} tc.parse(d.Id()) tableIdentifier, err := getTableIdentifier(tc.tableID) if err != nil { - return err + return diag.FromErr(err) } dropRequest := sdk.NewTableConstraintDropActionRequest().WithConstraintName(&tc.name) @@ -416,7 +417,7 @@ func DeleteTableConstraint(d *schema.ResourceData, meta interface{}) error { d.SetId("") return nil } - return fmt.Errorf("error dropping table constraint %v err = %w", tc.name, err) + return diag.FromErr(fmt.Errorf("error dropping table constraint %v err = %w", tc.name, err)) } d.SetId("") diff --git a/pkg/resources/user_authentication_policy_attachment.go b/pkg/resources/user_authentication_policy_attachment.go index c29b54880f..7e03ef118b 100644 --- a/pkg/resources/user_authentication_policy_attachment.go +++ b/pkg/resources/user_authentication_policy_attachment.go @@ -4,6 +4,9 @@ import ( "context" "fmt" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" @@ -31,20 +34,19 @@ var userAuthenticationPolicyAttachmentSchema = map[string]*schema.Schema{ // UserAuthenticationPolicyAttachment returns a pointer to the resource representing a user authentication policy attachment. func UserAuthenticationPolicyAttachment() *schema.Resource { return &schema.Resource{ - Description: "Specifies the authentication policy to use for a certain user.", - Create: CreateUserAuthenticationPolicyAttachment, - Read: ReadUserAuthenticationPolicyAttachment, - Delete: DeleteUserAuthenticationPolicyAttachment, - Schema: userAuthenticationPolicyAttachmentSchema, + Description: "Specifies the authentication policy to use for a certain user.", + CreateContext: TrackingCreateWrapper(resources.UserAuthenticationPolicyAttachment, CreateUserAuthenticationPolicyAttachment), + ReadContext: TrackingReadWrapper(resources.UserAuthenticationPolicyAttachment, ReadUserAuthenticationPolicyAttachment), + DeleteContext: TrackingDeleteWrapper(resources.UserAuthenticationPolicyAttachment, DeleteUserAuthenticationPolicyAttachment), + Schema: userAuthenticationPolicyAttachmentSchema, Importer: &schema.ResourceImporter{ StateContext: schema.ImportStatePassthroughContext, }, } } -func CreateUserAuthenticationPolicyAttachment(d *schema.ResourceData, meta any) error { +func CreateUserAuthenticationPolicyAttachment(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() userName := sdk.NewAccountObjectIdentifierFromFullyQualifiedName(d.Get("user_name").(string)) authenticationPolicy := sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(d.Get("authentication_policy_name").(string)) @@ -55,28 +57,27 @@ func CreateUserAuthenticationPolicyAttachment(d *schema.ResourceData, meta any) }, }) if err != nil { - return err + return diag.FromErr(err) } d.SetId(helpers.EncodeResourceIdentifier(userName.FullyQualifiedName(), authenticationPolicy.FullyQualifiedName())) - return ReadUserAuthenticationPolicyAttachment(d, meta) + return ReadUserAuthenticationPolicyAttachment(ctx, d, meta) } -func ReadUserAuthenticationPolicyAttachment(d *schema.ResourceData, meta any) error { +func ReadUserAuthenticationPolicyAttachment(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() parts := helpers.ParseResourceIdentifier(d.Id()) if len(parts) != 2 { - return fmt.Errorf("required id format 'user_name|authentication_policy_name', but got: '%s'", d.Id()) + return diag.FromErr(fmt.Errorf("required id format 'user_name|authentication_policy_name', but got: '%s'", d.Id())) } // Note: there is no alphanumeric id for an attachment, so we retrieve the authentication policies attached to a certain user. userName := sdk.NewAccountObjectIdentifierFromFullyQualifiedName(parts[0]) policyReferences, err := client.PolicyReferences.GetForEntity(ctx, sdk.NewGetForEntityPolicyReferenceRequest(userName, sdk.PolicyEntityDomainUser)) if err != nil { - return err + return diag.FromErr(err) } authenticationPolicyReferences := make([]sdk.PolicyReference, 0) @@ -88,7 +89,7 @@ func ReadUserAuthenticationPolicyAttachment(d *schema.ResourceData, meta any) er // Note: this should never happen, but just in case: so far, Snowflake only allows one Authentication Policy per user. if len(authenticationPolicyReferences) > 1 { - return fmt.Errorf("internal error: multiple policy references attached to a user. This should never happen") + return diag.FromErr(fmt.Errorf("internal error: multiple policy references attached to a user. This should never happen")) } // Note: this means the resource has been deleted outside of Terraform. @@ -98,7 +99,7 @@ func ReadUserAuthenticationPolicyAttachment(d *schema.ResourceData, meta any) er } if err := d.Set("user_name", userName.Name()); err != nil { - return err + return diag.FromErr(err) } if err := d.Set( "authentication_policy_name", @@ -107,15 +108,14 @@ func ReadUserAuthenticationPolicyAttachment(d *schema.ResourceData, meta any) er *authenticationPolicyReferences[0].PolicySchema, authenticationPolicyReferences[0].PolicyName, ).FullyQualifiedName()); err != nil { - return err + return diag.FromErr(err) } - return err + return diag.FromErr(err) } -func DeleteUserAuthenticationPolicyAttachment(d *schema.ResourceData, meta any) error { +func DeleteUserAuthenticationPolicyAttachment(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() userName := sdk.NewAccountObjectIdentifierFromFullyQualifiedName(d.Get("user_name").(string)) @@ -125,7 +125,7 @@ func DeleteUserAuthenticationPolicyAttachment(d *schema.ResourceData, meta any) }, }) if err != nil { - return err + return diag.FromErr(err) } d.SetId("") diff --git a/pkg/resources/user_password_policy_attachment.go b/pkg/resources/user_password_policy_attachment.go index 96bac9523a..84cdd42a1f 100644 --- a/pkg/resources/user_password_policy_attachment.go +++ b/pkg/resources/user_password_policy_attachment.go @@ -4,6 +4,9 @@ import ( "context" "fmt" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" @@ -31,20 +34,19 @@ var userPasswordPolicyAttachmentSchema = map[string]*schema.Schema{ // UserPasswordPolicyAttachment returns a pointer to the resource representing a user password policy attachment. func UserPasswordPolicyAttachment() *schema.Resource { return &schema.Resource{ - Description: "Specifies the password policy to use for a certain user.", - Create: CreateUserPasswordPolicyAttachment, - Read: ReadUserPasswordPolicyAttachment, - Delete: DeleteUserPasswordPolicyAttachment, - Schema: userPasswordPolicyAttachmentSchema, + Description: "Specifies the password policy to use for a certain user.", + CreateContext: TrackingCreateWrapper(resources.UserPasswordPolicyAttachment, CreateUserPasswordPolicyAttachment), + ReadContext: TrackingReadWrapper(resources.UserPasswordPolicyAttachment, ReadUserPasswordPolicyAttachment), + DeleteContext: TrackingDeleteWrapper(resources.UserPasswordPolicyAttachment, DeleteUserPasswordPolicyAttachment), + Schema: userPasswordPolicyAttachmentSchema, Importer: &schema.ResourceImporter{ StateContext: schema.ImportStatePassthroughContext, }, } } -func CreateUserPasswordPolicyAttachment(d *schema.ResourceData, meta any) error { +func CreateUserPasswordPolicyAttachment(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() userName := sdk.NewAccountObjectIdentifierFromFullyQualifiedName(d.Get("user_name").(string)) passwordPolicy := sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(d.Get("password_policy_name").(string)) @@ -55,28 +57,27 @@ func CreateUserPasswordPolicyAttachment(d *schema.ResourceData, meta any) error }, }) if err != nil { - return err + return diag.FromErr(err) } d.SetId(helpers.EncodeResourceIdentifier(userName.FullyQualifiedName(), passwordPolicy.FullyQualifiedName())) - return ReadUserPasswordPolicyAttachment(d, meta) + return ReadUserPasswordPolicyAttachment(ctx, d, meta) } -func ReadUserPasswordPolicyAttachment(d *schema.ResourceData, meta any) error { +func ReadUserPasswordPolicyAttachment(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() parts := helpers.ParseResourceIdentifier(d.Id()) if len(parts) != 2 { - return fmt.Errorf("required id format 'user_name|password_policy_name', but got: '%s'", d.Id()) + return diag.FromErr(fmt.Errorf("required id format 'user_name|password_policy_name', but got: '%s'", d.Id())) } // Note: there is no alphanumeric id for an attachment, so we retrieve the password policies attached to a certain user. userName := sdk.NewAccountObjectIdentifierFromFullyQualifiedName(parts[0]) policyReferences, err := client.PolicyReferences.GetForEntity(ctx, sdk.NewGetForEntityPolicyReferenceRequest(userName, sdk.PolicyEntityDomainUser)) if err != nil { - return err + return diag.FromErr(err) } passwordPolicyReferences := make([]sdk.PolicyReference, 0) @@ -88,7 +89,7 @@ func ReadUserPasswordPolicyAttachment(d *schema.ResourceData, meta any) error { // Note: this should never happen, but just in case: so far, Snowflake only allows one Password Policy per user. if len(passwordPolicyReferences) > 1 { - return fmt.Errorf("internal error: multiple policy references attached to a user. This should never happen") + return diag.FromErr(fmt.Errorf("internal error: multiple policy references attached to a user. This should never happen")) } // Note: this means the resource has been deleted outside of Terraform. @@ -98,7 +99,7 @@ func ReadUserPasswordPolicyAttachment(d *schema.ResourceData, meta any) error { } if err := d.Set("user_name", userName.Name()); err != nil { - return err + return diag.FromErr(err) } if err := d.Set( "password_policy_name", @@ -107,15 +108,14 @@ func ReadUserPasswordPolicyAttachment(d *schema.ResourceData, meta any) error { *passwordPolicyReferences[0].PolicySchema, passwordPolicyReferences[0].PolicyName, ).FullyQualifiedName()); err != nil { - return err + return diag.FromErr(err) } - return err + return diag.FromErr(err) } -func DeleteUserPasswordPolicyAttachment(d *schema.ResourceData, meta any) error { +func DeleteUserPasswordPolicyAttachment(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() userName := sdk.NewAccountObjectIdentifierFromFullyQualifiedName(d.Get("user_name").(string)) @@ -125,7 +125,7 @@ func DeleteUserPasswordPolicyAttachment(d *schema.ResourceData, meta any) error }, }) if err != nil { - return err + return diag.FromErr(err) } d.SetId("") diff --git a/pkg/resources/user_public_keys.go b/pkg/resources/user_public_keys.go index 6f97a1e9a6..c46b0990c7 100644 --- a/pkg/resources/user_public_keys.go +++ b/pkg/resources/user_public_keys.go @@ -8,6 +8,9 @@ import ( "log" "strings" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" @@ -51,10 +54,10 @@ var userPublicKeysSchema = map[string]*schema.Schema{ func UserPublicKeys() *schema.Resource { return &schema.Resource{ - Create: CreateUserPublicKeys, - Read: ReadUserPublicKeys, - Update: UpdateUserPublicKeys, - Delete: DeleteUserPublicKeys, + CreateContext: TrackingCreateWrapper(resources.UserPublicKeys, CreateUserPublicKeys), + ReadContext: TrackingReadWrapper(resources.UserPublicKeys, ReadUserPublicKeys), + UpdateContext: TrackingUpdateWrapper(resources.UserPublicKeys, UpdateUserPublicKeys), + DeleteContext: TrackingDeleteWrapper(resources.UserPublicKeys, DeleteUserPublicKeys), Schema: userPublicKeysSchema, Importer: &schema.ResourceImporter{ @@ -63,9 +66,7 @@ func UserPublicKeys() *schema.Resource { } } -func checkUserExists(client *sdk.Client, userId sdk.AccountObjectIdentifier) (bool, error) { - ctx := context.Background() - +func checkUserExists(ctx context.Context, client *sdk.Client, userId sdk.AccountObjectIdentifier) (bool, error) { // First check if user exists _, err := client.Users.Describe(ctx, userId) if errors.Is(err, sdk.ErrObjectNotExistOrAuthorized) { @@ -79,13 +80,13 @@ func checkUserExists(client *sdk.Client, userId sdk.AccountObjectIdentifier) (bo return true, nil } -func ReadUserPublicKeys(d *schema.ResourceData, meta interface{}) error { +func ReadUserPublicKeys(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client id := helpers.DecodeSnowflakeID(d.Id()).(sdk.AccountObjectIdentifier) - exists, err := checkUserExists(client, id) + exists, err := checkUserExists(ctx, client, id) if err != nil { - return err + return diag.FromErr(err) } // If not found, mark resource to be removed from state file during apply or refresh if !exists { @@ -96,7 +97,7 @@ func ReadUserPublicKeys(d *schema.ResourceData, meta interface{}) error { return nil } -func CreateUserPublicKeys(d *schema.ResourceData, meta interface{}) error { +func CreateUserPublicKeys(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client db := client.GetConn().DB name := d.Get("name").(string) @@ -108,15 +109,15 @@ func CreateUserPublicKeys(d *schema.ResourceData, meta interface{}) error { } err := updateUserPublicKeys(db, name, prop, publicKey.(string)) if err != nil { - return err + return diag.FromErr(err) } } d.SetId(name) - return ReadUserPublicKeys(d, meta) + return ReadUserPublicKeys(ctx, d, meta) } -func UpdateUserPublicKeys(d *schema.ResourceData, meta interface{}) error { +func UpdateUserPublicKeys(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client db := client.GetConn().DB name := d.Id() @@ -142,7 +143,7 @@ func UpdateUserPublicKeys(d *schema.ResourceData, meta interface{}) error { for prop, value := range propsToSet { err := updateUserPublicKeys(db, name, prop, value) if err != nil { - return err + return diag.FromErr(err) } } @@ -150,14 +151,14 @@ func UpdateUserPublicKeys(d *schema.ResourceData, meta interface{}) error { for k := range propsToUnset { err := unsetUserPublicKeys(db, name, k) if err != nil { - return err + return diag.FromErr(err) } } // re-sync - return ReadUserPublicKeys(d, meta) + return ReadUserPublicKeys(ctx, d, meta) } -func DeleteUserPublicKeys(d *schema.ResourceData, meta interface{}) error { +func DeleteUserPublicKeys(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client db := client.GetConn().DB name := d.Id() @@ -165,7 +166,7 @@ func DeleteUserPublicKeys(d *schema.ResourceData, meta interface{}) error { for _, prop := range userPublicKeyProperties { err := unsetUserPublicKeys(db, name, prop) if err != nil { - return err + return diag.FromErr(err) } } diff --git a/pkg/sdk/dynamic_table.go b/pkg/sdk/dynamic_table.go index 7d697e9709..dac13dc576 100644 --- a/pkg/sdk/dynamic_table.go +++ b/pkg/sdk/dynamic_table.go @@ -4,6 +4,8 @@ import ( "context" "database/sql" "time" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/tracking" ) type DynamicTables interface { @@ -172,7 +174,7 @@ func (dtr dynamicTableRow) convert() *DynamicTable { RefreshMode: DynamicTableRefreshMode(dtr.RefreshMode), Warehouse: dtr.Warehouse, Comment: dtr.Comment, - Text: dtr.Text, + Text: tracking.TrimMetadata(dtr.Text), AutomaticClustering: dtr.AutomaticClustering == "ON", // "ON" or "OFF SchedulingState: DynamicTableSchedulingState(dtr.SchedulingState), IsClone: dtr.IsClone, diff --git a/pkg/sdk/materialized_views_impl_gen.go b/pkg/sdk/materialized_views_impl_gen.go index d422d59b04..9f1393d30c 100644 --- a/pkg/sdk/materialized_views_impl_gen.go +++ b/pkg/sdk/materialized_views_impl_gen.go @@ -4,6 +4,7 @@ import ( "context" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/collections" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/tracking" ) var _ MaterializedViews = (*materializedViews)(nil) @@ -168,7 +169,7 @@ func (r materializedViewDBRow) convert() *MaterializedView { Owner: r.Owner, Invalid: r.Invalid, BehindBy: r.BehindBy, - Text: r.Text, + Text: tracking.TrimMetadata(r.Text), IsSecure: r.IsSecure, } if r.Reserved.Valid { diff --git a/pkg/sdk/testint/dynamic_table_integration_test.go b/pkg/sdk/testint/dynamic_table_integration_test.go index 778334f561..b7025d4e04 100644 --- a/pkg/sdk/testint/dynamic_table_integration_test.go +++ b/pkg/sdk/testint/dynamic_table_integration_test.go @@ -3,8 +3,12 @@ package testint import ( "context" "errors" + "fmt" "testing" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/tracking" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/helpers/random" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/stretchr/testify/assert" @@ -49,6 +53,23 @@ func TestInt_DynamicTableCreateAndDrop(t *testing.T) { assert.Equal(t, "ROLE", dynamicTableById.OwnerRoleType) }) + t.Run("create with usage tracking comment", func(t *testing.T) { + id := testClientHelper().Ids.RandomSchemaObjectIdentifier() + plainQuery := fmt.Sprintf("SELECT id FROM %s", tableTest.ID().FullyQualifiedName()) + query, err := tracking.AppendMetadata(plainQuery, tracking.NewVersionedMetadata(resources.DynamicTable, tracking.CreateOperation)) + require.NoError(t, err) + + err = client.DynamicTables.Create(ctx, sdk.NewCreateDynamicTableRequest(id, testClientHelper().Ids.WarehouseId(), sdk.TargetLag{ + MaximumDuration: sdk.String("2 minutes"), + }, query)) + require.NoError(t, err) + + dynamicTable, err := client.DynamicTables.ShowByID(ctx, id) + require.NoError(t, err) + + assert.Equal(t, fmt.Sprintf("CREATE DYNAMIC TABLE %s lag = '2 minutes' refresh_mode = 'AUTO' initialize = 'ON_CREATE' warehouse = %s AS %s", id.FullyQualifiedName(), testClientHelper().Ids.WarehouseId().FullyQualifiedName(), plainQuery), dynamicTable.Text) + }) + t.Run("test complete with target lag", func(t *testing.T) { id := testClientHelper().Ids.RandomSchemaObjectIdentifier() targetLag := sdk.TargetLag{ diff --git a/pkg/sdk/testint/materialized_views_gen_integration_test.go b/pkg/sdk/testint/materialized_views_gen_integration_test.go index 98ba774f94..6ef58925e9 100644 --- a/pkg/sdk/testint/materialized_views_gen_integration_test.go +++ b/pkg/sdk/testint/materialized_views_gen_integration_test.go @@ -5,6 +5,9 @@ import ( "fmt" "testing" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/tracking" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/collections" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/stretchr/testify/assert" @@ -109,6 +112,18 @@ func TestInt_MaterializedViews(t *testing.T) { assertMaterializedView(t, view, request.GetName()) }) + t.Run("create materialized view: with usage tracking comment", func(t *testing.T) { + id := testClientHelper().Ids.RandomSchemaObjectIdentifier() + plainQuery := fmt.Sprintf("SELECT id FROM %s", table.ID().FullyQualifiedName()) + query, err := tracking.AppendMetadata(plainQuery, tracking.NewVersionedMetadata(resources.MaterializedView, tracking.CreateOperation)) + require.NoError(t, err) + + view := createMaterializedViewWithRequest(t, sdk.NewCreateMaterializedViewRequest(id, query)) + + assertMaterializedView(t, view, sdk.NewCreateMaterializedViewRequest(id, query).GetName()) + assert.Equal(t, fmt.Sprintf("CREATE MATERIALIZED VIEW %s AS %s", id.FullyQualifiedName(), plainQuery), view.Text) + }) + t.Run("create materialized view: almost complete case", func(t *testing.T) { rowAccessPolicy, rowAccessPolicyCleanup := testClientHelper().RowAccessPolicy.CreateRowAccessPolicy(t) t.Cleanup(rowAccessPolicyCleanup) diff --git a/pkg/sdk/testint/views_gen_integration_test.go b/pkg/sdk/testint/views_gen_integration_test.go index 7fda5f15ae..3682a3ee09 100644 --- a/pkg/sdk/testint/views_gen_integration_test.go +++ b/pkg/sdk/testint/views_gen_integration_test.go @@ -7,6 +7,9 @@ import ( "slices" "testing" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/tracking" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + assertions "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/bettertestspoc/assert" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/bettertestspoc/assert/objectassert" @@ -166,6 +169,19 @@ func TestInt_Views(t *testing.T) { assertView(t, view, request.GetName()) }) + t.Run("create view: with usage tracking comment", func(t *testing.T) { + id := testClientHelper().Ids.RandomSchemaObjectIdentifier() + plainQuery := "SELECT NULL AS TYPE" + query, err := tracking.AppendMetadata(plainQuery, tracking.NewVersionedMetadata(resources.View, tracking.CreateOperation)) + require.NoError(t, err) + request := sdk.NewCreateViewRequest(id, query) + + view := createViewWithRequest(t, request) + + assertView(t, view, request.GetName()) + assert.Equal(t, fmt.Sprintf("CREATE VIEW %s AS %s", id.FullyQualifiedName(), plainQuery), view.Text) + }) + t.Run("create view: almost complete case - without masking and projection policies", func(t *testing.T) { rowAccessPolicy, rowAccessPolicyCleanup := testClientHelper().RowAccessPolicy.CreateRowAccessPolicy(t) t.Cleanup(rowAccessPolicyCleanup) diff --git a/pkg/sdk/views_impl_gen.go b/pkg/sdk/views_impl_gen.go index d9a939a268..a9a2783fb4 100644 --- a/pkg/sdk/views_impl_gen.go +++ b/pkg/sdk/views_impl_gen.go @@ -3,6 +3,8 @@ package sdk import ( "context" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/tracking" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/collections" ) @@ -277,7 +279,7 @@ func (r viewDBRow) convert() *View { view.Comment = r.Comment.String } if r.Text.Valid { - view.Text = r.Text.String + view.Text = tracking.TrimMetadata(r.Text.String) } if r.Kind.Valid { view.Kind = r.Kind.String