diff --git a/MIGRATION_GUIDE.md b/MIGRATION_GUIDE.md index ad1af3ec80..dbfef7d7f6 100644 --- a/MIGRATION_GUIDE.md +++ b/MIGRATION_GUIDE.md @@ -6,7 +6,103 @@ across different versions. > [!TIP] > We highly recommend upgrading the versions one by one instead of bulk upgrades. - + +## v0.99.0 ➞ v0.100.0 + +### snowflake_tag_association resource changes +#### *(behavior change)* new id format +In order to provide more functionality for tagging objects, we have changed the resource id from `"TAG_DATABASE"."TAG_SCHEMA"."TAG_NAME"` to `"TAG_DATABASE"."TAG_SCHEMA"."TAG_NAME"|TAG_VALUE|OBJECT_TYPE`. This allows to group tags associations per tag ID, tag value and object type in one resource. +``` +resource "snowflake_tag_association" "gold_warehouses" { + object_identifiers = [snowflake_warehouse.w1.fully_qualified_name, snowflake_warehouse.w2.fully_qualified_name] + object_type = "WAREHOUSE" + tag_id = snowflake_tag.tier.fully_qualified_name + tag_value = "gold" +} +resource "snowflake_tag_association" "silver_warehouses" { + object_identifiers = [snowflake_warehouse.w3.fully_qualified_name] + object_type = "WAREHOUSE" + tag_id = snowflake_tag.tier.fully_qualified_name + tag_value = "silver" +} +resource "snowflake_tag_association" "silver_databases" { + object_identifiers = [snowflake_database.d1.fully_qualified_name] + object_type = "DATABASE" + tag_id = snowflake_tag.tier.fully_qualified_name + tag_value = "silver" +} +``` + +Note that if you want to promote silver instances to gold, you can not simply change `tag_value` in `silver_warehouses`. Instead, you should first remove `object_identifiers` from `silver_warehouses`, run `terraform apply`, and then add the relevant `object_identifiers` in `gold_warehouses`, like this (note that `silver_warehouses` resource was deleted): +``` +resource "snowflake_tag_association" "gold_warehouses" { + object_identifiers = [snowflake_warehouse.w1.fully_qualified_name, snowflake_warehouse.w2.fully_qualified_name, snowflake_warehouse.w3.fully_qualified_name] + object_type = "WAREHOUSE" + tag_id = snowflake_tag.tier.fully_qualified_name + tag_value = "gold" +} +``` +and run `terraform apply` again. + +Note that the order of operations is not deterministic in this case, and if you do these operations in one step, it is possible that the tag value will be changed first, and unset later because of removing the resource with old value. + +The state is migrated automatically. There is no need to adjust configuration files, unless you use resource id `snowflake_tag_association.example.id` as a reference in other resources. + +#### *(behavior change)* changed fields +Behavior of some fields was changed: +- `object_identifier` was renamed to `object_identifiers` and it is now a set of fully qualified names. Change your configurations from +``` +resource "snowflake_tag_association" "table_association" { + object_identifier { + name = snowflake_table.test.name + database = snowflake_database.test.name + schema = snowflake_schema.test.name + } + object_type = "TABLE" + tag_id = snowflake_tag.test.fully_qualified_name + tag_value = "engineering" +} +``` +to +``` +resource "snowflake_tag_association" "table_association" { + object_identifiers = [snowflake_table.test.fully_qualified_name] + object_type = "TABLE" + tag_id = snowflake_tag.test.fully_qualified_name + tag_value = "engineering" +} +``` +- `tag_id` has now suppressed identifier quoting to prevent issues with Terraform showing permament differences, like [this one](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2982) +- `object_type` and `tag_id` are now marked as ForceNew + +The state is migrated automatically. Please adjust your configuration files. + +### Data type changes + +As part of reworking functions, procedures, and any other resource utilizing Snowflake data types, we adjusted the parsing of data types to be more aligned with Snowflake (according to [docs](https://docs.snowflake.com/en/sql-reference/intro-summary-data-types)). + +Affected resources: +- `snowflake_function` +- `snowflake_procedure` +- `snowflake_table` +- `snowflake_external_function` +- `snowflake_masking_policy` +- `snowflake_row_access_policy` +- `snowflake_dynamic_table` +You may encounter non-empty plans in these resources after bumping. + +Changes to the previous implementation/limitations: +- `BOOL` is no longer supported; use `BOOLEAN` instead. +- Following the change described [here](#bugfix-handle-data-type-diff-suppression-better-for-text-and-number), comparing and suppressing changes of data types was extended for all other data types with the following rules: + - `CHARACTER`, `CHAR`, `NCHAR` now have the default size set to 1 if not provided (following the [docs](https://docs.snowflake.com/en/sql-reference/data-types-text#char-character-nchar)) + - `BINARY` has default size set to 8388608 if not provided (following the [docs](https://docs.snowflake.com/en/sql-reference/data-types-text#binary)) + - `TIME` has default precision set to 9 if not provided (following the [docs](https://docs.snowflake.com/en/sql-reference/data-types-datetime#time)) + - `TIMESTAMP_LTZ` has default precision set to 9 if not provided (following the [docs](https://docs.snowflake.com/en/sql-reference/data-types-datetime#timestamp)); supported aliases: `TIMESTAMPLTZ`, `TIMESTAMP WITH LOCAL TIME ZONE`. + - `TIMESTAMP_NTZ` has default precision set to 9 if not provided (following the [docs](https://docs.snowflake.com/en/sql-reference/data-types-datetime#timestamp)); supported aliases: `TIMESTAMPNTZ`, `TIMESTAMP WITHOUT TIME ZONE`, `DATETIME`. + - `TIMESTAMP_TZ` has default precision set to 9 if not provided (following the [docs](https://docs.snowflake.com/en/sql-reference/data-types-datetime#timestamp)); supported aliases: `TIMESTAMPTZ`, `TIMESTAMP WITH TIME ZONE`. +- The session-settable `TIMESTAMP` is NOT supported ([docs](https://docs.snowflake.com/en/sql-reference/data-types-datetime#timestamp)) +- `VECTOR` type still is limited and will be addressed soon (probably before the release so it will be edited) + ## v0.98.0 ➞ v0.99.0 ### snowflake_tasks data source changes @@ -39,7 +135,7 @@ data "snowflake_tasks" "new_tasks" { in { # for IN SCHEMA specify: schema = "." - + # for IN DATABASE specify: database = "" } @@ -65,7 +161,7 @@ New fields: - `config` - enables to specify JSON-formatted metadata that can be retrieved in the `sql_statement` by using [SYSTEM$GET_TASK_GRAPH_CONFIG](https://docs.snowflake.com/en/sql-reference/functions/system_get_task_graph_config). - `show_output` and `parameters` fields added for holding SHOW and SHOW PARAMETERS output (see [raw Snowflake output](./v1-preparations/CHANGES_BEFORE_V1.md#raw-snowflake-output)). - Added support for finalizer tasks with `finalize` field. It conflicts with `after` and `schedule` (see [finalizer tasks](https://docs.snowflake.com/en/user-guide/tasks-graphs#release-and-cleanup-of-task-graphs)). - + Changes: - `enabled` field changed to `started` and type changed to string with only boolean values available (see ["empty" values](./v1-preparations/CHANGES_BEFORE_V1.md#empty-values)). It is also now required field, so make sure it's explicitly set (previously it was optional with the default value set to `false`). - `allow_overlapping_execution` type was changed to string with only boolean values available (see ["empty" values](./v1-preparations/CHANGES_BEFORE_V1.md#empty-values)). Previously, it had the default set to `false` which will be migrated. If nothing will be set the provider will plan the change to `default` value. If you want to make sure it's turned off, set it explicitly to `false`. @@ -132,7 +228,7 @@ resource "snowflake_task" "example" { ``` - `after` field type was changed from `list` to `set` and the values were changed from names to fully qualified names. - + Before: ```terraform resource "snowflake_task" "example" { diff --git a/docs/resources/tag_association.md b/docs/resources/tag_association.md index 77acc40091..fef230c0c4 100644 --- a/docs/resources/tag_association.md +++ b/docs/resources/tag_association.md @@ -2,12 +2,20 @@ page_title: "snowflake_tag_association Resource - terraform-provider-snowflake" subcategory: "" description: |- - + Resource used to manage tag associations. For more information, check object tagging documentation https://docs.snowflake.com/en/user-guide/object-tagging. --- -# snowflake_tag_association (Resource) +!> **V1 release candidate** This resource was reworked and is a release candidate for the V1. We do not expect significant changes in it before the V1. We will welcome any feedback and adjust the resource if needed. Any errors reported will be resolved with a higher priority. We encourage checking this resource out before the V1 release. Please follow the [migration guide](https://github.com/Snowflake-Labs/terraform-provider-snowflake/blob/main/MIGRATION_GUIDE.md#v0980--v0990) to use it. + +-> **Note** For `ACCOUNT` object type, only identifiers with organization name are supported. See [account identifier docs](https://docs.snowflake.com/en/user-guide/admin-account-identifier#format-1-preferred-account-name-in-your-organization) for more details. + +-> **Note** Tag association resource ID has the following format: `"TAG_DATABASE"."TAG_SCHEMA"."TAG_NAME"|TAG_VALUE|OBJECT_TYPE`. This means that a tuple of tag ID, tag value and object type should be unique across the resources. If you want to specify this combination for more than one object, you should use only one `tag_association` resource with specified `object_identifiers` set. +-> **Note** If you want to change tag value to a value that is already present in another `tag_association` resource, first remove the relevant `object_identifiers` from the resource with the old value, run `terraform apply`, then add the relevant `object_identifiers` in the resource with new value, and run `terrafrom apply` once again. + +# snowflake_tag_association (Resource) +Resource used to manage tag associations. For more information, check [object tagging documentation](https://docs.snowflake.com/en/user-guide/object-tagging). ## Example Usage @@ -29,12 +37,10 @@ resource "snowflake_tag" "test" { } resource "snowflake_tag_association" "db_association" { - object_identifier { - name = snowflake_database.test.name - } - object_type = "DATABASE" - tag_id = snowflake_tag.test.id - tag_value = "finance" + object_identifiers = [snowflake_database.test.fully_qualified_name] + object_type = "DATABASE" + tag_id = snowflake_tag.test.fully_qualified_name + tag_value = "finance" } resource "snowflake_table" "test" { @@ -53,28 +59,26 @@ resource "snowflake_table" "test" { } resource "snowflake_tag_association" "table_association" { - object_identifier { - name = snowflake_table.test.name - database = snowflake_database.test.name - schema = snowflake_schema.test.name - } - object_type = "TABLE" - tag_id = snowflake_tag.test.id - tag_value = "engineering" + object_identifiers = [snowflake_table.test.fully_qualified_name] + object_type = "TABLE" + tag_id = snowflake_tag.test.fully_qualified_name + tag_value = "engineering" } resource "snowflake_tag_association" "column_association" { - object_identifier { - name = "${snowflake_table.test.name}.column_name" - database = snowflake_database.test.name - schema = snowflake_schema.test.name - } - object_type = "COLUMN" - tag_id = snowflake_tag.test.id - tag_value = "engineering" + object_identifiers = [snowflake_database.test.fully_qualified_name] + object_type = "COLUMN" + tag_id = snowflake_tag.test.fully_qualified_name + tag_value = "engineering" } -``` +resource "snowflake_tag_association" "account_association" { + object_identifiers = ["\"ORGANIZATION_NAME\".\"ACCOUNT_NAME\""] + object_type = "ACCOUNT" + tag_id = snowflake_tag.test.fully_qualified_name + tag_value = "engineering" +} +``` -> **Note** Instead of using fully_qualified_name, you can reference objects managed outside Terraform by constructing a correct ID, consult [identifiers guide](https://registry.terraform.io/providers/Snowflake-Labs/snowflake/latest/docs/guides/identifiers#new-computed-fully-qualified-name-field-in-resources). @@ -83,9 +87,9 @@ resource "snowflake_tag_association" "column_association" { ### Required -- `object_identifier` (Block List, Min: 1) Specifies the object identifier for the tag association. (see [below for nested schema](#nestedblock--object_identifier)) +- `object_identifiers` (Set of String) Specifies the object identifiers for the tag association. - `object_type` (String) Specifies the type of object to add a tag. Allowed object types: [ACCOUNT APPLICATION APPLICATION PACKAGE DATABASE FAILOVER GROUP INTEGRATION NETWORK POLICY REPLICATION GROUP ROLE SHARE USER WAREHOUSE DATABASE ROLE SCHEMA ALERT SNOWFLAKE.CORE.BUDGET SNOWFLAKE.ML.CLASSIFICATION EXTERNAL FUNCTION EXTERNAL TABLE FUNCTION GIT REPOSITORY ICEBERG TABLE MATERIALIZED VIEW PIPE MASKING POLICY PASSWORD POLICY ROW ACCESS POLICY SESSION POLICY PRIVACY POLICY PROCEDURE STAGE STREAM TABLE TASK VIEW COLUMN EVENT TABLE]. -- `tag_id` (String) Specifies the identifier for the tag. Note: format must follow: "databaseName"."schemaName"."tagName" or "databaseName.schemaName.tagName" or "databaseName|schemaName.tagName" (snowflake_tag.tag.id) +- `tag_id` (String) Specifies the identifier for the tag. - `tag_value` (String) Specifies the value of the tag, (e.g. 'finance' or 'engineering') ### Optional @@ -98,19 +102,6 @@ resource "snowflake_tag_association" "column_association" { - `id` (String) The ID of this resource. - -### Nested Schema for `object_identifier` - -Required: - -- `name` (String) Name of the object to associate the tag with. - -Optional: - -- `database` (String) Name of the database that the object was created in. -- `schema` (String) Name of the schema that the object was created in. - - ### Nested Schema for `timeouts` @@ -120,9 +111,10 @@ Optional: ## Import +~> **Note** Due to technical limitations of Terraform SDK, `object_identifiers` are not set during import state. Please run `terraform refresh` after importing to get this field populated. + Import is supported using the following syntax: ```shell -# format is dbName.schemaName.tagName or dbName.schemaName.tagName -terraform import snowflake_tag_association.example 'dbName.schemaName.tagName' +terraform import snowflake_tag_association.example '"TAG_DATABASE"."TAG_SCHEMA"."TAG_NAME"|TAG_VALUE|OBJECT_TYPE' ``` diff --git a/docs/technical-documentation/grants_redesign_design_decisions.md b/docs/technical-documentation/grants_redesign_design_decisions.md index b70af89be8..ced7616159 100644 --- a/docs/technical-documentation/grants_redesign_design_decisions.md +++ b/docs/technical-documentation/grants_redesign_design_decisions.md @@ -13,7 +13,7 @@ Here’s a list of resources and data sources we introduced during the grant red - [snowflake_grant_privileges_to_account_role](https://registry.terraform.io/providers/Snowflake-Labs/snowflake/latest/docs/resources/grant_privileges_to_account_role) - [snowflake_grant_account_role](https://registry.terraform.io/providers/Snowflake-Labs/snowflake/latest/docs/resources/grant_account_role) - [snowflake_grant_database_role](https://registry.terraform.io/providers/Snowflake-Labs/snowflake/latest/docs/resources/grant_database_role) -- snowflake_grant_application_role (coming soon) +- [snowflake_grant_application_role](https://registry.terraform.io/providers/Snowflake-Labs/snowflake/latest/docs/resources/grant_application_role) - [snowflake_grant_privileges_to_share](https://registry.terraform.io/providers/Snowflake-Labs/snowflake/latest/docs/resources/grant_privileges_to_share) - [snowflake_grant_ownership](https://registry.terraform.io/providers/Snowflake-Labs/snowflake/latest/docs/resources/grant_ownership) diff --git a/examples/resources/snowflake_tag_association/import.sh b/examples/resources/snowflake_tag_association/import.sh index 8b55fc9a15..a3339e11e9 100644 --- a/examples/resources/snowflake_tag_association/import.sh +++ b/examples/resources/snowflake_tag_association/import.sh @@ -1,2 +1 @@ -# format is dbName.schemaName.tagName or dbName.schemaName.tagName -terraform import snowflake_tag_association.example 'dbName.schemaName.tagName' +terraform import snowflake_tag_association.example '"TAG_DATABASE"."TAG_SCHEMA"."TAG_NAME"|TAG_VALUE|OBJECT_TYPE' diff --git a/examples/resources/snowflake_tag_association/resource.tf b/examples/resources/snowflake_tag_association/resource.tf index 36d5fbf7de..00a3cc1324 100644 --- a/examples/resources/snowflake_tag_association/resource.tf +++ b/examples/resources/snowflake_tag_association/resource.tf @@ -15,12 +15,10 @@ resource "snowflake_tag" "test" { } resource "snowflake_tag_association" "db_association" { - object_identifier { - name = snowflake_database.test.name - } - object_type = "DATABASE" - tag_id = snowflake_tag.test.id - tag_value = "finance" + object_identifiers = [snowflake_database.test.fully_qualified_name] + object_type = "DATABASE" + tag_id = snowflake_tag.test.fully_qualified_name + tag_value = "finance" } resource "snowflake_table" "test" { @@ -39,23 +37,22 @@ resource "snowflake_table" "test" { } resource "snowflake_tag_association" "table_association" { - object_identifier { - name = snowflake_table.test.name - database = snowflake_database.test.name - schema = snowflake_schema.test.name - } - object_type = "TABLE" - tag_id = snowflake_tag.test.id - tag_value = "engineering" + object_identifiers = [snowflake_table.test.fully_qualified_name] + object_type = "TABLE" + tag_id = snowflake_tag.test.fully_qualified_name + tag_value = "engineering" } resource "snowflake_tag_association" "column_association" { - object_identifier { - name = "${snowflake_table.test.name}.column_name" - database = snowflake_database.test.name - schema = snowflake_schema.test.name - } - object_type = "COLUMN" - tag_id = snowflake_tag.test.id - tag_value = "engineering" + object_identifiers = [snowflake_database.test.fully_qualified_name] + object_type = "COLUMN" + tag_id = snowflake_tag.test.fully_qualified_name + tag_value = "engineering" +} + +resource "snowflake_tag_association" "account_association" { + object_identifiers = ["\"ORGANIZATION_NAME\".\"ACCOUNT_NAME\""] + object_type = "ACCOUNT" + tag_id = snowflake_tag.test.fully_qualified_name + tag_value = "engineering" } diff --git a/pkg/acceptance/bettertestspoc/README.md b/pkg/acceptance/bettertestspoc/README.md index 82a7b98f17..a7c4d4eec0 100644 --- a/pkg/acceptance/bettertestspoc/README.md +++ b/pkg/acceptance/bettertestspoc/README.md @@ -217,7 +217,7 @@ it will result in: object WAREHOUSE["XHZJCKAT_35D0BCC1_7797_974E_ACAF_C622C56FA2D2"] assertion [7/13]: failed with error: expected scaling policy: ECONOMY; got: STANDARD object WAREHOUSE["XHZJCKAT_35D0BCC1_7797_974E_ACAF_C622C56FA2D2"] assertion [8/13]: failed with error: expected auto suspend: 123; got: 600 object WAREHOUSE["XHZJCKAT_35D0BCC1_7797_974E_ACAF_C622C56FA2D2"] assertion [9/13]: failed with error: expected auto resume: false; got: true - object WAREHOUSE["XHZJCKAT_35D0BCC1_7797_974E_ACAF_C622C56FA2D2"] assertion [10/13]: failed with error: expected resource monitor: some-id; got: + object WAREHOUSE["XHZJCKAT_35D0BCC1_7797_974E_ACAF_C622C56FA2D2"] assertion [10/13]: failed with error: expected resource monitor: some-id; got: object WAREHOUSE["XHZJCKAT_35D0BCC1_7797_974E_ACAF_C622C56FA2D2"] assertion [11/13]: failed with error: expected comment: bad comment; got: Who does encouraging eagerly annoying dream several their scold straightaway. object WAREHOUSE["XHZJCKAT_35D0BCC1_7797_974E_ACAF_C622C56FA2D2"] assertion [12/13]: failed with error: expected enable query acceleration: true; got: false object WAREHOUSE["XHZJCKAT_35D0BCC1_7797_974E_ACAF_C622C56FA2D2"] assertion [13/13]: failed with error: expected query acceleration max scale factor: 12; got: 8 @@ -281,17 +281,17 @@ it will result in: WBJKHLAT_2E52D1E6_D23D_33A0_F568_4EBDBE083B65 imported resource assertion [6/12]: failed with error: expected: ECONOMY, got: STANDARD WBJKHLAT_2E52D1E6_D23D_33A0_F568_4EBDBE083B65 imported resource assertion [7/12]: failed with error: expected: 123, got: 600 WBJKHLAT_2E52D1E6_D23D_33A0_F568_4EBDBE083B65 imported resource assertion [8/12]: failed with error: expected: false, got: true - WBJKHLAT_2E52D1E6_D23D_33A0_F568_4EBDBE083B65 imported resource assertion [9/12]: failed with error: expected: abc, got: + WBJKHLAT_2E52D1E6_D23D_33A0_F568_4EBDBE083B65 imported resource assertion [9/12]: failed with error: expected: abc, got: WBJKHLAT_2E52D1E6_D23D_33A0_F568_4EBDBE083B65 imported resource assertion [10/12]: failed with error: expected: bad comment, got: Promise my huh off certain you bravery dynasty with Roman. WBJKHLAT_2E52D1E6_D23D_33A0_F568_4EBDBE083B65 imported resource assertion [11/12]: failed with error: expected: true, got: false WBJKHLAT_2E52D1E6_D23D_33A0_F568_4EBDBE083B65 imported resource assertion [12/12]: failed with error: expected: 16, got: 8 check 9/11 error: WBJKHLAT_2E52D1E6_D23D_33A0_F568_4EBDBE083B65 imported parameters assertion [2/7]: failed with error: expected: 1, got: 8 - WBJKHLAT_2E52D1E6_D23D_33A0_F568_4EBDBE083B65 imported parameters assertion [3/7]: failed with error: expected: WAREHOUSE, got: + WBJKHLAT_2E52D1E6_D23D_33A0_F568_4EBDBE083B65 imported parameters assertion [3/7]: failed with error: expected: WAREHOUSE, got: WBJKHLAT_2E52D1E6_D23D_33A0_F568_4EBDBE083B65 imported parameters assertion [4/7]: failed with error: expected: 23, got: 0 - WBJKHLAT_2E52D1E6_D23D_33A0_F568_4EBDBE083B65 imported parameters assertion [5/7]: failed with error: expected: WAREHOUSE, got: + WBJKHLAT_2E52D1E6_D23D_33A0_F568_4EBDBE083B65 imported parameters assertion [5/7]: failed with error: expected: WAREHOUSE, got: WBJKHLAT_2E52D1E6_D23D_33A0_F568_4EBDBE083B65 imported parameters assertion [6/7]: failed with error: expected: 1232, got: 172800 - WBJKHLAT_2E52D1E6_D23D_33A0_F568_4EBDBE083B65 imported parameters assertion [7/7]: failed with error: expected: WAREHOUSE, got: + WBJKHLAT_2E52D1E6_D23D_33A0_F568_4EBDBE083B65 imported parameters assertion [7/7]: failed with error: expected: WAREHOUSE, got: check 10/11 error: object WAREHOUSE["WBJKHLAT_2E52D1E6_D23D_33A0_F568_4EBDBE083B65"] assertion [1/13]: failed with error: expected name: bad name; got: WBJKHLAT_2E52D1E6_D23D_33A0_F568_4EBDBE083B65 object WAREHOUSE["WBJKHLAT_2E52D1E6_D23D_33A0_F568_4EBDBE083B65"] assertion [2/13]: failed with error: expected state: SUSPENDED; got: STARTED @@ -302,7 +302,7 @@ it will result in: object WAREHOUSE["WBJKHLAT_2E52D1E6_D23D_33A0_F568_4EBDBE083B65"] assertion [7/13]: failed with error: expected scaling policy: ECONOMY; got: STANDARD object WAREHOUSE["WBJKHLAT_2E52D1E6_D23D_33A0_F568_4EBDBE083B65"] assertion [8/13]: failed with error: expected auto suspend: 123; got: 600 object WAREHOUSE["WBJKHLAT_2E52D1E6_D23D_33A0_F568_4EBDBE083B65"] assertion [9/13]: failed with error: expected auto resume: false; got: true - object WAREHOUSE["WBJKHLAT_2E52D1E6_D23D_33A0_F568_4EBDBE083B65"] assertion [10/13]: failed with error: expected resource monitor: some-id; got: + object WAREHOUSE["WBJKHLAT_2E52D1E6_D23D_33A0_F568_4EBDBE083B65"] assertion [10/13]: failed with error: expected resource monitor: some-id; got: object WAREHOUSE["WBJKHLAT_2E52D1E6_D23D_33A0_F568_4EBDBE083B65"] assertion [11/13]: failed with error: expected comment: bad comment; got: Promise my huh off certain you bravery dynasty with Roman. object WAREHOUSE["WBJKHLAT_2E52D1E6_D23D_33A0_F568_4EBDBE083B65"] assertion [12/13]: failed with error: expected enable query acceleration: true; got: false object WAREHOUSE["WBJKHLAT_2E52D1E6_D23D_33A0_F568_4EBDBE083B65"] assertion [13/13]: failed with error: expected query acceleration max scale factor: 12; got: 8 @@ -331,7 +331,7 @@ it will result in: ``` it will result in: ``` - commons.go:101: + commons.go:101: Error Trace: /Users/asawicki/Projects/terraform-provider-snowflake/pkg/sdk/testint/warehouses_integration_test.go:149 Error: Received unexpected error: object WAREHOUSE["VKSENEIT_535F314F_6549_348F_370E_AB430EE4BC7B"] assertion [1/13]: failed with error: expected name: bad name; got: VKSENEIT_535F314F_6549_348F_370E_AB430EE4BC7B @@ -402,6 +402,7 @@ func (w *WarehouseDatasourceShowOutputAssert) IsEmpty() { - consider duplicating the builders template from resource (currently same template used for datasources and provider which limits the customization possibilities for just one block type) - consider merging ResourceModel with DatasourceModel (currently the implementation is really similar) - remove schema.TypeMap workaround or make it wiser (e.g. during generation we could programmatically gather all schema.TypeMap and use this workaround only for them) +- support asserting resource id in `assert/resourceassert/*_gen.go` ## Known limitations - generating provider config may misbehave when used only with one object/map paramter (like `params`), e.g.: diff --git a/pkg/acceptance/bettertestspoc/assert/resourceassert/gen/resource_schema_def.go b/pkg/acceptance/bettertestspoc/assert/resourceassert/gen/resource_schema_def.go index bcbe79ed5b..0352763bb0 100644 --- a/pkg/acceptance/bettertestspoc/assert/resourceassert/gen/resource_schema_def.go +++ b/pkg/acceptance/bettertestspoc/assert/resourceassert/gen/resource_schema_def.go @@ -109,6 +109,10 @@ var allResourceSchemaDefs = []ResourceSchemaDef{ name: "Tag", schema: resources.Tag().Schema, }, + { + name: "TagAssociation", + schema: resources.TagAssociation().Schema, + }, { name: "Task", schema: resources.Task().Schema, diff --git a/pkg/acceptance/bettertestspoc/assert/resourceassert/tag_association_resource_ext.go b/pkg/acceptance/bettertestspoc/assert/resourceassert/tag_association_resource_ext.go new file mode 100644 index 0000000000..d7ae9ed91d --- /dev/null +++ b/pkg/acceptance/bettertestspoc/assert/resourceassert/tag_association_resource_ext.go @@ -0,0 +1,17 @@ +package resourceassert + +import ( + "fmt" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/bettertestspoc/assert" +) + +func (t *TagAssociationResourceAssert) HasObjectIdentifiersLength(len int) *TagAssociationResourceAssert { + t.AddAssertion(assert.ValueSet("object_identifiers.#", fmt.Sprintf("%d", len))) + return t +} + +func (t *TagAssociationResourceAssert) HasIdString(expected string) *TagAssociationResourceAssert { + t.AddAssertion(assert.ValueSet("id", expected)) + return t +} diff --git a/pkg/acceptance/bettertestspoc/assert/resourceassert/tag_association_resource_gen.go b/pkg/acceptance/bettertestspoc/assert/resourceassert/tag_association_resource_gen.go new file mode 100644 index 0000000000..d9c10ba7b2 --- /dev/null +++ b/pkg/acceptance/bettertestspoc/assert/resourceassert/tag_association_resource_gen.go @@ -0,0 +1,97 @@ +// Code generated by assertions generator; DO NOT EDIT. + +package resourceassert + +import ( + "testing" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/bettertestspoc/assert" +) + +type TagAssociationResourceAssert struct { + *assert.ResourceAssert +} + +func TagAssociationResource(t *testing.T, name string) *TagAssociationResourceAssert { + t.Helper() + + return &TagAssociationResourceAssert{ + ResourceAssert: assert.NewResourceAssert(name, "resource"), + } +} + +func ImportedTagAssociationResource(t *testing.T, id string) *TagAssociationResourceAssert { + t.Helper() + + return &TagAssociationResourceAssert{ + ResourceAssert: assert.NewImportedResourceAssert(id, "imported resource"), + } +} + +/////////////////////////////////// +// Attribute value string checks // +/////////////////////////////////// + +func (t *TagAssociationResourceAssert) HasObjectIdentifiersString(expected string) *TagAssociationResourceAssert { + t.AddAssertion(assert.ValueSet("object_identifiers", expected)) + return t +} + +func (t *TagAssociationResourceAssert) HasObjectNameString(expected string) *TagAssociationResourceAssert { + t.AddAssertion(assert.ValueSet("object_name", expected)) + return t +} + +func (t *TagAssociationResourceAssert) HasObjectTypeString(expected string) *TagAssociationResourceAssert { + t.AddAssertion(assert.ValueSet("object_type", expected)) + return t +} + +func (t *TagAssociationResourceAssert) HasSkipValidationString(expected string) *TagAssociationResourceAssert { + t.AddAssertion(assert.ValueSet("skip_validation", expected)) + return t +} + +func (t *TagAssociationResourceAssert) HasTagIdString(expected string) *TagAssociationResourceAssert { + t.AddAssertion(assert.ValueSet("tag_id", expected)) + return t +} + +func (t *TagAssociationResourceAssert) HasTagValueString(expected string) *TagAssociationResourceAssert { + t.AddAssertion(assert.ValueSet("tag_value", expected)) + return t +} + +//////////////////////////// +// Attribute empty checks // +//////////////////////////// + +func (t *TagAssociationResourceAssert) HasNoObjectIdentifiers() *TagAssociationResourceAssert { + t.AddAssertion(assert.ValueNotSet("object_identifiers")) + return t +} + +func (t *TagAssociationResourceAssert) HasNoObjectName() *TagAssociationResourceAssert { + t.AddAssertion(assert.ValueNotSet("object_name")) + return t +} + +func (t *TagAssociationResourceAssert) HasNoObjectType() *TagAssociationResourceAssert { + t.AddAssertion(assert.ValueNotSet("object_type")) + return t +} + +func (t *TagAssociationResourceAssert) HasNoSkipValidation() *TagAssociationResourceAssert { + t.AddAssertion(assert.ValueNotSet("skip_validation")) + return t +} + +func (t *TagAssociationResourceAssert) HasNoTagId() *TagAssociationResourceAssert { + t.AddAssertion(assert.ValueNotSet("tag_id")) + return t +} + +func (t *TagAssociationResourceAssert) HasNoTagValue() *TagAssociationResourceAssert { + t.AddAssertion(assert.ValueNotSet("tag_value")) + return t +} diff --git a/pkg/acceptance/bettertestspoc/assert/resourceassert/tag_resource_gen.go b/pkg/acceptance/bettertestspoc/assert/resourceassert/tag_resource_gen.go index 27102d9656..919658aff2 100644 --- a/pkg/acceptance/bettertestspoc/assert/resourceassert/tag_resource_gen.go +++ b/pkg/acceptance/bettertestspoc/assert/resourceassert/tag_resource_gen.go @@ -52,8 +52,8 @@ func (t *TagResourceAssert) HasFullyQualifiedNameString(expected string) *TagRes return t } -func (t *TagResourceAssert) HasMaskingPolicyString(expected string) *TagResourceAssert { - t.AddAssertion(assert.ValueSet("masking_policy", expected)) +func (t *TagResourceAssert) HasMaskingPoliciesString(expected string) *TagResourceAssert { + t.AddAssertion(assert.ValueSet("masking_policies", expected)) return t } diff --git a/pkg/acceptance/bettertestspoc/config/model/primary_connection_model_gen.go b/pkg/acceptance/bettertestspoc/config/model/primary_connection_model_gen.go index f8f29bf1cf..3cbb735d91 100644 --- a/pkg/acceptance/bettertestspoc/config/model/primary_connection_model_gen.go +++ b/pkg/acceptance/bettertestspoc/config/model/primary_connection_model_gen.go @@ -13,6 +13,7 @@ type PrimaryConnectionModel struct { Comment tfconfig.Variable `json:"comment,omitempty"` EnableFailoverToAccounts tfconfig.Variable `json:"enable_failover_to_accounts,omitempty"` FullyQualifiedName tfconfig.Variable `json:"fully_qualified_name,omitempty"` + IsPrimary tfconfig.Variable `json:"is_primary,omitempty"` Name tfconfig.Variable `json:"name,omitempty"` *config.ResourceModelMeta @@ -55,6 +56,11 @@ func (p *PrimaryConnectionModel) WithFullyQualifiedName(fullyQualifiedName strin return p } +func (p *PrimaryConnectionModel) WithIsPrimary(isPrimary bool) *PrimaryConnectionModel { + p.IsPrimary = tfconfig.BoolVariable(isPrimary) + return p +} + func (p *PrimaryConnectionModel) WithName(name string) *PrimaryConnectionModel { p.Name = tfconfig.StringVariable(name) return p @@ -79,6 +85,11 @@ func (p *PrimaryConnectionModel) WithFullyQualifiedNameValue(value tfconfig.Vari return p } +func (p *PrimaryConnectionModel) WithIsPrimaryValue(value tfconfig.Variable) *PrimaryConnectionModel { + p.IsPrimary = value + return p +} + func (p *PrimaryConnectionModel) WithNameValue(value tfconfig.Variable) *PrimaryConnectionModel { p.Name = value return p diff --git a/pkg/acceptance/bettertestspoc/config/model/stream_on_directory_table_model_gen.go b/pkg/acceptance/bettertestspoc/config/model/stream_on_directory_table_model_gen.go index 4956e940c9..cc2270d46c 100644 --- a/pkg/acceptance/bettertestspoc/config/model/stream_on_directory_table_model_gen.go +++ b/pkg/acceptance/bettertestspoc/config/model/stream_on_directory_table_model_gen.go @@ -17,6 +17,8 @@ type StreamOnDirectoryTableModel struct { Name tfconfig.Variable `json:"name,omitempty"` Schema tfconfig.Variable `json:"schema,omitempty"` Stage tfconfig.Variable `json:"stage,omitempty"` + Stale tfconfig.Variable `json:"stale,omitempty"` + StreamType tfconfig.Variable `json:"stream_type,omitempty"` *config.ResourceModelMeta } @@ -93,6 +95,16 @@ func (s *StreamOnDirectoryTableModel) WithStage(stage string) *StreamOnDirectory return s } +func (s *StreamOnDirectoryTableModel) WithStale(stale bool) *StreamOnDirectoryTableModel { + s.Stale = tfconfig.BoolVariable(stale) + return s +} + +func (s *StreamOnDirectoryTableModel) WithStreamType(streamType string) *StreamOnDirectoryTableModel { + s.StreamType = tfconfig.StringVariable(streamType) + return s +} + ////////////////////////////////////////// // below it's possible to set any value // ////////////////////////////////////////// @@ -131,3 +143,13 @@ func (s *StreamOnDirectoryTableModel) WithStageValue(value tfconfig.Variable) *S s.Stage = value return s } + +func (s *StreamOnDirectoryTableModel) WithStaleValue(value tfconfig.Variable) *StreamOnDirectoryTableModel { + s.Stale = value + return s +} + +func (s *StreamOnDirectoryTableModel) WithStreamTypeValue(value tfconfig.Variable) *StreamOnDirectoryTableModel { + s.StreamType = value + return s +} diff --git a/pkg/acceptance/bettertestspoc/config/model/stream_on_external_table_model_gen.go b/pkg/acceptance/bettertestspoc/config/model/stream_on_external_table_model_gen.go index 09c87c5e23..de66941879 100644 --- a/pkg/acceptance/bettertestspoc/config/model/stream_on_external_table_model_gen.go +++ b/pkg/acceptance/bettertestspoc/config/model/stream_on_external_table_model_gen.go @@ -20,6 +20,8 @@ type StreamOnExternalTableModel struct { InsertOnly tfconfig.Variable `json:"insert_only,omitempty"` Name tfconfig.Variable `json:"name,omitempty"` Schema tfconfig.Variable `json:"schema,omitempty"` + Stale tfconfig.Variable `json:"stale,omitempty"` + StreamType tfconfig.Variable `json:"stream_type,omitempty"` *config.ResourceModelMeta } @@ -105,6 +107,16 @@ func (s *StreamOnExternalTableModel) WithSchema(schema string) *StreamOnExternal return s } +func (s *StreamOnExternalTableModel) WithStale(stale bool) *StreamOnExternalTableModel { + s.Stale = tfconfig.BoolVariable(stale) + return s +} + +func (s *StreamOnExternalTableModel) WithStreamType(streamType string) *StreamOnExternalTableModel { + s.StreamType = tfconfig.StringVariable(streamType) + return s +} + ////////////////////////////////////////// // below it's possible to set any value // ////////////////////////////////////////// @@ -158,3 +170,13 @@ func (s *StreamOnExternalTableModel) WithSchemaValue(value tfconfig.Variable) *S s.Schema = value return s } + +func (s *StreamOnExternalTableModel) WithStaleValue(value tfconfig.Variable) *StreamOnExternalTableModel { + s.Stale = value + return s +} + +func (s *StreamOnExternalTableModel) WithStreamTypeValue(value tfconfig.Variable) *StreamOnExternalTableModel { + s.StreamType = value + return s +} diff --git a/pkg/acceptance/bettertestspoc/config/model/stream_on_table_model_gen.go b/pkg/acceptance/bettertestspoc/config/model/stream_on_table_model_gen.go index 05cbdb51e9..3c337e9e9a 100644 --- a/pkg/acceptance/bettertestspoc/config/model/stream_on_table_model_gen.go +++ b/pkg/acceptance/bettertestspoc/config/model/stream_on_table_model_gen.go @@ -20,6 +20,8 @@ type StreamOnTableModel struct { Name tfconfig.Variable `json:"name,omitempty"` Schema tfconfig.Variable `json:"schema,omitempty"` ShowInitialRows tfconfig.Variable `json:"show_initial_rows,omitempty"` + Stale tfconfig.Variable `json:"stale,omitempty"` + StreamType tfconfig.Variable `json:"stream_type,omitempty"` Table tfconfig.Variable `json:"table,omitempty"` *config.ResourceModelMeta @@ -106,6 +108,16 @@ func (s *StreamOnTableModel) WithShowInitialRows(showInitialRows string) *Stream return s } +func (s *StreamOnTableModel) WithStale(stale bool) *StreamOnTableModel { + s.Stale = tfconfig.BoolVariable(stale) + return s +} + +func (s *StreamOnTableModel) WithStreamType(streamType string) *StreamOnTableModel { + s.StreamType = tfconfig.StringVariable(streamType) + return s +} + func (s *StreamOnTableModel) WithTable(table string) *StreamOnTableModel { s.Table = tfconfig.StringVariable(table) return s @@ -165,6 +177,16 @@ func (s *StreamOnTableModel) WithShowInitialRowsValue(value tfconfig.Variable) * return s } +func (s *StreamOnTableModel) WithStaleValue(value tfconfig.Variable) *StreamOnTableModel { + s.Stale = value + return s +} + +func (s *StreamOnTableModel) WithStreamTypeValue(value tfconfig.Variable) *StreamOnTableModel { + s.StreamType = value + return s +} + func (s *StreamOnTableModel) WithTableValue(value tfconfig.Variable) *StreamOnTableModel { s.Table = value return s diff --git a/pkg/acceptance/bettertestspoc/config/model/stream_on_view_model_gen.go b/pkg/acceptance/bettertestspoc/config/model/stream_on_view_model_gen.go index b627e07a5c..d4942dd8a6 100644 --- a/pkg/acceptance/bettertestspoc/config/model/stream_on_view_model_gen.go +++ b/pkg/acceptance/bettertestspoc/config/model/stream_on_view_model_gen.go @@ -20,6 +20,8 @@ type StreamOnViewModel struct { Name tfconfig.Variable `json:"name,omitempty"` Schema tfconfig.Variable `json:"schema,omitempty"` ShowInitialRows tfconfig.Variable `json:"show_initial_rows,omitempty"` + Stale tfconfig.Variable `json:"stale,omitempty"` + StreamType tfconfig.Variable `json:"stream_type,omitempty"` View tfconfig.Variable `json:"view,omitempty"` *config.ResourceModelMeta @@ -106,6 +108,16 @@ func (s *StreamOnViewModel) WithShowInitialRows(showInitialRows string) *StreamO return s } +func (s *StreamOnViewModel) WithStale(stale bool) *StreamOnViewModel { + s.Stale = tfconfig.BoolVariable(stale) + return s +} + +func (s *StreamOnViewModel) WithStreamType(streamType string) *StreamOnViewModel { + s.StreamType = tfconfig.StringVariable(streamType) + return s +} + func (s *StreamOnViewModel) WithView(view string) *StreamOnViewModel { s.View = tfconfig.StringVariable(view) return s @@ -165,6 +177,16 @@ func (s *StreamOnViewModel) WithShowInitialRowsValue(value tfconfig.Variable) *S return s } +func (s *StreamOnViewModel) WithStaleValue(value tfconfig.Variable) *StreamOnViewModel { + s.Stale = value + return s +} + +func (s *StreamOnViewModel) WithStreamTypeValue(value tfconfig.Variable) *StreamOnViewModel { + s.StreamType = value + return s +} + func (s *StreamOnViewModel) WithViewValue(value tfconfig.Variable) *StreamOnViewModel { s.View = value return s diff --git a/pkg/acceptance/bettertestspoc/config/model/tag_association_model_ext.go b/pkg/acceptance/bettertestspoc/config/model/tag_association_model_ext.go new file mode 100644 index 0000000000..7d565264fc --- /dev/null +++ b/pkg/acceptance/bettertestspoc/config/model/tag_association_model_ext.go @@ -0,0 +1,16 @@ +package model + +import ( + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + tfconfig "github.com/hashicorp/terraform-plugin-testing/config" +) + +func (t *TagAssociationModel) WithObjectIdentifiers(objectIdentifiers ...sdk.ObjectIdentifier) *TagAssociationModel { + objectIdentifiersStringVariables := make([]tfconfig.Variable, len(objectIdentifiers)) + for i, v := range objectIdentifiers { + objectIdentifiersStringVariables[i] = tfconfig.StringVariable(v.FullyQualifiedName()) + } + + t.ObjectIdentifiers = tfconfig.SetVariable(objectIdentifiersStringVariables...) + return t +} diff --git a/pkg/acceptance/bettertestspoc/config/model/tag_association_model_gen.go b/pkg/acceptance/bettertestspoc/config/model/tag_association_model_gen.go new file mode 100644 index 0000000000..12457d4973 --- /dev/null +++ b/pkg/acceptance/bettertestspoc/config/model/tag_association_model_gen.go @@ -0,0 +1,120 @@ +// Code generated by config model builder generator; DO NOT EDIT. + +package model + +import ( + tfconfig "github.com/hashicorp/terraform-plugin-testing/config" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/bettertestspoc/config" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" +) + +type TagAssociationModel struct { + ObjectIdentifiers tfconfig.Variable `json:"object_identifiers,omitempty"` + ObjectName tfconfig.Variable `json:"object_name,omitempty"` + ObjectType tfconfig.Variable `json:"object_type,omitempty"` + SkipValidation tfconfig.Variable `json:"skip_validation,omitempty"` + TagId tfconfig.Variable `json:"tag_id,omitempty"` + TagValue tfconfig.Variable `json:"tag_value,omitempty"` + + *config.ResourceModelMeta +} + +///////////////////////////////////////////////// +// Basic builders (resource name and required) // +///////////////////////////////////////////////// + +func TagAssociation( + resourceName string, + objectIdentifiers []sdk.ObjectIdentifier, + objectType string, + tagId string, + tagValue string, +) *TagAssociationModel { + t := &TagAssociationModel{ResourceModelMeta: config.Meta(resourceName, resources.TagAssociation)} + t.WithObjectIdentifiers(objectIdentifiers...) + t.WithObjectType(objectType) + t.WithTagId(tagId) + t.WithTagValue(tagValue) + return t +} + +func TagAssociationWithDefaultMeta( + objectIdentifiers []sdk.ObjectIdentifier, + objectType string, + tagId string, + tagValue string, +) *TagAssociationModel { + t := &TagAssociationModel{ResourceModelMeta: config.DefaultMeta(resources.TagAssociation)} + t.WithObjectIdentifiers(objectIdentifiers...) + t.WithObjectType(objectType) + t.WithTagId(tagId) + t.WithTagValue(tagValue) + return t +} + +///////////////////////////////// +// below all the proper values // +///////////////////////////////// + +// object_identifiers attribute type is not yet supported, so WithObjectIdentifiers can't be generated + +func (t *TagAssociationModel) WithObjectName(objectName string) *TagAssociationModel { + t.ObjectName = tfconfig.StringVariable(objectName) + return t +} + +func (t *TagAssociationModel) WithObjectType(objectType string) *TagAssociationModel { + t.ObjectType = tfconfig.StringVariable(objectType) + return t +} + +func (t *TagAssociationModel) WithSkipValidation(skipValidation bool) *TagAssociationModel { + t.SkipValidation = tfconfig.BoolVariable(skipValidation) + return t +} + +func (t *TagAssociationModel) WithTagId(tagId string) *TagAssociationModel { + t.TagId = tfconfig.StringVariable(tagId) + return t +} + +func (t *TagAssociationModel) WithTagValue(tagValue string) *TagAssociationModel { + t.TagValue = tfconfig.StringVariable(tagValue) + return t +} + +////////////////////////////////////////// +// below it's possible to set any value // +////////////////////////////////////////// + +func (t *TagAssociationModel) WithObjectIdentifiersValue(value tfconfig.Variable) *TagAssociationModel { + t.ObjectIdentifiers = value + return t +} + +func (t *TagAssociationModel) WithObjectNameValue(value tfconfig.Variable) *TagAssociationModel { + t.ObjectName = value + return t +} + +func (t *TagAssociationModel) WithObjectTypeValue(value tfconfig.Variable) *TagAssociationModel { + t.ObjectType = value + return t +} + +func (t *TagAssociationModel) WithSkipValidationValue(value tfconfig.Variable) *TagAssociationModel { + t.SkipValidation = value + return t +} + +func (t *TagAssociationModel) WithTagIdValue(value tfconfig.Variable) *TagAssociationModel { + t.TagId = value + return t +} + +func (t *TagAssociationModel) WithTagValueValue(value tfconfig.Variable) *TagAssociationModel { + t.TagValue = value + return t +} diff --git a/pkg/acceptance/bettertestspoc/config/model/tag_model_gen.go b/pkg/acceptance/bettertestspoc/config/model/tag_model_gen.go index 91b5bb9eff..a649154cb2 100644 --- a/pkg/acceptance/bettertestspoc/config/model/tag_model_gen.go +++ b/pkg/acceptance/bettertestspoc/config/model/tag_model_gen.go @@ -71,7 +71,7 @@ func (t *TagModel) WithFullyQualifiedName(fullyQualifiedName string) *TagModel { return t } -// masking_policy attribute type is not yet supported, so WithMaskingPolicy can't be generated +// masking_policies attribute type is not yet supported, so WithMaskingPolicies can't be generated func (t *TagModel) WithName(name string) *TagModel { t.Name = tfconfig.StringVariable(name) diff --git a/pkg/acceptance/check_destroy.go b/pkg/acceptance/check_destroy.go index 57145b726f..404ad98917 100644 --- a/pkg/acceptance/check_destroy.go +++ b/pkg/acceptance/check_destroy.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "reflect" + "strconv" "strings" "testing" @@ -490,6 +491,75 @@ func CheckUserAuthenticationPolicyAttachmentDestroy(t *testing.T) func(*terrafor } } +// CheckResourceTagUnset is a custom check that should be later incorporated into generic CheckDestroy +func CheckResourceTagUnset(t *testing.T) func(*terraform.State) error { + t.Helper() + + return func(s *terraform.State) error { + for _, rs := range s.RootModule().Resources { + if rs.Type != "snowflake_tag_association" { + continue + } + objectType := sdk.ObjectType(rs.Primary.Attributes["object_type"]) + tagId, err := sdk.ParseSchemaObjectIdentifier(rs.Primary.Attributes["tag_id"]) + if err != nil { + return err + } + idLen, err := strconv.Atoi(rs.Primary.Attributes["object_identifiers.#"]) + if err != nil { + return err + } + for i := 0; i < idLen; i++ { + idRaw := rs.Primary.Attributes[fmt.Sprintf("object_identifiers.%d", i)] + var id sdk.ObjectIdentifier + // TODO(SNOW-1229218): Use a common mapper to get object id. + if objectType == sdk.ObjectTypeAccount { + id, err = sdk.ParseAccountIdentifier(idRaw) + if err != nil { + return fmt.Errorf("invalid account id: %w", err) + } + } else { + id, err = sdk.ParseObjectIdentifierString(idRaw) + if err != nil { + return fmt.Errorf("invalid object id: %w", err) + } + } + if err := assertTagUnset(t, tagId, id, objectType); err != nil { + return err + } + } + } + return nil + } +} + +// CheckTagUnset is a custom check that should be later incorporated into generic CheckDestroy +func CheckTagUnset(t *testing.T, tagId sdk.SchemaObjectIdentifier, id sdk.ObjectIdentifier, objectType sdk.ObjectType) func(*terraform.State) error { + t.Helper() + + return func(s *terraform.State) error { + return assertTagUnset(t, tagId, id, objectType) + } +} + +func assertTagUnset(t *testing.T, tagId sdk.SchemaObjectIdentifier, id sdk.ObjectIdentifier, objectType sdk.ObjectType) error { + t.Helper() + + tag, err := TestClient().Tag.GetForObject(t, tagId, id, objectType) + if err != nil { + if strings.Contains(err.Error(), "does not exist or not authorized") { + // Note: this can happen if the referenced object was deleted before; in this case, ignore the error + t.Logf("could not get tag for %v : %v, continuing...", id.FullyQualifiedName(), err) + return nil + } + return err + } + if tag != nil { + return fmt.Errorf("tag %s for object %s expected to be empty, got %s", tagId.FullyQualifiedName(), id.FullyQualifiedName(), *tag) + } + return err +} + func TestAccCheckGrantApplicationRoleDestroy(s *terraform.State) error { client := TestAccProvider.Meta().(*provider.Context).Client for _, rs := range s.RootModule().Resources { diff --git a/pkg/acceptance/helpers/tag_client.go b/pkg/acceptance/helpers/tag_client.go index 90d7cbcb8b..c32598f9b2 100644 --- a/pkg/acceptance/helpers/tag_client.go +++ b/pkg/acceptance/helpers/tag_client.go @@ -52,6 +52,22 @@ func (c *TagClient) CreateWithRequest(t *testing.T, req *sdk.CreateTagRequest) ( return tag, c.DropTagFunc(t, req.GetName()) } +func (c *TagClient) Unset(t *testing.T, objectType sdk.ObjectType, id sdk.ObjectIdentifier, unsetTags []sdk.ObjectIdentifier) { + t.Helper() + ctx := context.Background() + + err := c.client().Unset(ctx, sdk.NewUnsetTagRequest(objectType, id).WithUnsetTags(unsetTags)) + require.NoError(t, err) +} + +func (c *TagClient) Set(t *testing.T, objectType sdk.ObjectType, id sdk.ObjectIdentifier, setTags []sdk.TagAssociation) { + t.Helper() + ctx := context.Background() + + err := c.client().Set(ctx, sdk.NewSetTagRequest(objectType, id).WithSetTags(setTags)) + require.NoError(t, err) +} + func (c *TagClient) Alter(t *testing.T, req *sdk.AlterTagRequest) { t.Helper() ctx := context.Background() @@ -75,3 +91,11 @@ func (c *TagClient) Show(t *testing.T, id sdk.SchemaObjectIdentifier) (*sdk.Tag, return c.client().ShowByID(ctx, id) } + +func (c *TagClient) GetForObject(t *testing.T, tagId sdk.SchemaObjectIdentifier, objectId sdk.ObjectIdentifier, objectType sdk.ObjectType) (*string, error) { + t.Helper() + ctx := context.Background() + client := c.context.client.SystemFunctions + + return client.GetTag(ctx, tagId, objectId, objectType) +} diff --git a/pkg/helpers/helpers.go b/pkg/helpers/helpers.go index 6b3e39d4cf..b4dde1acd7 100644 --- a/pkg/helpers/helpers.go +++ b/pkg/helpers/helpers.go @@ -142,7 +142,8 @@ func DecodeSnowflakeAccountIdentifier(identifier string) (sdk.AccountIdentifier, } } -// TODO: use slices.Concat in Go 1.22 +// ConcatSlices is a temporary replacement for slices.Concat that will be available after we migrate to Go 1.22. +// TODO [SNOW-1844769]: use slices.Concat func ConcatSlices[T any](slices ...[]T) []T { var tmp []T for _, s := range slices { diff --git a/pkg/resources/diff_suppressions.go b/pkg/resources/diff_suppressions.go index 597529e4ec..14efa760b2 100644 --- a/pkg/resources/diff_suppressions.go +++ b/pkg/resources/diff_suppressions.go @@ -9,10 +9,15 @@ import ( "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/collections" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" ) func NormalizeAndCompare[T comparable](normalize func(string) (T, error)) schema.SchemaDiffSuppressFunc { + return NormalizeAndCompareUsingFunc(normalize, func(a, b T) bool { return a == b }) +} + +func NormalizeAndCompareUsingFunc[T any](normalize func(string) (T, error), compareFunc func(a, b T) bool) schema.SchemaDiffSuppressFunc { return func(_, oldValue, newValue string, _ *schema.ResourceData) bool { oldNormalized, err := normalize(oldValue) if err != nil { @@ -22,10 +27,15 @@ func NormalizeAndCompare[T comparable](normalize func(string) (T, error)) schema if err != nil { return false } - return oldNormalized == newNormalized + + return compareFunc(oldNormalized, newNormalized) } } +// DiffSuppressDataTypes handles data type suppression taking into account data type attributes for each type. +// It falls back to Snowflake defaults for arguments if no arguments were provided for the data type. +var DiffSuppressDataTypes = NormalizeAndCompareUsingFunc(datatypes.ParseDataType, datatypes.AreTheSame) + // NormalizeAndCompareIdentifiersInSet is a diff suppression function that should be used at top-level TypeSet fields that // hold identifiers to avoid diffs like: // - "DATABASE"."SCHEMA"."OBJECT" @@ -254,3 +264,15 @@ func IgnoreNewEmptyListOrSubfields(ignoredSubfields ...string) schema.SchemaDiff return len(parts) == 3 && slices.Contains(ignoredSubfields, parts[2]) && new == "" } } + +func ignoreTrimSpaceSuppressFunc(_, old, new string, _ *schema.ResourceData) bool { + return strings.TrimSpace(old) == strings.TrimSpace(new) +} + +func ignoreCaseSuppressFunc(_, old, new string, _ *schema.ResourceData) bool { + return strings.EqualFold(old, new) +} + +func ignoreCaseAndTrimSpaceSuppressFunc(_, old, new string, _ *schema.ResourceData) bool { + return strings.EqualFold(strings.TrimSpace(old), strings.TrimSpace(new)) +} diff --git a/pkg/resources/external_function.go b/pkg/resources/external_function.go index 9459adab6a..2580fb6141 100644 --- a/pkg/resources/external_function.go +++ b/pkg/resources/external_function.go @@ -8,11 +8,11 @@ import ( "strconv" "strings" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/schemas" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/hashicorp/go-cty/cty" "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" @@ -227,11 +227,11 @@ func CreateContextExternalFunction(ctx context.Context, d *schema.ResourceData, for _, arg := range v.([]interface{}) { argName := arg.(map[string]interface{})["name"].(string) argType := arg.(map[string]interface{})["type"].(string) - argDataType, err := sdk.ToDataType(argType) + argDataType, err := datatypes.ParseDataType(argType) if err != nil { return diag.FromErr(err) } - args = append(args, sdk.ExternalFunctionArgumentRequest{ArgName: argName, ArgDataType: argDataType}) + args = append(args, sdk.ExternalFunctionArgumentRequest{ArgName: argName, ArgDataType: sdk.LegacyDataTypeFrom(argDataType)}) } } argTypes := make([]sdk.DataType, 0, len(args)) @@ -241,13 +241,13 @@ func CreateContextExternalFunction(ctx context.Context, d *schema.ResourceData, id := sdk.NewSchemaObjectIdentifierWithArguments(database, schemaName, name, argTypes...) returnType := d.Get("return_type").(string) - resultDataType, err := sdk.ToDataType(returnType) + resultDataType, err := datatypes.ParseDataType(returnType) if err != nil { return diag.FromErr(err) } apiIntegration := sdk.NewAccountObjectIdentifier(d.Get("api_integration").(string)) urlOfProxyAndResource := d.Get("url_of_proxy_and_resource").(string) - req := sdk.NewCreateExternalFunctionRequest(id.SchemaObjectId(), resultDataType, &apiIntegration, urlOfProxyAndResource) + req := sdk.NewCreateExternalFunctionRequest(id.SchemaObjectId(), sdk.LegacyDataTypeFrom(resultDataType), &apiIntegration, urlOfProxyAndResource) // Set optionals if len(args) > 0 { diff --git a/pkg/resources/external_function_state_upgraders.go b/pkg/resources/external_function_state_upgraders.go index aba74585aa..315d7f6caa 100644 --- a/pkg/resources/external_function_state_upgraders.go +++ b/pkg/resources/external_function_state_upgraders.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" ) type v085ExternalFunctionId struct { @@ -52,11 +53,11 @@ func v085ExternalFunctionStateUpgrader(ctx context.Context, rawState map[string] argDataTypes := make([]sdk.DataType, 0) if parsedV085ExternalFunctionId.ExternalFunctionArgTypes != "" { for _, argType := range strings.Split(parsedV085ExternalFunctionId.ExternalFunctionArgTypes, "-") { - argDataType, err := sdk.ToDataType(argType) + argDataType, err := datatypes.ParseDataType(argType) if err != nil { return nil, err } - argDataTypes = append(argDataTypes, argDataType) + argDataTypes = append(argDataTypes, sdk.LegacyDataTypeFrom(argDataType)) } } diff --git a/pkg/resources/external_table.go b/pkg/resources/external_table.go index 833b8c4dd8..56404cf703 100644 --- a/pkg/resources/external_table.go +++ b/pkg/resources/external_table.go @@ -5,17 +5,14 @@ import ( "fmt" "log" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" - "github.com/hashicorp/terraform-plugin-sdk/v2/diag" - + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/schemas" - - "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" - - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" ) var externalTableSchema = map[string]*schema.Schema{ @@ -59,11 +56,11 @@ var externalTableSchema = map[string]*schema.Schema{ ForceNew: true, }, "type": { - Type: schema.TypeString, - Required: true, - Description: "Column type, e.g. VARIANT", - ForceNew: true, - ValidateFunc: IsDataType(), + Type: schema.TypeString, + Required: true, + Description: "Column type, e.g. VARIANT", + ForceNew: true, + ValidateDiagFunc: IsDataTypeValid, }, "as": { Type: schema.TypeString, diff --git a/pkg/resources/function.go b/pkg/resources/function.go index 314439b96d..19415f91f1 100644 --- a/pkg/resources/function.go +++ b/pkg/resources/function.go @@ -7,11 +7,11 @@ import ( "regexp" "strings" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/schemas" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake" "github.com/hashicorp/go-cty/cty" "github.com/hashicorp/terraform-plugin-sdk/v2/diag" @@ -311,7 +311,7 @@ func createScalaFunction(ctx context.Context, d *schema.ResourceData, meta inter functionDefinition := d.Get("statement").(string) handler := d.Get("handler").(string) // create request with required - request := sdk.NewCreateForScalaFunctionRequest(id, returnDataType, handler) + request := sdk.NewCreateForScalaFunctionRequest(id, sdk.LegacyDataTypeFrom(returnDataType), handler) request.WithFunctionDefinition(functionDefinition) // Set optionals @@ -739,16 +739,16 @@ func parseFunctionArguments(d *schema.ResourceData) ([]sdk.FunctionArgumentReque if diags != nil { return nil, diags } - args = append(args, sdk.FunctionArgumentRequest{ArgName: argName, ArgDataType: argDataType}) + args = append(args, sdk.FunctionArgumentRequest{ArgName: argName, ArgDataType: sdk.LegacyDataTypeFrom(argDataType)}) } } return args, nil } -func convertFunctionDataType(s string) (sdk.DataType, diag.Diagnostics) { - dataType, err := sdk.ToDataType(s) +func convertFunctionDataType(s string) (datatypes.DataType, diag.Diagnostics) { + dataType, err := datatypes.ParseDataType(s) if err != nil { - return dataType, diag.FromErr(err) + return nil, diag.FromErr(err) } return dataType, nil } @@ -759,13 +759,13 @@ func convertFunctionColumns(s string) ([]sdk.FunctionColumn, diag.Diagnostics) { var columns []sdk.FunctionColumn for _, match := range matches { if len(match) == 3 { - dataType, err := sdk.ToDataType(match[2]) + dataType, err := datatypes.ParseDataType(match[2]) if err != nil { return nil, diag.FromErr(err) } columns = append(columns, sdk.FunctionColumn{ ColumnName: match[1], - ColumnDataType: dataType, + ColumnDataType: sdk.LegacyDataTypeFrom(dataType), }) } } @@ -789,7 +789,7 @@ func parseFunctionReturnsRequest(s string) (*sdk.FunctionReturnsRequest, diag.Di if diags != nil { return nil, diags } - returns.WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(returnDataType)) + returns.WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(sdk.LegacyDataTypeFrom(returnDataType))) } return returns, nil } diff --git a/pkg/resources/function_state_upgraders.go b/pkg/resources/function_state_upgraders.go index 501e44f1dc..7be3c5b9b8 100644 --- a/pkg/resources/function_state_upgraders.go +++ b/pkg/resources/function_state_upgraders.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" ) type v085FunctionId struct { @@ -48,11 +49,11 @@ func v085FunctionIdStateUpgrader(ctx context.Context, rawState map[string]interf argDataTypes := make([]sdk.DataType, len(parsedV085FunctionId.ArgTypes)) for i, argType := range parsedV085FunctionId.ArgTypes { - argDataType, err := sdk.ToDataType(argType) + argDataType, err := datatypes.ParseDataType(argType) if err != nil { return nil, err } - argDataTypes[i] = argDataType + argDataTypes[i] = sdk.LegacyDataTypeFrom(argDataType) } schemaObjectIdentifierWithArguments := sdk.NewSchemaObjectIdentifierWithArgumentsOld(parsedV085FunctionId.DatabaseName, parsedV085FunctionId.SchemaName, parsedV085FunctionId.FunctionName, argDataTypes) diff --git a/pkg/resources/grant_ownership.go b/pkg/resources/grant_ownership.go index 7173f18380..24077bcd48 100644 --- a/pkg/resources/grant_ownership.go +++ b/pkg/resources/grant_ownership.go @@ -408,7 +408,7 @@ func ReadGrantOwnership(ctx context.Context, d *schema.ResourceData, meta any) d } // TODO(SNOW-1229218): Make sdk.ObjectType + string objectName to sdk.ObjectIdentifier mapping available in the sdk (for all object types). -func getOnObjectIdentifier(objectType sdk.ObjectType, objectName string) (sdk.ObjectIdentifier, error) { +func GetOnObjectIdentifier(objectType sdk.ObjectType, objectName string) (sdk.ObjectIdentifier, error) { switch objectType { case sdk.ObjectTypeComputePool, sdk.ObjectTypeDatabase, @@ -458,6 +458,9 @@ func getOnObjectIdentifier(objectType sdk.ObjectType, objectName string) (sdk.Ob sdk.ObjectTypeProcedure, sdk.ObjectTypeExternalFunction: return sdk.ParseSchemaObjectIdentifierWithArguments(objectName) + case sdk.ObjectTypeColumn: + return sdk.ParseTableColumnIdentifier(objectName) + default: return nil, sdk.NewError(fmt.Sprintf("object_type %s is not supported, please create a feature request for the provider if given object_type should be supported", objectType)) } @@ -475,7 +478,7 @@ func getOwnershipGrantOn(d *schema.ResourceData) (*sdk.OwnershipGrantOn, error) switch { case len(onObjectType) > 0 && len(onObjectName) > 0: objectType := sdk.ObjectType(strings.ToUpper(onObjectType)) - objectName, err := getOnObjectIdentifier(objectType, onObjectName) + objectName, err := GetOnObjectIdentifier(objectType, onObjectName) if err != nil { return nil, err } @@ -626,7 +629,7 @@ func createGrantOwnershipIdFromSchema(d *schema.ResourceData) (*GrantOwnershipId case len(objectType) > 0 && len(objectName) > 0: id.Kind = OnObjectGrantOwnershipKind objectType := sdk.ObjectType(objectType) - objectName, err := getOnObjectIdentifier(objectType, objectName) + objectName, err := GetOnObjectIdentifier(objectType, objectName) if err != nil { return nil, err } diff --git a/pkg/resources/grant_ownership_identifier.go b/pkg/resources/grant_ownership_identifier.go index 2b233d6932..2eff1f7f43 100644 --- a/pkg/resources/grant_ownership_identifier.go +++ b/pkg/resources/grant_ownership_identifier.go @@ -125,7 +125,7 @@ func ParseGrantOwnershipId(id string) (*GrantOwnershipId, error) { return grantOwnershipId, sdk.NewError(`grant ownership identifier should consist of 6 parts "|||OnObject||"`) } objectType := sdk.ObjectType(parts[4]) - objectName, err := getOnObjectIdentifier(objectType, parts[5]) + objectName, err := GetOnObjectIdentifier(objectType, parts[5]) if err != nil { return nil, err } diff --git a/pkg/resources/grant_ownership_test.go b/pkg/resources/grant_ownership_test.go index 208346de46..4cfe2fe3ae 100644 --- a/pkg/resources/grant_ownership_test.go +++ b/pkg/resources/grant_ownership_test.go @@ -89,7 +89,7 @@ func TestGetOnObjectIdentifier(t *testing.T) { for _, tt := range testCases { tt := tt t.Run(tt.Name, func(t *testing.T) { - id, err := getOnObjectIdentifier(tt.ObjectType, tt.ObjectName) + id, err := GetOnObjectIdentifier(tt.ObjectType, tt.ObjectName) if tt.Error == "" { assert.NoError(t, err) assert.Equal(t, tt.Expected, id) diff --git a/pkg/resources/grant_privileges_to_database_role_acceptance_test.go b/pkg/resources/grant_privileges_to_database_role_acceptance_test.go index c1da413547..8b701424e4 100644 --- a/pkg/resources/grant_privileges_to_database_role_acceptance_test.go +++ b/pkg/resources/grant_privileges_to_database_role_acceptance_test.go @@ -27,6 +27,7 @@ func TestAcc_GrantPrivilegesToDatabaseRole_OnDatabase(t *testing.T) { configVariables := config.Variables{ "name": config.StringVariable(databaseRoleId.Name()), "privileges": config.ListVariable( + config.StringVariable(string(sdk.AccountObjectPrivilegeApplyBudget)), config.StringVariable(string(sdk.AccountObjectPrivilegeCreateSchema)), config.StringVariable(string(sdk.AccountObjectPrivilegeModify)), config.StringVariable(string(sdk.AccountObjectPrivilegeUsage)), @@ -53,13 +54,14 @@ func TestAcc_GrantPrivilegesToDatabaseRole_OnDatabase(t *testing.T) { ConfigVariables: configVariables, Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr(resourceName, "database_role_name", databaseRoleId.FullyQualifiedName()), - resource.TestCheckResourceAttr(resourceName, "privileges.#", "3"), - resource.TestCheckResourceAttr(resourceName, "privileges.0", string(sdk.AccountObjectPrivilegeCreateSchema)), - resource.TestCheckResourceAttr(resourceName, "privileges.1", string(sdk.AccountObjectPrivilegeModify)), - resource.TestCheckResourceAttr(resourceName, "privileges.2", string(sdk.AccountObjectPrivilegeUsage)), + resource.TestCheckResourceAttr(resourceName, "privileges.#", "4"), + resource.TestCheckResourceAttr(resourceName, "privileges.0", string(sdk.AccountObjectPrivilegeApplyBudget)), + resource.TestCheckResourceAttr(resourceName, "privileges.1", string(sdk.AccountObjectPrivilegeCreateSchema)), + resource.TestCheckResourceAttr(resourceName, "privileges.2", string(sdk.AccountObjectPrivilegeModify)), + resource.TestCheckResourceAttr(resourceName, "privileges.3", string(sdk.AccountObjectPrivilegeUsage)), resource.TestCheckResourceAttr(resourceName, "on_database", acc.TestClient().Ids.DatabaseId().FullyQualifiedName()), resource.TestCheckResourceAttr(resourceName, "with_grant_option", "true"), - resource.TestCheckResourceAttr(resourceName, "id", fmt.Sprintf("%s|true|false|CREATE SCHEMA,MODIFY,USAGE|OnDatabase|%s", databaseRoleId.FullyQualifiedName(), acc.TestClient().Ids.DatabaseId().FullyQualifiedName())), + resource.TestCheckResourceAttr(resourceName, "id", fmt.Sprintf("%s|true|false|APPLYBUDGET,CREATE SCHEMA,MODIFY,USAGE|OnDatabase|%s", databaseRoleId.FullyQualifiedName(), acc.TestClient().Ids.DatabaseId().FullyQualifiedName())), ), }, { diff --git a/pkg/resources/helper_expansion.go b/pkg/resources/helper_expansion.go index 82a00efcad..fe29b26c15 100644 --- a/pkg/resources/helper_expansion.go +++ b/pkg/resources/helper_expansion.go @@ -1,7 +1,10 @@ package resources import ( + "fmt" "slices" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" ) // borrowed from https://github.com/terraform-providers/terraform-provider-aws/blob/master/aws/structure.go#L924:6 @@ -27,6 +30,29 @@ func expandStringList(configured []interface{}) []string { return vs } +func ExpandObjectIdentifierSet(configured []any, objectType sdk.ObjectType) ([]sdk.ObjectIdentifier, error) { + vs := expandStringList(configured) + ids := make([]sdk.ObjectIdentifier, len(vs)) + for i, idRaw := range vs { + var id sdk.ObjectIdentifier + var err error + // TODO(SNOW-1229218): Use a common mapper to get object id. + if objectType == sdk.ObjectTypeAccount { + id, err = sdk.ParseAccountIdentifier(idRaw) + if err != nil { + return nil, fmt.Errorf("invalid account id: %w", err) + } + } else { + id, err = GetOnObjectIdentifier(objectType, idRaw) + if err != nil { + return nil, fmt.Errorf("invalid object id: %w", err) + } + } + ids[i] = id + } + return ids, nil +} + func expandStringListAllowEmpty(configured []interface{}) []string { // Allow empty values during expansion vs := make([]string, 0, len(configured)) diff --git a/pkg/resources/helpers.go b/pkg/resources/helpers.go index 3c47984e7e..6752840e1e 100644 --- a/pkg/resources/helpers.go +++ b/pkg/resources/helpers.go @@ -3,87 +3,25 @@ package resources import ( "fmt" "slices" - "strings" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/logging" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake" "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" ) -func dataTypeValidateFunc(val interface{}, _ string) (warns []string, errs []error) { - if ok := sdk.IsValidDataType(val.(string)); !ok { - errs = append(errs, fmt.Errorf("%v is not a valid data type", val)) - } - return -} - -func dataTypeDiffSuppressFunc(_, old, new string, _ *schema.ResourceData) bool { - oldDT, err := sdk.ToDataType(old) - if err != nil { - return false - } - newDT, err := sdk.ToDataType(new) - if err != nil { - return false - } - return oldDT == newDT -} - -// DataTypeIssue3007DiffSuppressFunc is a temporary solution to handle data type suppression problems. -// Currently, it handles only number and text data types. -// It falls back to Snowflake defaults for arguments if no arguments were provided for the data type. -// TODO [SNOW-1348103 or SNOW-1348106]: visit with functions and procedures rework -func DataTypeIssue3007DiffSuppressFunc(_, old, new string, _ *schema.ResourceData) bool { - oldDataType, err := sdk.ToDataType(old) - if err != nil { - return false - } - newDataType, err := sdk.ToDataType(new) - if err != nil { - return false - } - if oldDataType != newDataType { - return false - } - switch v := oldDataType; v { - case sdk.DataTypeNumber: - logging.DebugLogger.Printf("[DEBUG] DataTypeIssue3007DiffSuppressFunc: Handling number data type diff suppression") - oldPrecision, oldScale := sdk.ParseNumberDataTypeRaw(old) - newPrecision, newScale := sdk.ParseNumberDataTypeRaw(new) - return oldPrecision == newPrecision && oldScale == newScale - case sdk.DataTypeVARCHAR: - logging.DebugLogger.Printf("[DEBUG] DataTypeIssue3007DiffSuppressFunc: Handling text data type diff suppression") - oldLength := sdk.ParseVarcharDataTypeRaw(old) - newLength := sdk.ParseVarcharDataTypeRaw(new) - return oldLength == newLength +func getTagObjectIdentifier(obj map[string]any) sdk.ObjectIdentifier { + database := obj["database"].(string) + schema := obj["schema"].(string) + name := obj["name"].(string) + switch { + case schema != "": + return sdk.NewSchemaObjectIdentifier(database, schema, name) + case database != "": + return sdk.NewDatabaseObjectIdentifier(database, name) default: - logging.DebugLogger.Printf("[DEBUG] DataTypeIssue3007DiffSuppressFunc: Diff suppression for %s can't be currently handled", v) - } - return true -} - -func ignoreTrimSpaceSuppressFunc(_, old, new string, _ *schema.ResourceData) bool { - return strings.TrimSpace(old) == strings.TrimSpace(new) -} - -func ignoreCaseSuppressFunc(_, old, new string, _ *schema.ResourceData) bool { - return strings.EqualFold(old, new) -} - -func ignoreCaseAndTrimSpaceSuppressFunc(_, old, new string, _ *schema.ResourceData) bool { - return strings.EqualFold(strings.TrimSpace(old), strings.TrimSpace(new)) -} - -func getTagObjectIdentifier(v map[string]any) sdk.ObjectIdentifier { - if _, ok := v["database"]; ok { - if _, ok := v["schema"]; ok { - return sdk.NewSchemaObjectIdentifier(v["database"].(string), v["schema"].(string), v["name"].(string)) - } - return sdk.NewDatabaseObjectIdentifier(v["database"].(string), v["name"].(string)) + return sdk.NewAccountObjectIdentifier(name) } - return sdk.NewAccountObjectIdentifier(v["name"].(string)) } func getPropertyTags(d *schema.ResourceData, key string) []sdk.TagAssociation { @@ -310,25 +248,35 @@ func JoinDiags(diagnostics ...diag.Diagnostics) diag.Diagnostics { return result } -// ListDiff Compares two lists (before and after), then compares and returns two lists that include +// ListDiff compares two lists (before and after), then compares and returns two lists that include // added and removed items between those lists. func ListDiff[T comparable](beforeList []T, afterList []T) (added []T, removed []T) { + added, removed, _ = ListDiffWithCommonItems(beforeList, afterList) + return +} + +// ListDiffWithCommonItems compares two lists (before and after), then compares and returns three lists that include +// added, removed and common items between those lists. +func ListDiffWithCommonItems[T comparable](beforeList []T, afterList []T) (added []T, removed []T, common []T) { added = make([]T, 0) removed = make([]T, 0) + common = make([]T, 0) - for _, privilegeBeforeChange := range beforeList { - if !slices.Contains(afterList, privilegeBeforeChange) { - removed = append(removed, privilegeBeforeChange) + for _, beforeItem := range beforeList { + if !slices.Contains(afterList, beforeItem) { + removed = append(removed, beforeItem) + } else { + common = append(common, beforeItem) } } - for _, privilegeAfterChange := range afterList { - if !slices.Contains(beforeList, privilegeAfterChange) { - added = append(added, privilegeAfterChange) + for _, afterItem := range afterList { + if !slices.Contains(beforeList, afterItem) { + added = append(added, afterItem) } } - return added, removed + return added, removed, common } // parseSchemaObjectIdentifierSet is a helper function to parse a given schema object identifier list from ResourceData. diff --git a/pkg/resources/helpers_test.go b/pkg/resources/helpers_test.go index c143d40f03..a29a8658a4 100644 --- a/pkg/resources/helpers_test.go +++ b/pkg/resources/helpers_test.go @@ -11,6 +11,7 @@ import ( "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/resources" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/hashicorp/terraform-plugin-testing/terraform" "github.com/stretchr/testify/assert" @@ -260,7 +261,84 @@ func TestListDiff(t *testing.T) { } } -func Test_DataTypeIssue3007DiffSuppressFunc(t *testing.T) { +func TestListDiffWithCommonItems(t *testing.T) { + testCases := []struct { + Name string + Before []any + After []any + Added []any + Removed []any + Common []any + }{ + { + Name: "no changes", + Before: []any{1, 2, 3, 4}, + After: []any{1, 2, 3, 4}, + Removed: []any{}, + Added: []any{}, + Common: []any{1, 2, 3, 4}, + }, + { + Name: "only removed", + Before: []any{1, 2, 3, 4}, + After: []any{}, + Removed: []any{1, 2, 3, 4}, + Added: []any{}, + Common: []any{}, + }, + { + Name: "only added", + Before: []any{}, + After: []any{1, 2, 3, 4}, + Removed: []any{}, + Added: []any{1, 2, 3, 4}, + Common: []any{}, + }, + { + Name: "added repeated items", + Before: []any{2}, + After: []any{1, 2, 1}, + Removed: []any{}, + Added: []any{1, 1}, + Common: []any{2}, + }, + { + Name: "removed repeated items", + Before: []any{1, 2, 1}, + After: []any{2}, + Removed: []any{1, 1}, + Added: []any{}, + Common: []any{2}, + }, + { + Name: "simple diff: ints", + Before: []any{1, 2, 3, 4, 5, 6, 7, 8, 9}, + After: []any{1, 3, 5, 7, 9, 12, 13, 14}, + Removed: []any{2, 4, 6, 8}, + Added: []any{12, 13, 14}, + Common: []any{1, 3, 5, 7, 9}, + }, + { + Name: "simple diff: strings", + Before: []any{"one", "two", "three", "four"}, + After: []any{"five", "two", "four", "six"}, + Removed: []any{"one", "three"}, + Added: []any{"five", "six"}, + Common: []any{"two", "four"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + added, removed, common := resources.ListDiffWithCommonItems(tc.Before, tc.After) + assert.Equal(t, tc.Added, added) + assert.Equal(t, tc.Removed, removed) + assert.Equal(t, tc.Common, common) + }) + } +} + +func Test_DataTypeDiffSuppressFunc(t *testing.T) { testCases := []struct { name string old string @@ -324,7 +402,7 @@ func Test_DataTypeIssue3007DiffSuppressFunc(t *testing.T) { { name: "synonym number data type precision implicit and same", old: "NUMBER", - new: fmt.Sprintf("DECIMAL(%d)", sdk.DefaultNumberPrecision), + new: fmt.Sprintf("DECIMAL(%d)", datatypes.DefaultNumberPrecision), expected: true, }, { @@ -348,7 +426,7 @@ func Test_DataTypeIssue3007DiffSuppressFunc(t *testing.T) { { name: "synonym number data type default scale implicit and explicit", old: "NUMBER(30)", - new: fmt.Sprintf("DECIMAL(30, %d)", sdk.DefaultNumberScale), + new: fmt.Sprintf("DECIMAL(30, %d)", datatypes.DefaultNumberScale), expected: true, }, { @@ -360,13 +438,13 @@ func Test_DataTypeIssue3007DiffSuppressFunc(t *testing.T) { { name: "synonym number data type both precision and scale implicit and explicit", old: "NUMBER", - new: fmt.Sprintf("DECIMAL(%d, %d)", sdk.DefaultNumberPrecision, sdk.DefaultNumberScale), + new: fmt.Sprintf("DECIMAL(%d, %d)", datatypes.DefaultNumberPrecision, datatypes.DefaultNumberScale), expected: true, }, { name: "synonym number data type both precision and scale implicit and scale different", old: "NUMBER", - new: fmt.Sprintf("DECIMAL(%d, 2)", sdk.DefaultNumberPrecision), + new: fmt.Sprintf("DECIMAL(%d, 2)", datatypes.DefaultNumberPrecision), expected: false, }, { @@ -384,7 +462,7 @@ func Test_DataTypeIssue3007DiffSuppressFunc(t *testing.T) { { name: "synonym text data type length implicit and same", old: "VARCHAR", - new: fmt.Sprintf("TEXT(%d)", sdk.DefaultVarcharLength), + new: fmt.Sprintf("TEXT(%d)", datatypes.DefaultVarcharLength), expected: true, }, { @@ -398,7 +476,7 @@ func Test_DataTypeIssue3007DiffSuppressFunc(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { - result := resources.DataTypeIssue3007DiffSuppressFunc("", tc.old, tc.new, nil) + result := resources.DiffSuppressDataTypes("", tc.old, tc.new, nil) require.Equal(t, tc.expected, result) }) } diff --git a/pkg/resources/masking_policy.go b/pkg/resources/masking_policy.go index 4acd32bbb1..f35df310cd 100644 --- a/pkg/resources/masking_policy.go +++ b/pkg/resources/masking_policy.go @@ -6,13 +6,12 @@ import ( "fmt" "log" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" - + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/schemas" - - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/hashicorp/go-cty/cty" "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/customdiff" @@ -55,8 +54,8 @@ var maskingPolicySchema = map[string]*schema.Schema{ "type": { Type: schema.TypeString, Required: true, - DiffSuppressFunc: NormalizeAndCompare(sdk.ToDataType), - ValidateDiagFunc: sdkValidation(sdk.ToDataType), + DiffSuppressFunc: DiffSuppressDataTypes, + ValidateDiagFunc: IsDataTypeValid, Description: dataTypeFieldDescription("The argument type. VECTOR data types are not yet supported."), ForceNew: true, }, @@ -77,8 +76,8 @@ var maskingPolicySchema = map[string]*schema.Schema{ Required: true, Description: dataTypeFieldDescription("The return data type must match the input data type of the first column that is specified as an input column."), ForceNew: true, - DiffSuppressFunc: NormalizeAndCompare(sdk.ToDataType), - ValidateDiagFunc: sdkValidation(sdk.ToDataType), + DiffSuppressFunc: DiffSuppressDataTypes, + ValidateDiagFunc: IsDataTypeValid, }, "exempt_other_policies": { Type: schema.TypeString, @@ -198,17 +197,17 @@ func CreateMaskingPolicy(ctx context.Context, d *schema.ResourceData, meta any) args := make([]sdk.TableColumnSignature, 0) for _, arg := range arguments { v := arg.(map[string]any) - dataType, err := sdk.ToDataType(v["type"].(string)) + dataType, err := datatypes.ParseDataType(v["type"].(string)) if err != nil { return diag.FromErr(err) } args = append(args, sdk.TableColumnSignature{ Name: v["name"].(string), - Type: dataType, + Type: sdk.LegacyDataTypeFrom(dataType), }) } - returns, err := sdk.ToDataType(returnDataType) + returns, err := datatypes.ParseDataType(returnDataType) if err != nil { return diag.FromErr(err) } @@ -226,7 +225,7 @@ func CreateMaskingPolicy(ctx context.Context, d *schema.ResourceData, meta any) opts.ExemptOtherPolicies = sdk.Pointer(parsed) } - err = client.MaskingPolicies.Create(ctx, id, args, returns, expression, opts) + err = client.MaskingPolicies.Create(ctx, id, args, sdk.LegacyDataTypeFrom(returns), expression, opts) if err != nil { return diag.FromErr(err) } diff --git a/pkg/resources/object_parameter_acceptance_test.go b/pkg/resources/object_parameter_acceptance_test.go index a1f3021ba4..b03ff6450d 100644 --- a/pkg/resources/object_parameter_acceptance_test.go +++ b/pkg/resources/object_parameter_acceptance_test.go @@ -13,6 +13,16 @@ import ( ) func TestAcc_ObjectParameter(t *testing.T) { + _ = testenvs.GetOrSkipTest(t, testenvs.EnableAcceptance) + + // TODO(SNOW-1528546): Remove after parameter-setting resources are using UNSET in the delete operation. + t.Cleanup(func() { + acc.TestClient().Database.Alter(t, acc.TestClient().Ids.DatabaseId(), &sdk.AlterDatabaseOptions{ + Unset: &sdk.DatabaseUnset{ + UserTaskTimeoutMs: sdk.Bool(true), + }, + }) + }) resource.Test(t, resource.TestCase{ ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, TerraformVersionChecks: []tfversion.TerraformVersionCheck{ @@ -34,6 +44,8 @@ func TestAcc_ObjectParameter(t *testing.T) { } func TestAcc_ObjectParameterAccount(t *testing.T) { + _ = testenvs.GetOrSkipTest(t, testenvs.EnableAcceptance) + // TODO(SNOW-1528546): Remove after parameter-setting resources are using UNSET in the delete operation. t.Cleanup(func() { acc.TestClient().Parameter.UnsetAccountParameter(t, sdk.AccountParameterDataRetentionTimeInDays) diff --git a/pkg/resources/procedure.go b/pkg/resources/procedure.go index adb80061bb..aa8b557250 100644 --- a/pkg/resources/procedure.go +++ b/pkg/resources/procedure.go @@ -8,11 +8,11 @@ import ( "slices" "strings" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/schemas" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/hashicorp/go-cty/cty" "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/customdiff" @@ -60,9 +60,9 @@ var procedureSchema = map[string]*schema.Schema{ "type": { Type: schema.TypeString, Required: true, - ValidateFunc: dataTypeValidateFunc, - DiffSuppressFunc: dataTypeDiffSuppressFunc, Description: "The argument type", + ValidateDiagFunc: IsDataTypeValid, + DiffSuppressFunc: DiffSuppressDataTypes, }, }, }, @@ -322,7 +322,7 @@ func createJavaScriptProcedure(ctx context.Context, d *schema.ResourceData, meta return diags } procedureDefinition := d.Get("statement").(string) - req := sdk.NewCreateForJavaScriptProcedureRequest(id.SchemaObjectId(), returnDataType, procedureDefinition) + req := sdk.NewCreateForJavaScriptProcedureRequest(id.SchemaObjectId(), sdk.LegacyDataTypeFrom(returnDataType), procedureDefinition) if len(args) > 0 { req.WithArguments(args) } @@ -735,16 +735,16 @@ func getProcedureArguments(d *schema.ResourceData) ([]sdk.ProcedureArgumentReque if diags != nil { return nil, diags } - args = append(args, sdk.ProcedureArgumentRequest{ArgName: argName, ArgDataType: argDataType}) + args = append(args, sdk.ProcedureArgumentRequest{ArgName: argName, ArgDataType: sdk.LegacyDataTypeFrom(argDataType)}) } } return args, nil } -func convertProcedureDataType(s string) (sdk.DataType, diag.Diagnostics) { - dataType, err := sdk.ToDataType(s) +func convertProcedureDataType(s string) (datatypes.DataType, diag.Diagnostics) { + dataType, err := datatypes.ParseDataType(s) if err != nil { - return dataType, diag.FromErr(err) + return nil, diag.FromErr(err) } return dataType, nil } @@ -755,13 +755,13 @@ func convertProcedureColumns(s string) ([]sdk.ProcedureColumn, diag.Diagnostics) var columns []sdk.ProcedureColumn for _, match := range matches { if len(match) == 3 { - dataType, err := sdk.ToDataType(match[2]) + dataType, err := datatypes.ParseDataType(match[2]) if err != nil { return nil, diag.FromErr(err) } columns = append(columns, sdk.ProcedureColumn{ ColumnName: match[1], - ColumnDataType: dataType, + ColumnDataType: sdk.LegacyDataTypeFrom(dataType), }) } } @@ -785,7 +785,7 @@ func parseProcedureReturnsRequest(s string) (*sdk.ProcedureReturnsRequest, diag. if diags != nil { return nil, diags } - returns.WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(returnDataType)) + returns.WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(sdk.LegacyDataTypeFrom(returnDataType))) } return returns, nil } @@ -807,7 +807,7 @@ func parseProcedureSQLReturnsRequest(s string) (*sdk.ProcedureSQLReturnsRequest, if diags != nil { return nil, diags } - returns.WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(returnDataType)) + returns.WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(sdk.LegacyDataTypeFrom(returnDataType))) } return returns, nil } diff --git a/pkg/resources/procedure_state_upgraders.go b/pkg/resources/procedure_state_upgraders.go index 24e47d7d9f..610822401d 100644 --- a/pkg/resources/procedure_state_upgraders.go +++ b/pkg/resources/procedure_state_upgraders.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" ) type v085ProcedureId struct { @@ -48,11 +49,11 @@ func v085ProcedureStateUpgrader(ctx context.Context, rawState map[string]interfa argDataTypes := make([]sdk.DataType, len(parsedV085ProcedureId.ArgTypes)) for i, argType := range parsedV085ProcedureId.ArgTypes { - argDataType, err := sdk.ToDataType(argType) + argDataType, err := datatypes.ParseDataType(argType) if err != nil { return nil, err } - argDataTypes[i] = argDataType + argDataTypes[i] = sdk.LegacyDataTypeFrom(argDataType) } schemaObjectIdentifierWithArguments := sdk.NewSchemaObjectIdentifierWithArgumentsOld(parsedV085ProcedureId.DatabaseName, parsedV085ProcedureId.SchemaName, parsedV085ProcedureId.ProcedureName, argDataTypes) diff --git a/pkg/resources/row_access_policy.go b/pkg/resources/row_access_policy.go index 12c3050cb3..5722b148ba 100644 --- a/pkg/resources/row_access_policy.go +++ b/pkg/resources/row_access_policy.go @@ -6,13 +6,12 @@ import ( "fmt" "log" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" - + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/schemas" - - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/hashicorp/go-cty/cty" "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/customdiff" @@ -54,8 +53,8 @@ var rowAccessPolicySchema = map[string]*schema.Schema{ "type": { Type: schema.TypeString, Required: true, - DiffSuppressFunc: NormalizeAndCompare(sdk.ToDataType), - ValidateDiagFunc: sdkValidation(sdk.ToDataType), + DiffSuppressFunc: DiffSuppressDataTypes, + ValidateDiagFunc: IsDataTypeValid, Description: dataTypeFieldDescription("The argument type. VECTOR data types are not yet supported."), ForceNew: true, }, @@ -179,11 +178,11 @@ func CreateRowAccessPolicy(ctx context.Context, d *schema.ResourceData, meta any args := make([]sdk.CreateRowAccessPolicyArgsRequest, 0) for _, arg := range arguments { v := arg.(map[string]any) - dataType, err := sdk.ToDataType(v["type"].(string)) + dataType, err := datatypes.ParseDataType(v["type"].(string)) if err != nil { return diag.FromErr(err) } - args = append(args, *sdk.NewCreateRowAccessPolicyArgsRequest(v["name"].(string), dataType)) + args = append(args, *sdk.NewCreateRowAccessPolicyArgsRequest(v["name"].(string), sdk.LegacyDataTypeFrom(dataType))) } createRequest := sdk.NewCreateRowAccessPolicyRequest(id, args, rowAccessExpression) diff --git a/pkg/resources/storage_integration.go b/pkg/resources/storage_integration.go index 85a40729d3..cd15e0cc6f 100644 --- a/pkg/resources/storage_integration.go +++ b/pkg/resources/storage_integration.go @@ -7,6 +7,9 @@ import ( "slices" "strings" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/schemas" @@ -112,10 +115,10 @@ var storageIntegrationSchema = map[string]*schema.Schema{ // StorageIntegration returns a pointer to the resource representing a storage integration. func StorageIntegration() *schema.Resource { return &schema.Resource{ - Create: CreateStorageIntegration, - Read: ReadStorageIntegration, - Update: UpdateStorageIntegration, - Delete: DeleteStorageIntegration, + CreateContext: TrackingCreateWrapper(resources.StorageIntegration, CreateStorageIntegration), + ReadContext: TrackingReadWrapper(resources.StorageIntegration, ReadStorageIntegration), + UpdateContext: TrackingUpdateWrapper(resources.StorageIntegration, UpdateStorageIntegration), + DeleteContext: TrackingDeleteWrapper(resources.StorageIntegration, DeleteStorageIntegration), Schema: storageIntegrationSchema, Importer: &schema.ResourceImporter{ @@ -124,9 +127,8 @@ func StorageIntegration() *schema.Resource { } } -func CreateStorageIntegration(d *schema.ResourceData, meta any) error { +func CreateStorageIntegration(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() name := sdk.NewAccountObjectIdentifierFromFullyQualifiedName(d.Get("name").(string)) enabled := d.Get("enabled").(bool) @@ -161,12 +163,12 @@ func CreateStorageIntegration(d *schema.ResourceData, meta any) error { case slices.Contains(sdk.AllS3Protocols, sdk.S3Protocol(storageProvider)): s3Protocol, err := sdk.ToS3Protocol(storageProvider) if err != nil { - return err + return diag.FromErr(err) } v, ok := d.GetOk("storage_aws_role_arn") if !ok { - return fmt.Errorf("if you use the S3 storage provider you must specify a storage_aws_role_arn") + return diag.FromErr(fmt.Errorf("if you use the S3 storage provider you must specify a storage_aws_role_arn")) } s3Params := sdk.NewS3StorageParamsRequest(s3Protocol, v.(string)) @@ -177,29 +179,28 @@ func CreateStorageIntegration(d *schema.ResourceData, meta any) error { case storageProvider == "AZURE": v, ok := d.GetOk("azure_tenant_id") if !ok { - return fmt.Errorf("if you use the Azure storage provider you must specify an azure_tenant_id") + return diag.FromErr(fmt.Errorf("if you use the Azure storage provider you must specify an azure_tenant_id")) } req.WithAzureStorageProviderParams(*sdk.NewAzureStorageParamsRequest(sdk.String(v.(string)))) case storageProvider == "GCS": req.WithGCSStorageProviderParams(*sdk.NewGCSStorageParamsRequest()) default: - return fmt.Errorf("unexpected provider %v", storageProvider) + return diag.FromErr(fmt.Errorf("unexpected provider %v", storageProvider)) } if err := client.StorageIntegrations.Create(ctx, req); err != nil { - return fmt.Errorf("error creating storage integration: %w", err) + return diag.FromErr(fmt.Errorf("error creating storage integration: %w", err)) } d.SetId(helpers.EncodeSnowflakeID(name)) - return ReadStorageIntegration(d, meta) + return ReadStorageIntegration(ctx, d, meta) } -func ReadStorageIntegration(d *schema.ResourceData, meta any) error { +func ReadStorageIntegration(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() id, ok := helpers.DecodeSnowflakeID(d.Id()).(sdk.AccountObjectIdentifier) if !ok { - return fmt.Errorf("storage integration read, error decoding id: %s as sdk.AccountObjectIdentifier, got: %T", d.Id(), id) + return diag.FromErr(fmt.Errorf("storage integration read, error decoding id: %s as sdk.AccountObjectIdentifier, got: %T", d.Id(), id)) } s, err := client.StorageIntegrations.ShowByID(ctx, id) @@ -209,91 +210,90 @@ func ReadStorageIntegration(d *schema.ResourceData, meta any) error { return nil } if err := d.Set(FullyQualifiedNameAttributeName, id.FullyQualifiedName()); err != nil { - return err + return diag.FromErr(err) } if s.Category != "STORAGE" { - return fmt.Errorf("expected %v to be a STORAGE integration, got %v", d.Id(), s.Category) + return diag.FromErr(fmt.Errorf("expected %v to be a STORAGE integration, got %v", d.Id(), s.Category)) } if err := d.Set("name", s.Name); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("type", s.StorageType); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("created_on", s.CreatedOn.String()); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("enabled", s.Enabled); err != nil { - return err + return diag.FromErr(err) } if err := d.Set("comment", s.Comment); err != nil { - return err + return diag.FromErr(err) } storageIntegrationProps, err := client.StorageIntegrations.Describe(ctx, id) if err != nil { - return fmt.Errorf("could not describe storage integration (%s), err = %w", d.Id(), err) + return diag.FromErr(fmt.Errorf("could not describe storage integration (%s), err = %w", d.Id(), err)) } for _, prop := range storageIntegrationProps { switch prop.Name { case "STORAGE_PROVIDER": if err := d.Set("storage_provider", prop.Value); err != nil { - return err + return diag.FromErr(err) } case "STORAGE_ALLOWED_LOCATIONS": if err := d.Set("storage_allowed_locations", strings.Split(prop.Value, ",")); err != nil { - return err + return diag.FromErr(err) } case "STORAGE_BLOCKED_LOCATIONS": if prop.Value != "" { if err := d.Set("storage_blocked_locations", strings.Split(prop.Value, ",")); err != nil { - return err + return diag.FromErr(err) } } case "STORAGE_AWS_IAM_USER_ARN": if err := d.Set("storage_aws_iam_user_arn", prop.Value); err != nil { - return err + return diag.FromErr(err) } case "STORAGE_AWS_OBJECT_ACL": if prop.Value != "" { if err := d.Set("storage_aws_object_acl", prop.Value); err != nil { - return err + return diag.FromErr(err) } } case "STORAGE_AWS_ROLE_ARN": if err := d.Set("storage_aws_role_arn", prop.Value); err != nil { - return err + return diag.FromErr(err) } case "STORAGE_AWS_EXTERNAL_ID": if err := d.Set("storage_aws_external_id", prop.Value); err != nil { - return err + return diag.FromErr(err) } case "STORAGE_GCP_SERVICE_ACCOUNT": if err := d.Set("storage_gcp_service_account", prop.Value); err != nil { - return err + return diag.FromErr(err) } case "AZURE_CONSENT_URL": if err := d.Set("azure_consent_url", prop.Value); err != nil { - return err + return diag.FromErr(err) } case "AZURE_MULTI_TENANT_APP_NAME": if err := d.Set("azure_multi_tenant_app_name", prop.Value); err != nil { - return err + return diag.FromErr(err) } } } - return err + return diag.FromErr(err) } -func UpdateStorageIntegration(d *schema.ResourceData, meta any) error { +func UpdateStorageIntegration(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() id, ok := helpers.DecodeSnowflakeID(d.Id()).(sdk.AccountObjectIdentifier) if !ok { - return fmt.Errorf("storage integration update, error decoding id: %s as sdk.AccountObjectIdentifier, got: %T", d.Id(), id) + return diag.FromErr(fmt.Errorf("storage integration update, error decoding id: %s as sdk.AccountObjectIdentifier, got: %T", d.Id(), id)) } var runSetStatement bool @@ -327,7 +327,7 @@ func UpdateStorageIntegration(d *schema.ResourceData, meta any) error { if len(v) == 0 { if err := client.StorageIntegrations.Alter(ctx, sdk.NewAlterStorageIntegrationRequest(id). WithUnset(*sdk.NewStorageIntegrationUnsetRequest().WithStorageBlockedLocations(true))); err != nil { - return fmt.Errorf("error unsetting storage_blocked_locations, err = %w", err) + return diag.FromErr(fmt.Errorf("error unsetting storage_blocked_locations, err = %w", err)) } } else { runSetStatement = true @@ -352,7 +352,7 @@ func UpdateStorageIntegration(d *schema.ResourceData, meta any) error { } else { if err := client.StorageIntegrations.Alter(ctx, sdk.NewAlterStorageIntegrationRequest(id). WithUnset(*sdk.NewStorageIntegrationUnsetRequest().WithStorageAwsObjectAcl(true))); err != nil { - return fmt.Errorf("error unsetting storage_aws_object_acl, err = %w", err) + return diag.FromErr(fmt.Errorf("error unsetting storage_aws_object_acl, err = %w", err)) } } } @@ -367,22 +367,21 @@ func UpdateStorageIntegration(d *schema.ResourceData, meta any) error { if runSetStatement { if err := client.StorageIntegrations.Alter(ctx, sdk.NewAlterStorageIntegrationRequest(id).WithSet(*setReq)); err != nil { - return fmt.Errorf("error updating storage integration, err = %w", err) + return diag.FromErr(fmt.Errorf("error updating storage integration, err = %w", err)) } } - return ReadStorageIntegration(d, meta) + return ReadStorageIntegration(ctx, d, meta) } -func DeleteStorageIntegration(d *schema.ResourceData, meta any) error { +func DeleteStorageIntegration(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - ctx := context.Background() id, ok := helpers.DecodeSnowflakeID(d.Id()).(sdk.AccountObjectIdentifier) if !ok { - return fmt.Errorf("storage integration delete, error decoding id: %s as sdk.AccountObjectIdentifier, got: %T", d.Id(), id) + return diag.FromErr(fmt.Errorf("storage integration delete, error decoding id: %s as sdk.AccountObjectIdentifier, got: %T", d.Id(), id)) } if err := client.StorageIntegrations.Drop(ctx, sdk.NewDropStorageIntegrationRequest(id)); err != nil { - return fmt.Errorf("error dropping storage integration (%s), err = %w", d.Id(), err) + return diag.FromErr(fmt.Errorf("error dropping storage integration (%s), err = %w", d.Id(), err)) } d.SetId("") diff --git a/pkg/resources/table.go b/pkg/resources/table.go index 017ad799d8..f0d4d77ea4 100644 --- a/pkg/resources/table.go +++ b/pkg/resources/table.go @@ -7,16 +7,15 @@ import ( "strconv" "strings" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" - "github.com/hashicorp/terraform-plugin-sdk/v2/diag" - "github.com/hashicorp/terraform-plugin-sdk/v2/helper/customdiff" - + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/schemas" - - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/customdiff" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" ) @@ -64,8 +63,8 @@ var tableSchema = map[string]*schema.Schema{ Type: schema.TypeString, Required: true, Description: "Column type, e.g. VARIANT. For a full list of column types, see [Summary of Data Types](https://docs.snowflake.com/en/sql-reference/intro-summary-data-types).", - ValidateFunc: dataTypeValidateFunc, - DiffSuppressFunc: DataTypeIssue3007DiffSuppressFunc, + ValidateDiagFunc: IsDataTypeValid, + DiffSuppressFunc: DiffSuppressDataTypes, }, "nullable": { Type: schema.TypeBool, @@ -388,9 +387,13 @@ func getColumns(from interface{}) (to columns) { return to } -func getTableColumnRequest(from interface{}) *sdk.TableColumnRequest { +func getTableColumnRequest(from interface{}) (*sdk.TableColumnRequest, error) { c := from.(map[string]interface{}) _type := c["type"].(string) + dataType, err := datatypes.ParseDataType(_type) + if err != nil { + return nil, err + } nameInQuotes := fmt.Sprintf(`"%v"`, snowflake.EscapeString(c["name"].(string))) request := sdk.NewTableColumnRequest(nameInQuotes, sdk.DataType(_type)) @@ -400,7 +403,7 @@ func getTableColumnRequest(from interface{}) *sdk.TableColumnRequest { if len(_default) == 1 { if c, ok := _default[0].(map[string]interface{})["constant"]; ok { if constant, ok := c.(string); ok && len(constant) > 0 { - if sdk.IsStringType(_type) { + if datatypes.IsTextDataType(dataType) { expression = snowflake.EscapeSnowflakeString(constant) } else { expression = constant @@ -415,7 +418,7 @@ func getTableColumnRequest(from interface{}) *sdk.TableColumnRequest { } if s, ok := _default[0].(map[string]interface{})["sequence"]; ok { - if seq := s.(string); ok && len(seq) > 0 { + if seq, ok2 := s.(string); ok2 && len(seq) > 0 { expression = fmt.Sprintf(`%v.NEXTVAL`, seq) } } @@ -435,22 +438,26 @@ func getTableColumnRequest(from interface{}) *sdk.TableColumnRequest { request.WithMaskingPolicy(sdk.NewColumnMaskingPolicyRequest(sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(maskingPolicy))) } - if sdk.IsStringType(_type) { + if datatypes.IsTextDataType(dataType) { request.WithCollate(sdk.String(c["collate"].(string))) } return request. WithNotNull(sdk.Bool(!c["nullable"].(bool))). - WithComment(sdk.String(c["comment"].(string))) + WithComment(sdk.String(c["comment"].(string))), nil } -func getTableColumnRequests(from interface{}) []sdk.TableColumnRequest { +func getTableColumnRequests(from interface{}) ([]sdk.TableColumnRequest, error) { cols := from.([]interface{}) to := make([]sdk.TableColumnRequest, len(cols)) for i, c := range cols { - to[i] = *getTableColumnRequest(c) + cReq, err := getTableColumnRequest(c) + if err != nil { + return nil, err + } + to[i] = *cReq } - return to + return to, nil } type primarykey struct { @@ -577,7 +584,10 @@ func CreateTable(ctx context.Context, d *schema.ResourceData, meta any) diag.Dia name := d.Get("name").(string) id := sdk.NewSchemaObjectIdentifier(databaseName, schemaName, name) - tableColumnRequests := getTableColumnRequests(d.Get("column").([]interface{})) + tableColumnRequests, err := getTableColumnRequests(d.Get("column").([]interface{})) + if err != nil { + return diag.FromErr(err) + } createRequest := sdk.NewCreateTableRequest(id, tableColumnRequests) @@ -620,7 +630,7 @@ func CreateTable(ctx context.Context, d *schema.ResourceData, meta any) diag.Dia createRequest.WithTags(tagAssociationRequests) } - err := client.Tables.Create(ctx, createRequest) + err = client.Tables.Create(ctx, createRequest) if err != nil { return diag.FromErr(fmt.Errorf("error creating table %v err = %w", name, err)) } diff --git a/pkg/resources/table_acceptance_test.go b/pkg/resources/table_acceptance_test.go index caeeb1daaf..8ba4f24a8a 100644 --- a/pkg/resources/table_acceptance_test.go +++ b/pkg/resources/table_acceptance_test.go @@ -15,6 +15,7 @@ import ( "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/hashicorp/terraform-plugin-testing/config" "github.com/hashicorp/terraform-plugin-testing/helper/resource" "github.com/hashicorp/terraform-plugin-testing/plancheck" @@ -2097,7 +2098,7 @@ func TestAcc_Table_issue3007_textColumn(t *testing.T) { tableId := acc.TestClient().Ids.RandomSchemaObjectIdentifier() resourceName := "snowflake_table.test_table" - defaultVarchar := fmt.Sprintf("VARCHAR(%d)", sdk.DefaultVarcharLength) + defaultVarchar := fmt.Sprintf("VARCHAR(%d)", datatypes.DefaultVarcharLength) resource.Test(t, resource.TestCase{ PreCheck: func() { acc.TestAccPreCheck(t) }, @@ -2170,7 +2171,7 @@ func TestAcc_Table_issue3007_numberColumn(t *testing.T) { tableId := acc.TestClient().Ids.RandomSchemaObjectIdentifier() resourceName := "snowflake_table.test_table" - defaultNumber := fmt.Sprintf("NUMBER(%d,%d)", sdk.DefaultNumberPrecision, sdk.DefaultNumberScale) + defaultNumber := fmt.Sprintf("NUMBER(%d,%d)", datatypes.DefaultNumberPrecision, datatypes.DefaultNumberScale) resource.Test(t, resource.TestCase{ PreCheck: func() { acc.TestAccPreCheck(t) }, diff --git a/pkg/resources/tag_association.go b/pkg/resources/tag_association.go index f9141ab1fe..e76e5fc69f 100644 --- a/pkg/resources/tag_association.go +++ b/pkg/resources/tag_association.go @@ -2,21 +2,21 @@ package resources import ( "context" + "errors" "fmt" "log" - "slices" - "strings" "time" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" - + "github.com/hashicorp/go-cty/cty" "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/retry" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/collections" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" ) @@ -26,54 +26,37 @@ var tagAssociationSchema = map[string]*schema.Schema{ Optional: true, Description: "Specifies the object identifier for the tag association.", ForceNew: true, - Deprecated: "Use `object_identifier` instead", + Deprecated: "Use `object_identifiers` instead", }, - "object_identifier": { - Type: schema.TypeList, + "object_identifiers": { + Type: schema.TypeSet, MinItems: 1, Required: true, - Description: "Specifies the object identifier for the tag association.", - Elem: &schema.Resource{ - Schema: map[string]*schema.Schema{ - "name": { - Type: schema.TypeString, - Required: true, - ForceNew: true, - Description: "Name of the object to associate the tag with.", - }, - "database": { - Type: schema.TypeString, - Optional: true, - ForceNew: true, - Description: "Name of the database that the object was created in.", - }, - "schema": { - Type: schema.TypeString, - Optional: true, - ForceNew: true, - Description: "Name of the schema that the object was created in.", - }, - }, + Description: "Specifies the object identifiers for the tag association.", + Elem: &schema.Schema{ + Type: schema.TypeString, }, + DiffSuppressFunc: NormalizeAndCompareIdentifiersInSet("object_identifiers"), }, "object_type": { - Type: schema.TypeString, - Required: true, - Description: fmt.Sprintf("Specifies the type of object to add a tag. Allowed object types: %v.", sdk.TagAssociationAllowedObjectTypesString), - ValidateFunc: validation.StringInSlice(sdk.TagAssociationAllowedObjectTypesString, true), - ForceNew: true, + Type: schema.TypeString, + Required: true, + Description: fmt.Sprintf("Specifies the type of object to add a tag. Allowed object types: %v.", sdk.TagAssociationAllowedObjectTypesString), + ValidateFunc: validation.StringInSlice(sdk.TagAssociationAllowedObjectTypesString, true), + DiffSuppressFunc: ignoreCaseSuppressFunc, + ForceNew: true, }, "tag_id": { - Type: schema.TypeString, - Required: true, - Description: "Specifies the identifier for the tag. Note: format must follow: \"databaseName\".\"schemaName\".\"tagName\" or \"databaseName.schemaName.tagName\" or \"databaseName|schemaName.tagName\" (snowflake_tag.tag.id)", - ForceNew: true, + Type: schema.TypeString, + Required: true, + Description: "Specifies the identifier for the tag.", + ForceNew: true, + DiffSuppressFunc: suppressIdentifierQuoting, }, "tag_value": { Type: schema.TypeString, Required: true, Description: "Specifies the value of the tag, (e.g. 'finance' or 'engineering')", - ForceNew: true, }, "skip_validation": { Type: schema.TypeBool, @@ -86,81 +69,89 @@ var tagAssociationSchema = map[string]*schema.Schema{ // TagAssociation returns a pointer to the resource representing a schema. func TagAssociation() *schema.Resource { return &schema.Resource{ + SchemaVersion: 1, + CreateContext: TrackingCreateWrapper(resources.TagAssociation, CreateContextTagAssociation), ReadContext: TrackingReadWrapper(resources.TagAssociation, ReadContextTagAssociation), UpdateContext: TrackingUpdateWrapper(resources.TagAssociation, UpdateContextTagAssociation), DeleteContext: TrackingDeleteWrapper(resources.TagAssociation, DeleteContextTagAssociation), + Description: "Resource used to manage tag associations. For more information, check [object tagging documentation](https://docs.snowflake.com/en/user-guide/object-tagging).", + Schema: tagAssociationSchema, Importer: &schema.ResourceImporter{ - StateContext: schema.ImportStatePassthroughContext, + StateContext: ImportTagAssociation, }, Timeouts: &schema.ResourceTimeout{ Create: schema.DefaultTimeout(70 * time.Minute), }, + + StateUpgraders: []schema.StateUpgrader{ + { + Version: 0, + // setting type to cty.EmptyObject is a bit hacky here but following https://developer.hashicorp.com/terraform/plugin/framework/migrating/resources/state-upgrade#sdkv2-1 would require lots of repetitive code; this should work with cty.EmptyObject + Type: cty.EmptyObject, + Upgrade: v0_98_0_TagAssociationStateUpgrader, + }, + }, } } -func TagIdentifierAndObjectIdentifier(d *schema.ResourceData) (sdk.SchemaObjectIdentifier, []sdk.ObjectIdentifier, sdk.ObjectType) { +func ImportTagAssociation(ctx context.Context, d *schema.ResourceData, meta any) ([]*schema.ResourceData, error) { + log.Printf("[DEBUG] Starting tag association import") + idParts := helpers.ParseResourceIdentifier(d.Id()) + if len(idParts) != 3 { + return nil, fmt.Errorf("invalid resource id: expected 3 arguments, but got %d", len(idParts)) + } + objectType, err := sdk.ToObjectType(idParts[2]) + if err != nil { + return nil, err + } + + if err := d.Set("tag_id", idParts[0]); err != nil { + return nil, err + } + if err := d.Set("tag_value", idParts[1]); err != nil { + return nil, err + } + if err := d.Set("object_type", objectType); err != nil { + return nil, err + } + return []*schema.ResourceData{d}, nil +} + +func TagIdentifierAndObjectIdentifier(d *schema.ResourceData) (sdk.SchemaObjectIdentifier, []sdk.ObjectIdentifier, sdk.ObjectType, error) { tag := d.Get("tag_id").(string) - objectType := sdk.ObjectType(d.Get("object_type").(string)) - - tagDatabase, tagSchema, tagName := ParseFullyQualifiedObjectID(tag) - tid := sdk.NewSchemaObjectIdentifier(tagDatabase, tagSchema, tagName) - - var identifiers []sdk.ObjectIdentifier - for _, item := range d.Get("object_identifier").([]interface{}) { - m := item.(map[string]interface{}) - name := strings.Trim(m["name"].(string), `"`) - var databaseName, schemaName string - if v, ok := m["database"]; ok { - databaseName = strings.Trim(v.(string), `"`) - if databaseName == "" && slices.Contains(sdk.TagAssociationTagObjectTypeIsSchemaObjectType, objectType) { - databaseName = tagDatabase - } - } - if v, ok := m["schema"]; ok { - schemaName = strings.Trim(v.(string), `"`) - if schemaName == "" && slices.Contains(sdk.TagAssociationTagObjectTypeIsSchemaObjectType, objectType) { - schemaName = tagSchema - } - } - switch { - case databaseName != "" && schemaName != "": - if objectType == sdk.ObjectTypeColumn { - fields := strings.Split(name, ".") - if len(fields) > 1 { - tableName := strings.ReplaceAll(fields[0], `"`, "") - var parts []string - for i := 1; i < len(fields); i++ { - parts = append(parts, strings.ReplaceAll(fields[i], `"`, "")) - } - columnName := strings.Join(parts, ".") - identifiers = append(identifiers, sdk.NewTableColumnIdentifier(databaseName, schemaName, tableName, columnName)) - } else { - identifiers = append(identifiers, sdk.NewSchemaObjectIdentifier(databaseName, schemaName, name)) - } - } else { - identifiers = append(identifiers, sdk.NewSchemaObjectIdentifier(databaseName, schemaName, name)) - } - case databaseName != "": - identifiers = append(identifiers, sdk.NewDatabaseObjectIdentifier(databaseName, name)) - default: - identifiers = append(identifiers, sdk.NewAccountObjectIdentifier(name)) - } + tagId, err := sdk.ParseSchemaObjectIdentifier(tag) + if err != nil { + return sdk.SchemaObjectIdentifier{}, nil, "", fmt.Errorf("invalid tag id: %w", err) + } + + objectType, err := sdk.ToObjectType(d.Get("object_type").(string)) + if err != nil { + return sdk.SchemaObjectIdentifier{}, nil, "", err + } + + ids, err := ExpandObjectIdentifierSet(d.Get("object_identifiers").(*schema.Set).List(), objectType) + if err != nil { + return sdk.SchemaObjectIdentifier{}, nil, "", err } - return tid, identifiers, objectType + + return tagId, ids, objectType, nil } -func CreateContextTagAssociation(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { +func CreateContextTagAssociation(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client tagValue := d.Get("tag_value").(string) - tid, ids, ot := TagIdentifierAndObjectIdentifier(d) + tagId, ids, objectType, err := TagIdentifierAndObjectIdentifier(d) + if err != nil { + return diag.FromErr(err) + } for _, oid := range ids { - request := sdk.NewSetTagRequest(ot, oid).WithSetTags([]sdk.TagAssociation{ + request := sdk.NewSetTagRequest(objectType, oid).WithSetTags([]sdk.TagAssociation{ { - Name: tid, + Name: tagId, Value: tagValue, }, }) @@ -171,12 +162,12 @@ func CreateContextTagAssociation(ctx context.Context, d *schema.ResourceData, me if !skipValidate { log.Println("[DEBUG] validating tag creation") if err := retry.RetryContext(ctx, d.Timeout(schema.TimeoutCreate)-time.Minute, func() *retry.RetryError { - tag, err := client.SystemFunctions.GetTag(ctx, tid, oid, ot) + tag, err := client.SystemFunctions.GetTag(ctx, tagId, oid, objectType) if err != nil { return retry.NonRetryableError(fmt.Errorf("error getting tag: %w", err)) } // if length of response is zero, tag association was not found. retry - if len(tag) == 0 { + if tag == nil { return retry.RetryableError(fmt.Errorf("expected tag association to be created but not yet created")) } return nil @@ -185,63 +176,124 @@ func CreateContextTagAssociation(ctx context.Context, d *schema.ResourceData, me } } } - d.SetId(helpers.EncodeSnowflakeID(tid.DatabaseName(), tid.SchemaName(), tid.Name())) + d.SetId(helpers.EncodeResourceIdentifier(tagId.FullyQualifiedName(), tagValue, string(objectType))) return ReadContextTagAssociation(ctx, d, meta) } -func ReadContextTagAssociation(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { - diags := diag.Diagnostics{} +func ReadContextTagAssociation(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client + tagValue := d.Get("tag_value").(string) - tid, ids, ot := TagIdentifierAndObjectIdentifier(d) + tagId, ids, objectType, err := TagIdentifierAndObjectIdentifier(d) + if err != nil { + return diag.FromErr(err) + } + var correctObjectIds []string for _, oid := range ids { - tagValue, err := client.SystemFunctions.GetTag(ctx, tid, oid, ot) + objectTagValue, err := client.SystemFunctions.GetTag(ctx, tagId, oid, objectType) if err != nil { return diag.FromErr(err) } - if err := d.Set("tag_value", tagValue); err != nil { - return diag.FromErr(err) + if objectTagValue != nil && *objectTagValue == tagValue { + correctObjectIds = append(correctObjectIds, oid.FullyQualifiedName()) } } - return diags + if err := d.Set("object_identifiers", correctObjectIds); err != nil { + return diag.FromErr(err) + } + // ensure that object_type is upper case in the state + if err := d.Set("object_type", objectType); err != nil { + return diag.FromErr(err) + } + return nil } -func UpdateContextTagAssociation(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { +func UpdateContextTagAssociation(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - tid, ids, ot := TagIdentifierAndObjectIdentifier(d) - for _, oid := range ids { - if d.HasChange("skip_validation") { - o, n := d.GetChange("skip_validation") - log.Printf("[DEBUG] skip_validation changed from %v to %v", o, n) + tagId, _, objectType, err := TagIdentifierAndObjectIdentifier(d) + if err != nil { + return diag.FromErr(err) + } + if d.HasChanges("object_identifiers", "tag_value") { + tagValue := d.Get("tag_value").(string) + + o, n := d.GetChange("object_identifiers") + + oldIds, err := ExpandObjectIdentifierSet(o.(*schema.Set).List(), objectType) + if err != nil { + return diag.FromErr(err) + } + newIds, err := ExpandObjectIdentifierSet(n.(*schema.Set).List(), objectType) + if err != nil { + return diag.FromErr(err) + } + + addedIds, removedIds, commonIds := ListDiffWithCommonItems(oldIds, newIds) + + for _, id := range addedIds { + request := sdk.NewSetTagRequest(objectType, id).WithSetTags([]sdk.TagAssociation{ + { + Name: tagId, + Value: tagValue, + }, + }) + if err := client.Tags.Set(ctx, request); err != nil { + return diag.FromErr(err) + } + } + + for _, id := range removedIds { + if objectType == sdk.ObjectTypeColumn { + skip, err := skipColumnIfDoesNotExist(ctx, client, id) + if err != nil { + return diag.FromErr(err) + } + if skip { + continue + } + } + request := sdk.NewUnsetTagRequest(objectType, id).WithUnsetTags([]sdk.ObjectIdentifier{tagId}).WithIfExists(true) + if err := client.Tags.Unset(ctx, request); err != nil { + return diag.FromErr(err) + } } + if d.HasChange("tag_value") { - tagValue, ok := d.GetOk("tag_value") - if ok { - request := sdk.NewSetTagRequest(ot, oid).WithSetTags([]sdk.TagAssociation{ + for _, id := range commonIds { + request := sdk.NewSetTagRequest(objectType, id).WithSetTags([]sdk.TagAssociation{ { - Name: tid, - Value: tagValue.(string), + Name: tagId, + Value: tagValue, }, }) if err := client.Tags.Set(ctx, request); err != nil { return diag.FromErr(err) } - } else { - request := sdk.NewUnsetTagRequest(ot, oid).WithUnsetTags([]sdk.ObjectIdentifier{tid}) - if err := client.Tags.Unset(ctx, request); err != nil { - return diag.FromErr(err) - } } + d.SetId(helpers.EncodeResourceIdentifier(tagId.FullyQualifiedName(), tagValue, string(objectType))) } } + return ReadContextTagAssociation(ctx, d, meta) } func DeleteContextTagAssociation(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client - tid, ids, ot := TagIdentifierAndObjectIdentifier(d) - for _, oid := range ids { - request := sdk.NewUnsetTagRequest(ot, oid).WithUnsetTags([]sdk.ObjectIdentifier{tid}) + tagId, ids, objectType, err := TagIdentifierAndObjectIdentifier(d) + if err != nil { + return diag.FromErr(err) + } + for _, id := range ids { + if objectType == sdk.ObjectTypeColumn { + skip, err := skipColumnIfDoesNotExist(ctx, client, id) + if err != nil { + return diag.FromErr(err) + } + if skip { + continue + } + } + request := sdk.NewUnsetTagRequest(objectType, id).WithUnsetTags([]sdk.ObjectIdentifier{tagId}).WithIfExists(true) if err := client.Tags.Unset(ctx, request); err != nil { return diag.FromErr(err) } @@ -249,3 +301,31 @@ func DeleteContextTagAssociation(ctx context.Context, d *schema.ResourceData, me d.SetId("") return nil } + +// we need to skip the column manually, because ALTER COLUMN lacks IF EXISTS +func skipColumnIfDoesNotExist(ctx context.Context, client *sdk.Client, id sdk.ObjectIdentifier) (bool, error) { + columnId, ok := id.(sdk.TableColumnIdentifier) + if !ok { + return false, errors.New("invalid column identifier") + } + // TODO [SNOW-1007542]: use SHOW COLUMNS + _, err := client.Tables.ShowByID(ctx, columnId.SchemaObjectId()) + if err != nil { + if errors.Is(err, sdk.ErrObjectNotFound) { + log.Printf("[DEBUG] table %s not found, skipping\n", columnId.SchemaObjectId()) + return true, nil + } + return false, err + } + columns, err := client.Tables.DescribeColumns(ctx, sdk.NewDescribeTableColumnsRequest(columnId.SchemaObjectId())) + if err != nil { + return false, err + } + if _, err := collections.FindFirst(columns, func(c sdk.TableColumnDetails) bool { + return c.Name == columnId.Name() + }); err != nil { + log.Printf("[DEBUG] column %s not found in table %s, skipping\n", columnId.Name(), columnId.SchemaObjectId()) + return true, nil + } + return false, nil +} diff --git a/pkg/resources/tag_association_acceptance_test.go b/pkg/resources/tag_association_acceptance_test.go index 37e1d54804..86e616e529 100644 --- a/pkg/resources/tag_association_acceptance_test.go +++ b/pkg/resources/tag_association_acceptance_test.go @@ -3,27 +3,41 @@ package resources_test import ( "context" "fmt" + "strings" "testing" acc "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/bettertestspoc/assert" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/bettertestspoc/assert/resourceassert" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/bettertestspoc/config" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/bettertestspoc/config/model" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/testenvs" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" - "github.com/hashicorp/terraform-plugin-testing/config" + tfconfig "github.com/hashicorp/terraform-plugin-testing/config" "github.com/hashicorp/terraform-plugin-testing/helper/resource" + "github.com/hashicorp/terraform-plugin-testing/plancheck" "github.com/hashicorp/terraform-plugin-testing/terraform" "github.com/hashicorp/terraform-plugin-testing/tfversion" ) func TestAcc_TagAssociation(t *testing.T) { + _ = testenvs.GetOrSkipTest(t, testenvs.EnableAcceptance) tagId := acc.TestClient().Ids.RandomSchemaObjectIdentifier() + tag2Id := acc.TestClient().Ids.RandomSchemaObjectIdentifier() + tagValue := "foo" + tagValue2 := "bar" + databaseId := acc.TestClient().Ids.DatabaseId() resourceName := "snowflake_tag_association.test" - m := func() map[string]config.Variable { - return map[string]config.Variable{ - "tag_name": config.StringVariable(tagId.Name()), - "database": config.StringVariable(acc.TestDatabaseName), - "schema": config.StringVariable(acc.TestSchemaName), + m := func(tagId sdk.SchemaObjectIdentifier, tagValue string) map[string]tfconfig.Variable { + return map[string]tfconfig.Variable{ + "tag_name": tfconfig.StringVariable(tagId.Name()), + "tag_value": tfconfig.StringVariable(tagValue), + "database": tfconfig.StringVariable(databaseId.Name()), + "schema": tfconfig.StringVariable(acc.TestSchemaName), + "database_fully_qualified_name": tfconfig.StringVariable(databaseId.FullyQualifiedName()), } } resource.Test(t, resource.TestCase{ @@ -32,15 +46,229 @@ func TestAcc_TagAssociation(t *testing.T) { TerraformVersionChecks: []tfversion.TerraformVersionCheck{ tfversion.RequireAbove(tfversion.Version1_5_0), }, - CheckDestroy: nil, + CheckDestroy: acc.CheckResourceTagUnset(t), Steps: []resource.TestStep{ { ConfigDirectory: acc.ConfigurationDirectory("TestAcc_TagAssociation/basic"), - ConfigVariables: m(), - Check: resource.ComposeTestCheckFunc( - resource.TestCheckResourceAttr(resourceName, "object_type", "DATABASE"), + ConfigVariables: m(tagId, tagValue), + Check: resource.ComposeAggregateTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "id", helpers.EncodeSnowflakeID(tagId.FullyQualifiedName(), tagValue, string(sdk.ObjectTypeDatabase))), + resource.TestCheckResourceAttr(resourceName, "object_type", string(sdk.ObjectTypeDatabase)), + resource.TestCheckResourceAttr(resourceName, "object_identifiers.#", "1"), + resource.TestCheckTypeSetElemAttr(resourceName, "object_identifiers.*", acc.TestClient().Ids.DatabaseId().FullyQualifiedName()), + resource.TestCheckResourceAttr(resourceName, "tag_id", tagId.FullyQualifiedName()), + resource.TestCheckResourceAttr(resourceName, "tag_value", tagValue), + ), + }, + // external change - unset tag + { + PreConfig: func() { + acc.TestClient().Tag.Unset(t, sdk.ObjectTypeDatabase, databaseId, []sdk.ObjectIdentifier{tagId}) + }, + ConfigDirectory: acc.ConfigurationDirectory("TestAcc_TagAssociation/basic"), + ConfigVariables: m(tagId, tagValue), + Check: resource.ComposeAggregateTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "id", helpers.EncodeSnowflakeID(tagId.FullyQualifiedName(), tagValue, string(sdk.ObjectTypeDatabase))), + resource.TestCheckResourceAttr(resourceName, "object_type", string(sdk.ObjectTypeDatabase)), + resource.TestCheckResourceAttr(resourceName, "object_identifiers.#", "1"), + resource.TestCheckTypeSetElemAttr(resourceName, "object_identifiers.*", acc.TestClient().Ids.DatabaseId().FullyQualifiedName()), + resource.TestCheckResourceAttr(resourceName, "tag_id", tagId.FullyQualifiedName()), + resource.TestCheckResourceAttr(resourceName, "tag_value", tagValue), + ), + }, + // external change - set a different value + { + PreConfig: func() { + acc.TestClient().Tag.Set(t, sdk.ObjectTypeDatabase, databaseId, []sdk.TagAssociation{ + { + Name: tagId, + Value: "external", + }, + }) + }, + ConfigDirectory: acc.ConfigurationDirectory("TestAcc_TagAssociation/basic"), + ConfigVariables: m(tagId, tagValue), + Check: resource.ComposeAggregateTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "id", helpers.EncodeSnowflakeID(tagId.FullyQualifiedName(), tagValue, string(sdk.ObjectTypeDatabase))), + resource.TestCheckResourceAttr(resourceName, "object_type", string(sdk.ObjectTypeDatabase)), + resource.TestCheckResourceAttr(resourceName, "object_identifiers.#", "1"), + resource.TestCheckTypeSetElemAttr(resourceName, "object_identifiers.*", acc.TestClient().Ids.DatabaseId().FullyQualifiedName()), + resource.TestCheckResourceAttr(resourceName, "tag_id", tagId.FullyQualifiedName()), + resource.TestCheckResourceAttr(resourceName, "tag_value", tagValue), + ), + }, + // change tag value + { + ConfigDirectory: acc.ConfigurationDirectory("TestAcc_TagAssociation/basic"), + ConfigVariables: m(tagId, tagValue2), + ConfigPlanChecks: resource.ConfigPlanChecks{ + PreApply: []plancheck.PlanCheck{ + plancheck.ExpectResourceAction(resourceName, plancheck.ResourceActionUpdate), + }, + }, + Check: resource.ComposeAggregateTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "id", helpers.EncodeSnowflakeID(tagId.FullyQualifiedName(), tagValue2, string(sdk.ObjectTypeDatabase))), + resource.TestCheckResourceAttr(resourceName, "object_type", string(sdk.ObjectTypeDatabase)), + resource.TestCheckResourceAttr(resourceName, "object_identifiers.#", "1"), + resource.TestCheckTypeSetElemAttr(resourceName, "object_identifiers.*", acc.TestClient().Ids.DatabaseId().FullyQualifiedName()), resource.TestCheckResourceAttr(resourceName, "tag_id", tagId.FullyQualifiedName()), - resource.TestCheckResourceAttr(resourceName, "tag_value", "finance"), + resource.TestCheckResourceAttr(resourceName, "tag_value", tagValue2), + ), + }, + // change tag id + { + ConfigDirectory: acc.ConfigurationDirectory("TestAcc_TagAssociation/basic"), + ConfigVariables: m(tag2Id, tagValue2), + ConfigPlanChecks: resource.ConfigPlanChecks{ + PreApply: []plancheck.PlanCheck{ + plancheck.ExpectResourceAction(resourceName, plancheck.ResourceActionDestroyBeforeCreate), + }, + }, + Check: resource.ComposeAggregateTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "id", helpers.EncodeSnowflakeID(tag2Id.FullyQualifiedName(), tagValue2, string(sdk.ObjectTypeDatabase))), + resource.TestCheckResourceAttr(resourceName, "object_type", string(sdk.ObjectTypeDatabase)), + resource.TestCheckResourceAttr(resourceName, "object_identifiers.#", "1"), + resource.TestCheckTypeSetElemAttr(resourceName, "object_identifiers.*", acc.TestClient().Ids.DatabaseId().FullyQualifiedName()), + resource.TestCheckResourceAttr(resourceName, "tag_id", tag2Id.FullyQualifiedName()), + resource.TestCheckResourceAttr(resourceName, "tag_value", tagValue2), + acc.CheckTagUnset(t, tagId, acc.TestClient().Ids.DatabaseId(), sdk.ObjectTypeDatabase), + ), + }, + { + ConfigVariables: m(tag2Id, tagValue2), + ConfigDirectory: acc.ConfigurationDirectory("TestAcc_TagAssociation/basic"), + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + // object_identifiers does not get set because during the import, the configuration is considered as empty + ImportStateVerifyIgnore: []string{"skip_validation", "object_identifiers.#", "object_identifiers.0"}, + }, + // after refreshing the state, object_identifiers is correct + { + RefreshState: true, + Check: resource.ComposeAggregateTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "id", helpers.EncodeSnowflakeID(tag2Id.FullyQualifiedName(), tagValue2, string(sdk.ObjectTypeDatabase))), + resource.TestCheckResourceAttr(resourceName, "object_type", string(sdk.ObjectTypeDatabase)), + resource.TestCheckResourceAttr(resourceName, "object_identifiers.#", "1"), + resource.TestCheckTypeSetElemAttr(resourceName, "object_identifiers.*", acc.TestClient().Ids.DatabaseId().FullyQualifiedName()), + resource.TestCheckResourceAttr(resourceName, "tag_id", tag2Id.FullyQualifiedName()), + resource.TestCheckResourceAttr(resourceName, "tag_value", tagValue2), + ), + }, + }, + }) +} + +func TestAcc_TagAssociation_objectIdentifiers(t *testing.T) { + _ = testenvs.GetOrSkipTest(t, testenvs.EnableAcceptance) + acc.TestAccPreCheck(t) + + tag, tagCleanup := acc.TestClient().Tag.CreateTag(t) + t.Cleanup(tagCleanup) + dbRole1, dbRole1Cleanup := acc.TestClient().DatabaseRole.CreateDatabaseRole(t) + t.Cleanup(dbRole1Cleanup) + dbRole2, dbRole2Cleanup := acc.TestClient().DatabaseRole.CreateDatabaseRole(t) + t.Cleanup(dbRole2Cleanup) + dbRole3, dbRole3Cleanup := acc.TestClient().DatabaseRole.CreateDatabaseRole(t) + t.Cleanup(dbRole3Cleanup) + + model12 := model.TagAssociation("test", []sdk.ObjectIdentifier{dbRole1.ID(), dbRole2.ID()}, string(sdk.ObjectTypeDatabaseRole), tag.ID().FullyQualifiedName(), "foo") + model123 := model.TagAssociation("test", []sdk.ObjectIdentifier{dbRole1.ID(), dbRole2.ID(), dbRole3.ID()}, string(sdk.ObjectTypeDatabaseRole), tag.ID().FullyQualifiedName(), "foo") + model13 := model.TagAssociation("test", []sdk.ObjectIdentifier{dbRole1.ID(), dbRole3.ID()}, string(sdk.ObjectTypeDatabaseRole), tag.ID().FullyQualifiedName(), "foo") + + resource.Test(t, resource.TestCase{ + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: resource.ComposeAggregateTestCheckFunc( + acc.CheckResourceTagUnset(t), + ), + Steps: []resource.TestStep{ + { + Config: config.FromModel(t, model12), + Check: assert.AssertThat(t, resourceassert.TagAssociationResource(t, model12.ResourceReference()). + HasObjectTypeString(string(sdk.ObjectTypeDatabaseRole)). + HasTagIdString(tag.ID().FullyQualifiedName()). + HasObjectIdentifiersLength(2). + HasTagValueString("foo"), + assert.Check(resource.TestCheckTypeSetElemAttr(model12.ResourceReference(), "object_identifiers.*", dbRole1.ID().FullyQualifiedName())), + assert.Check(resource.TestCheckTypeSetElemAttr(model12.ResourceReference(), "object_identifiers.*", dbRole2.ID().FullyQualifiedName())), + ), + }, + { + Config: config.FromModel(t, model123), + Check: assert.AssertThat(t, resourceassert.TagAssociationResource(t, model12.ResourceReference()). + HasObjectTypeString(string(sdk.ObjectTypeDatabaseRole)). + HasTagIdString(tag.ID().FullyQualifiedName()). + HasObjectIdentifiersLength(3). + HasTagValueString("foo"), + assert.Check(resource.TestCheckTypeSetElemAttr(model12.ResourceReference(), "object_identifiers.*", dbRole1.ID().FullyQualifiedName())), + assert.Check(resource.TestCheckTypeSetElemAttr(model12.ResourceReference(), "object_identifiers.*", dbRole2.ID().FullyQualifiedName())), + assert.Check(resource.TestCheckTypeSetElemAttr(model12.ResourceReference(), "object_identifiers.*", dbRole3.ID().FullyQualifiedName())), + ), + }, + { + Config: config.FromModel(t, model13), + Check: assert.AssertThat(t, resourceassert.TagAssociationResource(t, model13.ResourceReference()). + HasObjectTypeString(string(sdk.ObjectTypeDatabaseRole)). + HasTagIdString(tag.ID().FullyQualifiedName()). + HasObjectIdentifiersLength(2). + HasTagValueString("foo"), + assert.Check(resource.TestCheckTypeSetElemAttr(model13.ResourceReference(), "object_identifiers.*", dbRole1.ID().FullyQualifiedName())), + assert.Check(resource.TestCheckTypeSetElemAttr(model13.ResourceReference(), "object_identifiers.*", dbRole3.ID().FullyQualifiedName())), + assert.Check(acc.CheckTagUnset(t, tag.ID(), dbRole2.ID(), sdk.ObjectTypeDatabaseRole)), + ), + }, + }, + }) +} + +func TestAcc_TagAssociation_objectType(t *testing.T) { + _ = testenvs.GetOrSkipTest(t, testenvs.EnableAcceptance) + acc.TestAccPreCheck(t) + + tag, tagCleanup := acc.TestClient().Tag.CreateTag(t) + t.Cleanup(tagCleanup) + role, roleCleanup := acc.TestClient().Role.CreateRole(t) + t.Cleanup(roleCleanup) + dbRole, dbRoleCleanup := acc.TestClient().DatabaseRole.CreateDatabaseRole(t) + t.Cleanup(dbRoleCleanup) + + baseModel := model.TagAssociation("test", []sdk.ObjectIdentifier{role.ID()}, string(sdk.ObjectTypeRole), tag.ID().FullyQualifiedName(), "foo") + modelWithDifferentObjectType := model.TagAssociation("test", []sdk.ObjectIdentifier{dbRole.ID()}, string(sdk.ObjectTypeDatabaseRole), tag.ID().FullyQualifiedName(), "foo") + + resource.Test(t, resource.TestCase{ + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: resource.ComposeAggregateTestCheckFunc( + acc.CheckResourceTagUnset(t), + ), + Steps: []resource.TestStep{ + { + Config: config.FromModel(t, baseModel), + Check: assert.AssertThat(t, resourceassert.TagAssociationResource(t, baseModel.ResourceReference()). + HasObjectTypeString(string(sdk.ObjectTypeRole)). + HasTagIdString(tag.ID().FullyQualifiedName()). + HasObjectIdentifiersLength(1). + HasTagValueString("foo"), + ), + }, + { + Config: config.FromModel(t, modelWithDifferentObjectType), + ConfigPlanChecks: resource.ConfigPlanChecks{ + PreApply: []plancheck.PlanCheck{ + plancheck.ExpectResourceAction(modelWithDifferentObjectType.ResourceReference(), plancheck.ResourceActionDestroyBeforeCreate), + }, + }, + Check: assert.AssertThat(t, resourceassert.TagAssociationResource(t, baseModel.ResourceReference()). + HasObjectTypeString(string(sdk.ObjectTypeDatabaseRole)). + HasTagIdString(tag.ID().FullyQualifiedName()). + HasObjectIdentifiersLength(1). + HasTagValueString("foo"), + assert.Check(acc.CheckTagUnset(t, tag.ID(), role.ID(), sdk.ObjectTypeRole)), ), }, }, @@ -48,13 +276,16 @@ func TestAcc_TagAssociation(t *testing.T) { } func TestAcc_TagAssociationSchema(t *testing.T) { + _ = testenvs.GetOrSkipTest(t, testenvs.EnableAcceptance) tagId := acc.TestClient().Ids.RandomSchemaObjectIdentifier() + schemaId := acc.TestClient().Ids.SchemaId() resourceName := "snowflake_tag_association.test" - m := func() map[string]config.Variable { - return map[string]config.Variable{ - "tag_name": config.StringVariable(tagId.Name()), - "database": config.StringVariable(acc.TestDatabaseName), - "schema": config.StringVariable(acc.TestSchemaName), + m := func() map[string]tfconfig.Variable { + return map[string]tfconfig.Variable{ + "tag_name": tfconfig.StringVariable(tagId.Name()), + "database": tfconfig.StringVariable(acc.TestDatabaseName), + "schema": tfconfig.StringVariable(acc.TestSchemaName), + "schema_fully_qualified_name": tfconfig.StringVariable(schemaId.FullyQualifiedName()), } } resource.Test(t, resource.TestCase{ @@ -68,8 +299,46 @@ func TestAcc_TagAssociationSchema(t *testing.T) { { ConfigDirectory: acc.ConfigurationDirectory("TestAcc_TagAssociation/schema"), ConfigVariables: m(), - Check: resource.ComposeTestCheckFunc( - resource.TestCheckResourceAttr(resourceName, "object_type", "SCHEMA"), + Check: resource.ComposeAggregateTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "id", helpers.EncodeSnowflakeID(tagId.FullyQualifiedName(), "TAG_VALUE", string(sdk.ObjectTypeSchema))), + resource.TestCheckResourceAttr(resourceName, "object_type", string(sdk.ObjectTypeSchema)), + resource.TestCheckResourceAttr(resourceName, "object_identifiers.#", "1"), + resource.TestCheckTypeSetElemAttr(resourceName, "object_identifiers.*", schemaId.FullyQualifiedName()), + resource.TestCheckResourceAttr(resourceName, "tag_id", tagId.FullyQualifiedName()), + resource.TestCheckResourceAttr(resourceName, "tag_value", "TAG_VALUE"), + ), + }, + }, + }) +} + +// proves https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/3235 is fixed +func TestAcc_TagAssociation_lowercaseObjectType(t *testing.T) { + _ = testenvs.GetOrSkipTest(t, testenvs.EnableAcceptance) + acc.TestAccPreCheck(t) + + tag, tagCleanup := acc.TestClient().Tag.CreateTag(t) + t.Cleanup(tagCleanup) + objectType := strings.ToLower(string(sdk.ObjectTypeSchema)) + objectId := acc.TestClient().Ids.SchemaId() + + model := model.TagAssociation("test", []sdk.ObjectIdentifier{objectId}, objectType, tag.ID().FullyQualifiedName(), "foo") + resource.Test(t, resource.TestCase{ + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: nil, + Steps: []resource.TestStep{ + { + Config: config.FromModel(t, model), + Check: assert.AssertThat(t, resourceassert.TagAssociationResource(t, model.ResourceReference()). + HasIdString(helpers.EncodeSnowflakeID(tag.ID().FullyQualifiedName(), "foo", string(sdk.ObjectTypeSchema))). + HasObjectTypeString(string(sdk.ObjectTypeSchema)). + HasTagIdString(tag.ID().FullyQualifiedName()). + HasObjectIdentifiersLength(1). + HasTagValueString("foo"), ), }, }, @@ -77,15 +346,20 @@ func TestAcc_TagAssociationSchema(t *testing.T) { } func TestAcc_TagAssociationColumn(t *testing.T) { + _ = testenvs.GetOrSkipTest(t, testenvs.EnableAcceptance) + tagId := acc.TestClient().Ids.RandomSchemaObjectIdentifier() - tableName := acc.TestClient().Ids.Alpha() + tableId := acc.TestClient().Ids.RandomSchemaObjectIdentifier() + columnId := sdk.NewTableColumnIdentifier(tableId.DatabaseName(), tableId.SchemaName(), tableId.Name(), "column") resourceName := "snowflake_tag_association.test" - m := func() map[string]config.Variable { - return map[string]config.Variable{ - "tag_name": config.StringVariable(tagId.Name()), - "table_name": config.StringVariable(tableName), - "database": config.StringVariable(acc.TestDatabaseName), - "schema": config.StringVariable(acc.TestSchemaName), + m := func() map[string]tfconfig.Variable { + return map[string]tfconfig.Variable{ + "tag_name": tfconfig.StringVariable(tagId.Name()), + "table_name": tfconfig.StringVariable(tableId.Name()), + "database": tfconfig.StringVariable(acc.TestDatabaseName), + "schema": tfconfig.StringVariable(acc.TestSchemaName), + "column": tfconfig.StringVariable("column"), + "column_fully_qualified_name": tfconfig.StringVariable(columnId.FullyQualifiedName()), } } resource.Test(t, resource.TestCase{ @@ -99,29 +373,31 @@ func TestAcc_TagAssociationColumn(t *testing.T) { { ConfigDirectory: acc.ConfigurationDirectory("TestAcc_TagAssociation/column"), ConfigVariables: m(), - Check: resource.ComposeTestCheckFunc( - resource.TestCheckResourceAttr(resourceName, "object_type", "COLUMN"), + Check: resource.ComposeAggregateTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "id", helpers.EncodeSnowflakeID(tagId.FullyQualifiedName(), "TAG_VALUE", string(sdk.ObjectTypeColumn))), + resource.TestCheckResourceAttr(resourceName, "object_type", string(sdk.ObjectTypeColumn)), + resource.TestCheckResourceAttr(resourceName, "object_identifiers.#", "1"), + resource.TestCheckTypeSetElemAttr(resourceName, "object_identifiers.*", columnId.FullyQualifiedName()), resource.TestCheckResourceAttr(resourceName, "tag_id", tagId.FullyQualifiedName()), resource.TestCheckResourceAttr(resourceName, "tag_value", "TAG_VALUE"), - resource.TestCheckResourceAttr(resourceName, "object_identifier.0.%", "3"), - resource.TestCheckResourceAttr(resourceName, "object_identifier.0.name", fmt.Sprintf("%s.column_name", tableName)), - resource.TestCheckResourceAttr(resourceName, "object_identifier.0.database", acc.TestDatabaseName), - resource.TestCheckResourceAttr(resourceName, "object_identifier.0.schema", acc.TestSchemaName)), + ), }, }, }) } func TestAcc_TagAssociationIssue1202(t *testing.T) { + _ = testenvs.GetOrSkipTest(t, testenvs.EnableAcceptance) + tagId := acc.TestClient().Ids.RandomSchemaObjectIdentifier() tableName := acc.TestClient().Ids.Alpha() resourceName := "snowflake_tag_association.test" - m := func() map[string]config.Variable { - return map[string]config.Variable{ - "tag_name": config.StringVariable(tagId.Name()), - "table_name": config.StringVariable(tableName), - "database": config.StringVariable(acc.TestDatabaseName), - "schema": config.StringVariable(acc.TestSchemaName), + m := func() map[string]tfconfig.Variable { + return map[string]tfconfig.Variable{ + "tag_name": tfconfig.StringVariable(tagId.Name()), + "table_name": tfconfig.StringVariable(tableName), + "database": tfconfig.StringVariable(acc.TestDatabaseName), + "schema": tfconfig.StringVariable(acc.TestSchemaName), } } resource.Test(t, resource.TestCase{ @@ -135,7 +411,7 @@ func TestAcc_TagAssociationIssue1202(t *testing.T) { { ConfigDirectory: acc.ConfigurationDirectory("TestAcc_TagAssociation/issue1202"), ConfigVariables: m(), - Check: resource.ComposeTestCheckFunc( + Check: resource.ComposeAggregateTestCheckFunc( resource.TestCheckResourceAttr(resourceName, "object_type", "TABLE"), resource.TestCheckResourceAttr(resourceName, "tag_id", tagId.FullyQualifiedName()), resource.TestCheckResourceAttr(resourceName, "tag_value", "v1"), @@ -146,21 +422,24 @@ func TestAcc_TagAssociationIssue1202(t *testing.T) { } func TestAcc_TagAssociationIssue1909(t *testing.T) { + _ = testenvs.GetOrSkipTest(t, testenvs.EnableAcceptance) + tagId := acc.TestClient().Ids.RandomSchemaObjectIdentifier() - tableName := acc.TestClient().Ids.Alpha() - tableName2 := acc.TestClient().Ids.Alpha() - columnName := "test.column" + tableId1 := acc.TestClient().Ids.RandomSchemaObjectIdentifier() + tableId2 := acc.TestClient().Ids.RandomSchemaObjectIdentifier() + columnId1 := sdk.NewTableColumnIdentifier(tableId1.DatabaseName(), tableId1.SchemaName(), tableId1.Name(), "test.column") + columnId2 := sdk.NewTableColumnIdentifier(tableId2.DatabaseName(), tableId2.SchemaName(), tableId2.Name(), "test.column") resourceName := "snowflake_tag_association.test" - objectID := sdk.NewTableColumnIdentifier(acc.TestDatabaseName, acc.TestSchemaName, tableName, columnName) - objectID2 := sdk.NewTableColumnIdentifier(acc.TestDatabaseName, acc.TestSchemaName, tableName2, columnName) - m := func() map[string]config.Variable { - return map[string]config.Variable{ - "tag_name": config.StringVariable(tagId.Name()), - "table_name": config.StringVariable(tableName), - "table_name2": config.StringVariable(tableName2), - "column_name": config.StringVariable("test.column"), - "database": config.StringVariable(acc.TestDatabaseName), - "schema": config.StringVariable(acc.TestSchemaName), + m := func() map[string]tfconfig.Variable { + return map[string]tfconfig.Variable{ + "tag_name": tfconfig.StringVariable(tagId.Name()), + "table_name": tfconfig.StringVariable(tableId1.Name()), + "table_name2": tfconfig.StringVariable(tableId2.Name()), + "column_name": tfconfig.StringVariable("test.column"), + "column_fully_qualified_name": tfconfig.StringVariable(columnId1.FullyQualifiedName()), + "column2_fully_qualified_name": tfconfig.StringVariable(columnId2.FullyQualifiedName()), + "database": tfconfig.StringVariable(acc.TestDatabaseName), + "schema": tfconfig.StringVariable(acc.TestSchemaName), } } resource.Test(t, resource.TestCase{ @@ -174,12 +453,12 @@ func TestAcc_TagAssociationIssue1909(t *testing.T) { { ConfigDirectory: acc.ConfigurationDirectory("TestAcc_TagAssociation/issue1909"), ConfigVariables: m(), - Check: resource.ComposeTestCheckFunc( - resource.TestCheckResourceAttr(resourceName, "object_type", "COLUMN"), + Check: resource.ComposeAggregateTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "object_type", string(sdk.ObjectTypeColumn)), resource.TestCheckResourceAttr(resourceName, "tag_id", tagId.FullyQualifiedName()), resource.TestCheckResourceAttr(resourceName, "tag_value", "v1"), - testAccCheckTableColumnTagAssociation(tagId, objectID, "v1"), - testAccCheckTableColumnTagAssociation(tagId, objectID2, "v1"), + testAccCheckTableColumnTagAssociation(tagId, columnId1, "v1"), + testAccCheckTableColumnTagAssociation(tagId, columnId2, "v1"), ), }, }, @@ -194,25 +473,30 @@ func testAccCheckTableColumnTagAssociation(tagID sdk.SchemaObjectIdentifier, obj if err != nil { return err } - if tagValue != tv { - return fmt.Errorf("expected tag value %s, got %s", tagValue, tv) + if tv == nil { + return fmt.Errorf("expected tag value %s, got nil", tagValue) + } + if tagValue != *tv { + return fmt.Errorf("expected tag value %s, got %s", tagValue, *tv) } return nil } } +// TODO(SNOW-1165821): use a separate account with ORGADMIN in CI + func TestAcc_TagAssociationAccountIssues1910(t *testing.T) { - // todo: use role with ORGADMIN in CI (SNOW-1165821) - _ = testenvs.GetOrSkipTest(t, testenvs.TestAccountCreate) + _ = testenvs.GetOrSkipTest(t, testenvs.EnableAcceptance) + tagId := acc.TestClient().Ids.RandomSchemaObjectIdentifier() - accountName := acc.TestClient().Ids.Alpha() + accountId := acc.TestClient().Context.CurrentAccountIdentifier(t) resourceName := "snowflake_tag_association.test" - m := func() map[string]config.Variable { - return map[string]config.Variable{ - "tag_name": config.StringVariable(tagId.Name()), - "account_name": config.StringVariable(accountName), - "database": config.StringVariable(acc.TestDatabaseName), - "schema": config.StringVariable(acc.TestSchemaName), + m := func() map[string]tfconfig.Variable { + return map[string]tfconfig.Variable{ + "tag_name": tfconfig.StringVariable(tagId.Name()), + "account_fully_qualified_name": tfconfig.StringVariable(accountId.FullyQualifiedName()), + "database": tfconfig.StringVariable(acc.TestDatabaseName), + "schema": tfconfig.StringVariable(acc.TestSchemaName), } } @@ -227,9 +511,11 @@ func TestAcc_TagAssociationAccountIssues1910(t *testing.T) { { ConfigDirectory: acc.ConfigurationDirectory("TestAcc_TagAssociation/issue1910"), ConfigVariables: m(), - Check: resource.ComposeTestCheckFunc( - resource.TestCheckResourceAttr(resourceName, "object_type", "ACCOUNT"), - resource.TestCheckResourceAttr(resourceName, "tag_id", tagId.Name()), + Check: resource.ComposeAggregateTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "object_type", string(sdk.ObjectTypeAccount)), + resource.TestCheckResourceAttr(resourceName, "object_identifiers.#", "1"), + resource.TestCheckTypeSetElemAttr(resourceName, "object_identifiers.*", accountId.FullyQualifiedName()), + resource.TestCheckResourceAttr(resourceName, "tag_id", tagId.FullyQualifiedName()), resource.TestCheckResourceAttr(resourceName, "tag_value", "v1"), ), }, @@ -238,29 +524,34 @@ func TestAcc_TagAssociationAccountIssues1910(t *testing.T) { } func TestAcc_TagAssociationIssue1926(t *testing.T) { + _ = testenvs.GetOrSkipTest(t, testenvs.EnableAcceptance) + tagId := acc.TestClient().Ids.RandomSchemaObjectIdentifier() - tableName := acc.TestClient().Ids.Alpha() + tableId1 := acc.TestClient().Ids.RandomSchemaObjectIdentifier() + columnId1 := sdk.NewTableColumnIdentifier(tableId1.DatabaseName(), tableId1.SchemaName(), tableId1.Name(), "init") resourceName := "snowflake_tag_association.test" - columnName := "test.column" - m := func() map[string]config.Variable { - return map[string]config.Variable{ - "tag_name": config.StringVariable(tagId.Name()), - "table_name": config.StringVariable(tableName), - "column_name": config.StringVariable(columnName), - "database": config.StringVariable(acc.TestDatabaseName), - "schema": config.StringVariable(acc.TestSchemaName), + m := func() map[string]tfconfig.Variable { + return map[string]tfconfig.Variable{ + "tag_name": tfconfig.StringVariable(tagId.Name()), + "table_name": tfconfig.StringVariable(tableId1.Name()), + "column_name": tfconfig.StringVariable(columnId1.Name()), + "column_fully_qualified_name": tfconfig.StringVariable(columnId1.FullyQualifiedName()), + "database": tfconfig.StringVariable(acc.TestDatabaseName), + "schema": tfconfig.StringVariable(acc.TestSchemaName), } } m2 := m() - tableName2 := "table.test" - columnName2 := "column" - columnName3 := "column.test" - m2["table_name"] = config.StringVariable(tableName2) - m2["column_name"] = config.StringVariable(columnName2) + tableId2 := acc.TestClient().Ids.RandomSchemaObjectIdentifierWithPrefix("table.test") + columnId2 := sdk.NewTableColumnIdentifier(tableId2.DatabaseName(), tableId2.SchemaName(), tableId2.Name(), "column") + columnId3 := sdk.NewTableColumnIdentifier(tableId2.DatabaseName(), tableId2.SchemaName(), tableId2.Name(), "column.test") + m2["table_name"] = tfconfig.StringVariable(tableId2.Name()) + m2["column_name"] = tfconfig.StringVariable(columnId2.Name()) + m2["column_fully_qualified_name"] = tfconfig.StringVariable(columnId2.FullyQualifiedName()) m3 := m() - m3["table_name"] = config.StringVariable(tableName2) - m3["column_name"] = config.StringVariable(columnName3) + m3["table_name"] = tfconfig.StringVariable(tableId2.Name()) + m3["column_name"] = tfconfig.StringVariable(columnId3.Name()) + m3["column_fully_qualified_name"] = tfconfig.StringVariable(columnId3.FullyQualifiedName()) resource.Test(t, resource.TestCase{ ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, PreCheck: func() { acc.TestAccPreCheck(t) }, @@ -272,44 +563,122 @@ func TestAcc_TagAssociationIssue1926(t *testing.T) { { ConfigDirectory: acc.ConfigurationDirectory("TestAcc_TagAssociation/issue1926"), ConfigVariables: m(), - Check: resource.ComposeTestCheckFunc( - resource.TestCheckResourceAttr(resourceName, "object_type", "COLUMN"), + Check: resource.ComposeAggregateTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "id", helpers.EncodeSnowflakeID(tagId.FullyQualifiedName(), "TAG_VALUE", string(sdk.ObjectTypeColumn))), + resource.TestCheckResourceAttr(resourceName, "object_type", string(sdk.ObjectTypeColumn)), + resource.TestCheckResourceAttr(resourceName, "object_identifiers.#", "1"), + resource.TestCheckTypeSetElemAttr(resourceName, "object_identifiers.*", columnId1.FullyQualifiedName()), resource.TestCheckResourceAttr(resourceName, "tag_id", tagId.FullyQualifiedName()), - resource.TestCheckResourceAttr(resourceName, "tag_value", "v1"), - resource.TestCheckResourceAttr(resourceName, "object_identifier.0.%", "3"), - resource.TestCheckResourceAttr(resourceName, "object_identifier.0.name", fmt.Sprintf("%s.%s", tableName, columnName)), - resource.TestCheckResourceAttr(resourceName, "object_identifier.0.database", acc.TestDatabaseName), - resource.TestCheckResourceAttr(resourceName, "object_identifier.0.schema", acc.TestSchemaName), + resource.TestCheckResourceAttr(resourceName, "tag_value", "TAG_VALUE"), + ), + }, + { + ConfigDirectory: acc.ConfigurationDirectory("TestAcc_TagAssociation/issue1926"), + ConfigVariables: m2, + Check: resource.ComposeAggregateTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "id", helpers.EncodeSnowflakeID(tagId.FullyQualifiedName(), "TAG_VALUE", string(sdk.ObjectTypeColumn))), + resource.TestCheckResourceAttr(resourceName, "object_type", string(sdk.ObjectTypeColumn)), + resource.TestCheckResourceAttr(resourceName, "object_identifiers.#", "1"), + resource.TestCheckTypeSetElemAttr(resourceName, "object_identifiers.*", columnId2.FullyQualifiedName()), + resource.TestCheckResourceAttr(resourceName, "tag_id", tagId.FullyQualifiedName()), + resource.TestCheckResourceAttr(resourceName, "tag_value", "TAG_VALUE"), + ), + }, + { + ConfigDirectory: acc.ConfigurationDirectory("TestAcc_TagAssociation/issue1926"), + ConfigVariables: m3, + Check: resource.ComposeAggregateTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "id", helpers.EncodeSnowflakeID(tagId.FullyQualifiedName(), "TAG_VALUE", string(sdk.ObjectTypeColumn))), + resource.TestCheckResourceAttr(resourceName, "object_type", string(sdk.ObjectTypeColumn)), + resource.TestCheckResourceAttr(resourceName, "object_identifiers.#", "1"), + resource.TestCheckTypeSetElemAttr(resourceName, "object_identifiers.*", columnId3.FullyQualifiedName()), + resource.TestCheckResourceAttr(resourceName, "tag_id", tagId.FullyQualifiedName()), + resource.TestCheckResourceAttr(resourceName, "tag_value", "TAG_VALUE"), ), }, - /* - todo: (SNOW-1205719) uncomment this - { - ConfigDirectory: acc.ConfigurationDirectory("TestAcc_TagAssociation/issue1926"), - ConfigVariables: m2, - Check: resource.ComposeTestCheckFunc( - resource.TestCheckResourceAttr(resourceName, "object_type", "COLUMN"), - resource.TestCheckResourceAttr(resourceName, "tag_id", fmt.Sprintf("%s|%s|%s", acc.TestDatabaseName, acc.TestSchemaName, tagName)), - resource.TestCheckResourceAttr(resourceName, "tag_value", "v1"), - resource.TestCheckResourceAttr(resourceName, "object_identifier.0.%", "3"), - resource.TestCheckResourceAttr(resourceName, "object_identifier.0.name", fmt.Sprintf("%s.%s", tableName2, columnName2)), - resource.TestCheckResourceAttr(resourceName, "object_identifier.0.database", acc.TestDatabaseName), - resource.TestCheckResourceAttr(resourceName, "object_identifier.0.schema", acc.TestSchemaName), - ), + }, + }) +} + +func TestAcc_TagAssociation_migrateFromVersion_0_98_0(t *testing.T) { + t.Setenv(string(testenvs.ConfigureClientOnce), "") + _ = testenvs.GetOrSkipTest(t, testenvs.EnableAcceptance) + acc.TestAccPreCheck(t) + tagId := acc.TestClient().Ids.RandomSchemaObjectIdentifier() + resourceName := "snowflake_tag_association.test" + schemaId := acc.TestClient().Ids.SchemaId() + + m := func() tfconfig.Variables { + return tfconfig.Variables{ + "tag_name": tfconfig.StringVariable(tagId.Name()), + "database": tfconfig.StringVariable(acc.TestDatabaseName), + "schema": tfconfig.StringVariable(acc.TestSchemaName), + "schema_fully_qualified_name": tfconfig.StringVariable(schemaId.FullyQualifiedName()), + } + } + + resource.Test(t, resource.TestCase{ + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + + Steps: []resource.TestStep{ + { + ExternalProviders: acc.ExternalProviderWithExactVersion("0.98.0"), + Config: tagAssociation_v_0_98_0(tagId, "TAG_VALUE", sdk.ObjectTypeSchema, schemaId), + Check: resource.ComposeAggregateTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "id", helpers.EncodeSnowflakeID(tagId.DatabaseName(), tagId.SchemaName(), tagId.Name())), + resource.TestCheckResourceAttr(resourceName, "object_type", string(sdk.ObjectTypeSchema)), + resource.TestCheckResourceAttr(resourceName, "object_identifier.#", "1"), + resource.TestCheckResourceAttr(resourceName, "object_identifier.0.name", schemaId.Name()), + resource.TestCheckResourceAttr(resourceName, "object_identifier.0.database", schemaId.DatabaseName()), + resource.TestCheckResourceAttr(resourceName, "object_identifier.0.schema", ""), + resource.TestCheckResourceAttr(resourceName, "tag_id", tagId.FullyQualifiedName()), + resource.TestCheckResourceAttr(resourceName, "tag_value", "TAG_VALUE"), + ), + }, + { + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + ConfigDirectory: acc.ConfigurationDirectory("TestAcc_TagAssociation/schema"), + ConfigVariables: m(), + ConfigPlanChecks: resource.ConfigPlanChecks{ + PreApply: []plancheck.PlanCheck{ + plancheck.ExpectResourceAction(resourceName, plancheck.ResourceActionNoop), + }, + PostApplyPostRefresh: []plancheck.PlanCheck{ + plancheck.ExpectResourceAction(resourceName, plancheck.ResourceActionNoop), + }, }, - { - ConfigDirectory: acc.ConfigurationDirectory("TestAcc_TagAssociation/issue1926"), - ConfigVariables: m3, - Check: resource.ComposeTestCheckFunc( - resource.TestCheckResourceAttr(resourceName, "object_type", "COLUMN"), - resource.TestCheckResourceAttr(resourceName, "tag_id", fmt.Sprintf("%s|%s|%s", acc.TestDatabaseName, acc.TestSchemaName, tagName)), - resource.TestCheckResourceAttr(resourceName, "tag_value", "v1"), - resource.TestCheckResourceAttr(resourceName, "object_identifier.0.%", "3"), - resource.TestCheckResourceAttr(resourceName, "object_identifier.0.name", fmt.Sprintf("%s.%s", tableName2, columnName3)), - resource.TestCheckResourceAttr(resourceName, "object_identifier.0.database", acc.TestDatabaseName), - resource.TestCheckResourceAttr(resourceName, "object_identifier.0.schema", acc.TestSchemaName), - ), - },*/ + Check: resource.ComposeAggregateTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "id", helpers.EncodeSnowflakeID(tagId.FullyQualifiedName(), "TAG_VALUE", string(sdk.ObjectTypeSchema))), + resource.TestCheckResourceAttr(resourceName, "object_type", string(sdk.ObjectTypeSchema)), + resource.TestCheckResourceAttr(resourceName, "object_identifiers.#", "1"), + resource.TestCheckTypeSetElemAttr(resourceName, "object_identifiers.*", schemaId.FullyQualifiedName()), + resource.TestCheckResourceAttr(resourceName, "tag_id", tagId.FullyQualifiedName()), + resource.TestCheckResourceAttr(resourceName, "tag_value", "TAG_VALUE"), + ), + }, }, }) } + +func tagAssociation_v_0_98_0(tagId sdk.SchemaObjectIdentifier, tagValue string, objectType sdk.ObjectType, objectId sdk.DatabaseObjectIdentifier) string { + s := ` +resource "snowflake_tag_association" "test" { + tag_id = snowflake_tag.test.fully_qualified_name + tag_value = "%[1]s" + object_type = "%[2]s" + object_identifier { + name = "%[3]s" + database = "%[4]s" + } +} + +resource "snowflake_tag" "test" { + name = "%[5]s" + database = "%[6]s" + schema = "%[7]s" +} +` + return fmt.Sprintf(s, tagValue, objectType, objectId.Name(), objectId.DatabaseName(), tagId.Name(), tagId.DatabaseName(), tagId.SchemaName()) +} diff --git a/pkg/resources/tag_association_state_upgraders.go b/pkg/resources/tag_association_state_upgraders.go new file mode 100644 index 0000000000..5072c55d47 --- /dev/null +++ b/pkg/resources/tag_association_state_upgraders.go @@ -0,0 +1,38 @@ +package resources + +import ( + "context" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" +) + +func v0_98_0_TagAssociationStateUpgrader(ctx context.Context, rawState map[string]any, meta any) (map[string]any, error) { + if rawState == nil { + return rawState, nil + } + tagId, err := sdk.ParseSchemaObjectIdentifier(rawState["tag_id"].(string)) + if err != nil { + return nil, err + } + tagValue := rawState["tag_value"].(string) + objectType := rawState["object_type"].(string) + + rawState["id"] = helpers.EncodeSnowflakeID(tagId.FullyQualifiedName(), tagValue, objectType) + + objectIdentifiersOld := rawState["object_identifier"].([]any) + objectIdentifiers := make([]string, 0, len(objectIdentifiersOld)) + for _, objectIdentifierOld := range objectIdentifiersOld { + obj := objectIdentifierOld.(map[string]any) + var id sdk.ObjectIdentifier + if objectType == string(sdk.ObjectTypeAccount) { + id = sdk.NewAccountIdentifierFromFullyQualifiedName(obj["name"].(string)) + } else { + id = getTagObjectIdentifier(obj) + } + objectIdentifiers = append(objectIdentifiers, id.FullyQualifiedName()) + } + rawState["object_identifiers"] = objectIdentifiers + + return rawState, nil +} diff --git a/pkg/resources/tag_association_test.go b/pkg/resources/tag_association_test.go index 57ebe30aff..efd605d3bc 100644 --- a/pkg/resources/tag_association_test.go +++ b/pkg/resources/tag_association_test.go @@ -7,75 +7,118 @@ import ( "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestTagIdentifierAndObjectIdentifier(t *testing.T) { + tagId := sdk.NewSchemaObjectIdentifier("test_db", "test_schema", "test_tag") + t.Run("account identifier", func(t *testing.T) { + in := map[string]any{ + "tag_id": tagId.FullyQualifiedName(), + "object_type": "ACCOUNT", + "object_identifiers": []any{ + "orgname.accountname", + }, + } + d := schema.TestResourceDataRaw(t, resources.TagAssociation().Schema, in) + tid, identifiers, objectType, err := resources.TagIdentifierAndObjectIdentifier(d) + require.NoError(t, err) + assert.Equal(t, tagId, tid) + assert.Len(t, identifiers, 1) + assert.Equal(t, `"orgname"."accountname"`, identifiers[0].FullyQualifiedName()) + assert.Equal(t, sdk.ObjectTypeAccount, objectType) + }) t.Run("account object identifier", func(t *testing.T) { - in := map[string]interface{}{ - "tag_id": "\"test_db\".\"test_schema\".\"test_tag\"", + in := map[string]any{ + "tag_id": tagId.FullyQualifiedName(), "object_type": "DATABASE", - "object_identifier": []interface{}{map[string]interface{}{ - "name": "test_db", - }}, + "object_identifiers": []any{ + "test_db", + }, } d := schema.TestResourceDataRaw(t, resources.TagAssociation().Schema, in) - tid, identifiers, objectType := resources.TagIdentifierAndObjectIdentifier(d) - assert.Equal(t, sdk.NewSchemaObjectIdentifier("test_db", "test_schema", "test_tag"), tid) + tid, identifiers, objectType, err := resources.TagIdentifierAndObjectIdentifier(d) + require.NoError(t, err) + assert.Equal(t, tagId, tid) assert.Len(t, identifiers, 1) assert.Equal(t, "\"test_db\"", identifiers[0].FullyQualifiedName()) assert.Equal(t, sdk.ObjectTypeDatabase, objectType) }) t.Run("database object identifier", func(t *testing.T) { - in := map[string]interface{}{ - "tag_id": "\"test_db\".\"test_schema\".\"test_tag\"", + in := map[string]any{ + "tag_id": tagId.FullyQualifiedName(), "object_type": "SCHEMA", - "object_identifier": []interface{}{map[string]interface{}{ - "name": "test_schema", - "database": "test_db", - }}, + "object_identifiers": []any{ + "test_db.test_schema", + }, } d := schema.TestResourceDataRaw(t, resources.TagAssociation().Schema, in) - tid, identifiers, objectType := resources.TagIdentifierAndObjectIdentifier(d) - assert.Equal(t, sdk.NewSchemaObjectIdentifier("test_db", "test_schema", "test_tag"), tid) + tid, identifiers, objectType, err := resources.TagIdentifierAndObjectIdentifier(d) + require.NoError(t, err) + assert.Equal(t, tagId, tid) assert.Len(t, identifiers, 1) assert.Equal(t, "\"test_db\".\"test_schema\"", identifiers[0].FullyQualifiedName()) assert.Equal(t, sdk.ObjectTypeSchema, objectType) }) t.Run("schema object identifier", func(t *testing.T) { - in := map[string]interface{}{ - "tag_id": "\"test_db\".\"test_schema\".\"test_tag\"", + in := map[string]any{ + "tag_id": tagId.FullyQualifiedName(), "object_type": "TABLE", - "object_identifier": []interface{}{map[string]interface{}{ - "name": "test_table", - "database": "test_db", - "schema": "test_schema", - }}, + "object_identifiers": []any{ + "test_db.test_schema.test_table", + }, } d := schema.TestResourceDataRaw(t, resources.TagAssociation().Schema, in) - tid, identifiers, objectType := resources.TagIdentifierAndObjectIdentifier(d) - assert.Equal(t, sdk.NewSchemaObjectIdentifier("test_db", "test_schema", "test_tag"), tid) + tid, identifiers, objectType, err := resources.TagIdentifierAndObjectIdentifier(d) + require.NoError(t, err) + assert.Equal(t, tagId, tid) assert.Len(t, identifiers, 1) assert.Equal(t, "\"test_db\".\"test_schema\".\"test_table\"", identifiers[0].FullyQualifiedName()) assert.Equal(t, sdk.ObjectTypeTable, objectType) }) t.Run("column object identifier", func(t *testing.T) { - in := map[string]interface{}{ + in := map[string]any{ "tag_id": "\"test_db\".\"test_schema\".\"test_tag\"", "object_type": "COLUMN", - "object_identifier": []interface{}{map[string]interface{}{ - "name": "test_table.test_column", - "database": "test_db", - "schema": "test_schema", - }}, + "object_identifiers": []any{ + "test_db.test_schema.test_table.test_column", + }, } d := schema.TestResourceDataRaw(t, resources.TagAssociation().Schema, in) - tid, identifiers, objectType := resources.TagIdentifierAndObjectIdentifier(d) + tid, identifiers, objectType, err := resources.TagIdentifierAndObjectIdentifier(d) + require.NoError(t, err) assert.Equal(t, sdk.NewSchemaObjectIdentifier("test_db", "test_schema", "test_tag"), tid) assert.Len(t, identifiers, 1) assert.Equal(t, "\"test_db\".\"test_schema\".\"test_table\".\"test_column\"", identifiers[0].FullyQualifiedName()) assert.Equal(t, sdk.ObjectTypeColumn, objectType) }) + + t.Run("invalid object identifier", func(t *testing.T) { + in := map[string]any{ + "tag_id": tagId.FullyQualifiedName(), + "object_type": "COLUMN", + "object_identifiers": []any{ + "\"", + }, + } + d := schema.TestResourceDataRaw(t, resources.TagAssociation().Schema, in) + _, _, _, err := resources.TagIdentifierAndObjectIdentifier(d) + require.ErrorContains(t, err, `unable to read identifier: ", err = parse error on line 1, column 2: extraneous or missing " in quoted-field`) + }) + + t.Run("invalid tag identifier", func(t *testing.T) { + in := map[string]any{ + "tag_id": "\"test_schema\".\"test_tag\"", + "object_type": "DATABASE", + "object_identifiers": []any{ + "test_db", + }, + } + d := schema.TestResourceDataRaw(t, resources.TagAssociation().Schema, in) + _, _, _, err := resources.TagIdentifierAndObjectIdentifier(d) + require.ErrorContains(t, err, `unexpected number of parts 2 in identifier "test_schema"."test_tag", expected 3 in a form of ".."`) + }) } diff --git a/pkg/resources/task.go b/pkg/resources/task.go index e7e6c940d3..f9f697fda3 100644 --- a/pkg/resources/task.go +++ b/pkg/resources/task.go @@ -617,7 +617,7 @@ func ReadTask(withExternalChangesMarking bool) schema.ReadContextFunc { d.Set("config", task.Config), d.Set("comment", task.Comment), d.Set("sql_statement", task.Definition), - d.Set("after", collections.Map(task.Predecessors, sdk.SchemaObjectIdentifier.FullyQualifiedName)), + d.Set("after", collections.Map(task.TaskRelations.Predecessors, sdk.SchemaObjectIdentifier.FullyQualifiedName)), handleTaskParameterRead(d, taskParameters), d.Set(FullyQualifiedNameAttributeName, id.FullyQualifiedName()), d.Set(ShowOutputAttributeName, []map[string]any{schemas.TaskToSchema(task)}), diff --git a/pkg/resources/task_acceptance_test.go b/pkg/resources/task_acceptance_test.go index 0a0545397f..b2c50e76d6 100644 --- a/pkg/resources/task_acceptance_test.go +++ b/pkg/resources/task_acceptance_test.go @@ -1724,7 +1724,7 @@ func TestAcc_Task_ConvertStandaloneTaskToFinalizer(t *testing.T) { HasSuspendTaskAfterNumFailuresString("2"), resourceshowoutputassert.TaskShowOutput(t, rootTaskModel.ResourceReference()). HasScheduleMinutes(schedule). - // TODO(SNOW-1348116 - next pr): Create ticket and report; this field in task relations seems to have mixed chances of appearing (needs deeper digging, doesn't affect the resource; could be removed for now) + // TODO(SNOW-1843489): Create ticket and report; this field in task relations seems to have mixed chances of appearing (needs deeper digging, doesn't affect the resource; could be removed for now) // HasTaskRelations(sdk.TaskRelations{FinalizerTask: &finalizerTaskId}). HasState(sdk.TaskStateStarted), resourceassert.TaskResource(t, childTaskModel.ResourceReference()). diff --git a/pkg/resources/testdata/TestAcc_TagAssociation/basic/test.tf b/pkg/resources/testdata/TestAcc_TagAssociation/basic/test.tf index ba9ae97faa..69d7ab1172 100644 --- a/pkg/resources/testdata/TestAcc_TagAssociation/basic/test.tf +++ b/pkg/resources/testdata/TestAcc_TagAssociation/basic/test.tf @@ -2,15 +2,13 @@ resource "snowflake_tag" "test" { name = var.tag_name database = var.database schema = var.schema - allowed_values = ["finance", "hr"] + allowed_values = ["bar", "foo", "external"] comment = "Terraform acceptance test" } resource "snowflake_tag_association" "test" { - object_identifier { - name = var.database - } - object_type = "DATABASE" - tag_id = snowflake_tag.test.id - tag_value = "finance" + object_identifiers = [var.database_fully_qualified_name] + object_type = "DATABASE" + tag_id = snowflake_tag.test.fully_qualified_name + tag_value = var.tag_value } diff --git a/pkg/resources/testdata/TestAcc_TagAssociation/basic/variables.tf b/pkg/resources/testdata/TestAcc_TagAssociation/basic/variables.tf index 45a4ceb1c8..10ad654156 100644 --- a/pkg/resources/testdata/TestAcc_TagAssociation/basic/variables.tf +++ b/pkg/resources/testdata/TestAcc_TagAssociation/basic/variables.tf @@ -9,3 +9,11 @@ variable "database" { variable "schema" { type = string } + +variable "database_fully_qualified_name" { + type = string +} + +variable "tag_value" { + type = string +} diff --git a/pkg/resources/testdata/TestAcc_TagAssociation/column/test.tf b/pkg/resources/testdata/TestAcc_TagAssociation/column/test.tf index d10067998b..fea0c22120 100644 --- a/pkg/resources/testdata/TestAcc_TagAssociation/column/test.tf +++ b/pkg/resources/testdata/TestAcc_TagAssociation/column/test.tf @@ -10,19 +10,16 @@ resource "snowflake_table" "test" { schema = var.schema column { - name = "column_name" + name = var.column type = "VARIANT" } } resource "snowflake_tag_association" "test" { - object_identifier { - database = var.database - schema = var.schema - name = "${snowflake_table.test.name}.${snowflake_table.test.column[0].name}" - } + object_identifiers = [var.column_fully_qualified_name] object_type = "COLUMN" - tag_id = snowflake_tag.test.id + tag_id = snowflake_tag.test.fully_qualified_name tag_value = "TAG_VALUE" + depends_on = [snowflake_table.test] } diff --git a/pkg/resources/testdata/TestAcc_TagAssociation/column/variables.tf b/pkg/resources/testdata/TestAcc_TagAssociation/column/variables.tf index bf4d1ec509..8707532841 100644 --- a/pkg/resources/testdata/TestAcc_TagAssociation/column/variables.tf +++ b/pkg/resources/testdata/TestAcc_TagAssociation/column/variables.tf @@ -13,3 +13,11 @@ variable "database" { variable "schema" { type = string } + +variable "column" { + type = string +} + +variable "column_fully_qualified_name" { + type = string +} diff --git a/pkg/resources/testdata/TestAcc_TagAssociation/issue1202/main.tf b/pkg/resources/testdata/TestAcc_TagAssociation/issue1202/main.tf index 4c316e627b..dd81b42a82 100644 --- a/pkg/resources/testdata/TestAcc_TagAssociation/issue1202/main.tf +++ b/pkg/resources/testdata/TestAcc_TagAssociation/issue1202/main.tf @@ -15,17 +15,8 @@ resource "snowflake_table" "test" { } resource "snowflake_tag_association" "test" { - // we need to set the object_identifier to avoid the following error: - // provider_test.go:17: err: resource snowflake_tag_association: object_identifier: Optional or Required must be set, not both - // we should consider deprecating object_identifier in favor of object_name - // https://github.com/Snowflake-Labs/terraform-provider-snowflake/pull/2534#discussion_r1507570740 - // object_name = "\"${var.database}\".\"${var.schema}\".\"${var.table_name}\"" - object_identifier { - database = var.database - schema = var.schema - name = snowflake_table.test.name - } - object_type = "TABLE" - tag_id = snowflake_tag.test.id - tag_value = "v1" + object_identifiers = [snowflake_table.test.fully_qualified_name] + object_type = "TABLE" + tag_id = snowflake_tag.test.fully_qualified_name + tag_value = "v1" } diff --git a/pkg/resources/testdata/TestAcc_TagAssociation/issue1909/test.tf b/pkg/resources/testdata/TestAcc_TagAssociation/issue1909/test.tf index 01940e1183..519df8f05e 100644 --- a/pkg/resources/testdata/TestAcc_TagAssociation/issue1909/test.tf +++ b/pkg/resources/testdata/TestAcc_TagAssociation/issue1909/test.tf @@ -24,17 +24,9 @@ resource "snowflake_table" "test2" { } resource "snowflake_tag_association" "test" { - object_identifier { - database = var.database - schema = var.schema - name = "${snowflake_table.test.name}.${snowflake_table.test.column[0].name}" - } - object_identifier { - database = var.database - schema = var.schema - name = "${snowflake_table.test2.name}.${snowflake_table.test2.column[0].name}" - } - object_type = "COLUMN" - tag_id = snowflake_tag.test.id - tag_value = "v1" + object_identifiers = [var.column_fully_qualified_name, var.column2_fully_qualified_name] + object_type = "COLUMN" + tag_id = snowflake_tag.test.fully_qualified_name + tag_value = "v1" + depends_on = [snowflake_table.test, snowflake_table.test2] } diff --git a/pkg/resources/testdata/TestAcc_TagAssociation/issue1909/variables.tf b/pkg/resources/testdata/TestAcc_TagAssociation/issue1909/variables.tf index 58c3e62640..ff45812815 100644 --- a/pkg/resources/testdata/TestAcc_TagAssociation/issue1909/variables.tf +++ b/pkg/resources/testdata/TestAcc_TagAssociation/issue1909/variables.tf @@ -21,3 +21,11 @@ variable "database" { variable "schema" { type = string } + +variable "column_fully_qualified_name" { + type = string +} + +variable "column2_fully_qualified_name" { + type = string +} diff --git a/pkg/resources/testdata/TestAcc_TagAssociation/issue1910/test.tf b/pkg/resources/testdata/TestAcc_TagAssociation/issue1910/test.tf index c05081252c..8f5408595e 100644 --- a/pkg/resources/testdata/TestAcc_TagAssociation/issue1910/test.tf +++ b/pkg/resources/testdata/TestAcc_TagAssociation/issue1910/test.tf @@ -4,23 +4,9 @@ resource "snowflake_tag" "test" { schema = var.schema } -resource "snowflake_account" "test" { - name = var.account_name - admin_name = "someadmin" - admin_password = "123456" - first_name = "Ad" - last_name = "Min" - email = "admin@example.com" - must_change_password = false - edition = "BUSINESS_CRITICAL" - grace_period_in_days = 4 -} - resource "snowflake_tag_association" "test" { - object_identifier { - name = snowflake_account.test.name - } - object_type = "ACCOUNT" - tag_id = snowflake_tag.test.id - tag_value = "v1" + object_identifiers = [var.account_fully_qualified_name] + object_type = "ACCOUNT" + tag_id = snowflake_tag.test.fully_qualified_name + tag_value = "v1" } diff --git a/pkg/resources/testdata/TestAcc_TagAssociation/issue1910/variables.tf b/pkg/resources/testdata/TestAcc_TagAssociation/issue1910/variables.tf index 032442851b..e7f2ec926a 100644 --- a/pkg/resources/testdata/TestAcc_TagAssociation/issue1910/variables.tf +++ b/pkg/resources/testdata/TestAcc_TagAssociation/issue1910/variables.tf @@ -2,7 +2,7 @@ variable "tag_name" { type = string } -variable "account_name" { +variable "account_fully_qualified_name" { type = string } diff --git a/pkg/resources/testdata/TestAcc_TagAssociation/issue1926/test.tf b/pkg/resources/testdata/TestAcc_TagAssociation/issue1926/test.tf index 44e3c95437..8401d09080 100644 --- a/pkg/resources/testdata/TestAcc_TagAssociation/issue1926/test.tf +++ b/pkg/resources/testdata/TestAcc_TagAssociation/issue1926/test.tf @@ -7,18 +7,21 @@ resource "snowflake_table" "test" { name = var.table_name database = var.database schema = var.schema + // TODO(SNOW-1348114): use only one column, if possible. + // We need a dummy column here because a table must have at least one column, and when we rename the second one in the config, it gets dropped for a moment. + column { + name = "DUMMY" + type = "VARIANT" + } column { name = var.column_name type = "VARIANT" } } resource "snowflake_tag_association" "test" { - object_identifier { - database = var.database - schema = var.schema - name = "${snowflake_table.test.name}.${snowflake_table.test.column[0].name}" - } - object_type = "COLUMN" - tag_id = snowflake_tag.test.id - tag_value = "v1" + object_identifiers = [var.column_fully_qualified_name] + object_type = "COLUMN" + tag_id = snowflake_tag.test.fully_qualified_name + tag_value = "TAG_VALUE" + depends_on = [snowflake_table.test] } diff --git a/pkg/resources/testdata/TestAcc_TagAssociation/issue1926/variables.tf b/pkg/resources/testdata/TestAcc_TagAssociation/issue1926/variables.tf index 222b3e3b4e..778af61e64 100644 --- a/pkg/resources/testdata/TestAcc_TagAssociation/issue1926/variables.tf +++ b/pkg/resources/testdata/TestAcc_TagAssociation/issue1926/variables.tf @@ -17,3 +17,7 @@ variable "database" { variable "schema" { type = string } + +variable "column_fully_qualified_name" { + type = string +} diff --git a/pkg/resources/testdata/TestAcc_TagAssociation/schema/test.tf b/pkg/resources/testdata/TestAcc_TagAssociation/schema/test.tf index 5ebee25710..ba06156efc 100644 --- a/pkg/resources/testdata/TestAcc_TagAssociation/schema/test.tf +++ b/pkg/resources/testdata/TestAcc_TagAssociation/schema/test.tf @@ -6,12 +6,9 @@ resource "snowflake_tag" "test" { } resource "snowflake_tag_association" "test" { - object_identifier { - database = var.database - name = var.schema - } + object_identifiers = [var.schema_fully_qualified_name] object_type = "SCHEMA" - tag_id = snowflake_tag.test.id + tag_id = snowflake_tag.test.fully_qualified_name tag_value = "TAG_VALUE" } diff --git a/pkg/resources/testdata/TestAcc_TagAssociation/schema/variables.tf b/pkg/resources/testdata/TestAcc_TagAssociation/schema/variables.tf index 45a4ceb1c8..aa9cba3679 100644 --- a/pkg/resources/testdata/TestAcc_TagAssociation/schema/variables.tf +++ b/pkg/resources/testdata/TestAcc_TagAssociation/schema/variables.tf @@ -9,3 +9,7 @@ variable "database" { variable "schema" { type = string } + +variable "schema_fully_qualified_name" { + type = string +} diff --git a/pkg/resources/validators.go b/pkg/resources/validators.go index 071a73a33c..d9fcb14edb 100644 --- a/pkg/resources/validators.go +++ b/pkg/resources/validators.go @@ -7,28 +7,12 @@ import ( "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider/validators" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/hashicorp/go-cty/cty" "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" ) -func IsDataType() schema.SchemaValidateFunc { //nolint:staticcheck - return func(value any, key string) (warnings []string, errors []error) { - stringValue, ok := value.(string) - if !ok { - errors = append(errors, fmt.Errorf("expected type of %s to be string, got %T", key, value)) - return warnings, errors - } - - _, err := sdk.ToDataType(stringValue) - if err != nil { - errors = append(errors, fmt.Errorf("expected %s to be one of %T values, got %s", key, sdk.DataTypeString, stringValue)) - } - - return warnings, errors - } -} - func IsValidIdentifier[T sdk.AccountObjectIdentifier | sdk.DatabaseObjectIdentifier | sdk.SchemaObjectIdentifier | sdk.TableColumnIdentifier]() schema.SchemaValidateDiagFunc { return validators.IsValidIdentifier[T]() } @@ -98,6 +82,8 @@ func sdkValidation[T any](normalize func(string) (T, error)) schema.SchemaValida return validators.NormalizeValidation(normalize) } +var IsDataTypeValid = sdkValidation(datatypes.ParseDataType) + func isNotEqualTo(notExpectedValue string, errorMessage string) schema.SchemaValidateDiagFunc { return func(value any, path cty.Path) diag.Diagnostics { if value != nil { diff --git a/pkg/resources/validators_test.go b/pkg/resources/validators_test.go index 9125bfd98c..af2408697e 100644 --- a/pkg/resources/validators_test.go +++ b/pkg/resources/validators_test.go @@ -9,48 +9,6 @@ import ( "github.com/stretchr/testify/assert" ) -func TestIsDataType(t *testing.T) { - isDataType := IsDataType() - key := "tag" - - testCases := []struct { - Name string - Value any - Error string - }{ - { - Name: "validation: correct DataType value", - Value: "NUMBER", - }, - { - Name: "validation: correct DataType value in lowercase", - Value: "number", - }, - { - Name: "validation: incorrect DataType value", - Value: "invalid data type", - Error: "expected tag to be one of", - }, - { - Name: "validation: incorrect value type", - Value: 123, - Error: "expected type of tag to be string", - }, - } - - for _, tt := range testCases { - t.Run(tt.Name, func(t *testing.T) { - _, errors := isDataType(tt.Value, key) - if tt.Error != "" { - assert.Len(t, errors, 1) - assert.ErrorContains(t, errors[0], tt.Error) - } else { - assert.Len(t, errors, 0) - } - }) - } -} - func Test_sdkValidation(t *testing.T) { genericNormalize := func(value string) (any, error) { if value == "ok" { diff --git a/pkg/sdk/common_types.go b/pkg/sdk/common_types.go index 1627ab9d2a..7a4975a78e 100644 --- a/pkg/sdk/common_types.go +++ b/pkg/sdk/common_types.go @@ -6,6 +6,8 @@ import ( "strconv" "strings" "time" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" ) var ( @@ -91,13 +93,13 @@ func ParseTableColumnSignature(signature string) ([]TableColumnSignature, error) if len(parts) < 2 { return []TableColumnSignature{}, fmt.Errorf("expected argument name and type, got %s", elem) } - dataType, err := ToDataType(parts[len(parts)-1]) + dataType, err := datatypes.ParseDataType(parts[len(parts)-1]) if err != nil { return []TableColumnSignature{}, err } arguments[i] = TableColumnSignature{ Name: strings.Join(parts[:len(parts)-1], " "), - Type: dataType, + Type: LegacyDataTypeFrom(dataType), } } return arguments, nil diff --git a/pkg/sdk/data_types.go b/pkg/sdk/data_types.go deleted file mode 100644 index 40ced4cb83..0000000000 --- a/pkg/sdk/data_types.go +++ /dev/null @@ -1,176 +0,0 @@ -package sdk - -import ( - "fmt" - "slices" - "strconv" - "strings" - - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/logging" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/util" -) - -// DataType is based on https://docs.snowflake.com/en/sql-reference/intro-summary-data-types. -type DataType string - -var allowedVectorInnerTypes = []DataType{ - DataTypeInt, - DataTypeFloat, -} - -const ( - DataTypeNumber DataType = "NUMBER" - DataTypeInt DataType = "INT" - DataTypeFloat DataType = "FLOAT" - DataTypeVARCHAR DataType = "VARCHAR" - DataTypeString DataType = "STRING" - DataTypeBinary DataType = "BINARY" - DataTypeBoolean DataType = "BOOLEAN" - DataTypeDate DataType = "DATE" - DataTypeTime DataType = "TIME" - DataTypeTimestamp DataType = "TIMESTAMP" - DataTypeTimestampLTZ DataType = "TIMESTAMP_LTZ" - DataTypeTimestampNTZ DataType = "TIMESTAMP_NTZ" - DataTypeTimestampTZ DataType = "TIMESTAMP_TZ" - DataTypeVariant DataType = "VARIANT" - DataTypeObject DataType = "OBJECT" - DataTypeArray DataType = "ARRAY" - DataTypeGeography DataType = "GEOGRAPHY" - DataTypeGeometry DataType = "GEOMETRY" -) - -var ( - DataTypeNumberSynonyms = []string{"NUMBER", "DECIMAL", "NUMERIC", "INT", "INTEGER", "BIGINT", "SMALLINT", "TINYINT", "BYTEINT"} - DataTypeFloatSynonyms = []string{"FLOAT", "FLOAT4", "FLOAT8", "DOUBLE", "DOUBLE PRECISION", "REAL"} - DataTypeVarcharSynonyms = []string{"VARCHAR", "CHAR", "CHARACTER", "STRING", "TEXT"} - DataTypeBinarySynonyms = []string{"BINARY", "VARBINARY"} - DataTypeBooleanSynonyms = []string{"BOOLEAN", "BOOL"} - DataTypeTimestampLTZSynonyms = []string{"TIMESTAMP_LTZ"} - DataTypeTimestampTZSynonyms = []string{"TIMESTAMP_TZ"} - DataTypeTimestampNTZSynonyms = []string{"DATETIME", "TIMESTAMP", "TIMESTAMP_NTZ"} - DataTypeTimeSynonyms = []string{"TIME"} - DataTypeVectorSynonyms = []string{"VECTOR"} -) - -const ( - DefaultNumberPrecision = 38 - DefaultNumberScale = 0 - DefaultVarcharLength = 16777216 -) - -func ToDataType(s string) (DataType, error) { - dType := strings.ToUpper(s) - - switch dType { - case "DATE": - return DataTypeDate, nil - case "VARIANT": - return DataTypeVariant, nil - case "OBJECT": - return DataTypeObject, nil - case "ARRAY": - return DataTypeArray, nil - case "GEOGRAPHY": - return DataTypeGeography, nil - case "GEOMETRY": - return DataTypeGeometry, nil - } - - if slices.ContainsFunc(DataTypeNumberSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { - return DataTypeNumber, nil - } - if slices.ContainsFunc(DataTypeFloatSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { - return DataTypeFloat, nil - } - if slices.ContainsFunc(DataTypeVarcharSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { - return DataTypeVARCHAR, nil - } - if slices.ContainsFunc(DataTypeBinarySynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { - return DataTypeBinary, nil - } - if slices.Contains(DataTypeBooleanSynonyms, dType) { - return DataTypeBoolean, nil - } - if slices.ContainsFunc(DataTypeTimestampLTZSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { - return DataTypeTimestampLTZ, nil - } - if slices.ContainsFunc(DataTypeTimestampTZSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { - return DataTypeTimestampTZ, nil - } - if slices.ContainsFunc(DataTypeTimestampNTZSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { - return DataTypeTimestampNTZ, nil - } - if slices.ContainsFunc(DataTypeTimeSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { - return DataTypeTime, nil - } - if slices.ContainsFunc(DataTypeVectorSynonyms, func(e string) bool { return strings.HasPrefix(dType, e) }) { - return DataType(dType), nil - } - return "", fmt.Errorf("invalid data type: %s", s) -} - -func IsStringType(_type string) bool { - t := strings.ToUpper(_type) - return strings.HasPrefix(t, "STRING") || - strings.HasPrefix(t, "VARCHAR") || - strings.HasPrefix(t, "CHAR") || - strings.HasPrefix(t, "TEXT") || - strings.HasPrefix(t, "NVARCHAR") || - strings.HasPrefix(t, "NCHAR") -} - -// ParseNumberDataTypeRaw extracts precision and scale from the raw number data type input. -// It returns defaults if it can't parse arguments, data type is different, or no arguments were provided. -// TODO [SNOW-1348103 or SNOW-1348106]: visit with functions and procedures rework -func ParseNumberDataTypeRaw(rawDataType string) (int, int) { - r := util.TrimAllPrefixes(strings.TrimSpace(strings.ToUpper(rawDataType)), DataTypeNumberSynonyms...) - r = strings.TrimSpace(r) - if strings.HasPrefix(r, "(") && strings.HasSuffix(r, ")") { - parts := strings.Split(r[1:len(r)-1], ",") - switch l := len(parts); l { - case 1: - precision, err := strconv.Atoi(strings.TrimSpace(parts[0])) - if err == nil { - return precision, DefaultNumberScale - } else { - logging.DebugLogger.Printf(`[DEBUG] Could not parse number precision "%s", err: %v`, parts[0], err) - } - case 2: - precision, err1 := strconv.Atoi(strings.TrimSpace(parts[0])) - scale, err2 := strconv.Atoi(strings.TrimSpace(parts[1])) - if err1 == nil && err2 == nil { - return precision, scale - } else { - logging.DebugLogger.Printf(`[DEBUG] Could not parse number precision "%s" or scale "%s", errs: %v, %v`, parts[0], parts[1], err1, err2) - } - default: - logging.DebugLogger.Printf("[DEBUG] Unexpected length of number arguments") - } - } - logging.DebugLogger.Printf("[DEBUG] Returning default number precision and scale") - return DefaultNumberPrecision, DefaultNumberScale -} - -// ParseVarcharDataTypeRaw extracts length from the raw text data type input. -// It returns default if it can't parse arguments, data type is different, or no length argument was provided. -// TODO [SNOW-1348103 or SNOW-1348106]: visit with functions and procedures rework -func ParseVarcharDataTypeRaw(rawDataType string) int { - r := util.TrimAllPrefixes(strings.TrimSpace(strings.ToUpper(rawDataType)), DataTypeVarcharSynonyms...) - r = strings.TrimSpace(r) - if strings.HasPrefix(r, "(") && strings.HasSuffix(r, ")") { - parts := strings.Split(r[1:len(r)-1], ",") - switch l := len(parts); l { - case 1: - length, err := strconv.Atoi(strings.TrimSpace(parts[0])) - if err == nil { - return length - } else { - logging.DebugLogger.Printf(`[DEBUG] Could not parse varchar length "%s", err: %v`, parts[0], err) - } - default: - logging.DebugLogger.Printf("[DEBUG] Unexpected length of varchar arguments") - } - } - logging.DebugLogger.Printf("[DEBUG] Returning default varchar length") - return DefaultVarcharLength -} diff --git a/pkg/sdk/data_types_deprecated.go b/pkg/sdk/data_types_deprecated.go new file mode 100644 index 0000000000..0d0315ad5e --- /dev/null +++ b/pkg/sdk/data_types_deprecated.go @@ -0,0 +1,51 @@ +package sdk + +import ( + "strings" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" +) + +// DataType is based on https://docs.snowflake.com/en/sql-reference/intro-summary-data-types. +type DataType string + +var allowedVectorInnerTypes = []DataType{ + DataTypeInt, + DataTypeFloat, +} + +const ( + DataTypeNumber DataType = "NUMBER" + DataTypeInt DataType = "INT" + DataTypeFloat DataType = "FLOAT" + DataTypeVARCHAR DataType = "VARCHAR" + DataTypeString DataType = "STRING" + DataTypeBinary DataType = "BINARY" + DataTypeBoolean DataType = "BOOLEAN" + DataTypeDate DataType = "DATE" + DataTypeTime DataType = "TIME" + DataTypeTimestampLTZ DataType = "TIMESTAMP_LTZ" + DataTypeTimestampNTZ DataType = "TIMESTAMP_NTZ" + DataTypeTimestampTZ DataType = "TIMESTAMP_TZ" + DataTypeVariant DataType = "VARIANT" + DataTypeObject DataType = "OBJECT" + DataTypeArray DataType = "ARRAY" + DataTypeGeography DataType = "GEOGRAPHY" + DataTypeGeometry DataType = "GEOMETRY" +) + +// IsStringType is a legacy method. datatypes.IsTextDataType should be used instead. +// TODO [SNOW-1348114]: remove with tables rework +func IsStringType(_type string) bool { + t := strings.ToUpper(_type) + return strings.HasPrefix(t, "STRING") || + strings.HasPrefix(t, "VARCHAR") || + strings.HasPrefix(t, "CHAR") || + strings.HasPrefix(t, "TEXT") || + strings.HasPrefix(t, "NVARCHAR") || + strings.HasPrefix(t, "NCHAR") +} + +func LegacyDataTypeFrom(newDataType datatypes.DataType) DataType { + return DataType(newDataType.ToLegacyDataTypeSql()) +} diff --git a/pkg/sdk/data_types_deprecated_test.go b/pkg/sdk/data_types_deprecated_test.go new file mode 100644 index 0000000000..ae4445f746 --- /dev/null +++ b/pkg/sdk/data_types_deprecated_test.go @@ -0,0 +1,56 @@ +package sdk + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsStringType(t *testing.T) { + type test struct { + input string + want bool + } + + tests := []test{ + // case insensitive. + {input: "STRING", want: true}, + {input: "string", want: true}, + {input: "String", want: true}, + + // varchar types. + {input: "VARCHAR", want: true}, + {input: "NVARCHAR", want: true}, + {input: "NVARCHAR2", want: true}, + {input: "CHAR", want: true}, + {input: "NCHAR", want: true}, + {input: "CHAR VARYING", want: true}, + {input: "NCHAR VARYING", want: true}, + {input: "TEXT", want: true}, + + // with length + {input: "VARCHAR(100)", want: true}, + {input: "NVARCHAR(100)", want: true}, + {input: "NVARCHAR2(100)", want: true}, + {input: "CHAR(100)", want: true}, + {input: "NCHAR(100)", want: true}, + {input: "CHAR VARYING(100)", want: true}, + {input: "NCHAR VARYING(100)", want: true}, + {input: "TEXT(100)", want: true}, + + // binary is not string types. + {input: "binary", want: false}, + {input: "varbinary", want: false}, + + // other types + {input: "boolean", want: false}, + {input: "number", want: false}, + } + + for _, tc := range tests { + t.Run(tc.input, func(t *testing.T) { + got := IsStringType(tc.input) + require.Equal(t, tc.want, got) + }) + } +} diff --git a/pkg/sdk/data_types_test.go b/pkg/sdk/data_types_test.go deleted file mode 100644 index 156e5dc8f2..0000000000 --- a/pkg/sdk/data_types_test.go +++ /dev/null @@ -1,218 +0,0 @@ -package sdk - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestToDataType(t *testing.T) { - type test struct { - input string - want DataType - } - - tests := []test{ - // case insensitive. - {input: "STRING", want: DataTypeVARCHAR}, - {input: "string", want: DataTypeVARCHAR}, - {input: "String", want: DataTypeVARCHAR}, - - // number types. - {input: "number", want: DataTypeNumber}, - {input: "decimal", want: DataTypeNumber}, - {input: "numeric", want: DataTypeNumber}, - {input: "int", want: DataTypeNumber}, - {input: "integer", want: DataTypeNumber}, - {input: "bigint", want: DataTypeNumber}, - {input: "smallint", want: DataTypeNumber}, - {input: "tinyint", want: DataTypeNumber}, - {input: "byteint", want: DataTypeNumber}, - - // float types. - {input: "float", want: DataTypeFloat}, - {input: "float4", want: DataTypeFloat}, - {input: "float8", want: DataTypeFloat}, - {input: "double", want: DataTypeFloat}, - {input: "double precision", want: DataTypeFloat}, - {input: "real", want: DataTypeFloat}, - - // varchar types. - {input: "varchar", want: DataTypeVARCHAR}, - {input: "char", want: DataTypeVARCHAR}, - {input: "character", want: DataTypeVARCHAR}, - {input: "string", want: DataTypeVARCHAR}, - {input: "text", want: DataTypeVARCHAR}, - - // binary types. - {input: "binary", want: DataTypeBinary}, - {input: "varbinary", want: DataTypeBinary}, - {input: "boolean", want: DataTypeBoolean}, - - // boolean types. - {input: "boolean", want: DataTypeBoolean}, - {input: "bool", want: DataTypeBoolean}, - - // timestamp ntz types. - {input: "datetime", want: DataTypeTimestampNTZ}, - {input: "timestamp", want: DataTypeTimestampNTZ}, - {input: "timestamp_ntz", want: DataTypeTimestampNTZ}, - - // timestamp tz types. - {input: "timestamp_tz", want: DataTypeTimestampTZ}, - {input: "timestamp_tz(9)", want: DataTypeTimestampTZ}, - - // timestamp ltz types. - {input: "timestamp_ltz", want: DataTypeTimestampLTZ}, - {input: "timestamp_ltz(9)", want: DataTypeTimestampLTZ}, - - // time types. - {input: "time", want: DataTypeTime}, - {input: "time(9)", want: DataTypeTime}, - - // all othertypes - {input: "date", want: DataTypeDate}, - {input: "variant", want: DataTypeVariant}, - {input: "object", want: DataTypeObject}, - {input: "array", want: DataTypeArray}, - {input: "geography", want: DataTypeGeography}, - {input: "geometry", want: DataTypeGeometry}, - {input: "VECTOR(INT, 10)", want: "VECTOR(INT, 10)"}, - {input: "VECTOR(INT, 20)", want: "VECTOR(INT, 20)"}, - {input: "VECTOR(FLOAT, 10)", want: "VECTOR(FLOAT, 10)"}, - {input: "VECTOR(FLOAT, 20)", want: "VECTOR(FLOAT, 20)"}, - } - - for _, tc := range tests { - t.Run(tc.input, func(t *testing.T) { - got, err := ToDataType(tc.input) - require.NoError(t, err) - require.Equal(t, tc.want, got) - }) - } -} - -func TestIsStringType(t *testing.T) { - type test struct { - input string - want bool - } - - tests := []test{ - // case insensitive. - {input: "STRING", want: true}, - {input: "string", want: true}, - {input: "String", want: true}, - - // varchar types. - {input: "VARCHAR", want: true}, - {input: "NVARCHAR", want: true}, - {input: "NVARCHAR2", want: true}, - {input: "CHAR", want: true}, - {input: "NCHAR", want: true}, - {input: "CHAR VARYING", want: true}, - {input: "NCHAR VARYING", want: true}, - {input: "TEXT", want: true}, - - // with length - {input: "VARCHAR(100)", want: true}, - {input: "NVARCHAR(100)", want: true}, - {input: "NVARCHAR2(100)", want: true}, - {input: "CHAR(100)", want: true}, - {input: "NCHAR(100)", want: true}, - {input: "CHAR VARYING(100)", want: true}, - {input: "NCHAR VARYING(100)", want: true}, - {input: "TEXT(100)", want: true}, - - // binary is not string types. - {input: "binary", want: false}, - {input: "varbinary", want: false}, - - // other types - {input: "boolean", want: false}, - {input: "number", want: false}, - } - - for _, tc := range tests { - t.Run(tc.input, func(t *testing.T) { - got := IsStringType(tc.input) - require.Equal(t, tc.want, got) - }) - } -} - -func Test_ParseNumberDataTypeRaw(t *testing.T) { - type test struct { - input string - expectedPrecision int - expectedScale int - } - defaults := func(input string) test { - return test{input: input, expectedPrecision: DefaultNumberPrecision, expectedScale: DefaultNumberScale} - } - - tests := []test{ - {input: "NUMBER(30)", expectedPrecision: 30, expectedScale: DefaultNumberScale}, - {input: "NUMBER(30, 2)", expectedPrecision: 30, expectedScale: 2}, - {input: "decimal(30, 2)", expectedPrecision: 30, expectedScale: 2}, - {input: "NUMBER( 30 , 2 )", expectedPrecision: 30, expectedScale: 2}, - {input: " NUMBER ( 30 , 2 ) ", expectedPrecision: 30, expectedScale: 2}, - - // returns defaults if it can't parse arguments, data type is different, or no arguments were provided - defaults("VARCHAR(1, 2)"), - defaults("VARCHAR(1)"), - defaults("VARCHAR"), - defaults("NUMBER"), - defaults("NUMBER()"), - defaults("NUMBER(x)"), - defaults(fmt.Sprintf("NUMBER(%d)", DefaultNumberPrecision)), - defaults(fmt.Sprintf("NUMBER(%d, x)", DefaultNumberPrecision)), - defaults(fmt.Sprintf("NUMBER(x, %d)", DefaultNumberScale)), - defaults("NUMBER(1, 2, 3)"), - } - - for _, tc := range tests { - tc := tc - t.Run(tc.input, func(t *testing.T) { - precision, scale := ParseNumberDataTypeRaw(tc.input) - assert.Equal(t, tc.expectedPrecision, precision) - assert.Equal(t, tc.expectedScale, scale) - }) - } -} - -func Test_ParseVarcharDataTypeRaw(t *testing.T) { - type test struct { - input string - expectedLength int - } - defaults := func(input string) test { - return test{input: input, expectedLength: DefaultVarcharLength} - } - - tests := []test{ - {input: "VARCHAR(30)", expectedLength: 30}, - {input: "text(30)", expectedLength: 30}, - {input: "VARCHAR( 30 )", expectedLength: 30}, - {input: " VARCHAR ( 30 ) ", expectedLength: 30}, - - // returns defaults if it can't parse arguments, data type is different, or no arguments were provided - defaults("VARCHAR(1, 2)"), - defaults("VARCHAR(x)"), - defaults("VARCHAR"), - defaults("NUMBER"), - defaults("NUMBER()"), - defaults("NUMBER(x)"), - defaults(fmt.Sprintf("VARCHAR(%d)", DefaultVarcharLength)), - } - - for _, tc := range tests { - tc := tc - t.Run(tc.input, func(t *testing.T) { - length := ParseVarcharDataTypeRaw(tc.input) - assert.Equal(t, tc.expectedLength, length) - }) - } -} diff --git a/pkg/sdk/datatypes/array.go b/pkg/sdk/datatypes/array.go new file mode 100644 index 0000000000..eb7247f6e6 --- /dev/null +++ b/pkg/sdk/datatypes/array.go @@ -0,0 +1,21 @@ +package datatypes + +// ArrayDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-semistructured#array +// It does not have synonyms. It does not have any attributes. +type ArrayDataType struct { + underlyingType string +} + +func (t *ArrayDataType) ToSql() string { + return t.underlyingType +} + +func (t *ArrayDataType) ToLegacyDataTypeSql() string { + return ArrayLegacyDataType +} + +var ArrayDataTypeSynonyms = []string{ArrayLegacyDataType} + +func parseArrayDataTypeRaw(raw sanitizedDataTypeRaw) (*ArrayDataType, error) { + return &ArrayDataType{raw.matchedByType}, nil +} diff --git a/pkg/sdk/datatypes/binary.go b/pkg/sdk/datatypes/binary.go new file mode 100644 index 0000000000..c50dba0570 --- /dev/null +++ b/pkg/sdk/datatypes/binary.go @@ -0,0 +1,51 @@ +package datatypes + +import ( + "fmt" + "strconv" + "strings" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/logging" +) + +const DefaultBinarySize = 8388608 + +// BinaryDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-text#data-types-for-binary-strings +// It does have synonyms that allow specifying size. +type BinaryDataType struct { + size int + underlyingType string +} + +func (t *BinaryDataType) ToSql() string { + return fmt.Sprintf("%s(%d)", t.underlyingType, t.size) +} + +func (t *BinaryDataType) ToLegacyDataTypeSql() string { + return BinaryLegacyDataType +} + +var BinaryDataTypeSynonyms = []string{BinaryLegacyDataType, "VARBINARY"} + +func parseBinaryDataTypeRaw(raw sanitizedDataTypeRaw) (*BinaryDataType, error) { + r := strings.TrimSpace(strings.TrimPrefix(raw.raw, raw.matchedByType)) + if r == "" { + logging.DebugLogger.Printf("[DEBUG] Returning default size for binary") + return &BinaryDataType{DefaultBinarySize, raw.matchedByType}, nil + } + if !strings.HasPrefix(r, "(") || !strings.HasSuffix(r, ")") { + logging.DebugLogger.Printf(`binary %s could not be parsed, use "%s(size)" format`, raw.raw, raw.matchedByType) + return nil, fmt.Errorf(`binary %s could not be parsed, use "%s(size)" format`, raw.raw, raw.matchedByType) + } + sizeRaw := r[1 : len(r)-1] + size, err := strconv.Atoi(strings.TrimSpace(sizeRaw)) + if err != nil { + logging.DebugLogger.Printf(`[DEBUG] Could not parse binary size "%s", err: %v`, sizeRaw, err) + return nil, fmt.Errorf(`could not parse the binary's size: "%s", err: %w`, sizeRaw, err) + } + return &BinaryDataType{size, raw.matchedByType}, nil +} + +func areBinaryDataTypesTheSame(a, b *BinaryDataType) bool { + return a.size == b.size +} diff --git a/pkg/sdk/datatypes/boolean.go b/pkg/sdk/datatypes/boolean.go new file mode 100644 index 0000000000..4e84979f40 --- /dev/null +++ b/pkg/sdk/datatypes/boolean.go @@ -0,0 +1,21 @@ +package datatypes + +// BooleanDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-logical +// It does not have synonyms. It does not have any attributes. +type BooleanDataType struct { + underlyingType string +} + +func (t *BooleanDataType) ToSql() string { + return t.underlyingType +} + +func (t *BooleanDataType) ToLegacyDataTypeSql() string { + return BooleanLegacyDataType +} + +var BooleanDataTypeSynonyms = []string{BooleanLegacyDataType} + +func parseBooleanDataTypeRaw(raw sanitizedDataTypeRaw) (*BooleanDataType, error) { + return &BooleanDataType{raw.matchedByType}, nil +} diff --git a/pkg/sdk/datatypes/data_types.go b/pkg/sdk/datatypes/data_types.go new file mode 100644 index 0000000000..e1c0065855 --- /dev/null +++ b/pkg/sdk/datatypes/data_types.go @@ -0,0 +1,149 @@ +package datatypes + +import ( + "fmt" + "reflect" + "slices" + "strings" +) + +// TODO [SNOW-1843440]: generalize definitions for different types; generalize the ParseDataType function +// TODO [SNOW-1843440]: generalize implementation in types (i.e. the internal struct implementing ToLegacyDataTypeSql and containing the underlyingType) +// TODO [SNOW-1843440]: consider known/unknown to use Snowflake defaults and allow better handling in terraform resources +// TODO [SNOW-1843440]: replace old DataTypes + +// DataType is the common interface that represents all Snowflake datatypes documented in https://docs.snowflake.com/en/sql-reference/intro-summary-data-types. +type DataType interface { + ToSql() string + ToLegacyDataTypeSql() string +} + +type sanitizedDataTypeRaw struct { + raw string + matchedByType string +} + +// ParseDataType is the entry point to get the implementation of the DataType from input raw string. +// TODO [SNOW-1843440]: order currently matters (e.g. HasPrefix(TIME) can match also TIMESTAMP*, make the checks more precise and order-independent) +func ParseDataType(raw string) (DataType, error) { + dataTypeRaw := strings.TrimSpace(strings.ToUpper(raw)) + + if idx := slices.IndexFunc(AllNumberDataTypes, func(s string) bool { return strings.HasPrefix(dataTypeRaw, s) }); idx >= 0 { + return parseNumberDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, AllNumberDataTypes[idx]}) + } + if slices.Contains(FloatDataTypeSynonyms, dataTypeRaw) { + return parseFloatDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, dataTypeRaw}) + } + if idx := slices.IndexFunc(AllTextDataTypes, func(s string) bool { return strings.HasPrefix(dataTypeRaw, s) }); idx >= 0 { + return parseTextDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, AllTextDataTypes[idx]}) + } + if idx := slices.IndexFunc(BinaryDataTypeSynonyms, func(s string) bool { return strings.HasPrefix(dataTypeRaw, s) }); idx >= 0 { + return parseBinaryDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, BinaryDataTypeSynonyms[idx]}) + } + if slices.Contains(BooleanDataTypeSynonyms, dataTypeRaw) { + return parseBooleanDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, dataTypeRaw}) + } + if slices.Contains(DateDataTypeSynonyms, dataTypeRaw) { + return parseDateDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, dataTypeRaw}) + } + if idx := slices.IndexFunc(TimestampLtzDataTypeSynonyms, func(s string) bool { return strings.HasPrefix(dataTypeRaw, s) }); idx >= 0 { + return parseTimestampLtzDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, TimestampLtzDataTypeSynonyms[idx]}) + } + if idx := slices.IndexFunc(TimestampNtzDataTypeSynonyms, func(s string) bool { return strings.HasPrefix(dataTypeRaw, s) }); idx >= 0 { + return parseTimestampNtzDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, TimestampNtzDataTypeSynonyms[idx]}) + } + if idx := slices.IndexFunc(TimestampTzDataTypeSynonyms, func(s string) bool { return strings.HasPrefix(dataTypeRaw, s) }); idx >= 0 { + return parseTimestampTzDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, TimestampTzDataTypeSynonyms[idx]}) + } + if idx := slices.IndexFunc(TimeDataTypeSynonyms, func(s string) bool { return strings.HasPrefix(dataTypeRaw, s) }); idx >= 0 { + return parseTimeDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, TimeDataTypeSynonyms[idx]}) + } + if slices.Contains(VariantDataTypeSynonyms, dataTypeRaw) { + return parseVariantDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, dataTypeRaw}) + } + if slices.Contains(ObjectDataTypeSynonyms, dataTypeRaw) { + return parseObjectDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, dataTypeRaw}) + } + if slices.Contains(ArrayDataTypeSynonyms, dataTypeRaw) { + return parseArrayDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, dataTypeRaw}) + } + if slices.Contains(GeographyDataTypeSynonyms, dataTypeRaw) { + return parseGeographyDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, dataTypeRaw}) + } + if slices.Contains(GeometryDataTypeSynonyms, dataTypeRaw) { + return parseGeometryDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, dataTypeRaw}) + } + if idx := slices.IndexFunc(VectorDataTypeSynonyms, func(s string) bool { return strings.HasPrefix(dataTypeRaw, s) }); idx >= 0 { + return parseVectorDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, VectorDataTypeSynonyms[idx]}) + } + + return nil, fmt.Errorf("invalid data type: %s", raw) +} + +// AreTheSame compares any two data types. +// If both data types are nil it returns true. +// If only one data type is nil it returns false. +// It returns false for different underlying types. +// For the same type it performs type-specific comparison. +func AreTheSame(a DataType, b DataType) bool { + if a == nil && b == nil { + return true + } + if a == nil && b != nil || a != nil && b == nil { + return false + } + if reflect.TypeOf(a) != reflect.TypeOf(b) { + return false + } + switch v := a.(type) { + case *ArrayDataType: + return castSuccessfully(v, b, noArgsDataTypesAreTheSame) + case *BinaryDataType: + return castSuccessfully(v, b, areBinaryDataTypesTheSame) + case *BooleanDataType: + return castSuccessfully(v, b, noArgsDataTypesAreTheSame) + case *DateDataType: + return castSuccessfully(v, b, noArgsDataTypesAreTheSame) + case *FloatDataType: + return castSuccessfully(v, b, noArgsDataTypesAreTheSame) + case *GeographyDataType: + return castSuccessfully(v, b, noArgsDataTypesAreTheSame) + case *GeometryDataType: + return castSuccessfully(v, b, noArgsDataTypesAreTheSame) + case *NumberDataType: + return castSuccessfully(v, b, areNumberDataTypesTheSame) + case *ObjectDataType: + return castSuccessfully(v, b, noArgsDataTypesAreTheSame) + case *TextDataType: + return castSuccessfully(v, b, areTextDataTypesTheSame) + case *TimeDataType: + return castSuccessfully(v, b, areTimeDataTypesTheSame) + case *TimestampLtzDataType: + return castSuccessfully(v, b, areTimestampLtzDataTypesTheSame) + case *TimestampNtzDataType: + return castSuccessfully(v, b, areTimestampNtzDataTypesTheSame) + case *TimestampTzDataType: + return castSuccessfully(v, b, areTimestampTzDataTypesTheSame) + case *VariantDataType: + return castSuccessfully(v, b, noArgsDataTypesAreTheSame) + case *VectorDataType: + return castSuccessfully(v, b, areVectorDataTypesTheSame) + } + return false +} + +func IsTextDataType(a DataType) bool { + _, ok := a.(*TextDataType) + return ok +} + +func castSuccessfully[T any](a T, b DataType, invoke func(a T, b T) bool) bool { + if dCasted, ok := b.(T); ok { + return invoke(a, dCasted) + } + return false +} + +func noArgsDataTypesAreTheSame[T DataType](_ T, _ T) bool { + return true +} diff --git a/pkg/sdk/datatypes/data_types_test.go b/pkg/sdk/datatypes/data_types_test.go new file mode 100644 index 0000000000..21525fded8 --- /dev/null +++ b/pkg/sdk/datatypes/data_types_test.go @@ -0,0 +1,1148 @@ +package datatypes + +import ( + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ParseDataType_Number(t *testing.T) { + type test struct { + input string + expectedPrecision int + expectedScale int + expectedUnderlyingType string + } + defaults := func(input string) test { + return test{ + input: input, + expectedPrecision: DefaultNumberPrecision, + expectedScale: DefaultNumberScale, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + {input: "NUMBER(30)", expectedPrecision: 30, expectedScale: DefaultNumberScale, expectedUnderlyingType: "NUMBER"}, + {input: "NUMBER(30, 2)", expectedPrecision: 30, expectedScale: 2, expectedUnderlyingType: "NUMBER"}, + {input: "dec(30)", expectedPrecision: 30, expectedScale: DefaultNumberScale, expectedUnderlyingType: "DEC"}, + {input: "dec(30, 2)", expectedPrecision: 30, expectedScale: 2, expectedUnderlyingType: "DEC"}, + {input: "decimal(30)", expectedPrecision: 30, expectedScale: DefaultNumberScale, expectedUnderlyingType: "DECIMAL"}, + {input: "decimal(30, 2)", expectedPrecision: 30, expectedScale: 2, expectedUnderlyingType: "DECIMAL"}, + {input: "NuMeRiC(30)", expectedPrecision: 30, expectedScale: DefaultNumberScale, expectedUnderlyingType: "NUMERIC"}, + {input: "NuMeRiC(30, 2)", expectedPrecision: 30, expectedScale: 2, expectedUnderlyingType: "NUMERIC"}, + {input: "NUMBER( 30 , 2 )", expectedPrecision: 30, expectedScale: 2, expectedUnderlyingType: "NUMBER"}, + {input: " NUMBER ( 30 , 2 ) ", expectedPrecision: 30, expectedScale: 2, expectedUnderlyingType: "NUMBER"}, + {input: fmt.Sprintf("NUMBER(%d)", DefaultNumberPrecision), expectedPrecision: DefaultNumberPrecision, expectedScale: DefaultNumberScale, expectedUnderlyingType: "NUMBER"}, + {input: fmt.Sprintf("NUMBER(%d, %d)", DefaultNumberPrecision, DefaultNumberScale), expectedPrecision: DefaultNumberPrecision, expectedScale: DefaultNumberScale, expectedUnderlyingType: "NUMBER"}, + + defaults("NUMBER"), + defaults("DEC"), + defaults("DECIMAL"), + defaults("NUMERIC"), + defaults(" NUMBER "), + + defaults("INT"), + defaults("INTEGER"), + defaults("BIGINT"), + defaults("SMALLINT"), + defaults("TINYINT"), + defaults("BYTEINT"), + defaults("int"), + defaults("integer"), + defaults("bigint"), + defaults("smallint"), + defaults("tinyint"), + defaults("byteint"), + } + + negativeTestCases := []test{ + negative("other(1, 2)"), + negative("other(1)"), + negative("other"), + negative("NUMBER()"), + negative("NUMBER(x)"), + negative(fmt.Sprintf("NUMBER(%d, x)", DefaultNumberPrecision)), + negative(fmt.Sprintf("NUMBER(x, %d)", DefaultNumberScale)), + negative("NUMBER(1, 2, 3)"), + negative("NUMBER("), + negative("NUMBER)"), + negative("NUM BER"), + negative("INT(30)"), + negative("INT(30, 2)"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &NumberDataType{}, parsed) + + assert.Equal(t, tc.expectedPrecision, parsed.(*NumberDataType).precision) + assert.Equal(t, tc.expectedScale, parsed.(*NumberDataType).scale) + assert.Equal(t, tc.expectedUnderlyingType, parsed.(*NumberDataType).underlyingType) + + assert.Equal(t, NumberLegacyDataType, parsed.ToLegacyDataTypeSql()) + assert.Equal(t, fmt.Sprintf("%s(%d, %d)", parsed.(*NumberDataType).underlyingType, parsed.(*NumberDataType).precision, parsed.(*NumberDataType).scale), parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_ParseDataType_Float(t *testing.T) { + type test struct { + input string + expectedUnderlyingType string + } + defaults := func(input string) test { + return test{ + input: input, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + defaults(" FLOAT "), + defaults("FLOAT"), + defaults("FLOAT4"), + defaults("FLOAT8"), + defaults("DOUBLE PRECISION"), + defaults("DOUBLE"), + defaults("REAL"), + defaults("float"), + defaults("float4"), + defaults("float8"), + defaults("double precision"), + defaults("double"), + defaults("real"), + } + + negativeTestCases := []test{ + negative("FLOAT(38, 0)"), + negative("FLOAT(38, 2)"), + negative("FLOAT(38)"), + negative("FLOAT()"), + negative("F L O A T"), + negative("other"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &FloatDataType{}, parsed) + + assert.Equal(t, tc.expectedUnderlyingType, parsed.(*FloatDataType).underlyingType) + + assert.Equal(t, FloatLegacyDataType, parsed.ToLegacyDataTypeSql()) + assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_ParseDataType_Text(t *testing.T) { + type test struct { + input string + expectedLength int + expectedUnderlyingType string + } + defaultsVarchar := func(input string) test { + return test{ + input: input, + expectedLength: DefaultVarcharLength, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + defaultsChar := func(input string) test { + return test{ + input: input, + expectedLength: DefaultCharLength, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + {input: "VARCHAR(30)", expectedLength: 30, expectedUnderlyingType: "VARCHAR"}, + {input: "string(30)", expectedLength: 30, expectedUnderlyingType: "STRING"}, + {input: "VARCHAR( 30 )", expectedLength: 30, expectedUnderlyingType: "VARCHAR"}, + {input: " VARCHAR ( 30 ) ", expectedLength: 30, expectedUnderlyingType: "VARCHAR"}, + {input: fmt.Sprintf("VARCHAR(%d)", DefaultVarcharLength), expectedLength: DefaultVarcharLength, expectedUnderlyingType: "VARCHAR"}, + + {input: "CHAR(30)", expectedLength: 30, expectedUnderlyingType: "CHAR"}, + {input: "character(30)", expectedLength: 30, expectedUnderlyingType: "CHARACTER"}, + {input: "CHAR( 30 )", expectedLength: 30, expectedUnderlyingType: "CHAR"}, + {input: " CHAR ( 30 ) ", expectedLength: 30, expectedUnderlyingType: "CHAR"}, + {input: fmt.Sprintf("CHAR(%d)", DefaultCharLength), expectedLength: DefaultCharLength, expectedUnderlyingType: "CHAR"}, + + defaultsVarchar(" VARCHAR "), + defaultsVarchar("VARCHAR"), + defaultsVarchar("STRING"), + defaultsVarchar("TEXT"), + defaultsVarchar("NVARCHAR"), + defaultsVarchar("NVARCHAR2"), + defaultsVarchar("CHAR VARYING"), + defaultsVarchar("NCHAR VARYING"), + defaultsVarchar("varchar"), + defaultsVarchar("string"), + defaultsVarchar("text"), + defaultsVarchar("nvarchar"), + defaultsVarchar("nvarchar2"), + defaultsVarchar("char varying"), + defaultsVarchar("nchar varying"), + + defaultsChar(" CHAR "), + defaultsChar("CHAR"), + defaultsChar("CHARACTER"), + defaultsChar("NCHAR"), + defaultsChar("char"), + defaultsChar("character"), + defaultsChar("nchar"), + } + + negativeTestCases := []test{ + negative("other(1, 2)"), + negative("other(1)"), + negative("other"), + negative("VARCHAR()"), + negative("VARCHAR(x)"), + negative("VARCHAR( )"), + negative("CHAR()"), + negative("CHAR(x)"), + negative("CHAR( )"), + negative("VARCHAR(1, 2)"), + negative("VARCHAR("), + negative("VARCHAR)"), + negative("VAR CHAR"), + negative("CHAR(1, 2)"), + negative("CHAR("), + negative("CHAR)"), + negative("CH AR"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &TextDataType{}, parsed) + + assert.Equal(t, tc.expectedLength, parsed.(*TextDataType).length) + assert.Equal(t, tc.expectedUnderlyingType, parsed.(*TextDataType).underlyingType) + + assert.Equal(t, VarcharLegacyDataType, parsed.ToLegacyDataTypeSql()) + assert.Equal(t, fmt.Sprintf("%s(%d)", parsed.(*TextDataType).underlyingType, parsed.(*TextDataType).length), parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_ParseDataType_Binary(t *testing.T) { + type test struct { + input string + expectedSize int + expectedUnderlyingType string + } + defaults := func(input string) test { + return test{ + input: input, + expectedSize: DefaultBinarySize, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + {input: "BINARY(30)", expectedSize: 30, expectedUnderlyingType: "BINARY"}, + {input: "varbinary(30)", expectedSize: 30, expectedUnderlyingType: "VARBINARY"}, + {input: "BINARY( 30 )", expectedSize: 30, expectedUnderlyingType: "BINARY"}, + {input: " BINARY ( 30 ) ", expectedSize: 30, expectedUnderlyingType: "BINARY"}, + {input: fmt.Sprintf("BINARY(%d)", DefaultBinarySize), expectedSize: DefaultBinarySize, expectedUnderlyingType: "BINARY"}, + + defaults(" BINARY "), + defaults("BINARY"), + defaults("VARBINARY"), + defaults("binary"), + defaults("varbinary"), + } + + negativeTestCases := []test{ + negative("other(1, 2)"), + negative("other(1)"), + negative("other"), + negative("BINARY()"), + negative("BINARY(x)"), + negative("BINARY( )"), + negative("BINARY(1, 2)"), + negative("BINARY("), + negative("BINARY)"), + negative("BIN ARY"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &BinaryDataType{}, parsed) + + assert.Equal(t, tc.expectedSize, parsed.(*BinaryDataType).size) + assert.Equal(t, tc.expectedUnderlyingType, parsed.(*BinaryDataType).underlyingType) + + assert.Equal(t, BinaryLegacyDataType, parsed.ToLegacyDataTypeSql()) + assert.Equal(t, fmt.Sprintf("%s(%d)", parsed.(*BinaryDataType).underlyingType, parsed.(*BinaryDataType).size), parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_ParseDataType_Boolean(t *testing.T) { + type test struct { + input string + expectedUnderlyingType string + } + defaults := func(input string) test { + return test{ + input: input, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + defaults(" BOOLEAN "), + defaults("BOOLEAN"), + defaults("boolean"), + } + + negativeTestCases := []test{ + negative("BOOLEAN(38, 0)"), + negative("BOOLEAN(38, 2)"), + negative("BOOLEAN(38)"), + negative("BOOLEAN()"), + negative("BOOL"), + negative("bool"), + negative("B O O L E A N"), + negative("other"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &BooleanDataType{}, parsed) + + assert.Equal(t, tc.expectedUnderlyingType, parsed.(*BooleanDataType).underlyingType) + + assert.Equal(t, BooleanLegacyDataType, parsed.ToLegacyDataTypeSql()) + assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_ParseDataType_Date(t *testing.T) { + type test struct { + input string + expectedUnderlyingType string + } + defaults := func(input string) test { + return test{ + input: input, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + defaults(" DATE "), + defaults("DATE"), + defaults("date"), + } + + negativeTestCases := []test{ + negative("DATE(38, 0)"), + negative("DATE(38, 2)"), + negative("DATE(38)"), + negative("DATE()"), + negative("D A T E"), + negative("other"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &DateDataType{}, parsed) + + assert.Equal(t, tc.expectedUnderlyingType, parsed.(*DateDataType).underlyingType) + + assert.Equal(t, DateLegacyDataType, parsed.ToLegacyDataTypeSql()) + assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_ParseDataType_Time(t *testing.T) { + type test struct { + input string + expectedPrecision int + expectedUnderlyingType string + } + defaults := func(input string) test { + return test{ + input: input, + expectedPrecision: DefaultTimePrecision, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + defaults(" TIME "), + defaults("TIME"), + defaults("time"), + {input: "TIME(5)", expectedPrecision: 5, expectedUnderlyingType: "TIME"}, + {input: "time(5)", expectedPrecision: 5, expectedUnderlyingType: "TIME"}, + } + + negativeTestCases := []test{ + negative("TIME(38, 0)"), + negative("TIME(38, 2)"), + negative("TIME()"), + negative("T I M E"), + negative("other"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &TimeDataType{}, parsed) + + assert.Equal(t, tc.expectedUnderlyingType, parsed.(*TimeDataType).underlyingType) + assert.Equal(t, tc.expectedPrecision, parsed.(*TimeDataType).precision) + + assert.Equal(t, TimeLegacyDataType, parsed.ToLegacyDataTypeSql()) + assert.Equal(t, fmt.Sprintf("%s(%d)", tc.expectedUnderlyingType, tc.expectedPrecision), parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_ParseDataType_TimestampLtz(t *testing.T) { + type test struct { + input string + expectedPrecision int + expectedUnderlyingType string + } + defaults := func(input string) test { + return test{ + input: input, + expectedPrecision: DefaultTimestampPrecision, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + {input: "TIMESTAMP_LTZ(4)", expectedPrecision: 4, expectedUnderlyingType: "TIMESTAMP_LTZ"}, + {input: "timestamp with local time zone(5)", expectedPrecision: 5, expectedUnderlyingType: "TIMESTAMP WITH LOCAL TIME ZONE"}, + {input: "TIMESTAMP_LTZ( 2 )", expectedPrecision: 2, expectedUnderlyingType: "TIMESTAMP_LTZ"}, + {input: " TIMESTAMP_LTZ ( 7 ) ", expectedPrecision: 7, expectedUnderlyingType: "TIMESTAMP_LTZ"}, + {input: fmt.Sprintf("TIMESTAMP_LTZ(%d)", DefaultTimestampPrecision), expectedPrecision: DefaultTimestampPrecision, expectedUnderlyingType: "TIMESTAMP_LTZ"}, + + defaults(" TIMESTAMP_LTZ "), + defaults("TIMESTAMP_LTZ"), + defaults("TIMESTAMPLTZ"), + defaults("TIMESTAMP WITH LOCAL TIME ZONE"), + defaults("timestamp_ltz"), + defaults("timestampltz"), + defaults("timestamp with local time zone"), + } + + negativeTestCases := []test{ + negative("TIMESTAMP_LTZ(38, 0)"), + negative("TIMESTAMP_LTZ(38, 2)"), + negative("TIMESTAMP_LTZ()"), + negative("T I M E S T A M P _ L T Z"), + negative("other"), + negative("other(3)"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &TimestampLtzDataType{}, parsed) + + assert.Equal(t, tc.expectedPrecision, parsed.(*TimestampLtzDataType).precision) + assert.Equal(t, tc.expectedUnderlyingType, parsed.(*TimestampLtzDataType).underlyingType) + + assert.Equal(t, TimestampLtzLegacyDataType, parsed.ToLegacyDataTypeSql()) + assert.Equal(t, fmt.Sprintf("%s(%d)", parsed.(*TimestampLtzDataType).underlyingType, parsed.(*TimestampLtzDataType).precision), parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_ParseDataType_TimestampNtz(t *testing.T) { + type test struct { + input string + expectedPrecision int + expectedUnderlyingType string + } + defaults := func(input string) test { + return test{ + input: input, + expectedPrecision: DefaultTimestampPrecision, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + {input: "TIMESTAMP_NTZ(4)", expectedPrecision: 4, expectedUnderlyingType: "TIMESTAMP_NTZ"}, + {input: "timestamp without time zone(5)", expectedPrecision: 5, expectedUnderlyingType: "TIMESTAMP WITHOUT TIME ZONE"}, + {input: "TIMESTAMP_NTZ( 2 )", expectedPrecision: 2, expectedUnderlyingType: "TIMESTAMP_NTZ"}, + {input: " TIMESTAMP_NTZ ( 7 ) ", expectedPrecision: 7, expectedUnderlyingType: "TIMESTAMP_NTZ"}, + {input: fmt.Sprintf("TIMESTAMP_NTZ(%d)", DefaultTimestampPrecision), expectedPrecision: DefaultTimestampPrecision, expectedUnderlyingType: "TIMESTAMP_NTZ"}, + + defaults(" TIMESTAMP_NTZ "), + defaults("TIMESTAMP_NTZ"), + defaults("TIMESTAMPNTZ"), + defaults("TIMESTAMP WITHOUT TIME ZONE"), + defaults("DATETIME"), + defaults("timestamp_ntz"), + defaults("timestampntz"), + defaults("timestamp without time zone"), + defaults("datetime"), + } + + negativeTestCases := []test{ + negative("TIMESTAMP_NTZ(38, 0)"), + negative("TIMESTAMP_NTZ(38, 2)"), + negative("TIMESTAMP_NTZ()"), + negative("T I M E S T A M P _ N T Z"), + negative("other"), + negative("other(3)"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &TimestampNtzDataType{}, parsed) + + assert.Equal(t, tc.expectedPrecision, parsed.(*TimestampNtzDataType).precision) + assert.Equal(t, tc.expectedUnderlyingType, parsed.(*TimestampNtzDataType).underlyingType) + + assert.Equal(t, TimestampNtzLegacyDataType, parsed.ToLegacyDataTypeSql()) + assert.Equal(t, fmt.Sprintf("%s(%d)", parsed.(*TimestampNtzDataType).underlyingType, parsed.(*TimestampNtzDataType).precision), parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_ParseDataType_TimestampTz(t *testing.T) { + type test struct { + input string + expectedPrecision int + expectedUnderlyingType string + } + defaults := func(input string) test { + return test{ + input: input, + expectedPrecision: DefaultTimestampPrecision, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + {input: "TIMESTAMP_TZ(4)", expectedPrecision: 4, expectedUnderlyingType: "TIMESTAMP_TZ"}, + {input: "timestamp with time zone(5)", expectedPrecision: 5, expectedUnderlyingType: "TIMESTAMP WITH TIME ZONE"}, + {input: "TIMESTAMP_TZ( 2 )", expectedPrecision: 2, expectedUnderlyingType: "TIMESTAMP_TZ"}, + {input: " TIMESTAMP_TZ ( 7 ) ", expectedPrecision: 7, expectedUnderlyingType: "TIMESTAMP_TZ"}, + {input: fmt.Sprintf("TIMESTAMP_TZ(%d)", DefaultTimestampPrecision), expectedPrecision: DefaultTimestampPrecision, expectedUnderlyingType: "TIMESTAMP_TZ"}, + + defaults(" TIMESTAMP_TZ "), + defaults("TIMESTAMP_TZ"), + defaults("TIMESTAMPTZ"), + defaults("TIMESTAMP WITH TIME ZONE"), + defaults("timestamp_tz"), + defaults("timestamptz"), + defaults("timestamp with time zone"), + } + + negativeTestCases := []test{ + negative("TIMESTAMP_TZ(38, 0)"), + negative("TIMESTAMP_TZ(38, 2)"), + negative("TIMESTAMP_TZ()"), + negative("T I M E S T A M P _ T Z"), + negative("other"), + negative("other(3)"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &TimestampTzDataType{}, parsed) + + assert.Equal(t, tc.expectedPrecision, parsed.(*TimestampTzDataType).precision) + assert.Equal(t, tc.expectedUnderlyingType, parsed.(*TimestampTzDataType).underlyingType) + + assert.Equal(t, TimestampTzLegacyDataType, parsed.ToLegacyDataTypeSql()) + assert.Equal(t, fmt.Sprintf("%s(%d)", parsed.(*TimestampTzDataType).underlyingType, parsed.(*TimestampTzDataType).precision), parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_ParseDataType_Variant(t *testing.T) { + type test struct { + input string + expectedUnderlyingType string + } + defaults := func(input string) test { + return test{ + input: input, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + defaults(" VARIANT "), + defaults("VARIANT"), + defaults("variant"), + } + + negativeTestCases := []test{ + negative("VARIANT(38, 0)"), + negative("VARIANT(38, 2)"), + negative("VARIANT(38)"), + negative("VARIANT()"), + negative("V A R I A N T"), + negative("other"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &VariantDataType{}, parsed) + + assert.Equal(t, tc.expectedUnderlyingType, parsed.(*VariantDataType).underlyingType) + + assert.Equal(t, VariantLegacyDataType, parsed.ToLegacyDataTypeSql()) + assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_ParseDataType_Object(t *testing.T) { + type test struct { + input string + expectedUnderlyingType string + } + defaults := func(input string) test { + return test{ + input: input, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + defaults(" OBJECT "), + defaults("OBJECT"), + defaults("object"), + } + + negativeTestCases := []test{ + negative("OBJECT(38, 0)"), + negative("OBJECT(38, 2)"), + negative("OBJECT(38)"), + negative("OBJECT()"), + negative("O B J E C T"), + negative("other"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &ObjectDataType{}, parsed) + + assert.Equal(t, tc.expectedUnderlyingType, parsed.(*ObjectDataType).underlyingType) + + assert.Equal(t, ObjectLegacyDataType, parsed.ToLegacyDataTypeSql()) + assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_ParseDataType_Array(t *testing.T) { + type test struct { + input string + expectedUnderlyingType string + } + defaults := func(input string) test { + return test{ + input: input, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + defaults(" ARRAY "), + defaults("ARRAY"), + defaults("array"), + } + + negativeTestCases := []test{ + negative("ARRAY(38, 0)"), + negative("ARRAY(38, 2)"), + negative("ARRAY(38)"), + negative("ARRAY()"), + negative("A R R A Y"), + negative("other"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &ArrayDataType{}, parsed) + + assert.Equal(t, tc.expectedUnderlyingType, parsed.(*ArrayDataType).underlyingType) + + assert.Equal(t, ArrayLegacyDataType, parsed.ToLegacyDataTypeSql()) + assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_ParseDataType_Geography(t *testing.T) { + type test struct { + input string + expectedUnderlyingType string + } + defaults := func(input string) test { + return test{ + input: input, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + defaults(" GEOGRAPHY "), + defaults("GEOGRAPHY"), + defaults("geography"), + } + + negativeTestCases := []test{ + negative("GEOGRAPHY(38, 0)"), + negative("GEOGRAPHY(38, 2)"), + negative("GEOGRAPHY(38)"), + negative("GEOGRAPHY()"), + negative("G E O G R A P H Y"), + negative("other"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &GeographyDataType{}, parsed) + + assert.Equal(t, tc.expectedUnderlyingType, parsed.(*GeographyDataType).underlyingType) + + assert.Equal(t, GeographyLegacyDataType, parsed.ToLegacyDataTypeSql()) + assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_ParseDataType_Geometry(t *testing.T) { + type test struct { + input string + expectedUnderlyingType string + } + defaults := func(input string) test { + return test{ + input: input, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + defaults(" GEOMETRY "), + defaults("GEOMETRY"), + defaults("geometry"), + } + + negativeTestCases := []test{ + negative("GEOMETRY(38, 0)"), + negative("GEOMETRY(38, 2)"), + negative("GEOMETRY(38)"), + negative("GEOMETRY()"), + negative("G E O M E T R Y"), + negative("other"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &GeometryDataType{}, parsed) + + assert.Equal(t, tc.expectedUnderlyingType, parsed.(*GeometryDataType).underlyingType) + + assert.Equal(t, GeometryLegacyDataType, parsed.ToLegacyDataTypeSql()) + assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_ParseDataType_Vector(t *testing.T) { + type test struct { + input string + expectedInnerType string + expectedDimension int + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + {input: "VECTOR(INT, 2)", expectedInnerType: "INT", expectedDimension: 2}, + {input: "VECTOR(FLOAT, 2)", expectedInnerType: "FLOAT", expectedDimension: 2}, + {input: "VeCtOr ( InT , 40 )", expectedInnerType: "INT", expectedDimension: 40}, + {input: " VECTOR ( INT , 40 )", expectedInnerType: "INT", expectedDimension: 40}, + } + + negativeTestCases := []test{ + negative("VECTOR(1, 2)"), + negative("VECTOR(1)"), + negative("VECTOR(2, INT)"), + negative("VECTOR()"), + negative("VECTOR"), + negative("VECTOR(INT, 2, 3)"), + negative("VECTOR(INT)"), + negative("VECTOR(x, 2)"), + negative("VECTOR("), + negative("VECTOR)"), + negative("VEC TOR"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &VectorDataType{}, parsed) + + assert.Equal(t, tc.expectedInnerType, parsed.(*VectorDataType).innerType) + assert.Equal(t, tc.expectedDimension, parsed.(*VectorDataType).dimension) + assert.Equal(t, "VECTOR", parsed.(*VectorDataType).underlyingType) + + assert.Equal(t, fmt.Sprintf("%s(%s, %d)", parsed.(*VectorDataType).underlyingType, parsed.(*VectorDataType).innerType, parsed.(*VectorDataType).dimension), parsed.ToLegacyDataTypeSql()) + assert.Equal(t, fmt.Sprintf("%s(%s, %d)", parsed.(*VectorDataType).underlyingType, parsed.(*VectorDataType).innerType, parsed.(*VectorDataType).dimension), parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_AreTheSame(t *testing.T) { + // empty d1/d2 means nil DataType input + type test struct { + d1 string + d2 string + expectedOutcome bool + } + + testCases := []test{ + {d1: "", d2: "", expectedOutcome: true}, + {d1: "", d2: "NUMBER", expectedOutcome: false}, + {d1: "NUMBER", d2: "", expectedOutcome: false}, + + {d1: "NUMBER(20)", d2: "NUMBER(20, 2)", expectedOutcome: false}, + {d1: "NUMBER(20, 1)", d2: "NUMBER(20, 2)", expectedOutcome: false}, + {d1: "NUMBER", d2: "NUMBER(20, 2)", expectedOutcome: false}, + {d1: "NUMBER", d2: fmt.Sprintf("NUMBER(%d, %d)", DefaultNumberPrecision, DefaultNumberScale), expectedOutcome: true}, + {d1: fmt.Sprintf("NUMBER(%d)", DefaultNumberPrecision), d2: fmt.Sprintf("NUMBER(%d, %d)", DefaultNumberPrecision, DefaultNumberScale), expectedOutcome: true}, + {d1: "NUMBER", d2: "NUMBER", expectedOutcome: true}, + {d1: "NUMBER(20)", d2: "NUMBER(20)", expectedOutcome: true}, + {d1: "NUMBER(20, 2)", d2: "NUMBER(20, 2)", expectedOutcome: true}, + {d1: "INT", d2: "NUMBER", expectedOutcome: true}, + {d1: "INT", d2: fmt.Sprintf("NUMBER(%d, %d)", DefaultNumberPrecision, DefaultNumberScale), expectedOutcome: true}, + {d1: "INT", d2: "NUMBER(20)", expectedOutcome: false}, + {d1: "NUMBER", d2: "VARCHAR", expectedOutcome: false}, + {d1: "NUMBER(20)", d2: "VARCHAR(20)", expectedOutcome: false}, + {d1: "CHAR", d2: "VARCHAR", expectedOutcome: false}, + {d1: "CHAR", d2: fmt.Sprintf("VARCHAR(%d)", DefaultCharLength), expectedOutcome: true}, + {d1: fmt.Sprintf("CHAR(%d)", DefaultVarcharLength), d2: "VARCHAR", expectedOutcome: true}, + {d1: "BINARY", d2: "BINARY", expectedOutcome: true}, + {d1: "BINARY", d2: "VARBINARY", expectedOutcome: true}, + {d1: "BINARY(20)", d2: "BINARY(20)", expectedOutcome: true}, + {d1: "BINARY(20)", d2: "BINARY(30)", expectedOutcome: false}, + {d1: "BINARY", d2: "BINARY(30)", expectedOutcome: false}, + {d1: fmt.Sprintf("BINARY(%d)", DefaultBinarySize), d2: "BINARY", expectedOutcome: true}, + {d1: "FLOAT", d2: "FLOAT4", expectedOutcome: true}, + {d1: "DOUBLE", d2: "FLOAT8", expectedOutcome: true}, + {d1: "DOUBLE PRECISION", d2: "REAL", expectedOutcome: true}, + {d1: "TIMESTAMPLTZ", d2: "TIMESTAMPNTZ", expectedOutcome: false}, + {d1: "TIMESTAMPLTZ", d2: "TIMESTAMPTZ", expectedOutcome: false}, + {d1: "TIMESTAMPLTZ", d2: fmt.Sprintf("TIMESTAMPLTZ(%d)", DefaultTimestampPrecision), expectedOutcome: true}, + {d1: "VECTOR(INT, 20)", d2: "VECTOR(INT, 20)", expectedOutcome: true}, + {d1: "VECTOR(INT, 20)", d2: "VECTOR(INT, 30)", expectedOutcome: false}, + {d1: "VECTOR(FLOAT, 20)", d2: "VECTOR(INT, 30)", expectedOutcome: false}, + {d1: "VECTOR(FLOAT, 20)", d2: "VECTOR(INT, 20)", expectedOutcome: false}, + {d1: "VECTOR(FLOAT, 20)", d2: "VECTOR(FLOAT, 20)", expectedOutcome: true}, + {d1: "VECTOR(FLOAT, 20)", d2: "FLOAT", expectedOutcome: false}, + {d1: "TIME", d2: "TIME", expectedOutcome: true}, + {d1: "TIME", d2: "TIME(5)", expectedOutcome: false}, + {d1: "TIME", d2: fmt.Sprintf("TIME(%d)", DefaultTimePrecision), expectedOutcome: true}, + } + + for _, tc := range testCases { + tc := tc + t.Run(fmt.Sprintf(`compare "%s" with "%s" expecting %t`, tc.d1, tc.d2, tc.expectedOutcome), func(t *testing.T) { + var p1, p2 DataType + var err error + + if tc.d1 != "" { + p1, err = ParseDataType(tc.d1) + require.NoError(t, err) + } + + if tc.d2 != "" { + p2, err = ParseDataType(tc.d2) + require.NoError(t, err) + } + + require.Equal(t, tc.expectedOutcome, AreTheSame(p1, p2)) + }) + } +} diff --git a/pkg/sdk/datatypes/date.go b/pkg/sdk/datatypes/date.go new file mode 100644 index 0000000000..92ee7c27bc --- /dev/null +++ b/pkg/sdk/datatypes/date.go @@ -0,0 +1,21 @@ +package datatypes + +// DateDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-datetime#date +// It does not have synonyms. It does not have any attributes. +type DateDataType struct { + underlyingType string +} + +func (t *DateDataType) ToSql() string { + return t.underlyingType +} + +func (t *DateDataType) ToLegacyDataTypeSql() string { + return DateLegacyDataType +} + +var DateDataTypeSynonyms = []string{DateLegacyDataType} + +func parseDateDataTypeRaw(raw sanitizedDataTypeRaw) (*DateDataType, error) { + return &DateDataType{raw.matchedByType}, nil +} diff --git a/pkg/sdk/datatypes/float.go b/pkg/sdk/datatypes/float.go new file mode 100644 index 0000000000..a0ca84863b --- /dev/null +++ b/pkg/sdk/datatypes/float.go @@ -0,0 +1,21 @@ +package datatypes + +// FloatDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-numeric#data-types-for-floating-point-numbers +// It does have synonyms. It does not have any attributes. +type FloatDataType struct { + underlyingType string +} + +func (t *FloatDataType) ToSql() string { + return t.underlyingType +} + +func (t *FloatDataType) ToLegacyDataTypeSql() string { + return FloatLegacyDataType +} + +var FloatDataTypeSynonyms = []string{"FLOAT8", "FLOAT4", FloatLegacyDataType, "DOUBLE PRECISION", "DOUBLE", "REAL"} + +func parseFloatDataTypeRaw(raw sanitizedDataTypeRaw) (*FloatDataType, error) { + return &FloatDataType{raw.matchedByType}, nil +} diff --git a/pkg/sdk/datatypes/geography.go b/pkg/sdk/datatypes/geography.go new file mode 100644 index 0000000000..4a024a20b0 --- /dev/null +++ b/pkg/sdk/datatypes/geography.go @@ -0,0 +1,21 @@ +package datatypes + +// GeographyDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-geospatial#geography-data-type +// It does not have synonyms. It does not have any attributes. +type GeographyDataType struct { + underlyingType string +} + +func (t *GeographyDataType) ToSql() string { + return t.underlyingType +} + +func (t *GeographyDataType) ToLegacyDataTypeSql() string { + return GeographyLegacyDataType +} + +var GeographyDataTypeSynonyms = []string{GeographyLegacyDataType} + +func parseGeographyDataTypeRaw(raw sanitizedDataTypeRaw) (*GeographyDataType, error) { + return &GeographyDataType{raw.matchedByType}, nil +} diff --git a/pkg/sdk/datatypes/geometry.go b/pkg/sdk/datatypes/geometry.go new file mode 100644 index 0000000000..d09ebd3eea --- /dev/null +++ b/pkg/sdk/datatypes/geometry.go @@ -0,0 +1,21 @@ +package datatypes + +// GeometryDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-geospatial#geometry-data-type +// It does not have synonyms. It does not have any attributes. +type GeometryDataType struct { + underlyingType string +} + +func (t *GeometryDataType) ToSql() string { + return t.underlyingType +} + +func (t *GeometryDataType) ToLegacyDataTypeSql() string { + return GeometryLegacyDataType +} + +var GeometryDataTypeSynonyms = []string{GeometryLegacyDataType} + +func parseGeometryDataTypeRaw(raw sanitizedDataTypeRaw) (*GeometryDataType, error) { + return &GeometryDataType{raw.matchedByType}, nil +} diff --git a/pkg/sdk/datatypes/legacy.go b/pkg/sdk/datatypes/legacy.go new file mode 100644 index 0000000000..5a0e249cd7 --- /dev/null +++ b/pkg/sdk/datatypes/legacy.go @@ -0,0 +1,19 @@ +package datatypes + +const ( + ArrayLegacyDataType = "ARRAY" + BinaryLegacyDataType = "BINARY" + BooleanLegacyDataType = "BOOLEAN" + DateLegacyDataType = "DATE" + FloatLegacyDataType = "FLOAT" + GeographyLegacyDataType = "GEOGRAPHY" + GeometryLegacyDataType = "GEOMETRY" + NumberLegacyDataType = "NUMBER" + ObjectLegacyDataType = "OBJECT" + VarcharLegacyDataType = "VARCHAR" + TimeLegacyDataType = "TIME" + TimestampLtzLegacyDataType = "TIMESTAMP_LTZ" + TimestampNtzLegacyDataType = "TIMESTAMP_NTZ" + TimestampTzLegacyDataType = "TIMESTAMP_TZ" + VariantLegacyDataType = "VARIANT" +) diff --git a/pkg/sdk/datatypes/number.go b/pkg/sdk/datatypes/number.go new file mode 100644 index 0000000000..14ac2696fc --- /dev/null +++ b/pkg/sdk/datatypes/number.go @@ -0,0 +1,104 @@ +package datatypes + +import ( + "fmt" + "slices" + "strconv" + "strings" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/logging" +) + +const ( + DefaultNumberPrecision = 38 + DefaultNumberScale = 0 +) + +// NumberDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-numeric#data-types-for-fixed-point-numbers +// It does have synonyms that allow specifying precision and scale; here called synonyms. +// It does have synonyms that does not allow specifying precision and scale; here called subtypes. +type NumberDataType struct { + precision int + scale int + underlyingType string +} + +func (t *NumberDataType) ToSql() string { + return fmt.Sprintf("%s(%d, %d)", t.underlyingType, t.precision, t.scale) +} + +func (t *NumberDataType) ToLegacyDataTypeSql() string { + return NumberLegacyDataType +} + +var ( + NumberDataTypeSynonyms = []string{NumberLegacyDataType, "DECIMAL", "DEC", "NUMERIC"} + NumberDataTypeSubTypes = []string{"INTEGER", "INT", "BIGINT", "SMALLINT", "TINYINT", "BYTEINT"} + AllNumberDataTypes = append(NumberDataTypeSynonyms, NumberDataTypeSubTypes...) +) + +func parseNumberDataTypeRaw(raw sanitizedDataTypeRaw) (*NumberDataType, error) { + switch { + case slices.Contains(NumberDataTypeSynonyms, raw.matchedByType): + return parseNumberDataTypeWithPrecisionAndScale(raw) + case slices.Contains(NumberDataTypeSubTypes, raw.matchedByType): + return parseNumberDataTypeWithoutPrecisionAndScale(raw) + default: + return nil, fmt.Errorf("unknown number data type: %s", raw.raw) + } +} + +// parseNumberDataTypeWithPrecisionAndScale extracts precision and scale from the raw number data type input. +// It returns defaults if no arguments were provided. It returns error if any part is not parseable. +func parseNumberDataTypeWithPrecisionAndScale(raw sanitizedDataTypeRaw) (*NumberDataType, error) { + r := strings.TrimSpace(strings.TrimPrefix(raw.raw, raw.matchedByType)) + if r == "" { + logging.DebugLogger.Printf("[DEBUG] Returning default number precision and scale") + return &NumberDataType{DefaultNumberPrecision, DefaultNumberScale, raw.matchedByType}, nil + } + if !strings.HasPrefix(r, "(") || !strings.HasSuffix(r, ")") { + logging.DebugLogger.Printf(`number %s could not be parsed, use "%s(precision, scale)" format`, raw.raw, raw.matchedByType) + return nil, fmt.Errorf(`number %s could not be parsed, use "%s(precision, scale)" format`, raw.raw, raw.matchedByType) + } + onlyArgs := r[1 : len(r)-1] + parts := strings.Split(onlyArgs, ",") + switch l := len(parts); l { + case 1: + precision, err := strconv.Atoi(strings.TrimSpace(parts[0])) + if err != nil { + logging.DebugLogger.Printf(`[DEBUG] Could not parse number precision "%s", err: %v`, parts[0], err) + return nil, fmt.Errorf(`could not parse the number's precision: "%s", err: %w`, parts[0], err) + } + return &NumberDataType{precision, DefaultNumberScale, raw.matchedByType}, nil + case 2: + precision, err := strconv.Atoi(strings.TrimSpace(parts[0])) + if err != nil { + logging.DebugLogger.Printf(`[DEBUG] Could not parse number precision "%s", err: %v`, parts[0], err) + return nil, fmt.Errorf(`could not parse the number's precision: "%s", err: %w`, parts[0], err) + } + scale, err := strconv.Atoi(strings.TrimSpace(parts[1])) + if err != nil { + logging.DebugLogger.Printf(`[DEBUG] Could not parse number scale "%s", err: %v`, parts[1], err) + return nil, fmt.Errorf(`could not parse the number's scale: "%s", err: %w`, parts[1], err) + } + return &NumberDataType{precision, scale, raw.matchedByType}, nil + default: + logging.DebugLogger.Printf("[DEBUG] Unexpected length of number arguments") + return nil, fmt.Errorf(`number cannot have %d arguments: "%s"; only precision and scale are allowed`, l, onlyArgs) + } +} + +func parseNumberDataTypeWithoutPrecisionAndScale(raw sanitizedDataTypeRaw) (*NumberDataType, error) { + if raw.raw != raw.matchedByType { + args := strings.TrimPrefix(raw.raw, raw.matchedByType) + logging.DebugLogger.Printf("[DEBUG] Number type %s cannot have arguments: %s", raw.matchedByType, args) + return nil, fmt.Errorf("number type %s cannot have arguments: %s", raw.matchedByType, args) + } else { + logging.DebugLogger.Printf("[DEBUG] Returning default number precision and scale") + return &NumberDataType{DefaultNumberPrecision, DefaultNumberScale, raw.matchedByType}, nil + } +} + +func areNumberDataTypesTheSame(a, b *NumberDataType) bool { + return a.precision == b.precision && a.scale == b.scale +} diff --git a/pkg/sdk/datatypes/object.go b/pkg/sdk/datatypes/object.go new file mode 100644 index 0000000000..fe333aa7b0 --- /dev/null +++ b/pkg/sdk/datatypes/object.go @@ -0,0 +1,21 @@ +package datatypes + +// ObjectDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-semistructured#object +// It does not have synonyms. It does not have any attributes. +type ObjectDataType struct { + underlyingType string +} + +func (t *ObjectDataType) ToSql() string { + return t.underlyingType +} + +func (t *ObjectDataType) ToLegacyDataTypeSql() string { + return ObjectLegacyDataType +} + +var ObjectDataTypeSynonyms = []string{ObjectLegacyDataType} + +func parseObjectDataTypeRaw(raw sanitizedDataTypeRaw) (*ObjectDataType, error) { + return &ObjectDataType{raw.matchedByType}, nil +} diff --git a/pkg/sdk/datatypes/text.go b/pkg/sdk/datatypes/text.go new file mode 100644 index 0000000000..2598253101 --- /dev/null +++ b/pkg/sdk/datatypes/text.go @@ -0,0 +1,69 @@ +package datatypes + +import ( + "fmt" + "slices" + "strconv" + "strings" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/logging" +) + +const ( + DefaultVarcharLength = 16777216 + DefaultCharLength = 1 +) + +// TextDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-text#data-types-for-text-strings +// It does have synonyms that allow specifying length. +// It does have synonyms that allow specifying length but differ with the default length when length is omitted; here called subtypes. +type TextDataType struct { + length int + underlyingType string +} + +func (t *TextDataType) ToSql() string { + return fmt.Sprintf("%s(%d)", t.underlyingType, t.length) +} + +func (t *TextDataType) ToLegacyDataTypeSql() string { + return VarcharLegacyDataType +} + +var ( + TextDataTypeSynonyms = []string{VarcharLegacyDataType, "STRING", "TEXT", "NVARCHAR2", "NVARCHAR", "CHAR VARYING", "NCHAR VARYING"} + TextDataTypeSubtypes = []string{"CHARACTER", "CHAR", "NCHAR"} + AllTextDataTypes = append(TextDataTypeSynonyms, TextDataTypeSubtypes...) +) + +// parseTextDataTypeRaw extracts length from the raw text data type input. +// It returns default if it can't parse arguments, data type is different, or no length argument was provided. +func parseTextDataTypeRaw(raw sanitizedDataTypeRaw) (*TextDataType, error) { + r := strings.TrimSpace(strings.TrimPrefix(raw.raw, raw.matchedByType)) + if r == "" { + logging.DebugLogger.Printf("[DEBUG] Returning default length for text") + switch { + case slices.Contains(TextDataTypeSynonyms, raw.matchedByType): + return &TextDataType{DefaultVarcharLength, raw.matchedByType}, nil + case slices.Contains(TextDataTypeSubtypes, raw.matchedByType): + return &TextDataType{DefaultCharLength, raw.matchedByType}, nil + default: + return nil, fmt.Errorf("unknown text data type: %s", raw.raw) + } + } + if !strings.HasPrefix(r, "(") || !strings.HasSuffix(r, ")") { + logging.DebugLogger.Printf(`text %s could not be parsed, use "%s(length)" format`, raw.raw, raw.matchedByType) + return nil, fmt.Errorf(`text %s could not be parsed, use "%s(length)" format`, raw.raw, raw.matchedByType) + } + lengthRaw := r[1 : len(r)-1] + length, err := strconv.Atoi(strings.TrimSpace(lengthRaw)) + if err != nil { + logging.DebugLogger.Printf(`[DEBUG] Could not parse varchar length "%s", err: %v`, lengthRaw, err) + return nil, fmt.Errorf(`could not parse the varchar's length: "%s", err: %w`, lengthRaw, err) + } + return &TextDataType{length, raw.matchedByType}, nil +} + +func areTextDataTypesTheSame(a, b *TextDataType) bool { + return a.length == b.length +} diff --git a/pkg/sdk/datatypes/time.go b/pkg/sdk/datatypes/time.go new file mode 100644 index 0000000000..ee79421122 --- /dev/null +++ b/pkg/sdk/datatypes/time.go @@ -0,0 +1,51 @@ +package datatypes + +import ( + "fmt" + "strconv" + "strings" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/logging" +) + +const DefaultTimePrecision = 9 + +// TimeDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-datetime#time +// It does not have synonyms. It does have optional precision attribute. +type TimeDataType struct { + precision int + underlyingType string +} + +func (t *TimeDataType) ToSql() string { + return fmt.Sprintf("%s(%d)", t.underlyingType, t.precision) +} + +func (t *TimeDataType) ToLegacyDataTypeSql() string { + return TimeLegacyDataType +} + +var TimeDataTypeSynonyms = []string{TimeLegacyDataType} + +func parseTimeDataTypeRaw(raw sanitizedDataTypeRaw) (*TimeDataType, error) { + r := strings.TrimSpace(strings.TrimPrefix(raw.raw, raw.matchedByType)) + if r == "" { + logging.DebugLogger.Printf("[DEBUG] Returning default precision for time") + return &TimeDataType{DefaultTimePrecision, raw.matchedByType}, nil + } + if !strings.HasPrefix(r, "(") || !strings.HasSuffix(r, ")") { + logging.DebugLogger.Printf(`time %s could not be parsed, use "%s(precision)" format`, raw.raw, raw.matchedByType) + return nil, fmt.Errorf(`time %s could not be parsed, use "%s(precision)" format`, raw.raw, raw.matchedByType) + } + precisionRaw := r[1 : len(r)-1] + precision, err := strconv.Atoi(strings.TrimSpace(precisionRaw)) + if err != nil { + logging.DebugLogger.Printf(`[DEBUG] Could not parse time precision "%s", err: %v`, precisionRaw, err) + return nil, fmt.Errorf(`could not parse the time's precision: "%s", err: %w`, precisionRaw, err) + } + return &TimeDataType{precision, raw.matchedByType}, nil +} + +func areTimeDataTypesTheSame(a, b *TimeDataType) bool { + return a.precision == b.precision +} diff --git a/pkg/sdk/datatypes/timestamp.go b/pkg/sdk/datatypes/timestamp.go new file mode 100644 index 0000000000..82b22b74d2 --- /dev/null +++ b/pkg/sdk/datatypes/timestamp.go @@ -0,0 +1,3 @@ +package datatypes + +const DefaultTimestampPrecision = 9 diff --git a/pkg/sdk/datatypes/timestamp_ltz.go b/pkg/sdk/datatypes/timestamp_ltz.go new file mode 100644 index 0000000000..f844ec537f --- /dev/null +++ b/pkg/sdk/datatypes/timestamp_ltz.go @@ -0,0 +1,49 @@ +package datatypes + +import ( + "fmt" + "strconv" + "strings" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/logging" +) + +// TimestampLtzDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-datetime#timestamp-ltz-timestamp-ntz-timestamp-tz +// It does have synonyms. It does have optional precision attribute. +type TimestampLtzDataType struct { + precision int + underlyingType string +} + +func (t *TimestampLtzDataType) ToSql() string { + return fmt.Sprintf("%s(%d)", t.underlyingType, t.precision) +} + +func (t *TimestampLtzDataType) ToLegacyDataTypeSql() string { + return TimestampLtzLegacyDataType +} + +var TimestampLtzDataTypeSynonyms = []string{TimestampLtzLegacyDataType, "TIMESTAMPLTZ", "TIMESTAMP WITH LOCAL TIME ZONE"} + +func parseTimestampLtzDataTypeRaw(raw sanitizedDataTypeRaw) (*TimestampLtzDataType, error) { + r := strings.TrimSpace(strings.TrimPrefix(raw.raw, raw.matchedByType)) + if r == "" { + logging.DebugLogger.Printf("[DEBUG] Returning default precision for timestamp ltz") + return &TimestampLtzDataType{DefaultTimestampPrecision, raw.matchedByType}, nil + } + if !strings.HasPrefix(r, "(") || !strings.HasSuffix(r, ")") { + logging.DebugLogger.Printf(`timestamp ltz %s could not be parsed, use "%s(precision)" format`, raw.raw, raw.matchedByType) + return nil, fmt.Errorf(`timestamp ltz %s could not be parsed, use "%s(precision)" format`, raw.raw, raw.matchedByType) + } + precisionRaw := r[1 : len(r)-1] + precision, err := strconv.Atoi(strings.TrimSpace(precisionRaw)) + if err != nil { + logging.DebugLogger.Printf(`[DEBUG] Could not parse timestamp ltz precision "%s", err: %v`, precisionRaw, err) + return nil, fmt.Errorf(`could not parse the timestamp's precision: "%s", err: %w`, precisionRaw, err) + } + return &TimestampLtzDataType{precision, raw.matchedByType}, nil +} + +func areTimestampLtzDataTypesTheSame(a, b *TimestampLtzDataType) bool { + return a.precision == b.precision +} diff --git a/pkg/sdk/datatypes/timestamp_ntz.go b/pkg/sdk/datatypes/timestamp_ntz.go new file mode 100644 index 0000000000..86aa5f0a0c --- /dev/null +++ b/pkg/sdk/datatypes/timestamp_ntz.go @@ -0,0 +1,49 @@ +package datatypes + +import ( + "fmt" + "strconv" + "strings" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/logging" +) + +// TimestampNtzDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-datetime#timestamp-ltz-timestamp-ntz-timestamp-tz +// It does have synonyms. It does have optional precision attribute. +type TimestampNtzDataType struct { + precision int + underlyingType string +} + +func (t *TimestampNtzDataType) ToSql() string { + return fmt.Sprintf("%s(%d)", t.underlyingType, t.precision) +} + +func (t *TimestampNtzDataType) ToLegacyDataTypeSql() string { + return TimestampNtzLegacyDataType +} + +var TimestampNtzDataTypeSynonyms = []string{TimestampNtzLegacyDataType, "TIMESTAMPNTZ", "TIMESTAMP WITHOUT TIME ZONE", "DATETIME"} + +func parseTimestampNtzDataTypeRaw(raw sanitizedDataTypeRaw) (*TimestampNtzDataType, error) { + r := strings.TrimSpace(strings.TrimPrefix(raw.raw, raw.matchedByType)) + if r == "" { + logging.DebugLogger.Printf("[DEBUG] Returning default precision for timestamp ntz") + return &TimestampNtzDataType{DefaultTimestampPrecision, raw.matchedByType}, nil + } + if !strings.HasPrefix(r, "(") || !strings.HasSuffix(r, ")") { + logging.DebugLogger.Printf(`timestamp ntz %s could not be parsed, use "%s(precision)" format`, raw.raw, raw.matchedByType) + return nil, fmt.Errorf(`timestamp ntz %s could not be parsed, use "%s(precision)" format`, raw.raw, raw.matchedByType) + } + precisionRaw := r[1 : len(r)-1] + precision, err := strconv.Atoi(strings.TrimSpace(precisionRaw)) + if err != nil { + logging.DebugLogger.Printf(`[DEBUG] Could not parse timestamp ntz precision "%s", err: %v`, precisionRaw, err) + return nil, fmt.Errorf(`could not parse the timestamp's precision: "%s", err: %w`, precisionRaw, err) + } + return &TimestampNtzDataType{precision, raw.matchedByType}, nil +} + +func areTimestampNtzDataTypesTheSame(a, b *TimestampNtzDataType) bool { + return a.precision == b.precision +} diff --git a/pkg/sdk/datatypes/timestamp_tz.go b/pkg/sdk/datatypes/timestamp_tz.go new file mode 100644 index 0000000000..44e6cafeb6 --- /dev/null +++ b/pkg/sdk/datatypes/timestamp_tz.go @@ -0,0 +1,49 @@ +package datatypes + +import ( + "fmt" + "strconv" + "strings" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/logging" +) + +// TimestampTzDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-datetime#timestamp-ltz-timestamp-ntz-timestamp-tz +// It does have synonyms. It does have optional precision attribute. +type TimestampTzDataType struct { + precision int + underlyingType string +} + +func (t *TimestampTzDataType) ToSql() string { + return fmt.Sprintf("%s(%d)", t.underlyingType, t.precision) +} + +func (t *TimestampTzDataType) ToLegacyDataTypeSql() string { + return TimestampTzLegacyDataType +} + +var TimestampTzDataTypeSynonyms = []string{TimestampTzLegacyDataType, "TIMESTAMPTZ", "TIMESTAMP WITH TIME ZONE"} + +func parseTimestampTzDataTypeRaw(raw sanitizedDataTypeRaw) (*TimestampTzDataType, error) { + r := strings.TrimSpace(strings.TrimPrefix(raw.raw, raw.matchedByType)) + if r == "" { + logging.DebugLogger.Printf("[DEBUG] Returning default precision for timestamp tz") + return &TimestampTzDataType{DefaultTimestampPrecision, raw.matchedByType}, nil + } + if !strings.HasPrefix(r, "(") || !strings.HasSuffix(r, ")") { + logging.DebugLogger.Printf(`timestamp tz %s could not be parsed, use "%s(precision)" format`, raw.raw, raw.matchedByType) + return nil, fmt.Errorf(`timestamp tz %s could not be parsed, use "%s(precision)" format`, raw.raw, raw.matchedByType) + } + precisionRaw := r[1 : len(r)-1] + precision, err := strconv.Atoi(strings.TrimSpace(precisionRaw)) + if err != nil { + logging.DebugLogger.Printf(`[DEBUG] Could not parse timestamp tz precision "%s", err: %v`, precisionRaw, err) + return nil, fmt.Errorf(`could not parse the timestamp's precision: "%s", err: %w`, precisionRaw, err) + } + return &TimestampTzDataType{precision, raw.matchedByType}, nil +} + +func areTimestampTzDataTypesTheSame(a, b *TimestampTzDataType) bool { + return a.precision == b.precision +} diff --git a/pkg/sdk/datatypes/variant.go b/pkg/sdk/datatypes/variant.go new file mode 100644 index 0000000000..b096084934 --- /dev/null +++ b/pkg/sdk/datatypes/variant.go @@ -0,0 +1,21 @@ +package datatypes + +// VariantDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-semistructured#variant +// It does not have synonyms. It does not have any attributes. +type VariantDataType struct { + underlyingType string +} + +func (t *VariantDataType) ToSql() string { + return t.underlyingType +} + +func (t *VariantDataType) ToLegacyDataTypeSql() string { + return VariantLegacyDataType +} + +var VariantDataTypeSynonyms = []string{VariantLegacyDataType} + +func parseVariantDataTypeRaw(raw sanitizedDataTypeRaw) (*VariantDataType, error) { + return &VariantDataType{raw.matchedByType}, nil +} diff --git a/pkg/sdk/datatypes/vector.go b/pkg/sdk/datatypes/vector.go new file mode 100644 index 0000000000..a535ca2b58 --- /dev/null +++ b/pkg/sdk/datatypes/vector.go @@ -0,0 +1,65 @@ +package datatypes + +import ( + "fmt" + "slices" + "strconv" + "strings" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/logging" +) + +// VectorDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-vector#vector +// It does not have synonyms. It does have type (int or float) and dimension required attributes. +type VectorDataType struct { + innerType string + dimension int + underlyingType string +} + +func (t *VectorDataType) ToSql() string { + return fmt.Sprintf("%s(%s, %d)", t.underlyingType, t.innerType, t.dimension) +} + +// ToLegacyDataTypeSql for vector is the only one correct because in the old implementation it was returned as DataType(dType), so a proper format. +func (t *VectorDataType) ToLegacyDataTypeSql() string { + return t.ToSql() +} + +var ( + VectorDataTypeSynonyms = []string{"VECTOR"} + VectorAllowedInnerTypes = []string{"INT", "FLOAT"} +) + +// parseVectorDataTypeRaw extracts type and dimension from the raw vector data type input. +// Both attributes are required so no defaults are returned in case any of them is missing. +func parseVectorDataTypeRaw(raw sanitizedDataTypeRaw) (*VectorDataType, error) { + r := strings.TrimSpace(strings.TrimPrefix(raw.raw, raw.matchedByType)) + if r == "" || (!strings.HasPrefix(r, "(") || !strings.HasSuffix(r, ")")) { + logging.DebugLogger.Printf(`vector %s could not be parsed, use "%s(type, dimension)" format`, raw.raw, raw.matchedByType) + return nil, fmt.Errorf(`vector %s could not be parsed, use "%s(type, dimension)" format`, raw.raw, raw.matchedByType) + } + onlyArgs := r[1 : len(r)-1] + parts := strings.Split(onlyArgs, ",") + switch l := len(parts); l { + case 2: + vectorType := strings.TrimSpace(parts[0]) + if !slices.Contains(VectorAllowedInnerTypes, vectorType) { + logging.DebugLogger.Printf(`[DEBUG] Inner type for vector could not be recognized: "%s"; use one of %s`, parts[0], strings.Join(VectorAllowedInnerTypes, ",")) + return nil, fmt.Errorf(`could not parse vector's inner type': "%s"; use one of %s`, parts[0], strings.Join(VectorAllowedInnerTypes, ",")) + } + dimension, err := strconv.Atoi(strings.TrimSpace(parts[1])) + if err != nil { + logging.DebugLogger.Printf(`[DEBUG] Could not parse vector's dimension "%s", err: %v`, parts[1], err) + return nil, fmt.Errorf(`could not parse the vector's dimension: "%s", err: %w`, parts[1], err) + } + return &VectorDataType{vectorType, dimension, raw.matchedByType}, nil + default: + logging.DebugLogger.Printf("[DEBUG] Unexpected length of vector arguments") + return nil, fmt.Errorf(`vector cannot have %d arguments: "%s"; use "%s(type, dimension)" format`, l, onlyArgs, raw.matchedByType) + } +} + +func areVectorDataTypesTheSame(a, b *VectorDataType) bool { + return a.innerType == b.innerType && a.dimension == b.dimension +} diff --git a/pkg/sdk/dynamic_table.go b/pkg/sdk/dynamic_table.go index dac13dc576..35457cff6a 100644 --- a/pkg/sdk/dynamic_table.go +++ b/pkg/sdk/dynamic_table.go @@ -6,6 +6,7 @@ import ( "time" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/tracking" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" ) type DynamicTables interface { @@ -232,10 +233,10 @@ type dynamicTableDetailsRow struct { } func (row dynamicTableDetailsRow) convert() *DynamicTableDetails { - typ, _ := ToDataType(row.Type) + typ, _ := datatypes.ParseDataType(row.Type) dtd := &DynamicTableDetails{ Name: row.Name, - Type: typ, + Type: LegacyDataTypeFrom(typ), Kind: row.Kind, IsNull: row.IsNull == "Y", PrimaryKey: row.PrimaryKey, diff --git a/pkg/sdk/grants_test.go b/pkg/sdk/grants_test.go index 82dfc4fff4..9f0823b769 100644 --- a/pkg/sdk/grants_test.go +++ b/pkg/sdk/grants_test.go @@ -1,8 +1,6 @@ package sdk import ( - "errors" - "fmt" "testing" ) @@ -476,12 +474,6 @@ func TestGrants_GrantPrivilegesToDatabaseRole(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("GrantOnSchemaObjectIn", "InDatabase", "InSchema")) }) - t.Run("validation: unsupported database privilege", func(t *testing.T) { - opts := defaultGrantsForDb() - opts.privileges.DatabasePrivileges = []AccountObjectPrivilege{AccountObjectPrivilegeCreateDatabaseRole} - assertOptsInvalidJoinedErrors(t, opts, fmt.Errorf("privilege CREATE DATABASE ROLE is not allowed")) - }) - t.Run("on database", func(t *testing.T) { opts := defaultGrantsForDb() assertOptsValidAndSQLEquals(t, opts, `GRANT CREATE SCHEMA ON DATABASE %s TO DATABASE ROLE %s`, dbId.FullyQualifiedName(), databaseRoleId.FullyQualifiedName()) @@ -673,12 +665,6 @@ func TestGrants_RevokePrivilegesFromDatabaseRoleRole(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("GrantOnSchemaObjectIn", "InDatabase", "InSchema")) }) - t.Run("validation: unsupported database privilege", func(t *testing.T) { - opts := defaultGrantsForDb() - opts.privileges.DatabasePrivileges = []AccountObjectPrivilege{AccountObjectPrivilegeCreateDatabaseRole} - assertOptsInvalidJoinedErrors(t, opts, errors.New("privilege CREATE DATABASE ROLE is not allowed")) - }) - t.Run("on database", func(t *testing.T) { opts := defaultGrantsForDb() assertOptsValidAndSQLEquals(t, opts, `REVOKE CREATE SCHEMA ON DATABASE %s FROM DATABASE ROLE %s`, dbId.FullyQualifiedName(), databaseRoleId.FullyQualifiedName()) diff --git a/pkg/sdk/grants_validations.go b/pkg/sdk/grants_validations.go index 2d1727ee85..16f36413c5 100644 --- a/pkg/sdk/grants_validations.go +++ b/pkg/sdk/grants_validations.go @@ -293,19 +293,6 @@ func (v *DatabaseRoleGrantPrivileges) validate() error { if !exactlyOneValueSet(v.DatabasePrivileges, v.SchemaPrivileges, v.SchemaObjectPrivileges, v.AllPrivileges) { errs = append(errs, errExactlyOneOf("DatabaseRoleGrantPrivileges", "DatabasePrivileges", "SchemaPrivileges", "SchemaObjectPrivileges", "AllPrivileges")) } - if valueSet(v.DatabasePrivileges) { - allowedPrivileges := []AccountObjectPrivilege{ - AccountObjectPrivilegeCreateSchema, - AccountObjectPrivilegeModify, - AccountObjectPrivilegeMonitor, - AccountObjectPrivilegeUsage, - } - for _, p := range v.DatabasePrivileges { - if !slices.Contains(allowedPrivileges, p) { - errs = append(errs, fmt.Errorf("privilege %s is not allowed", p.String())) - } - } - } return errors.Join(errs...) } diff --git a/pkg/sdk/identifier_helpers.go b/pkg/sdk/identifier_helpers.go index 95ea8e894f..90d1acdf44 100644 --- a/pkg/sdk/identifier_helpers.go +++ b/pkg/sdk/identifier_helpers.go @@ -4,6 +4,8 @@ import ( "fmt" "log" "strings" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" ) type Identifier interface { @@ -247,8 +249,8 @@ func NewSchemaObjectIdentifierFromFullyQualifiedName(fullyQualifiedName string) if trimmedArg == "" { continue } - dt, _ := ToDataType(trimmedArg) - id.arguments = append(id.arguments, dt) + dt, _ := datatypes.ParseDataType(trimmedArg) + id.arguments = append(id.arguments, LegacyDataTypeFrom(dt)) } } else { // this is every other kind of schema object id.name = strings.Trim(parts[2], `"`) @@ -318,11 +320,11 @@ func NewSchemaObjectIdentifierWithArguments(databaseName, schemaName, name strin // Arguments have to be "normalized" with ToDataType, so the signature would match with the one returned by Snowflake. normalizedArguments := make([]DataType, len(argumentDataTypes)) for i, argument := range argumentDataTypes { - normalizedArgument, err := ToDataType(string(argument)) + normalizedArgument, err := datatypes.ParseDataType(string(argument)) if err != nil { log.Printf("[DEBUG] failed to normalize argument %d: %v, err = %v", i, argument, err) } - normalizedArguments[i] = normalizedArgument + normalizedArguments[i] = LegacyDataTypeFrom(normalizedArgument) } return SchemaObjectIdentifierWithArguments{ databaseName: strings.Trim(databaseName, `"`), diff --git a/pkg/sdk/masking_policy.go b/pkg/sdk/masking_policy.go index b6c87a2d0a..f92b2e89b6 100644 --- a/pkg/sdk/masking_policy.go +++ b/pkg/sdk/masking_policy.go @@ -9,6 +9,7 @@ import ( "time" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/collections" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" ) var _ MaskingPolicies = (*maskingPolicies)(nil) @@ -360,14 +361,14 @@ type maskingPolicyDetailsRow struct { } func (row maskingPolicyDetailsRow) toMaskingPolicyDetails() *MaskingPolicyDetails { - dataType, err := ToDataType(row.ReturnType) + dataType, err := datatypes.ParseDataType(row.ReturnType) if err != nil { return nil } v := &MaskingPolicyDetails{ Name: row.Name, Signature: []TableColumnSignature{}, - ReturnType: dataType, + ReturnType: LegacyDataTypeFrom(dataType), Body: row.Body, } diff --git a/pkg/sdk/object_types.go b/pkg/sdk/object_types.go index 6c65d694fe..d2cecefe37 100644 --- a/pkg/sdk/object_types.go +++ b/pkg/sdk/object_types.go @@ -1,6 +1,7 @@ package sdk import ( + "fmt" "slices" "strings" ) @@ -89,6 +90,82 @@ func (o ObjectType) IsWithArguments() bool { return slices.Contains([]ObjectType{ObjectTypeExternalFunction, ObjectTypeFunction, ObjectTypeProcedure}, o) } +var allObjectTypes = []ObjectType{ + ObjectTypeAccount, + ObjectTypeManagedAccount, + ObjectTypeUser, + ObjectTypeDatabaseRole, + ObjectTypeDataset, + ObjectTypeRole, + ObjectTypeIntegration, + ObjectTypeNetworkPolicy, + ObjectTypePasswordPolicy, + ObjectTypeSessionPolicy, + ObjectTypePrivacyPolicy, + ObjectTypeReplicationGroup, + ObjectTypeFailoverGroup, + ObjectTypeConnection, + ObjectTypeParameter, + ObjectTypeWarehouse, + ObjectTypeResourceMonitor, + ObjectTypeDatabase, + ObjectTypeSchema, + ObjectTypeShare, + ObjectTypeTable, + ObjectTypeDynamicTable, + ObjectTypeCortexSearchService, + ObjectTypeExternalTable, + ObjectTypeEventTable, + ObjectTypeView, + ObjectTypeMaterializedView, + ObjectTypeSequence, + ObjectTypeSnapshot, + ObjectTypeFunction, + ObjectTypeExternalFunction, + ObjectTypeProcedure, + ObjectTypeStream, + ObjectTypeTask, + ObjectTypeMaskingPolicy, + ObjectTypeRowAccessPolicy, + ObjectTypeTag, + ObjectTypeSecret, + ObjectTypeStage, + ObjectTypeFileFormat, + ObjectTypePipe, + ObjectTypeAlert, + ObjectTypeBudget, + ObjectTypeClassification, + ObjectTypeApplication, + ObjectTypeApplicationPackage, + ObjectTypeApplicationRole, + ObjectTypeStreamlit, + ObjectTypeColumn, + ObjectTypeIcebergTable, + ObjectTypeExternalVolume, + ObjectTypeNetworkRule, + ObjectTypeNotebook, + ObjectTypePackagesPolicy, + ObjectTypeComputePool, + ObjectTypeAggregationPolicy, + ObjectTypeAuthenticationPolicy, + ObjectTypeHybridTable, + ObjectTypeImageRepository, + ObjectTypeProjectionPolicy, + ObjectTypeDataMetricFunction, + ObjectTypeGitRepository, + ObjectTypeModel, + ObjectTypeService, +} + +// TODO(SNOW-1834370): use ToObjectType in other places with type conversion (instead of sdk.ObjectType) +func ToObjectType(s string) (ObjectType, error) { + s = strings.ToUpper(s) + if !slices.Contains(allObjectTypes, ObjectType(s)) { + return "", fmt.Errorf("invalid object type: %s", s) + } + return ObjectType(s), nil +} + func objectTypeSingularToPluralMap() map[ObjectType]PluralObjectType { return map[ObjectType]PluralObjectType{ ObjectTypeAccount: PluralObjectTypeAccounts, diff --git a/pkg/sdk/object_types_test.go b/pkg/sdk/object_types_test.go new file mode 100644 index 0000000000..3a9d34871f --- /dev/null +++ b/pkg/sdk/object_types_test.go @@ -0,0 +1,105 @@ +package sdk + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_ToObjectType(t *testing.T) { + type test struct { + input string + want ObjectType + } + + valid := []test{ + // Case insensitive. + {input: "schema", want: ObjectTypeSchema}, + + // Supported Values. + {input: "ACCOUNT", want: ObjectTypeAccount}, + {input: "MANAGED ACCOUNT", want: ObjectTypeManagedAccount}, + {input: "USER", want: ObjectTypeUser}, + {input: "DATABASE ROLE", want: ObjectTypeDatabaseRole}, + {input: "DATASET", want: ObjectTypeDataset}, + {input: "ROLE", want: ObjectTypeRole}, + {input: "INTEGRATION", want: ObjectTypeIntegration}, + {input: "NETWORK POLICY", want: ObjectTypeNetworkPolicy}, + {input: "PASSWORD POLICY", want: ObjectTypePasswordPolicy}, + {input: "SESSION POLICY", want: ObjectTypeSessionPolicy}, + {input: "PRIVACY POLICY", want: ObjectTypePrivacyPolicy}, + {input: "REPLICATION GROUP", want: ObjectTypeReplicationGroup}, + {input: "FAILOVER GROUP", want: ObjectTypeFailoverGroup}, + {input: "CONNECTION", want: ObjectTypeConnection}, + {input: "PARAMETER", want: ObjectTypeParameter}, + {input: "WAREHOUSE", want: ObjectTypeWarehouse}, + {input: "RESOURCE MONITOR", want: ObjectTypeResourceMonitor}, + {input: "DATABASE", want: ObjectTypeDatabase}, + {input: "SCHEMA", want: ObjectTypeSchema}, + {input: "SHARE", want: ObjectTypeShare}, + {input: "TABLE", want: ObjectTypeTable}, + {input: "DYNAMIC TABLE", want: ObjectTypeDynamicTable}, + {input: "CORTEX SEARCH SERVICE", want: ObjectTypeCortexSearchService}, + {input: "EXTERNAL TABLE", want: ObjectTypeExternalTable}, + {input: "EVENT TABLE", want: ObjectTypeEventTable}, + {input: "VIEW", want: ObjectTypeView}, + {input: "MATERIALIZED VIEW", want: ObjectTypeMaterializedView}, + {input: "SEQUENCE", want: ObjectTypeSequence}, + {input: "SNAPSHOT", want: ObjectTypeSnapshot}, + {input: "FUNCTION", want: ObjectTypeFunction}, + {input: "EXTERNAL FUNCTION", want: ObjectTypeExternalFunction}, + {input: "PROCEDURE", want: ObjectTypeProcedure}, + {input: "STREAM", want: ObjectTypeStream}, + {input: "TASK", want: ObjectTypeTask}, + {input: "MASKING POLICY", want: ObjectTypeMaskingPolicy}, + {input: "ROW ACCESS POLICY", want: ObjectTypeRowAccessPolicy}, + {input: "TAG", want: ObjectTypeTag}, + {input: "SECRET", want: ObjectTypeSecret}, + {input: "STAGE", want: ObjectTypeStage}, + {input: "FILE FORMAT", want: ObjectTypeFileFormat}, + {input: "PIPE", want: ObjectTypePipe}, + {input: "ALERT", want: ObjectTypeAlert}, + {input: "SNOWFLAKE.CORE.BUDGET", want: ObjectTypeBudget}, + {input: "SNOWFLAKE.ML.CLASSIFICATION", want: ObjectTypeClassification}, + {input: "APPLICATION", want: ObjectTypeApplication}, + {input: "APPLICATION PACKAGE", want: ObjectTypeApplicationPackage}, + {input: "APPLICATION ROLE", want: ObjectTypeApplicationRole}, + {input: "STREAMLIT", want: ObjectTypeStreamlit}, + {input: "COLUMN", want: ObjectTypeColumn}, + {input: "ICEBERG TABLE", want: ObjectTypeIcebergTable}, + {input: "EXTERNAL VOLUME", want: ObjectTypeExternalVolume}, + {input: "NETWORK RULE", want: ObjectTypeNetworkRule}, + {input: "NOTEBOOK", want: ObjectTypeNotebook}, + {input: "PACKAGES POLICY", want: ObjectTypePackagesPolicy}, + {input: "COMPUTE POOL", want: ObjectTypeComputePool}, + {input: "AGGREGATION POLICY", want: ObjectTypeAggregationPolicy}, + {input: "AUTHENTICATION POLICY", want: ObjectTypeAuthenticationPolicy}, + {input: "HYBRID TABLE", want: ObjectTypeHybridTable}, + {input: "IMAGE REPOSITORY", want: ObjectTypeImageRepository}, + {input: "PROJECTION POLICY", want: ObjectTypeProjectionPolicy}, + {input: "DATA METRIC FUNCTION", want: ObjectTypeDataMetricFunction}, + {input: "GIT REPOSITORY", want: ObjectTypeGitRepository}, + {input: "MODEL", want: ObjectTypeModel}, + {input: "SERVICE", want: ObjectTypeService}, + } + + invalid := []test{ + {input: ""}, + {input: "foo"}, + } + + for _, tc := range valid { + t.Run(tc.input, func(t *testing.T) { + got, err := ToObjectType(tc.input) + require.NoError(t, err) + require.Equal(t, tc.want, got) + }) + } + + for _, tc := range invalid { + t.Run(tc.input, func(t *testing.T) { + _, err := ToObjectType(tc.input) + require.Error(t, err) + }) + } +} diff --git a/pkg/sdk/parameters.go b/pkg/sdk/parameters.go index f341ed503a..cf29fa1da4 100644 --- a/pkg/sdk/parameters.go +++ b/pkg/sdk/parameters.go @@ -656,7 +656,6 @@ var AllUserParameters = []UserParameter{ type TaskParameter string -// TODO(SNOW-1348116 - next prs): Handle task parameters const ( // Task Parameters TaskParameterSuspendTaskAfterNumFailures TaskParameter = "SUSPEND_TASK_AFTER_NUM_FAILURES" diff --git a/pkg/sdk/shares.go b/pkg/sdk/shares.go index aa92fb0790..9a7656fa06 100644 --- a/pkg/sdk/shares.go +++ b/pkg/sdk/shares.go @@ -341,6 +341,7 @@ type shareDetailsRow struct { func (row *shareDetailsRow) toShareInfo() *ShareInfo { objectType := ObjectType(row.Kind) trimmedS := strings.Trim(row.Name, "\"") + // TODO(SNOW-1229218): Use a common mapper to get object id. id := objectType.GetObjectIdentifier(trimmedS) return &ShareInfo{ Kind: objectType, diff --git a/pkg/sdk/system_functions.go b/pkg/sdk/system_functions.go index 6777a013a4..9ea79c1671 100644 --- a/pkg/sdk/system_functions.go +++ b/pkg/sdk/system_functions.go @@ -2,6 +2,7 @@ package sdk import ( "context" + "database/sql" "encoding/json" "fmt" "slices" @@ -11,7 +12,7 @@ import ( ) type SystemFunctions interface { - GetTag(ctx context.Context, tagID ObjectIdentifier, objectID ObjectIdentifier, objectType ObjectType) (string, error) + GetTag(ctx context.Context, tagID ObjectIdentifier, objectID ObjectIdentifier, objectType ObjectType) (*string, error) PipeStatus(pipeId SchemaObjectIdentifier) (PipeExecutionState, error) // PipeForceResume unpauses a pipe after ownership transfer. Snowflake will throw an error whenever a pipe changes its owner, // and someone tries to unpause it. To unpause a pipe after ownership transfer, this system function has to be called instead of ALTER PIPE. @@ -26,21 +27,24 @@ type systemFunctions struct { client *Client } -func (c *systemFunctions) GetTag(ctx context.Context, tagID ObjectIdentifier, objectID ObjectIdentifier, objectType ObjectType) (string, error) { +func (c *systemFunctions) GetTag(ctx context.Context, tagID ObjectIdentifier, objectID ObjectIdentifier, objectType ObjectType) (*string, error) { objectType, err := normalizeGetTagObjectType(objectType) if err != nil { - return "", err + return nil, err } s := &struct { - Tag string `db:"TAG"` + Tag sql.NullString `db:"TAG"` }{} sql := fmt.Sprintf(`SELECT SYSTEM$GET_TAG('%s', '%s', '%v') AS "TAG"`, tagID.FullyQualifiedName(), objectID.FullyQualifiedName(), objectType) err = c.client.queryOne(ctx, s, sql) if err != nil { - return "", err + return nil, err + } + if !s.Tag.Valid { + return nil, nil } - return s.Tag, nil + return &s.Tag.String, nil } // normalize object types for some values because of errors like below diff --git a/pkg/sdk/tables_test.go b/pkg/sdk/tables_test.go index bcd807a75f..9c2c32ed7e 100644 --- a/pkg/sdk/tables_test.go +++ b/pkg/sdk/tables_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/helpers/random" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -379,7 +380,9 @@ func TestTableCreate(t *testing.T) { tableComment := random.Comment() collation := "de" columnName := "FIRST_COLUMN" - columnType, err := ToDataType("VARCHAR") + columnTypeRaw, err := datatypes.ParseDataType("VARCHAR") + require.NoError(t, err) + columnType := LegacyDataTypeFrom(columnTypeRaw) maskingPolicy := ColumnMaskingPolicy{ Name: randomSchemaObjectIdentifier(), Using: []string{"FOO", "BAR"}, @@ -551,8 +554,9 @@ func TestTableCreateAsSelect(t *testing.T) { t.Run("with complete options", func(t *testing.T) { id := randomSchemaObjectIdentifier() columnName := "FIRST_COLUMN" - columnType, err := ToDataType("VARCHAR") + columnTypeRaw, err := datatypes.ParseDataType("VARCHAR") require.NoError(t, err) + columnType := LegacyDataTypeFrom(columnTypeRaw) maskingPolicy := TableAsSelectColumnMaskingPolicy{ Name: randomSchemaObjectIdentifier(), } diff --git a/pkg/sdk/tags.go b/pkg/sdk/tags.go index 0fb31322ce..a0f2b8f6c2 100644 --- a/pkg/sdk/tags.go +++ b/pkg/sdk/tags.go @@ -30,6 +30,7 @@ type setTagOptions struct { type unsetTagOptions struct { alter bool `ddl:"static" sql:"ALTER"` objectType ObjectType `ddl:"keyword"` + IfExists *bool `ddl:"keyword" sql:"IF EXISTS"` objectName ObjectIdentifier `ddl:"identifier"` column *string `ddl:"parameter,no_equals,double_quotes" sql:"MODIFY COLUMN"` UnsetTags []ObjectIdentifier `ddl:"keyword" sql:"UNSET TAG"` diff --git a/pkg/sdk/tags_dto.go b/pkg/sdk/tags_dto.go index 10580f9ceb..95eb54fdc2 100644 --- a/pkg/sdk/tags_dto.go +++ b/pkg/sdk/tags_dto.go @@ -22,6 +22,7 @@ type UnsetTagRequest struct { objectType ObjectType // required objectName ObjectIdentifier // required + IfExists *bool UnsetTags []ObjectIdentifier } diff --git a/pkg/sdk/tags_dto_builders.go b/pkg/sdk/tags_dto_builders.go index 505d0ca854..20610294d7 100644 --- a/pkg/sdk/tags_dto_builders.go +++ b/pkg/sdk/tags_dto_builders.go @@ -24,6 +24,11 @@ func (s *UnsetTagRequest) WithUnsetTags(tags []ObjectIdentifier) *UnsetTagReques return s } +func (s *UnsetTagRequest) WithIfExists(ifExists bool) *UnsetTagRequest { + s.IfExists = &ifExists + return s +} + func NewSetTagOnCurrentAccountRequest() *SetTagOnCurrentAccountRequest { return &SetTagOnCurrentAccountRequest{} } diff --git a/pkg/sdk/tags_impl.go b/pkg/sdk/tags_impl.go index e80c934ff1..882189deb1 100644 --- a/pkg/sdk/tags_impl.go +++ b/pkg/sdk/tags_impl.go @@ -163,6 +163,7 @@ func (s *SetTagRequest) toOpts() *setTagOptions { func (s *UnsetTagRequest) toOpts() *unsetTagOptions { o := &unsetTagOptions{ objectType: s.objectType, + IfExists: s.IfExists, objectName: s.objectName, UnsetTags: s.UnsetTags, } diff --git a/pkg/sdk/tags_test.go b/pkg/sdk/tags_test.go index 94ed4c641c..0a9046ab78 100644 --- a/pkg/sdk/tags_test.go +++ b/pkg/sdk/tags_test.go @@ -388,6 +388,21 @@ func TestTagSet(t *testing.T) { assertOptsValidAndSQLEquals(t, opts, `ALTER %s %s SET TAG "tag1" = 'value1'`, opts.objectType, id.FullyQualifiedName()) }) + t.Run("set on account", func(t *testing.T) { + accountId := randomAccountIdentifier() + opts := &setTagOptions{ + objectType: ObjectTypeStage, + objectName: accountId, + SetTags: []TagAssociation{ + { + Name: NewAccountObjectIdentifier("tag1"), + Value: "value1", + }, + }, + } + assertOptsValidAndSQLEquals(t, opts, `ALTER %s %s SET TAG "tag1" = 'value1'`, opts.objectType, accountId.FullyQualifiedName()) + }) + t.Run("set with column", func(t *testing.T) { objectId := randomTableColumnIdentifierInSchemaObject(id) tagId := randomSchemaObjectIdentifier() @@ -434,7 +449,21 @@ func TestTagUnset(t *testing.T) { NewAccountObjectIdentifier("tag1"), NewAccountObjectIdentifier("tag2"), } - assertOptsValidAndSQLEquals(t, opts, `ALTER %s %s UNSET TAG "tag1", "tag2"`, opts.objectType, id.FullyQualifiedName()) + opts.IfExists = Pointer(true) + assertOptsValidAndSQLEquals(t, opts, `ALTER %s IF EXISTS %s UNSET TAG "tag1", "tag2"`, opts.objectType, id.FullyQualifiedName()) + }) + + t.Run("unset on account", func(t *testing.T) { + accountId := randomAccountIdentifier() + opts := &unsetTagOptions{ + objectType: ObjectTypeStage, + objectName: accountId, + UnsetTags: []ObjectIdentifier{ + NewAccountObjectIdentifier("tag1"), + NewAccountObjectIdentifier("tag2"), + }, + } + assertOptsValidAndSQLEquals(t, opts, `ALTER %s %s UNSET TAG "tag1", "tag2"`, opts.objectType, accountId.FullyQualifiedName()) }) t.Run("unset with column", func(t *testing.T) { @@ -448,8 +477,9 @@ func TestTagUnset(t *testing.T) { tagId1, tagId2, }, + IfExists: Pointer(true), } opts := request.toOpts() - assertOptsValidAndSQLEquals(t, opts, `ALTER %s %s MODIFY COLUMN "%s" UNSET TAG %s, %s`, opts.objectType, id.FullyQualifiedName(), objectId.Name(), tagId1.FullyQualifiedName(), tagId2.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `ALTER %s IF EXISTS %s MODIFY COLUMN "%s" UNSET TAG %s, %s`, opts.objectType, id.FullyQualifiedName(), objectId.Name(), tagId1.FullyQualifiedName(), tagId2.FullyQualifiedName()) }) } diff --git a/pkg/sdk/tasks_impl_gen.go b/pkg/sdk/tasks_impl_gen.go index 3d4b530194..5f7bd10a82 100644 --- a/pkg/sdk/tasks_impl_gen.go +++ b/pkg/sdk/tasks_impl_gen.go @@ -401,7 +401,6 @@ func (r taskDBRow) convert() *Task { return &task } -// TODO(SNOW-1348116 - next prs): Remove and use Task.TaskRelations instead func getPredecessors(predecessors string) ([]string, error) { // Since 2022_03, Snowflake returns this as a JSON array (even empty) // The list is formatted, e.g.: diff --git a/pkg/sdk/testint/data_types_integration_test.go b/pkg/sdk/testint/data_types_integration_test.go new file mode 100644 index 0000000000..e62a59ef5f --- /dev/null +++ b/pkg/sdk/testint/data_types_integration_test.go @@ -0,0 +1,349 @@ +package testint + +import ( + "fmt" + "slices" + "testing" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInt_DataTypes(t *testing.T) { + client := testClient(t) + ctx := testContext(t) + + incorrectBooleanDatatypes := []string{ + "BOOLEAN()", + "BOOLEAN(1)", + "BOOL", + } + incorrectFloatDatatypes := []string{ + "DOUBLE()", + "DOUBLE(1)", + "DOUBLE PRECISION(1)", + } + incorrectlyCorrectFloatDatatypes := []string{ + "FLOAT()", + "FLOAT(20)", + "FLOAT4(20)", + "FLOAT8(20)", + "REAL(20)", + } + incorrectNumberDatatypes := []string{ + "NUMBER()", + "NUMBER(x)", + "INT()", + "NUMBER(36, 5, 7)", + } + incorrectTextDatatypes := []string{ + "VARCHAR()", + "VARCHAR(x)", + "VARCHAR(36, 5)", + } + vectorInnerTypesSynonyms := helpers.ConcatSlices(datatypes.AllNumberDataTypes, datatypes.FloatDataTypeSynonyms) + vectorInnerTypeSynonymsThatWork := []string{ + "INTEGER", + "INT", + "FLOAT8", + "FLOAT4", + "FLOAT", + } + + for _, c := range datatypes.ArrayDataTypeSynonyms { + t.Run(fmt.Sprintf("check behavior of array datatype: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT []::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + + sql = fmt.Sprintf("SELECT []::%s(36)", c) + _, err = client.QueryUnsafe(ctx, sql) + assert.ErrorContains(t, err, "SQL compilation error") + assert.ErrorContains(t, err, "unexpected '36'") + }) + } + + for _, c := range datatypes.BinaryDataTypeSynonyms { + t.Run(fmt.Sprintf("check behavior of binary datatype: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT TO_BINARY('AB')::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + + sql = fmt.Sprintf("SELECT TO_BINARY('AB')::%s(36)", c) + _, err = client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + + sql = fmt.Sprintf("SELECT TO_BINARY('AB')::%s(36, 2)", c) + _, err = client.QueryUnsafe(ctx, sql) + assert.ErrorContains(t, err, "SQL compilation error") + assert.ErrorContains(t, err, "','") + assert.ErrorContains(t, err, "')'") + }) + } + + for _, c := range datatypes.BooleanDataTypeSynonyms { + t.Run(fmt.Sprintf("check behavior of boolean datatype: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT TRUE::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + }) + } + + for _, c := range incorrectBooleanDatatypes { + t.Run(fmt.Sprintf("check behavior of boolean datatype: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT TRUE::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + require.Error(t, err) + }) + } + + for _, c := range datatypes.DateDataTypeSynonyms { + t.Run(fmt.Sprintf("check behavior of date datatype: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT '2024-12-02'::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + }) + } + + for _, c := range datatypes.FloatDataTypeSynonyms { + t.Run(fmt.Sprintf("check behavior of float datatype: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT 1.1::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + }) + } + + for _, c := range incorrectFloatDatatypes { + t.Run(fmt.Sprintf("check behavior of float datatype: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT 1.1::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + require.Error(t, err) + }) + } + + // There is no attribute documented for float numbers: https://docs.snowflake.com/en/sql-reference/data-types-numeric#float-float4-float8. + // However, adding it succeeds for FLOAT, FLOAT4, FLOAT8, and REAL, but ift fails both for DOUBLE and DOUBLE PRECISION. + for _, c := range incorrectlyCorrectFloatDatatypes { + t.Run(fmt.Sprintf("document incorrect behavior of float datatype: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT 1.1::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + require.NoError(t, err) + }) + } + + // Testing on table creation here because casting (::GEOGRAPHY) was ending with errors (even for the "correct" cases). + for _, c := range datatypes.GeographyDataTypeSynonyms { + t.Run(fmt.Sprintf("check behavior of geography datatype: %s", c), func(t *testing.T) { + tableId := testClientHelper().Ids.RandomSchemaObjectIdentifier() + sql := fmt.Sprintf("CREATE TABLE %s (i %s)", tableId.FullyQualifiedName(), c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + t.Cleanup(testClientHelper().Table.DropFunc(t, tableId)) + + tableId = testClientHelper().Ids.RandomSchemaObjectIdentifier() + sql = fmt.Sprintf("CREATE TABLE %s (i %s())", tableId.FullyQualifiedName(), c) + _, err = client.QueryUnsafe(ctx, sql) + assert.ErrorContains(t, err, "SQL compilation error") + assert.ErrorContains(t, err, "unexpected '('") + t.Cleanup(testClientHelper().Table.DropFunc(t, tableId)) + }) + } + + // Testing on table creation here because casting (::GEOMETRY) was ending with errors (even for the "correct" cases). + for _, c := range datatypes.GeometryDataTypeSynonyms { + t.Run(fmt.Sprintf("check behavior of geometry datatype: %s", c), func(t *testing.T) { + tableId := testClientHelper().Ids.RandomSchemaObjectIdentifier() + sql := fmt.Sprintf("CREATE TABLE %s (i %s)", tableId.FullyQualifiedName(), c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + t.Cleanup(testClientHelper().Table.DropFunc(t, tableId)) + + tableId = testClientHelper().Ids.RandomSchemaObjectIdentifier() + sql = fmt.Sprintf("CREATE TABLE %s (i %s())", tableId.FullyQualifiedName(), c) + _, err = client.QueryUnsafe(ctx, sql) + assert.ErrorContains(t, err, "SQL compilation error") + assert.ErrorContains(t, err, "unexpected '('") + t.Cleanup(testClientHelper().Table.DropFunc(t, tableId)) + }) + } + + for _, c := range datatypes.NumberDataTypeSynonyms { + t.Run(fmt.Sprintf("check behavior of number datatype: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT 1::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + + sql = fmt.Sprintf("SELECT 1::%s(36)", c) + _, err = client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + + sql = fmt.Sprintf("SELECT 1::%s(36, 5)", c) + _, err = client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + }) + } + + for _, c := range datatypes.NumberDataTypeSubTypes { + t.Run(fmt.Sprintf("check behavior of number data type subtype: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT 1::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + + sql = fmt.Sprintf("SELECT 1::%s(36)", c) + _, err = client.QueryUnsafe(ctx, sql) + assert.ErrorContains(t, err, "SQL compilation error") + assert.ErrorContains(t, err, "unexpected '36'") + }) + } + + for _, c := range incorrectNumberDatatypes { + t.Run(fmt.Sprintf("check behavior of number datatype: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT 1::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + require.Error(t, err) + }) + } + + for _, c := range datatypes.ObjectDataTypeSynonyms { + t.Run(fmt.Sprintf("check behavior of object data type: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT {}::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + + sql = fmt.Sprintf("SELECT {}::%s(36)", c) + _, err = client.QueryUnsafe(ctx, sql) + assert.ErrorContains(t, err, "SQL compilation error") + assert.ErrorContains(t, err, "unexpected '36'") + }) + } + + for _, c := range datatypes.AllTextDataTypes { + t.Run(fmt.Sprintf("check behavior of text data type: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT 'A'::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + + sql = fmt.Sprintf("SELECT 'ABC'::%s(36)", c) + _, err = client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + }) + } + + for _, c := range incorrectTextDatatypes { + t.Run(fmt.Sprintf("check behavior of text datatype: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT ABC::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + require.Error(t, err) + }) + } + + for _, c := range datatypes.TimeDataTypeSynonyms { + t.Run(fmt.Sprintf("check behavior of time data type: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT '00:00:00'::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + + sql = fmt.Sprintf("SELECT '00:00:00'::%s(5)", c) + _, err = client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + }) + } + + for _, c := range datatypes.TimestampLtzDataTypeSynonyms { + t.Run(fmt.Sprintf("check behavior of timestamp ltz data types: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT '2024-12-02 00:00:00 +0000'::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + + sql = fmt.Sprintf("SELECT '2024-12-02 00:00:00 +0000'::%s(3)", c) + _, err = client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + }) + } + + for _, c := range datatypes.TimestampNtzDataTypeSynonyms { + t.Run(fmt.Sprintf("check behavior of timestamp ntz data types: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT '2024-12-02 00:00:00 +0000'::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + + sql = fmt.Sprintf("SELECT '2024-12-02 00:00:00 +0000'::%s(3)", c) + _, err = client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + }) + } + + for _, c := range datatypes.TimestampTzDataTypeSynonyms { + t.Run(fmt.Sprintf("check behavior of timestamp tz data types: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT '2024-12-02 00:00:00 +0000'::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + + sql = fmt.Sprintf("SELECT '2024-12-02 00:00:00 +0000'::%s(3)", c) + _, err = client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + }) + } + + for _, c := range datatypes.VariantDataTypeSynonyms { + t.Run(fmt.Sprintf("check behavior of variant data type: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT TO_VARIANT(1)::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + + sql = fmt.Sprintf("SELECT TO_VARIANT(1)::%s(36)", c) + _, err = client.QueryUnsafe(ctx, sql) + assert.ErrorContains(t, err, "SQL compilation error") + assert.ErrorContains(t, err, "unexpected '36'") + }) + } + + // Testing on table creation here because apparently VECTOR is not supported as query in the gosnowflake driver. + // It ends with "unsupported data type" from https://github.com/snowflakedb/gosnowflake/blob/171ddf2540f3a24f2a990e8453dc425ea864a4a0/converter.go#L1599. + for _, c := range datatypes.VectorDataTypeSynonyms { + for _, inner := range datatypes.VectorAllowedInnerTypes { + t.Run(fmt.Sprintf("check behavior of vector data type: %s, %s", c, inner), func(t *testing.T) { + tableId := testClientHelper().Ids.RandomSchemaObjectIdentifier() + sql := fmt.Sprintf("CREATE TABLE %s (i %s(%s, 2))", tableId.FullyQualifiedName(), c, inner) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + t.Cleanup(testClientHelper().Table.DropFunc(t, tableId)) + + tableId = testClientHelper().Ids.RandomSchemaObjectIdentifier() + sql = fmt.Sprintf("CREATE TABLE %s (i %s(%s))", tableId.FullyQualifiedName(), c, inner) + _, err = client.QueryUnsafe(ctx, sql) + assert.ErrorContains(t, err, "SQL compilation error") + assert.ErrorContains(t, err, "unexpected ')'") + t.Cleanup(testClientHelper().Table.DropFunc(t, tableId)) + }) + } + } + + // Testing on table creation here because apparently VECTOR is not supported as query in the gosnowflake driver. + // It ends with "unsupported data type" from https://github.com/snowflakedb/gosnowflake/blob/171ddf2540f3a24f2a990e8453dc425ea864a4a0/converter.go#L1599. + for _, c := range vectorInnerTypesSynonyms { + t.Run(fmt.Sprintf("document behavior of vector data type synonyms: %s", c), func(t *testing.T) { + tableId := testClientHelper().Ids.RandomSchemaObjectIdentifier() + sql := fmt.Sprintf("CREATE TABLE %s (i VECTOR(%s, 3))", tableId.FullyQualifiedName(), c) + _, err := client.QueryUnsafe(ctx, sql) + if slices.Contains(vectorInnerTypeSynonymsThatWork, c) { + assert.NoError(t, err) + } else { + assert.ErrorContains(t, err, "SQL compilation error") + switch { + case slices.Contains(datatypes.NumberDataTypeSynonyms, c): + assert.ErrorContains(t, err, fmt.Sprintf("unexpected '%s'", c)) + case slices.Contains(datatypes.NumberDataTypeSubTypes, c): + assert.ErrorContains(t, err, "Unsupported vector element type 'NUMBER(38,0)'") + case slices.Contains(datatypes.FloatDataTypeSynonyms, c): + assert.ErrorContains(t, err, "Unsupported vector element type 'FLOAT'") + default: + t.Fail() + } + } + t.Cleanup(testClientHelper().Table.DropFunc(t, tableId)) + }) + } +} diff --git a/pkg/sdk/testint/databases_integration_test.go b/pkg/sdk/testint/databases_integration_test.go index 8dc2408eee..0c0b513e69 100644 --- a/pkg/sdk/testint/databases_integration_test.go +++ b/pkg/sdk/testint/databases_integration_test.go @@ -138,11 +138,11 @@ func TestInt_DatabasesCreate(t *testing.T) { tag1Value, err := client.SystemFunctions.GetTag(ctx, tagTest.ID(), database.ID(), sdk.ObjectTypeDatabase) require.NoError(t, err) - assert.Equal(t, "v1", tag1Value) + assert.Equal(t, sdk.Pointer("v1"), tag1Value) tag2Value, err := client.SystemFunctions.GetTag(ctx, tag2Test.ID(), database.ID(), sdk.ObjectTypeDatabase) require.NoError(t, err) - assert.Equal(t, "v2", tag2Value) + assert.Equal(t, sdk.Pointer("v2"), tag2Value) }) } @@ -249,7 +249,7 @@ func TestInt_DatabasesCreateShared(t *testing.T) { tag1Value, err := client.SystemFunctions.GetTag(ctx, testTag.ID(), database.ID(), sdk.ObjectTypeDatabase) require.NoError(t, err) - assert.Equal(t, "v1", tag1Value) + assert.Equal(t, sdk.Pointer("v1"), tag1Value) } func TestInt_DatabasesCreateSecondary(t *testing.T) { diff --git a/pkg/sdk/testint/external_tables_integration_test.go b/pkg/sdk/testint/external_tables_integration_test.go index 26f0e0a720..d614aa10d9 100644 --- a/pkg/sdk/testint/external_tables_integration_test.go +++ b/pkg/sdk/testint/external_tables_integration_test.go @@ -27,7 +27,7 @@ func TestInt_ExternalTables(t *testing.T) { return []*sdk.ExternalTableColumnRequest{ sdk.NewExternalTableColumnRequest("filename", sdk.DataTypeString, "metadata$filename::string"), sdk.NewExternalTableColumnRequest("city", sdk.DataTypeString, "value:city:findname::string"), - sdk.NewExternalTableColumnRequest("time", sdk.DataTypeTimestamp, "to_timestamp(value:time::int)"), + sdk.NewExternalTableColumnRequest("time", sdk.DataTypeTimestampLTZ, "to_timestamp_ltz(value:time::int)"), sdk.NewExternalTableColumnRequest("weather", sdk.DataTypeVariant, "value:weather::variant"), } } diff --git a/pkg/sdk/testint/functions_integration_test.go b/pkg/sdk/testint/functions_integration_test.go index 1fe0a04bb9..44bb8b898a 100644 --- a/pkg/sdk/testint/functions_integration_test.go +++ b/pkg/sdk/testint/functions_integration_test.go @@ -9,6 +9,7 @@ import ( "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/collections" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -518,7 +519,6 @@ func TestInt_FunctionsShowByID(t *testing.T) { *sdk.NewFunctionArgumentRequest("M", sdk.DataTypeDate), *sdk.NewFunctionArgumentRequest("N", "DATETIME"), *sdk.NewFunctionArgumentRequest("O", sdk.DataTypeTime), - *sdk.NewFunctionArgumentRequest("P", sdk.DataTypeTimestamp), *sdk.NewFunctionArgumentRequest("R", sdk.DataTypeTimestampLTZ), *sdk.NewFunctionArgumentRequest("S", sdk.DataTypeTimestampNTZ), *sdk.NewFunctionArgumentRequest("T", sdk.DataTypeTimestampTZ), @@ -536,14 +536,15 @@ func TestInt_FunctionsShowByID(t *testing.T) { "add", ). WithArguments(args). - WithFunctionDefinition("def add(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, R, S, T, U, V, W, X, Y, Z): A + A"), + WithFunctionDefinition("def add(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, R, S, T, U, V, W, X, Y, Z): A + A"), ) require.NoError(t, err) dataTypes := make([]sdk.DataType, len(args)) for i, arg := range args { - dataTypes[i], err = sdk.ToDataType(string(arg.ArgDataType)) + dataType, err := datatypes.ParseDataType(string(arg.ArgDataType)) require.NoError(t, err) + dataTypes[i] = sdk.LegacyDataTypeFrom(dataType) } idWithArguments := sdk.NewSchemaObjectIdentifierWithArguments(id.DatabaseName(), id.SchemaName(), id.Name(), dataTypes...) diff --git a/pkg/sdk/testint/roles_integration_test.go b/pkg/sdk/testint/roles_integration_test.go index 5c173a1ada..2c9551bb88 100644 --- a/pkg/sdk/testint/roles_integration_test.go +++ b/pkg/sdk/testint/roles_integration_test.go @@ -78,11 +78,11 @@ func TestInt_Roles(t *testing.T) { // verify tags tag1Value, err := client.SystemFunctions.GetTag(ctx, tag.ID(), role.ID(), sdk.ObjectTypeRole) require.NoError(t, err) - assert.Equal(t, "v1", tag1Value) + assert.Equal(t, sdk.Pointer("v1"), tag1Value) tag2Value, err := client.SystemFunctions.GetTag(ctx, tag2.ID(), role.ID(), sdk.ObjectTypeRole) require.NoError(t, err) - assert.Equal(t, "v2", tag2Value) + assert.Equal(t, sdk.Pointer("v2"), tag2Value) }) t.Run("alter rename to", func(t *testing.T) { diff --git a/pkg/sdk/testint/row_access_policies_gen_integration_test.go b/pkg/sdk/testint/row_access_policies_gen_integration_test.go index 4241a7eab1..2833210fe8 100644 --- a/pkg/sdk/testint/row_access_policies_gen_integration_test.go +++ b/pkg/sdk/testint/row_access_policies_gen_integration_test.go @@ -6,6 +6,7 @@ import ( "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/helpers/random" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/collections" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -221,7 +222,7 @@ func TestInt_RowAccessPolicies(t *testing.T) { t.Run("describe row access policy: with timestamp data type normalization", func(t *testing.T) { argName := random.AlphaN(5) - argType := sdk.DataTypeTimestamp + argType := sdk.DataTypeTimestampLTZ args := sdk.NewCreateRowAccessPolicyArgsRequest(argName, argType) body := "true" @@ -234,7 +235,7 @@ func TestInt_RowAccessPolicies(t *testing.T) { assertRowAccessPolicyDescription(t, returnedRowAccessPolicyDescription, rowAccessPolicy.ID(), []sdk.TableColumnSignature{{ Name: argName, - Type: sdk.DataTypeTimestampNTZ, + Type: sdk.DataTypeTimestampLTZ, }}, body) }) @@ -317,7 +318,6 @@ func TestInt_RowAccessPoliciesDescribe(t *testing.T) { *sdk.NewCreateRowAccessPolicyArgsRequest("M", sdk.DataTypeDate), *sdk.NewCreateRowAccessPolicyArgsRequest("N", "DATETIME"), *sdk.NewCreateRowAccessPolicyArgsRequest("O", sdk.DataTypeTime), - *sdk.NewCreateRowAccessPolicyArgsRequest("P", sdk.DataTypeTimestamp), *sdk.NewCreateRowAccessPolicyArgsRequest("R", sdk.DataTypeTimestampLTZ), *sdk.NewCreateRowAccessPolicyArgsRequest("S", sdk.DataTypeTimestampNTZ), *sdk.NewCreateRowAccessPolicyArgsRequest("T", sdk.DataTypeTimestampTZ), @@ -342,11 +342,11 @@ func TestInt_RowAccessPoliciesDescribe(t *testing.T) { require.NoError(t, err) wantArgs := make([]sdk.TableColumnSignature, len(args)) for i, arg := range args { - dataType, err := sdk.ToDataType(string(arg.Type)) + dataType, err := datatypes.ParseDataType(string(arg.Type)) require.NoError(t, err) wantArgs[i] = sdk.TableColumnSignature{ Name: arg.Name, - Type: dataType, + Type: sdk.LegacyDataTypeFrom(dataType), } } assert.Equal(t, wantArgs, policyDetails.Signature) diff --git a/pkg/sdk/testint/schemas_integration_test.go b/pkg/sdk/testint/schemas_integration_test.go index a1eaa75c3c..d395339d44 100644 --- a/pkg/sdk/testint/schemas_integration_test.go +++ b/pkg/sdk/testint/schemas_integration_test.go @@ -153,7 +153,7 @@ func TestInt_Schemas(t *testing.T) { tv, err := client.SystemFunctions.GetTag(ctx, tag.ID(), schemaID, sdk.ObjectTypeSchema) require.NoError(t, err) - assert.Equal(t, tagValue, tv) + assert.Equal(t, &tagValue, tv) }) t.Run("create: complete", func(t *testing.T) { @@ -245,11 +245,11 @@ func TestInt_Schemas(t *testing.T) { tag1Value, err := client.SystemFunctions.GetTag(ctx, tagTest.ID(), schema.ID(), sdk.ObjectTypeSchema) require.NoError(t, err) - assert.Equal(t, "v1", tag1Value) + assert.Equal(t, sdk.Pointer("v1"), tag1Value) tag2Value, err := client.SystemFunctions.GetTag(ctx, tag2Test.ID(), schema.ID(), sdk.ObjectTypeSchema) require.NoError(t, err) - assert.Equal(t, "v2", tag2Value) + assert.Equal(t, sdk.Pointer("v2"), tag2Value) }) t.Run("alter: rename to", func(t *testing.T) { diff --git a/pkg/sdk/testint/streams_gen_integration_test.go b/pkg/sdk/testint/streams_gen_integration_test.go index f5cd1b4581..8fce3d0cb3 100644 --- a/pkg/sdk/testint/streams_gen_integration_test.go +++ b/pkg/sdk/testint/streams_gen_integration_test.go @@ -46,7 +46,7 @@ func TestInt_Streams(t *testing.T) { tag1Value, err := client.SystemFunctions.GetTag(ctx, tag.ID(), id, sdk.ObjectTypeStream) require.NoError(t, err) - assert.Equal(t, "v1", tag1Value) + assert.Equal(t, sdk.Pointer("v1"), tag1Value) assertions.AssertThatObject(t, objectassert.Stream(t, id). HasName(id.Name()). diff --git a/pkg/sdk/testint/system_functions_integration_test.go b/pkg/sdk/testint/system_functions_integration_test.go index 02208e84c2..7e018eceae 100644 --- a/pkg/sdk/testint/system_functions_integration_test.go +++ b/pkg/sdk/testint/system_functions_integration_test.go @@ -34,7 +34,7 @@ func TestInt_GetTag(t *testing.T) { require.NoError(t, err) s, err := client.SystemFunctions.GetTag(ctx, tagTest.ID(), maskingPolicyTest.ID(), sdk.ObjectTypeMaskingPolicy) require.NoError(t, err) - assert.Equal(t, tagValue, s) + assert.Equal(t, &tagValue, s) }) t.Run("masking policy with no set tag", func(t *testing.T) { @@ -42,8 +42,8 @@ func TestInt_GetTag(t *testing.T) { t.Cleanup(maskingPolicyCleanup) s, err := client.SystemFunctions.GetTag(ctx, tagTest.ID(), maskingPolicyTest.ID(), sdk.ObjectTypeMaskingPolicy) - require.Error(t, err) - assert.Equal(t, "", s) + require.NoError(t, err) + assert.Nil(t, s) }) t.Run("unsupported object type", func(t *testing.T) { _, err := client.SystemFunctions.GetTag(ctx, tagTest.ID(), testClientHelper().Ids.RandomAccountObjectIdentifier(), sdk.ObjectTypeSequence) diff --git a/pkg/sdk/testint/tables_integration_test.go b/pkg/sdk/testint/tables_integration_test.go index 19d7e8732f..7678ea5c36 100644 --- a/pkg/sdk/testint/tables_integration_test.go +++ b/pkg/sdk/testint/tables_integration_test.go @@ -13,6 +13,7 @@ import ( "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/collections" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/snowflakeroles" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -43,9 +44,9 @@ func TestInt_Table(t *testing.T) { require.Len(t, createdColumns, len(expectedColumns)) for i, expectedColumn := range expectedColumns { assert.Equal(t, strings.ToUpper(expectedColumn.Name), createdColumns[i].ColumnName) - createdColumnDataType, err := sdk.ToDataType(createdColumns[i].DataType) + createdColumnDataType, err := datatypes.ParseDataType(createdColumns[i].DataType) assert.NoError(t, err) - assert.Equal(t, expectedColumn.Type, createdColumnDataType) + assert.Equal(t, expectedColumn.Type, sdk.LegacyDataTypeFrom(createdColumnDataType)) } } diff --git a/pkg/sdk/testint/tags_integration_test.go b/pkg/sdk/testint/tags_integration_test.go index c3d9414750..996e44e070 100644 --- a/pkg/sdk/testint/tags_integration_test.go +++ b/pkg/sdk/testint/tags_integration_test.go @@ -16,6 +16,7 @@ import ( "github.com/stretchr/testify/require" ) +// TODO(SNOW-1813223): cleanup tests func TestInt_Tags(t *testing.T) { client := testClient(t) ctx := context.Background() @@ -154,6 +155,16 @@ func TestInt_Tags(t *testing.T) { err := client.Tags.Alter(ctx, sdk.NewAlterTagRequest(id).WithSet(set)) require.NoError(t, err) + ref, err := testClientHelper().PolicyReferences.GetPolicyReference(t, tag.ID(), sdk.PolicyEntityDomainTag) + require.NoError(t, err) + assert.Equal(t, policyTest.ID().Name(), ref.PolicyName) + assert.Equal(t, sdk.PolicyKindMaskingPolicy, ref.PolicyKind) + + // assert that setting masking policy does not apply the tag on the masking policy + returnedTagValue, err := client.SystemFunctions.GetTag(ctx, id, policyTest.ID(), sdk.ObjectTypeMaskingPolicy) + require.NoError(t, err) + assert.Nil(t, returnedTagValue) + unset := sdk.NewTagUnsetRequest().WithMaskingPolicies(policies) err = client.Tags.Alter(ctx, sdk.NewAlterTagRequest(id).WithUnset(unset)) require.NoError(t, err) @@ -309,19 +320,28 @@ func TestInt_TagsAssociations(t *testing.T) { tag.ID(), } - testTagSet := func(id sdk.ObjectIdentifier, objectType sdk.ObjectType) { - err := client.Tags.Set(ctx, sdk.NewSetTagRequest(objectType, id).WithSetTags(tags)) + assertTagSet := func(id sdk.ObjectIdentifier, objectType sdk.ObjectType) { + returnedTagValue, err := client.SystemFunctions.GetTag(ctx, tag.ID(), id, objectType) require.NoError(t, err) + assert.Equal(t, sdk.Pointer(tagValue), returnedTagValue) + } + assertTagUnset := func(id sdk.ObjectIdentifier, objectType sdk.ObjectType) { returnedTagValue, err := client.SystemFunctions.GetTag(ctx, tag.ID(), id, objectType) require.NoError(t, err) - assert.Equal(t, tagValue, returnedTagValue) + assert.Nil(t, returnedTagValue) + } + + testTagSet := func(id sdk.ObjectIdentifier, objectType sdk.ObjectType) { + err := client.Tags.Set(ctx, sdk.NewSetTagRequest(objectType, id).WithSetTags(tags)) + require.NoError(t, err) + + assertTagSet(id, objectType) err = client.Tags.Unset(ctx, sdk.NewUnsetTagRequest(objectType, id).WithUnsetTags(unsetTags)) require.NoError(t, err) - _, err = client.SystemFunctions.GetTag(ctx, tag.ID(), id, objectType) - require.ErrorContains(t, err, "sql: Scan error on column index 0, name \"TAG\": converting NULL to string is unsupported") + assertTagUnset(id, objectType) } t.Run("TestInt_TagAssociationForAccountLocator", func(t *testing.T) { @@ -331,31 +351,25 @@ func TestInt_TagsAssociations(t *testing.T) { }) require.NoError(t, err) - returnedTagValue, err := client.SystemFunctions.GetTag(ctx, tag.ID(), id, sdk.ObjectTypeAccount) - require.NoError(t, err) - assert.Equal(t, tagValue, returnedTagValue) + assertTagSet(id, sdk.ObjectTypeAccount) err = client.Accounts.Alter(ctx, &sdk.AlterAccountOptions{ UnsetTag: unsetTags, }) require.NoError(t, err) - _, err = client.SystemFunctions.GetTag(ctx, tag.ID(), id, sdk.ObjectTypeAccount) - require.ErrorContains(t, err, "sql: Scan error on column index 0, name \"TAG\": converting NULL to string is unsupported") + assertTagUnset(id, sdk.ObjectTypeAccount) // test tag sdk method err = client.Tags.SetOnCurrentAccount(ctx, sdk.NewSetTagOnCurrentAccountRequest().WithSetTags(tags)) require.NoError(t, err) - returnedTagValue, err = client.SystemFunctions.GetTag(ctx, tag.ID(), id, sdk.ObjectTypeAccount) - require.NoError(t, err) - assert.Equal(t, tagValue, returnedTagValue) + assertTagSet(id, sdk.ObjectTypeAccount) err = client.Tags.UnsetOnCurrentAccount(ctx, sdk.NewUnsetTagOnCurrentAccountRequest().WithUnsetTags(unsetTags)) require.NoError(t, err) - _, err = client.SystemFunctions.GetTag(ctx, tag.ID(), id, sdk.ObjectTypeAccount) - require.ErrorContains(t, err, "sql: Scan error on column index 0, name \"TAG\": converting NULL to string is unsupported") + assertTagUnset(id, sdk.ObjectTypeAccount) }) t.Run("TestInt_TagAssociationForAccount", func(t *testing.T) { @@ -363,15 +377,12 @@ func TestInt_TagsAssociations(t *testing.T) { err := client.Tags.Set(ctx, sdk.NewSetTagRequest(sdk.ObjectTypeAccount, id).WithSetTags(tags)) require.NoError(t, err) - returnedTagValue, err := client.SystemFunctions.GetTag(ctx, tag.ID(), id, sdk.ObjectTypeAccount) - require.NoError(t, err) - assert.Equal(t, tagValue, returnedTagValue) + assertTagSet(id, sdk.ObjectTypeAccount) err = client.Tags.Unset(ctx, sdk.NewUnsetTagRequest(sdk.ObjectTypeAccount, id).WithUnsetTags(unsetTags)) require.NoError(t, err) - _, err = client.SystemFunctions.GetTag(ctx, tag.ID(), id, sdk.ObjectTypeAccount) - require.ErrorContains(t, err, "sql: Scan error on column index 0, name \"TAG\": converting NULL to string is unsupported") + assertTagUnset(id, sdk.ObjectTypeAccount) }) accountObjectTestCases := []struct { @@ -634,15 +645,12 @@ func TestInt_TagsAssociations(t *testing.T) { err := tc.setTags(id, tags) require.NoError(t, err) - returnedTagValue, err := client.SystemFunctions.GetTag(ctx, tag.ID(), id, tc.objectType) - require.NoError(t, err) - assert.Equal(t, tagValue, returnedTagValue) + assertTagSet(id, tc.objectType) err = tc.unsetTags(id, unsetTags) require.NoError(t, err) - _, err = client.SystemFunctions.GetTag(ctx, tag.ID(), id, tc.objectType) - require.ErrorContains(t, err, "sql: Scan error on column index 0, name \"TAG\": converting NULL to string is unsupported") + assertTagUnset(id, tc.objectType) // test object methods testTagSet(id, tc.objectType) @@ -725,15 +733,12 @@ func TestInt_TagsAssociations(t *testing.T) { err := tc.setTags(id, tags) require.NoError(t, err) - returnedTagValue, err := client.SystemFunctions.GetTag(ctx, tag.ID(), id, tc.objectType) - require.NoError(t, err) - assert.Equal(t, tagValue, returnedTagValue) + assertTagSet(id, tc.objectType) err = tc.unsetTags(id, unsetTags) require.NoError(t, err) - _, err = client.SystemFunctions.GetTag(ctx, tag.ID(), id, tc.objectType) - require.ErrorContains(t, err, "sql: Scan error on column index 0, name \"TAG\": converting NULL to string is unsupported") + assertTagUnset(id, tc.objectType) // test object methods testTagSet(id, tc.objectType) @@ -794,23 +799,6 @@ func TestInt_TagsAssociations(t *testing.T) { }) }, }, - { - name: "MaskingPolicy", - objectType: sdk.ObjectTypeMaskingPolicy, - setupObject: func() (IDProvider[sdk.SchemaObjectIdentifier], func()) { - return testClientHelper().MaskingPolicy.CreateMaskingPolicy(t) - }, - setTags: func(id sdk.SchemaObjectIdentifier, tags []sdk.TagAssociation) error { - return client.MaskingPolicies.Alter(ctx, id, &sdk.AlterMaskingPolicyOptions{ - SetTag: tags, - }) - }, - unsetTags: func(id sdk.SchemaObjectIdentifier, tags []sdk.ObjectIdentifier) error { - return client.MaskingPolicies.Alter(ctx, id, &sdk.AlterMaskingPolicyOptions{ - UnsetTag: tags, - }) - }, - }, { name: "RowAccessPolicy", objectType: sdk.ObjectTypeRowAccessPolicy, @@ -929,21 +917,45 @@ func TestInt_TagsAssociations(t *testing.T) { err := tc.setTags(id, tags) require.NoError(t, err) - returnedTagValue, err := client.SystemFunctions.GetTag(ctx, tag.ID(), id, tc.objectType) - require.NoError(t, err) - assert.Equal(t, tagValue, returnedTagValue) + assertTagSet(id, tc.objectType) err = tc.unsetTags(id, unsetTags) require.NoError(t, err) - _, err = client.SystemFunctions.GetTag(ctx, tag.ID(), id, tc.objectType) - require.ErrorContains(t, err, "sql: Scan error on column index 0, name \"TAG\": converting NULL to string is unsupported") + assertTagUnset(id, tc.objectType) // test object methods testTagSet(id, tc.objectType) }) } + t.Run("schema object MaskingPolicy", func(t *testing.T) { + maskingPolicy, cleanup := testClientHelper().MaskingPolicy.CreateMaskingPolicy(t) + t.Cleanup(cleanup) + id := maskingPolicy.ID() + err := client.MaskingPolicies.Alter(ctx, id, &sdk.AlterMaskingPolicyOptions{ + SetTag: tags, + }) + require.NoError(t, err) + + assertTagSet(id, sdk.ObjectTypeMaskingPolicy) + + // assert that setting masking policy does not apply the tag on the masking policy + refs, err := testClientHelper().PolicyReferences.GetPolicyReferences(t, tag.ID(), sdk.PolicyEntityDomainTag) + require.NoError(t, err) + assert.Len(t, refs, 0) + + err = client.MaskingPolicies.Alter(ctx, id, &sdk.AlterMaskingPolicyOptions{ + UnsetTag: unsetTags, + }) + require.NoError(t, err) + + assertTagUnset(id, sdk.ObjectTypeMaskingPolicy) + + // test object methods + testTagSet(id, sdk.ObjectTypeMaskingPolicy) + }) + columnTestCases := []struct { name string setupObject func() (sdk.TableColumnIdentifier, func()) @@ -994,15 +1006,12 @@ func TestInt_TagsAssociations(t *testing.T) { err := tc.setTags(id, tags) require.NoError(t, err) - returnedTagValue, err := client.SystemFunctions.GetTag(ctx, tag.ID(), id, sdk.ObjectTypeColumn) - require.NoError(t, err) - assert.Equal(t, tagValue, returnedTagValue) + assertTagSet(id, sdk.ObjectTypeColumn) err = tc.unsetTags(id, unsetTags) require.NoError(t, err) - _, err = client.SystemFunctions.GetTag(ctx, tag.ID(), id, sdk.ObjectTypeColumn) - require.ErrorContains(t, err, "sql: Scan error on column index 0, name \"TAG\": converting NULL to string is unsupported") + assertTagUnset(id, sdk.ObjectTypeColumn) // test object methods testTagSet(id, sdk.ObjectTypeColumn) @@ -1071,15 +1080,12 @@ func TestInt_TagsAssociations(t *testing.T) { err := tc.setTags(id, tags) require.NoError(t, err) - returnedTagValue, err := client.SystemFunctions.GetTag(ctx, tag.ID(), id, tc.objectType) - require.NoError(t, err) - assert.Equal(t, tagValue, returnedTagValue) + assertTagSet(id, tc.objectType) err = tc.unsetTags(id, unsetTags) require.NoError(t, err) - _, err = client.SystemFunctions.GetTag(ctx, tag.ID(), id, tc.objectType) - require.ErrorContains(t, err, "sql: Scan error on column index 0, name \"TAG\": converting NULL to string is unsupported") + assertTagUnset(id, tc.objectType) // test object methods testTagSet(id, tc.objectType) diff --git a/pkg/sdk/testint/tasks_gen_integration_test.go b/pkg/sdk/testint/tasks_gen_integration_test.go index b47c7c9139..6304a830d4 100644 --- a/pkg/sdk/testint/tasks_gen_integration_test.go +++ b/pkg/sdk/testint/tasks_gen_integration_test.go @@ -356,25 +356,25 @@ func TestInt_Tasks(t *testing.T) { rootId := testClientHelper().Ids.RandomSchemaObjectIdentifier() root, rootCleanup := testClientHelper().Task.CreateWithRequest(t, sdk.NewCreateTaskRequest(rootId, sql).WithSchedule("10 MINUTE")) t.Cleanup(rootCleanup) - require.Empty(t, root.Predecessors) + require.Empty(t, root.TaskRelations.Predecessors) t1, t1Cleanup := testClientHelper().Task.CreateWithAfter(t, rootId) t.Cleanup(t1Cleanup) - require.Equal(t, []sdk.SchemaObjectIdentifier{rootId}, t1.Predecessors) + require.Equal(t, []sdk.SchemaObjectIdentifier{rootId}, t1.TaskRelations.Predecessors) t2, t2Cleanup := testClientHelper().Task.CreateWithAfter(t, t1.ID(), rootId) t.Cleanup(t2Cleanup) - require.Contains(t, t2.Predecessors, rootId) - require.Contains(t, t2.Predecessors, t1.ID()) - require.Len(t, t2.Predecessors, 2) + require.Contains(t, t2.TaskRelations.Predecessors, rootId) + require.Contains(t, t2.TaskRelations.Predecessors, t1.ID()) + require.Len(t, t2.TaskRelations.Predecessors, 2) t3, t3Cleanup := testClientHelper().Task.CreateWithAfter(t, t2.ID(), t1.ID()) t.Cleanup(t3Cleanup) - require.Contains(t, t3.Predecessors, t2.ID()) - require.Contains(t, t3.Predecessors, t1.ID()) - require.Len(t, t3.Predecessors, 2) + require.Contains(t, t3.TaskRelations.Predecessors, t2.ID()) + require.Contains(t, t3.TaskRelations.Predecessors, t1.ID()) + require.Len(t, t3.TaskRelations.Predecessors, 2) rootTasks, err := sdk.GetRootTasks(client.Tasks, ctx, rootId) require.NoError(t, err) @@ -433,19 +433,19 @@ func TestInt_Tasks(t *testing.T) { root1Id := testClientHelper().Ids.RandomSchemaObjectIdentifier() root1, root1Cleanup := testClientHelper().Task.CreateWithRequest(t, sdk.NewCreateTaskRequest(root1Id, sql).WithSchedule("10 MINUTE")) t.Cleanup(root1Cleanup) - require.Empty(t, root1.Predecessors) + require.Empty(t, root1.TaskRelations.Predecessors) root2Id := testClientHelper().Ids.RandomSchemaObjectIdentifier() root2, root2Cleanup := testClientHelper().Task.CreateWithRequest(t, sdk.NewCreateTaskRequest(root2Id, sql).WithSchedule("10 MINUTE")) t.Cleanup(root2Cleanup) - require.Empty(t, root2.Predecessors) + require.Empty(t, root2.TaskRelations.Predecessors) t1, t1Cleanup := testClientHelper().Task.CreateWithAfter(t, root1.ID(), root2.ID()) t.Cleanup(t1Cleanup) - require.Contains(t, t1.Predecessors, root1Id) - require.Contains(t, t1.Predecessors, root2Id) - require.Len(t, t1.Predecessors, 2) + require.Contains(t, t1.TaskRelations.Predecessors, root1Id) + require.Contains(t, t1.TaskRelations.Predecessors, root2Id) + require.Len(t, t1.TaskRelations.Predecessors, 2) rootTasks, err := sdk.GetRootTasks(client.Tasks, ctx, t1.ID()) require.NoError(t, err) @@ -493,7 +493,7 @@ func TestInt_Tasks(t *testing.T) { returnedTagValue, err := client.SystemFunctions.GetTag(ctx, tag.ID(), task.ID(), sdk.ObjectTypeTask) require.NoError(t, err) - assert.Equal(t, "v1", returnedTagValue) + assert.Equal(t, sdk.Pointer("v1"), returnedTagValue) }) t.Run("clone task: default", func(t *testing.T) { @@ -522,7 +522,7 @@ func TestInt_Tasks(t *testing.T) { assert.Equal(t, sourceTask.Config, task.Config) assert.Equal(t, sourceTask.Condition, task.Condition) assert.Equal(t, sourceTask.Warehouse, task.Warehouse) - assert.Equal(t, sourceTask.Predecessors, task.Predecessors) + assert.Equal(t, sourceTask.TaskRelations.Predecessors, task.TaskRelations.Predecessors) assert.Equal(t, sourceTask.AllowOverlappingExecution, task.AllowOverlappingExecution) assert.Equal(t, sourceTask.Comment, task.Comment) assert.Equal(t, sourceTask.ErrorIntegration, task.ErrorIntegration) @@ -613,7 +613,7 @@ func TestInt_Tasks(t *testing.T) { t.Cleanup(taskCleanup) err := client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(task.ID()).WithSet(*sdk.NewTaskSetRequest(). - // TODO(SNOW-1348116): Cannot set warehouse due to Snowflake error + // TODO(SNOW-1843489): Cannot set warehouse due to Snowflake error // WithWarehouse(testClientHelper().Ids.WarehouseId()). WithErrorIntegration(errorIntegration.ID()). WithSessionParameters(sessionParametersSet). @@ -754,7 +754,7 @@ func TestInt_Tasks(t *testing.T) { task, taskCleanup := testClientHelper().Task.CreateWithAfter(t, rootTask.ID()) t.Cleanup(taskCleanup) - assert.Contains(t, task.Predecessors, rootTask.ID()) + assert.Contains(t, task.TaskRelations.Predecessors, rootTask.ID()) err := client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(task.ID()).WithRemoveAfter([]sdk.SchemaObjectIdentifier{rootTask.ID()})) require.NoError(t, err) @@ -762,7 +762,7 @@ func TestInt_Tasks(t *testing.T) { task, err = client.Tasks.ShowByID(ctx, task.ID()) require.NoError(t, err) - assert.Empty(t, task.Predecessors) + assert.Empty(t, task.TaskRelations.Predecessors) err = client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(task.ID()).WithAddAfter([]sdk.SchemaObjectIdentifier{rootTask.ID()})) require.NoError(t, err) @@ -770,7 +770,7 @@ func TestInt_Tasks(t *testing.T) { task, err = client.Tasks.ShowByID(ctx, task.ID()) require.NoError(t, err) - assert.Contains(t, task.Predecessors, rootTask.ID()) + assert.Contains(t, task.TaskRelations.Predecessors, rootTask.ID()) }) t.Run("alter task: set and unset final task", func(t *testing.T) { diff --git a/pkg/sdk/testint/warehouses_integration_test.go b/pkg/sdk/testint/warehouses_integration_test.go index e2dce5e08c..998bca7c09 100644 --- a/pkg/sdk/testint/warehouses_integration_test.go +++ b/pkg/sdk/testint/warehouses_integration_test.go @@ -166,10 +166,10 @@ func TestInt_Warehouses(t *testing.T) { tag1Value, err := client.SystemFunctions.GetTag(ctx, tag.ID(), warehouse.ID(), sdk.ObjectTypeWarehouse) require.NoError(t, err) - assert.Equal(t, "v1", tag1Value) + assert.Equal(t, sdk.Pointer("v1"), tag1Value) tag2Value, err := client.SystemFunctions.GetTag(ctx, tag2.ID(), warehouse.ID(), sdk.ObjectTypeWarehouse) require.NoError(t, err) - assert.Equal(t, "v2", tag2Value) + assert.Equal(t, sdk.Pointer("v2"), tag2Value) }) t.Run("create: no options", func(t *testing.T) { @@ -616,7 +616,7 @@ func TestInt_Warehouses(t *testing.T) { require.NoError(t, err) assert.Equal(t, 1, result.Running) assert.Equal(t, 0, result.Queued) - assert.Equal(t, sdk.WarehouseStateSuspended, result.State) + assert.Eventually(t, func() bool { return sdk.WarehouseStateSuspended == result.State }, 5*time.Second, time.Second) }) t.Run("alter: resize with a long running-query", func(t *testing.T) { diff --git a/pkg/sdk/validations.go b/pkg/sdk/validations.go index ada355f2d2..d8199f2d24 100644 --- a/pkg/sdk/validations.go +++ b/pkg/sdk/validations.go @@ -2,10 +2,12 @@ package sdk import ( "reflect" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" ) func IsValidDataType(v string) bool { - _, err := ToDataType(v) + _, err := datatypes.ParseDataType(v) return err == nil } diff --git a/templates/resources/tag_association.md.tmpl b/templates/resources/tag_association.md.tmpl new file mode 100644 index 0000000000..a514fdf39d --- /dev/null +++ b/templates/resources/tag_association.md.tmpl @@ -0,0 +1,43 @@ +--- +page_title: "{{.Name}} {{.Type}} - {{.ProviderName}}" +subcategory: "" +description: |- +{{ if gt (len (split .Description "")) 1 -}} +{{ index (split .Description "") 1 | plainmarkdown | trimspace | prefixlines " " }} +{{- else -}} +{{ .Description | plainmarkdown | trimspace | prefixlines " " }} +{{- end }} +--- + +!> **V1 release candidate** This resource was reworked and is a release candidate for the V1. We do not expect significant changes in it before the V1. We will welcome any feedback and adjust the resource if needed. Any errors reported will be resolved with a higher priority. We encourage checking this resource out before the V1 release. Please follow the [migration guide](https://github.com/Snowflake-Labs/terraform-provider-snowflake/blob/main/MIGRATION_GUIDE.md#v0980--v0990) to use it. + +-> **Note** For `ACCOUNT` object type, only identifiers with organization name are supported. See [account identifier docs](https://docs.snowflake.com/en/user-guide/admin-account-identifier#format-1-preferred-account-name-in-your-organization) for more details. + +-> **Note** Tag association resource ID has the following format: `"TAG_DATABASE"."TAG_SCHEMA"."TAG_NAME"|TAG_VALUE|OBJECT_TYPE`. This means that a tuple of tag ID, tag value and object type should be unique across the resources. If you want to specify this combination for more than one object, you should use only one `tag_association` resource with specified `object_identifiers` set. + +-> **Note** If you want to change tag value to a value that is already present in another `tag_association` resource, first remove the relevant `object_identifiers` from the resource with the old value, run `terraform apply`, then add the relevant `object_identifiers` in the resource with new value, and run `terrafrom apply` once again. + +# {{.Name}} ({{.Type}}) + +{{ .Description | trimspace }} + +{{ if .HasExample -}} +## Example Usage + +{{ tffile (printf "examples/resources/%s/resource.tf" .Name)}} +-> **Note** Instead of using fully_qualified_name, you can reference objects managed outside Terraform by constructing a correct ID, consult [identifiers guide](https://registry.terraform.io/providers/Snowflake-Labs/snowflake/latest/docs/guides/identifiers#new-computed-fully-qualified-name-field-in-resources). + + +{{- end }} + +{{ .SchemaMarkdown | trimspace }} +{{- if .HasImport }} + +## Import + +~> **Note** Due to technical limitations of Terraform SDK, `object_identifiers` are not set during import state. Please run `terraform refresh` after importing to get this field populated. + +Import is supported using the following syntax: + +{{ codefile "shell" (printf "examples/resources/%s/import.sh" .Name)}} +{{- end }}