diff --git a/docs/resources/table.md b/docs/resources/table.md index fee4fc3856..3f24c433cf 100644 --- a/docs/resources/table.md +++ b/docs/resources/table.md @@ -118,7 +118,7 @@ Optional: - `comment` (String) Column comment - `default` (Block List, Max: 1) Defines the column default value; note due to limitations of Snowflake's ALTER TABLE ADD/MODIFY COLUMN updates to default will not be applied (see [below for nested schema](#nestedblock--column--default)) - `identity` (Block List, Max: 1) Defines the identity start/step values for a column. **Note** Identity/default are mutually exclusive. (see [below for nested schema](#nestedblock--column--identity)) -- `masking_policy` (String) Masking policy to apply on column +- `masking_policy` (String) Masking policy to apply on column. It has to be a fully qualified name. - `nullable` (Boolean) Whether this column can contain null values. **Note**: Depending on your Snowflake version, the default value will not suffice if this column is used in a primary key constraint. diff --git a/pkg/datasources/tables.go b/pkg/datasources/tables.go index 44b9992dc4..a04d804527 100644 --- a/pkg/datasources/tables.go +++ b/pkg/datasources/tables.go @@ -1,12 +1,12 @@ package datasources import ( + "context" "database/sql" - "errors" - "fmt" "log" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" ) @@ -58,38 +58,38 @@ func Tables() *schema.Resource { func ReadTables(d *schema.ResourceData, meta interface{}) error { db := meta.(*sql.DB) + ctx := context.Background() + client := sdk.NewClientFromDB(db) databaseName := d.Get("database").(string) schemaName := d.Get("schema").(string) - currentTables, err := snowflake.ListTables(databaseName, schemaName, db) - if errors.Is(err, sql.ErrNoRows) { - // If not found, mark resource to be removed from state file during apply or refresh - log.Printf("[DEBUG] tables in schema (%s) not found", d.Id()) - d.SetId("") - return nil - } else if err != nil { - log.Printf("[DEBUG] unable to parse tables in schema (%s)", d.Id()) + schemaId := sdk.NewDatabaseObjectIdentifier(databaseName, schemaName) + extractedTables, err := client.Tables.Show(ctx, sdk.NewShowTableRequest().WithIn( + &sdk.In{Schema: schemaId}, + )) + if err != nil { + log.Printf("[DEBUG] failed when searching tables in schema (%s), err = %s", schemaId.FullyQualifiedName(), err.Error()) d.SetId("") return nil } - tables := []map[string]interface{}{} + tables := make([]map[string]any, 0) - for _, table := range currentTables { - tableMap := map[string]interface{}{} - - if table.IsExternal.String == "Y" { + for _, extractedTable := range extractedTables { + if extractedTable.IsExternal { continue } - tableMap["name"] = table.TableName.String - tableMap["database"] = table.DatabaseName.String - tableMap["schema"] = table.SchemaName.String - tableMap["comment"] = table.Comment.String + table := map[string]any{ + "name": extractedTable.Name, + "database": extractedTable.DatabaseName, + "schema": extractedTable.SchemaName, + "comment": extractedTable.Comment, + } - tables = append(tables, tableMap) + tables = append(tables, table) } - d.SetId(fmt.Sprintf(`%v|%v`, databaseName, schemaName)) + d.SetId(helpers.EncodeSnowflakeID(databaseName, schemaName)) return d.Set("tables", tables) } diff --git a/pkg/datasources/tables_acceptance_test.go b/pkg/datasources/tables_acceptance_test.go index 62aa32daae..9205e3a033 100644 --- a/pkg/datasources/tables_acceptance_test.go +++ b/pkg/datasources/tables_acceptance_test.go @@ -5,8 +5,11 @@ import ( "strings" "testing" + acc "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance" + "github.com/hashicorp/terraform-plugin-testing/helper/acctest" "github.com/hashicorp/terraform-plugin-testing/helper/resource" + "github.com/hashicorp/terraform-plugin-testing/tfversion" ) func TestAcc_Tables(t *testing.T) { @@ -16,7 +19,11 @@ func TestAcc_Tables(t *testing.T) { stageName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) externalTableName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) resource.ParallelTest(t, resource.TestCase{ - Providers: providers(), + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, CheckDestroy: nil, Steps: []resource.TestStep{ { diff --git a/pkg/resources/materialized_view.go b/pkg/resources/materialized_view.go index fe7060393c..91ce9bd080 100644 --- a/pkg/resources/materialized_view.go +++ b/pkg/resources/materialized_view.go @@ -206,13 +206,13 @@ func UpdateMaterializedView(d *schema.ResourceData, meta interface{}) error { unsetRequest := sdk.NewMaterializedViewUnsetRequest() if d.HasChange("comment") { - comment := d.Get("comment") - if c := comment.(string); c == "" { + comment := d.Get("comment").(string) + if comment == "" { runUnsetStatement = true unsetRequest.WithComment(sdk.Bool(true)) } else { runSetStatement = true - setRequest.WithComment(sdk.String(d.Get("comment").(string))) + setRequest.WithComment(sdk.String(comment)) } } if d.HasChange("is_secure") { diff --git a/pkg/resources/stream.go b/pkg/resources/stream.go index c1ffb59d0f..7cd22ef87f 100644 --- a/pkg/resources/stream.go +++ b/pkg/resources/stream.go @@ -9,8 +9,6 @@ import ( "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" - - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" ) @@ -131,14 +129,12 @@ func CreateStream(d *schema.ResourceData, meta interface{}) error { } tableId := tableObjectIdentifier.(sdk.SchemaObjectIdentifier) - tq := snowflake.NewTableBuilder(tableId.Name(), tableId.DatabaseName(), tableId.SchemaName()).Show() - tableRow := snowflake.QueryRow(db, tq) - t, err := snowflake.ScanTable(tableRow) + table, err := client.Tables.ShowByID(ctx, tableId) if err != nil { return err } - if t.IsExternal.String == "Y" { + if table.IsExternal { req := sdk.NewCreateStreamOnExternalTableRequest(id, tableId) if insertOnly { req.WithInsertOnly(sdk.Bool(true)) diff --git a/pkg/resources/table.go b/pkg/resources/table.go index d71dfafe52..00a2c3fe5b 100644 --- a/pkg/resources/table.go +++ b/pkg/resources/table.go @@ -1,24 +1,23 @@ package resources import ( - "bytes" + "context" "database/sql" - "encoding/csv" - "errors" "fmt" "log" "slices" + "strconv" "strings" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" ) -const ( - tableIDDelimiter = '|' -) - +// TODO [SNOW-867235]: old implementation was quoting every column, SDK is not quoting them, therefore they are quoted here: decide if we quote columns or not +// TODO [SNOW-1031688]: move data manipulation logic to the SDK - SQL generation or builders part (e.g. different default types/identity) var tableSchema = map[string]*schema.Schema{ "name": { Type: schema.TypeString, @@ -79,6 +78,7 @@ var tableSchema = map[string]*schema.Schema{ MinItems: 1, MaxItems: 1, Elem: &schema.Resource{ + // TODO [SNOW-867235]: there is no such separation on SDK level. Should we keep it in V1? Schema: map[string]*schema.Schema{ "constant": { Type: schema.TypeString, @@ -137,7 +137,7 @@ var tableSchema = map[string]*schema.Schema{ Type: schema.TypeString, Optional: true, Default: "", - Description: "Masking policy to apply on column", + Description: "Masking policy to apply on column. It has to be a fully qualified name.", }, }, }, @@ -220,73 +220,12 @@ func Table() *schema.Resource { } } -type tableID struct { - DatabaseName string - SchemaName string - TableName string -} - -// String() takes in a tableID object and returns a pipe-delimited string: -// DatabaseName|SchemaName|TableName. -func (si *tableID) String() (string, error) { - var buf bytes.Buffer - csvWriter := csv.NewWriter(&buf) - csvWriter.Comma = tableIDDelimiter - dataIdentifiers := [][]string{{si.DatabaseName, si.SchemaName, si.TableName}} - if err := csvWriter.WriteAll(dataIdentifiers); err != nil { - return "", err - } - strTableID := strings.TrimSpace(buf.String()) - return strTableID, nil -} - -// tableIDFromString() takes in a pipe-delimited string: DatabaseName|SchemaName|TableName -// and returns a tableID object. -func tableIDFromString(stringID string) (*tableID, error) { - reader := csv.NewReader(strings.NewReader(stringID)) - reader.Comma = tableIDDelimiter - lines, err := reader.ReadAll() - if err != nil { - return nil, fmt.Errorf("not CSV compatible") - } - - if len(lines) != 1 { - return nil, fmt.Errorf("1 line at a time") - } - if len(lines[0]) != 3 { - return nil, fmt.Errorf("3 fields allowed") - } - - tableResult := &tableID{ - DatabaseName: lines[0][0], - SchemaName: lines[0][1], - TableName: lines[0][2], - } - return tableResult, nil -} - type columnDefault struct { constant *string expression *string sequence *string } -func (cd *columnDefault) toSnowflakeColumnDefault() *snowflake.ColumnDefault { - if cd.constant != nil { - return snowflake.NewColumnDefaultWithConstant(*cd.constant) - } - - if cd.expression != nil { - return snowflake.NewColumnDefaultWithExpression(*cd.expression) - } - - if cd.sequence != nil { - return snowflake.NewColumnDefaultWithSequence(*cd.sequence) - } - - return nil -} - func (cd *columnDefault) _type() string { if cd.constant != nil { return "constant" @@ -308,11 +247,6 @@ type columnIdentity struct { stepNum int } -func (identity *columnIdentity) toSnowflakeColumnIdentity() *snowflake.ColumnIdentity { - snowIdentity := snowflake.ColumnIdentity{} - return snowIdentity.WithStartNum(identity.startNum).WithStep(identity.stepNum) -} - type column struct { name string dataType string @@ -323,34 +257,8 @@ type column struct { maskingPolicy string } -func (c column) toSnowflakeColumn() snowflake.Column { - sC := &snowflake.Column{} - - if c._default != nil { - sC = sC.WithDefault(c._default.toSnowflakeColumnDefault()) - } - - if c.identity != nil { - sC = sC.WithIdentity(c.identity.toSnowflakeColumnIdentity()) - } - - return *sC.WithName(c.name). - WithType(c.dataType). - WithNullable(c.nullable). - WithComment(c.comment). - WithMaskingPolicy(c.maskingPolicy) -} - type columns []column -func (c columns) toSnowflakeColumns() []snowflake.Column { - sC := make([]snowflake.Column, len(c)) - for i, col := range c { - sC[i] = col.toSnowflakeColumn() - } - return sC -} - type changedColumns []changedColumn type changedColumn struct { @@ -468,6 +376,67 @@ func getColumns(from interface{}) (to columns) { return to } +func getTableColumnRequest(from interface{}) *sdk.TableColumnRequest { + c := from.(map[string]interface{}) + _type := c["type"].(string) + + nameInQuotes := fmt.Sprintf(`"%v"`, snowflake.EscapeString(c["name"].(string))) + request := sdk.NewTableColumnRequest(nameInQuotes, sdk.DataType(_type)) + + _default := c["default"].([]interface{}) + var expression string + if len(_default) == 1 { + if c, ok := _default[0].(map[string]interface{})["constant"]; ok { + if constant, ok := c.(string); ok && len(constant) > 0 { + if strings.Contains(_type, "CHAR") || _type == "STRING" || _type == "TEXT" { + expression = snowflake.EscapeSnowflakeString(constant) + } else { + expression = constant + } + } + } + + if e, ok := _default[0].(map[string]interface{})["expression"]; ok { + if expr, ok := e.(string); ok && len(expr) > 0 { + expression = expr + } + } + + if s, ok := _default[0].(map[string]interface{})["sequence"]; ok { + if seq := s.(string); ok && len(seq) > 0 { + expression = fmt.Sprintf(`%v.NEXTVAL`, seq) + } + } + request.WithDefaultValue(sdk.NewColumnDefaultValueRequest().WithExpression(sdk.String(expression))) + } + + identity := c["identity"].([]interface{}) + if len(identity) == 1 { + identityProp := identity[0].(map[string]interface{}) + startNum := identityProp["start_num"].(int) + stepNum := identityProp["step_num"].(int) + request.WithDefaultValue(sdk.NewColumnDefaultValueRequest().WithIdentity(sdk.NewColumnIdentityRequest(startNum, stepNum))) + } + + maskingPolicy := c["masking_policy"].(string) + if maskingPolicy != "" { + request.WithMaskingPolicy(sdk.NewColumnMaskingPolicyRequest(sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(maskingPolicy))) + } + + return request. + WithNotNull(sdk.Bool(!c["nullable"].(bool))). + WithComment(sdk.String(c["comment"].(string))) +} + +func getTableColumnRequests(from interface{}) []sdk.TableColumnRequest { + cols := from.([]interface{}) + to := make([]sdk.TableColumnRequest, len(cols)) + for i, c := range cols { + to[i] = *getTableColumnRequest(c) + } + return to +} + type primarykey struct { name string keys []string @@ -485,66 +454,158 @@ func getPrimaryKey(from interface{}) (to primarykey) { return to } -func (pk primarykey) toSnowflakePrimaryKey() snowflake.PrimaryKey { - snowPk := snowflake.PrimaryKey{} - return *snowPk.WithName(pk.name).WithKeys(pk.keys) +func toColumnConfig(descriptions []sdk.TableColumnDetails) []any { + flattened := make([]any, 0) + for _, td := range descriptions { + if td.Kind != "COLUMN" { + continue + } + + flat := map[string]any{} + flat["name"] = td.Name + flat["type"] = string(td.Type) + flat["nullable"] = td.IsNullable + + if td.Comment != nil { + flat["comment"] = *td.Comment + } + + if td.PolicyName != nil { + // TODO [SNOW-867240]: SHOW TABLE returns last part of id without double quotes... we have to quote it again. Move it to SDK. + flat["masking_policy"] = sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(*td.PolicyName).FullyQualifiedName() + } + + identity := toColumnIdentityConfig(td) + if identity != nil { + flat["identity"] = []any{identity} + } else { + def := toColumnDefaultConfig(td) + if def != nil { + flat["default"] = []any{def} + } + } + flattened = append(flattened, flat) + } + return flattened +} + +func toColumnDefaultConfig(td sdk.TableColumnDetails) map[string]any { + if td.Default == nil { + return nil + } + + defaultRaw := *td.Default + def := map[string]any{} + if strings.HasSuffix(defaultRaw, ".NEXTVAL") { + // TODO [SNOW-867240]: SHOW TABLE returns last part of id without double quotes... we have to quote it again. Move it to SDK. + sequenceIdRaw := strings.TrimSuffix(defaultRaw, ".NEXTVAL") + def["sequence"] = sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(sequenceIdRaw).FullyQualifiedName() + return def + } + + if strings.Contains(defaultRaw, "(") && strings.Contains(defaultRaw, ")") { + def["expression"] = defaultRaw + return def + } + + columnType := strings.ToUpper(string(td.Type)) + if strings.Contains(columnType, "CHAR") || columnType == "STRING" || columnType == "TEXT" { + def["constant"] = snowflake.UnescapeSnowflakeString(defaultRaw) + return def + } + + def["constant"] = defaultRaw + return def +} + +func toColumnIdentityConfig(td sdk.TableColumnDetails) map[string]any { + // if autoincrement is used this is reflected back IDENTITY START 1 INCREMENT 1 + if td.Default == nil { + return nil + } + + defaultRaw := *td.Default + + if strings.Contains(defaultRaw, "IDENTITY") { + identity := map[string]any{} + + split := strings.Split(defaultRaw, " ") + start, err := strconv.Atoi(split[2]) + if err == nil { + identity["start_num"] = start + } + step, err := strconv.Atoi(split[4]) + if err == nil { + identity["step_num"] = step + } + + return identity + } + return nil } // CreateTable implements schema.CreateFunc. func CreateTable(d *schema.ResourceData, meta interface{}) error { db := meta.(*sql.DB) - database := d.Get("database").(string) - schema := d.Get("schema").(string) + ctx := context.Background() + client := sdk.NewClientFromDB(db) + + databaseName := d.Get("database").(string) + schemaName := d.Get("schema").(string) name := d.Get("name").(string) + id := sdk.NewSchemaObjectIdentifier(databaseName, schemaName, name) - columns := getColumns(d.Get("column").([]interface{})) + tableColumnRequests := getTableColumnRequests(d.Get("column").([]interface{})) - builder := snowflake.NewTableWithColumnDefinitionsBuilder(name, database, schema, columns.toSnowflakeColumns()) + createRequest := sdk.NewCreateTableRequest(id, tableColumnRequests) - // Set optionals if v, ok := d.GetOk("comment"); ok { - builder.WithComment(v.(string)) + createRequest.WithComment(sdk.String(v.(string))) } if v, ok := d.GetOk("cluster_by"); ok { - builder.WithClustering(expandStringList(v.([]interface{}))) + createRequest.WithClusterBy(expandStringList(v.([]interface{}))) } if v, ok := d.GetOk("primary_key"); ok { - pk := getPrimaryKey(v.([]interface{})) - builder.WithPrimaryKey(pk.toSnowflakePrimaryKey()) + keysList := v.([]any) + if len(keysList) > 0 { + keys := expandStringList(keysList[0].(map[string]any)["keys"].([]interface{})) + constraintRequest := sdk.NewOutOfLineConstraintRequest(sdk.ColumnConstraintTypePrimaryKey).WithColumns(snowflake.QuoteStringList(keys)) + + keyName, isPresent := keysList[0].(map[string]any)["name"] + if isPresent && keyName != "" { + constraintRequest.WithName(sdk.String(keyName.(string))) + } + } } if v, ok := d.GetOk("data_retention_days"); ok { - builder.WithDataRetentionTimeInDays(v.(int)) + createRequest.WithDataRetentionTimeInDays(sdk.Int(v.(int))) } else if v, ok := d.GetOk("data_retention_time_in_days"); ok { - builder.WithDataRetentionTimeInDays(v.(int)) + createRequest.WithDataRetentionTimeInDays(sdk.Int(v.(int))) } if v, ok := d.GetOk("change_tracking"); ok { - builder.WithChangeTracking(v.(bool)) - } - - if v, ok := d.GetOk("tag"); ok { - tags := getTags(v) - builder.WithTags(tags.toSnowflakeTagValues()) + createRequest.WithChangeTracking(sdk.Bool(v.(bool))) } - stmt := builder.Create() - if err := snowflake.Exec(db, stmt); err != nil { - return fmt.Errorf("error creating table %v", name) + var tagAssociationRequests []sdk.TagAssociationRequest + if _, ok := d.GetOk("tag"); ok { + tagAssociations := getPropertyTags(d, "tag") + tagAssociationRequests = make([]sdk.TagAssociationRequest, len(tagAssociations)) + for i, t := range tagAssociations { + tagAssociationRequests[i] = *sdk.NewTagAssociationRequest(t.Name, t.Value) + } + createRequest.WithTags(tagAssociationRequests) } - tableID := &tableID{ - DatabaseName: database, - SchemaName: schema, - TableName: name, - } - dataIDInput, err := tableID.String() + err := client.Tables.Create(ctx, createRequest) if err != nil { - return err + return fmt.Errorf("error creating table %v err = %w", name, err) } - d.SetId(dataIDInput) + + d.SetId(helpers.EncodeSnowflakeID(id)) return ReadTable(d, meta) } @@ -552,59 +613,33 @@ func CreateTable(d *schema.ResourceData, meta interface{}) error { // ReadTable implements schema.ReadFunc. func ReadTable(d *schema.ResourceData, meta interface{}) error { db := meta.(*sql.DB) - tableID, err := tableIDFromString(d.Id()) - if err != nil { - return err - } - builder := snowflake.NewTableBuilder(tableID.TableName, tableID.DatabaseName, tableID.SchemaName) + ctx := context.Background() + client := sdk.NewClientFromDB(db) + id := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) - row := snowflake.QueryRow(db, builder.Show()) - table, err := snowflake.ScanTable(row) - if errors.Is(err, sql.ErrNoRows) { - // If not found, mark resource to be removed from state file during apply or refresh + table, err := client.Tables.ShowByID(ctx, id) + if err != nil { log.Printf("[DEBUG] table (%s) not found", d.Id()) d.SetId("") return nil } - if err != nil { - return err - } - // Describe the table to read the cols - tableDescriptionRows, err := snowflake.Query(db, builder.ShowColumns()) + tableDescription, err := client.Tables.DescribeColumns(ctx, sdk.NewDescribeTableColumnsRequest(id)) if err != nil { return err } - tableDescription, err := snowflake.ScanTableDescription(tableDescriptionRows) - if err != nil { - return err - } - - /* - deprecated as it conflicts with the new table_constraint resource - showPkrows, err := snowflake.Query(db, builder.ShowPrimaryKeys()) - if err != nil { - return err - } - - pkDescription, err := snowflake.ScanPrimaryKeyDescription(showPkrows) - if err != nil { - return err - }*/ - // Set the relevant data in the state toSet := map[string]interface{}{ - "name": table.TableName.String, - "owner": table.Owner.String, - "database": tableID.DatabaseName, - "schema": tableID.SchemaName, - "comment": table.Comment.String, - "column": snowflake.NewColumns(tableDescription).Flatten(), - "cluster_by": snowflake.ClusterStatementToList(table.ClusterBy.String), - // "primary_key": snowflake.FlattenTablePrimaryKey(pkDescription), - "change_tracking": (table.ChangeTracking.String == "ON"), - "qualified_name": fmt.Sprintf(`"%s"."%s"."%s"`, tableID.DatabaseName, tableID.SchemaName, table.TableName.String), + "name": table.Name, + "owner": table.Owner, + "database": table.DatabaseName, + "schema": table.SchemaName, + "comment": table.Comment, + "column": toColumnConfig(tableDescription), + "cluster_by": table.GetClusterByKeys(), + "change_tracking": table.ChangeTracking, + "qualified_name": id.FullyQualifiedName(), } var dataRetentionKey string if _, ok := d.GetOk("data_retention_time_in_days"); ok { @@ -613,7 +648,7 @@ func ReadTable(d *schema.ResourceData, meta interface{}) error { dataRetentionKey = "data_retention_days" } if dataRetentionKey != "" { - toSet[dataRetentionKey] = table.RetentionTime.Int32 + toSet[dataRetentionKey] = table.RetentionTime } for key, val := range toSet { @@ -626,170 +661,241 @@ func ReadTable(d *schema.ResourceData, meta interface{}) error { // UpdateTable implements schema.UpdateFunc. func UpdateTable(d *schema.ResourceData, meta interface{}) error { - tid, err := tableIDFromString(d.Id()) - if err != nil { - return err + db := meta.(*sql.DB) + ctx := context.Background() + client := sdk.NewClientFromDB(db) + id := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) + + if d.HasChange("name") { + newName := d.Get("name").(string) + + newId := sdk.NewSchemaObjectIdentifier(id.DatabaseName(), id.SchemaName(), newName) + + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithNewName(&newId)) + if err != nil { + return fmt.Errorf("error renaming table %v err = %w", d.Id(), err) + } + + d.SetId(helpers.EncodeSnowflakeID(newId)) + id = newId } - dbName := tid.DatabaseName - schema := tid.SchemaName - tableName := tid.TableName + var runSetStatement bool + var runUnsetStatement bool + setRequest := sdk.NewTableSetRequest() + unsetRequest := sdk.NewTableUnsetRequest() + + if d.HasChange("comment") { + comment := d.Get("comment").(string) + if comment == "" { + runUnsetStatement = true + unsetRequest.WithComment(true) + } else { + runSetStatement = true + setRequest.WithComment(sdk.String(comment)) + } + } - builder := snowflake.NewTableBuilder(tableName, dbName, schema) + if d.HasChange("change_tracking") { + changeTracking := d.Get("change_tracking").(bool) + runSetStatement = true + setRequest.WithChangeTracking(sdk.Bool(changeTracking)) + } - db := meta.(*sql.DB) - if d.HasChange("name") { - name := d.Get("name") - q := builder.Rename(name.(string)) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error updating table name on %v", d.Id()) + checkChangeForDataRetention := func(key string) { + if d.HasChange(key) { + dataRetentionDays := d.Get(key).(int) + runSetStatement = true + setRequest.WithDataRetentionTimeInDays(sdk.Int(dataRetentionDays)) } - tableID := &tableID{ - DatabaseName: dbName, - SchemaName: schema, - TableName: name.(string), + } + checkChangeForDataRetention("data_retention_days") + checkChangeForDataRetention("data_retention_time_in_days") + + if runSetStatement { + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithSet(setRequest)) + if err != nil { + return fmt.Errorf("error updating table: %w", err) } - dataIDInput, err := tableID.String() + } + + if runUnsetStatement { + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithUnset(unsetRequest)) if err != nil { - return err + return fmt.Errorf("error updating table: %w", err) } - d.SetId(dataIDInput) } - if d.HasChange("comment") { - comment := d.Get("comment") - q := builder.ChangeComment(comment.(string)) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error updating table comment on %v", d.Id()) + + if d.HasChange("cluster_by") { + cb := expandStringList(d.Get("cluster_by").([]interface{})) + + if len(cb) != 0 { + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithClusteringAction(sdk.NewTableClusteringActionRequest().WithClusterBy(cb))) + if err != nil { + return fmt.Errorf("error updating table: %w", err) + } + } else { + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithClusteringAction(sdk.NewTableClusteringActionRequest().WithDropClusteringKey(sdk.Bool(true)))) + if err != nil { + return fmt.Errorf("error updating table: %w", err) + } } } + if d.HasChange("column") { t, n := d.GetChange("column") removed, added, changed := getColumns(t).diffs(getColumns(n)) - for _, cA := range removed { - q := builder.DropColumn(cA.name) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error dropping column on %v", d.Id()) + + if len(removed) > 0 { + removedColumnNames := make([]string, len(removed)) + for i, r := range removed { + removedColumnNames[i] = r.name + } + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithColumnAction(sdk.NewTableColumnActionRequest().WithDropColumns(snowflake.QuoteStringList(removedColumnNames)))) + if err != nil { + return fmt.Errorf("error updating table: %w", err) } } + for _, cA := range added { - var q string + addRequest := sdk.NewTableColumnAddActionRequest(fmt.Sprintf("\"%s\"", cA.name), sdk.DataType(cA.dataType)). + WithInlineConstraint(sdk.NewTableColumnAddInlineConstraintRequest().WithNotNull(sdk.Bool(!cA.nullable))) - if cA.identity == nil && cA._default == nil { //nolint:gocritic // todo: please fix this to pass gocritic - q = builder.AddColumn(cA.name, cA.dataType, cA.nullable, nil, nil, cA.comment, cA.maskingPolicy) - } else if cA.identity != nil { - q = builder.AddColumn(cA.name, cA.dataType, cA.nullable, nil, cA.identity.toSnowflakeColumnIdentity(), cA.comment, cA.maskingPolicy) - } else { + if cA._default != nil { if cA._default._type() != "constant" { return fmt.Errorf("failed to add column %v => Only adding a column as a constant is supported by Snowflake", cA.name) } + var expression string + if strings.Contains(cA.dataType, "CHAR") || cA.dataType == "STRING" || cA.dataType == "TEXT" { + expression = snowflake.EscapeSnowflakeString(*cA._default.constant) + } else { + expression = *cA._default.constant + } + addRequest.WithDefaultValue(sdk.NewColumnDefaultValueRequest().WithExpression(sdk.String(expression))) + } + + if cA.identity != nil { + addRequest.WithDefaultValue(sdk.NewColumnDefaultValueRequest().WithIdentity(sdk.NewColumnIdentityRequest(cA.identity.startNum, cA.identity.stepNum))) + } - q = builder.AddColumn(cA.name, cA.dataType, cA.nullable, cA._default.toSnowflakeColumnDefault(), nil, cA.comment, cA.maskingPolicy) + if cA.maskingPolicy != "" { + addRequest.WithMaskingPolicy(sdk.NewColumnMaskingPolicyRequest(sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(cA.maskingPolicy))) } - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error adding column on %v", d.Id()) + if cA.comment != "" { + addRequest.WithComment(sdk.String(cA.comment)) + } + + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithColumnAction(sdk.NewTableColumnActionRequest().WithAdd(addRequest))) + if err != nil { + return fmt.Errorf("error adding column: %w", err) } } for _, cA := range changed { if cA.changedDataType { - q := builder.ChangeColumnType(cA.newColumn.name, cA.newColumn.dataType) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error changing property on %v", d.Id()) + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithColumnAction(sdk.NewTableColumnActionRequest().WithAlter([]sdk.TableColumnAlterActionRequest{*sdk.NewTableColumnAlterActionRequest(fmt.Sprintf("\"%s\"", cA.newColumn.name)).WithType(sdk.Pointer(sdk.DataType(cA.newColumn.dataType)))}))) + if err != nil { + return fmt.Errorf("error changing property on %v: err %w", d.Id(), err) } } if cA.changedNullConstraint { - q := builder.ChangeNullConstraint(cA.newColumn.name, cA.newColumn.nullable) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error changing property on %v", d.Id()) + nullabilityRequest := sdk.NewTableColumnNotNullConstraintRequest() + if !cA.newColumn.nullable { + nullabilityRequest.WithSet(sdk.Bool(true)) + } else { + nullabilityRequest.WithDrop(sdk.Bool(true)) + } + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithColumnAction(sdk.NewTableColumnActionRequest().WithAlter([]sdk.TableColumnAlterActionRequest{*sdk.NewTableColumnAlterActionRequest(fmt.Sprintf("\"%s\"", cA.newColumn.name)).WithNotNullConstraint(nullabilityRequest)}))) + if err != nil { + return fmt.Errorf("error changing property on %v: err %w", d.Id(), err) } } if cA.dropedDefault { - q := builder.DropColumnDefault(cA.newColumn.name) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error changing property on %v", d.Id()) + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithColumnAction(sdk.NewTableColumnActionRequest().WithAlter([]sdk.TableColumnAlterActionRequest{*sdk.NewTableColumnAlterActionRequest(fmt.Sprintf("\"%s\"", cA.newColumn.name)).WithDropDefault(sdk.Bool(true))}))) + if err != nil { + return fmt.Errorf("error changing property on %v: err %w", d.Id(), err) } } if cA.changedComment { - q := builder.ChangeColumnComment(cA.newColumn.name, cA.newColumn.comment) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error changing property on %v", d.Id()) + columnAlterActionRequest := sdk.NewTableColumnAlterActionRequest(fmt.Sprintf("\"%s\"", cA.newColumn.name)) + if cA.newColumn.comment == "" { + columnAlterActionRequest.WithUnsetComment(sdk.Bool(true)) + } else { + columnAlterActionRequest.WithComment(sdk.String(cA.newColumn.comment)) + } + + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithColumnAction(sdk.NewTableColumnActionRequest().WithAlter([]sdk.TableColumnAlterActionRequest{*columnAlterActionRequest}))) + if err != nil { + return fmt.Errorf("error changing property on %v: err %w", d.Id(), err) } } if cA.changedMaskingPolicy { - q := builder.ChangeColumnMaskingPolicy(cA.newColumn.name, cA.newColumn.maskingPolicy) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error changing property on %v", d.Id()) + columnAction := sdk.NewTableColumnActionRequest() + if strings.TrimSpace(cA.newColumn.maskingPolicy) == "" { + columnAction.WithUnsetMaskingPolicy(sdk.NewTableColumnAlterUnsetMaskingPolicyActionRequest(fmt.Sprintf("\"%s\"", cA.newColumn.name))) + } else { + columnAction.WithSetMaskingPolicy(sdk.NewTableColumnAlterSetMaskingPolicyActionRequest(fmt.Sprintf("\"%s\"", cA.newColumn.name), sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(cA.newColumn.maskingPolicy), []string{}).WithForce(sdk.Bool(true))) + } + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithColumnAction(columnAction)) + if err != nil { + return fmt.Errorf("error changing property on %v: err %w", d.Id(), err) } } } } - if d.HasChange("cluster_by") { - cb := expandStringList(d.Get("cluster_by").([]interface{})) - - var q string - if len(cb) != 0 { - builder.WithClustering(cb) - q = builder.ChangeClusterBy(builder.GetClusterKeyString()) - } else { - q = builder.DropClustering() - } - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error updating table clustering on %v", d.Id()) - } - } if d.HasChange("primary_key") { - opk, npk := d.GetChange("primary_key") + o, n := d.GetChange("primary_key") - newpk := getPrimaryKey(npk) - oldpk := getPrimaryKey(opk) + newKey := getPrimaryKey(n) + oldKey := getPrimaryKey(o) - if len(oldpk.keys) > 0 || len(newpk.keys) == 0 { + if len(oldKey.keys) > 0 || len(newKey.keys) == 0 { // drop our pk if there was an old primary key, or pk has been removed - q := builder.DropPrimaryKey() - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error changing primary key first on %v", d.Id()) + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithConstraintAction( + sdk.NewTableConstraintActionRequest(). + WithDrop(sdk.NewTableConstraintDropActionRequest().WithPrimaryKey(sdk.Bool(true))), + )) + if err != nil { + return fmt.Errorf("error updating table: %w", err) } } - if len(newpk.keys) > 0 { - // add our new pk - q := builder.ChangePrimaryKey(newpk.toSnowflakePrimaryKey()) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error changing property on %v", d.Id()) + if len(newKey.keys) > 0 { + constraint := sdk.NewOutOfLineConstraintRequest(sdk.ColumnConstraintTypePrimaryKey).WithColumns(snowflake.QuoteStringList(newKey.keys)) + if newKey.name != "" { + constraint.WithName(sdk.String(newKey.name)) + } + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithConstraintAction( + sdk.NewTableConstraintActionRequest().WithAdd(constraint), + )) + if err != nil { + return fmt.Errorf("error updating table: %w", err) } } } - updateDataRetention := func(key string) error { - if d.HasChange(key) { - ndr := d.Get(key) - q := builder.ChangeDataRetention(ndr.(int)) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error changing property on %v", d.Id()) + + if d.HasChange("tag") { + unsetTags, setTags := GetTagsDiff(d, "tag") + + if len(unsetTags) > 0 { + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithUnsetTags(unsetTags)) + if err != nil { + return fmt.Errorf("error setting tags on %v, err = %w", d.Id(), err) } } - return nil - } - err = updateDataRetention("data_retention_days") - if err != nil { - return err - } - err = updateDataRetention("data_retention_time_in_days") - if err != nil { - return err - } - if d.HasChange("change_tracking") { - nct := d.Get("change_tracking") - q := builder.ChangeChangeTracking(nct.(bool)) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error changing property on %v", d.Id()) + if len(setTags) > 0 { + tagAssociationRequests := make([]sdk.TagAssociationRequest, len(setTags)) + for i, t := range setTags { + tagAssociationRequests[i] = *sdk.NewTagAssociationRequest(t.Name, t.Value) + } + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithSetTags(tagAssociationRequests)) + if err != nil { + return fmt.Errorf("error setting tags on %v, err = %w", d.Id(), err) + } } } - tagChangeErr := handleTagChanges(db, d, builder) - if tagChangeErr != nil { - return tagChangeErr - } return ReadTable(d, meta) } @@ -797,21 +903,15 @@ func UpdateTable(d *schema.ResourceData, meta interface{}) error { // DeleteTable implements schema.DeleteFunc. func DeleteTable(d *schema.ResourceData, meta interface{}) error { db := meta.(*sql.DB) - tableID, err := tableIDFromString(d.Id()) + ctx := context.Background() + client := sdk.NewClientFromDB(db) + id := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) + + err := client.Tables.Drop(ctx, sdk.NewDropTableRequest(id)) if err != nil { return err } - dbName := tableID.DatabaseName - schemaName := tableID.SchemaName - tableName := tableID.TableName - - q := snowflake.NewTableBuilder(tableName, dbName, schemaName).Drop() - - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error deleting pipe %v err = %w", d.Id(), err) - } - d.SetId("") return nil diff --git a/pkg/resources/table_acceptance_test.go b/pkg/resources/table_acceptance_test.go index c0c5d72c96..cddc2b1dee 100644 --- a/pkg/resources/table_acceptance_test.go +++ b/pkg/resources/table_acceptance_test.go @@ -1,26 +1,31 @@ package resources_test import ( + "context" + "database/sql" "fmt" - "os" "strings" "testing" acc "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/hashicorp/terraform-plugin-testing/helper/acctest" "github.com/hashicorp/terraform-plugin-testing/helper/resource" + "github.com/hashicorp/terraform-plugin-testing/terraform" + "github.com/hashicorp/terraform-plugin-testing/tfversion" ) func TestAcc_TableWithSeparateDataRetentionObjectParameterWithoutLifecycle(t *testing.T) { - if _, ok := os.LookupEnv("SKIP_TABLE_DATA_RETENTION_TESTS"); ok { - t.Skip("Skipping TestAcc_TableWithSeparateDataRetentionObjectParameterWithoutLifecycle") - } - accName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) - resource.ParallelTest(t, resource.TestCase{ - Providers: acc.TestAccProviders(), - PreCheck: func() { acc.TestAccPreCheck(t) }, - CheckDestroy: nil, + + resource.Test(t, resource.TestCase{ + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: testAccCheckTableDestroy, Steps: []resource.TestStep{ { Config: tableConfig(accName, acc.TestDatabaseName, acc.TestSchemaName), @@ -61,15 +66,15 @@ func TestAcc_TableWithSeparateDataRetentionObjectParameterWithoutLifecycle(t *te } func TestAcc_TableWithSeparateDataRetentionObjectParameterWithLifecycle(t *testing.T) { - if _, ok := os.LookupEnv("SKIP_TABLE_DATA_RETENTION_TESTS"); ok { - t.Skip("Skipping TestAcc_TableWithSeparateDataRetentionObjectParameterWithLifecycle") - } - accName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) - resource.ParallelTest(t, resource.TestCase{ - Providers: acc.TestAccProviders(), - PreCheck: func() { acc.TestAccPreCheck(t) }, - CheckDestroy: nil, + + resource.Test(t, resource.TestCase{ + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: testAccCheckTableDestroy, Steps: []resource.TestStep{ { Config: tableConfig(accName, acc.TestDatabaseName, acc.TestSchemaName), @@ -131,10 +136,13 @@ func TestAcc_Table(t *testing.T) { table2Name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) table3Name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) - resource.ParallelTest(t, resource.TestCase{ - Providers: acc.TestAccProviders(), - PreCheck: func() { acc.TestAccPreCheck(t) }, - CheckDestroy: nil, + resource.Test(t, resource.TestCase{ + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: testAccCheckTableDestroy, Steps: []resource.TestStep{ { Config: tableConfig(accName, acc.TestDatabaseName, acc.TestSchemaName), @@ -670,7 +678,6 @@ resource "snowflake_table" "test_table" { nullable = false } primary_key { - name = "" keys = ["column2"] } } @@ -868,10 +875,13 @@ resource "snowflake_table" "test_table" { func TestAcc_TableDefaults(t *testing.T) { accName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) - resource.ParallelTest(t, resource.TestCase{ - Providers: acc.TestAccProviders(), - PreCheck: func() { acc.TestAccPreCheck(t) }, - CheckDestroy: nil, + resource.Test(t, resource.TestCase{ + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: testAccCheckTableDestroy, Steps: []resource.TestStep{ { Config: tableColumnWithDefaults(accName, acc.TestDatabaseName, acc.TestSchemaName), @@ -896,7 +906,7 @@ func TestAcc_TableDefaults(t *testing.T) { resource.TestCheckResourceAttr("snowflake_table.test_table", "column.2.type", "NUMBER(38,0)"), resource.TestCheckNoResourceAttr("snowflake_table.test_table", "column.2.type.default.0.constant"), resource.TestCheckNoResourceAttr("snowflake_table.test_table", "column.2.type.default.0.expression"), - resource.TestCheckResourceAttr("snowflake_table.test_table", "column.2.default.0.sequence", fmt.Sprintf(`"%v"."%v".%v`, acc.TestDatabaseName, acc.TestSchemaName, accName)), + resource.TestCheckResourceAttr("snowflake_table.test_table", "column.2.default.0.sequence", fmt.Sprintf(`"%v"."%v"."%v"`, acc.TestDatabaseName, acc.TestSchemaName, accName)), resource.TestCheckNoResourceAttr("snowflake_table.test_table", "primary_key.0"), ), }, @@ -919,7 +929,7 @@ func TestAcc_TableDefaults(t *testing.T) { resource.TestCheckResourceAttr("snowflake_table.test_table", "column.2.type", "NUMBER(38,0)"), resource.TestCheckNoResourceAttr("snowflake_table.test_table", "column.2.type.default.0.constant"), resource.TestCheckNoResourceAttr("snowflake_table.test_table", "column.2.type.default.0.expression"), - resource.TestCheckResourceAttr("snowflake_table.test_table", "column.2.default.0.sequence", fmt.Sprintf(`"%v"."%v".%v`, acc.TestDatabaseName, acc.TestSchemaName, accName)), + resource.TestCheckResourceAttr("snowflake_table.test_table", "column.2.default.0.sequence", fmt.Sprintf(`"%v"."%v"."%v"`, acc.TestDatabaseName, acc.TestSchemaName, accName)), resource.TestCheckNoResourceAttr("snowflake_table.test_table", "primary_key.0"), ), }, @@ -1005,10 +1015,14 @@ func TestAcc_TableTags(t *testing.T) { accName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) tagName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) tag2Name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) - resource.ParallelTest(t, resource.TestCase{ - Providers: acc.TestAccProviders(), - PreCheck: func() { acc.TestAccPreCheck(t) }, - CheckDestroy: nil, + + resource.Test(t, resource.TestCase{ + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: testAccCheckTableDestroy, Steps: []resource.TestStep{ { Config: tableWithTags(accName, tagName, tag2Name, acc.TestDatabaseName, acc.TestSchemaName), @@ -1080,10 +1094,13 @@ resource "snowflake_table" "test_table" { func TestAcc_TableIdentity(t *testing.T) { accName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) - resource.ParallelTest(t, resource.TestCase{ - Providers: acc.TestAccProviders(), - PreCheck: func() { acc.TestAccPreCheck(t) }, - CheckDestroy: nil, + resource.Test(t, resource.TestCase{ + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: testAccCheckTableDestroy, Steps: []resource.TestStep{ { Config: tableColumnWithIdentityDefault(accName, acc.TestDatabaseName, acc.TestSchemaName), @@ -1214,10 +1231,14 @@ resource "snowflake_table" "test_table" { func TestAcc_TableRename(t *testing.T) { oldTableName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) newTableName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) - resource.ParallelTest(t, resource.TestCase{ - Providers: acc.TestAccProviders(), - PreCheck: func() { acc.TestAccPreCheck(t) }, - CheckDestroy: nil, + + resource.Test(t, resource.TestCase{ + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: testAccCheckTableDestroy, Steps: []resource.TestStep{ { Config: tableConfigWithName(oldTableName, acc.TestDatabaseName, acc.TestSchemaName), @@ -1263,3 +1284,96 @@ resource "snowflake_table" "test_table" { ` return fmt.Sprintf(s, tableName, databaseName, schemaName) } + +func TestAcc_Table_MaskingPolicy(t *testing.T) { + accName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) + + resource.Test(t, resource.TestCase{ + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: testAccCheckTableDestroy, + Steps: []resource.TestStep{ + { + Config: tableWithMaskingPolicy(accName, acc.TestDatabaseName, acc.TestSchemaName, "policy1"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr("snowflake_table.test_table", "name", accName), + resource.TestCheckResourceAttr("snowflake_table.test_table", "column.0.masking_policy", sdk.NewSchemaObjectIdentifier(acc.TestDatabaseName, acc.TestSchemaName, fmt.Sprintf("%s1", accName)).FullyQualifiedName()), + ), + }, + // this step proves https://github.com/Snowflake-Labs/terraform-provider-snowflake/pull/2186 + { + Config: tableWithMaskingPolicy(accName, acc.TestDatabaseName, acc.TestSchemaName, "policy2"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr("snowflake_table.test_table", "name", accName), + resource.TestCheckResourceAttr("snowflake_table.test_table", "column.0.masking_policy", sdk.NewSchemaObjectIdentifier(acc.TestDatabaseName, acc.TestSchemaName, fmt.Sprintf("%s2", accName)).FullyQualifiedName()), + ), + }, + }, + }) +} + +func tableWithMaskingPolicy(name string, databaseName string, schemaName string, policy string) string { + s := ` +resource "snowflake_masking_policy" "policy1" { + name = "%[1]s1" + database = "%[2]s" + schema = "%[3]s" + signature { + column { + name = "val" + type = "VARCHAR" + } + } + masking_expression = "case when current_role() in ('ANALYST') then val else sha2(val, 512) end" + return_data_type = "VARCHAR(16777216)" +} + +resource "snowflake_masking_policy" "policy2" { + name = "%[1]s2" + database = "%[2]s" + schema = "%[3]s" + signature { + column { + name = "val" + type = "VARCHAR" + } + } + masking_expression = "case when current_role() in ('ANALYST') then val else sha2(val, 512) end" + return_data_type = "VARCHAR(16777216)" +} + +resource "snowflake_table" "test_table" { + name = "%[1]s" + database = "%[2]s" + schema = "%[3]s" + comment = "Terraform acceptance test" + + column { + name = "column1" + type = "VARCHAR(16)" + masking_policy = snowflake_masking_policy.%[4]s.qualified_name + } +} +` + return fmt.Sprintf(s, name, databaseName, schemaName, policy) +} + +func testAccCheckTableDestroy(s *terraform.State) error { + db := acc.TestAccProvider.Meta().(*sql.DB) + client := sdk.NewClientFromDB(db) + for _, rs := range s.RootModule().Resources { + if rs.Type != "snowflake_table" { + continue + } + ctx := context.Background() + id := sdk.NewSchemaObjectIdentifier(rs.Primary.Attributes["database"], rs.Primary.Attributes["schema"], rs.Primary.Attributes["name"]) + existingTable, err := client.Tables.ShowByID(ctx, id) + if err == nil { + return fmt.Errorf("table %v still exists", existingTable.ID().FullyQualifiedName()) + } + } + return nil +} diff --git a/pkg/resources/table_internal_test.go b/pkg/resources/table_internal_test.go deleted file mode 100644 index f35c953025..0000000000 --- a/pkg/resources/table_internal_test.go +++ /dev/null @@ -1,72 +0,0 @@ -package resources - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestTableIDFromString(t *testing.T) { - r := require.New(t) - // Vanilla - id := "database_name|schema_name|table" - table, err := tableIDFromString(id) - r.NoError(err) - r.Equal("database_name", table.DatabaseName) - r.Equal("schema_name", table.SchemaName) - r.Equal("table", table.TableName) - - // Bad ID -- not enough fields - id = "database" - _, err = tableIDFromString(id) - r.Equal(fmt.Errorf("3 fields allowed"), err) - - // Bad ID - id = "||" - _, err = tableIDFromString(id) - r.NoError(err) - - // 0 lines - id = "" - _, err = tableIDFromString(id) - r.Equal(fmt.Errorf("1 line at a time"), err) - - // 2 lines - id = `database_name|schema_name|table - database_name|schema_name|table` - _, err = tableIDFromString(id) - r.Equal(fmt.Errorf("1 line at a time"), err) -} - -func TestTableStruct(t *testing.T) { - r := require.New(t) - - // Vanilla - table := &tableID{ - DatabaseName: "database_name", - SchemaName: "schema_name", - TableName: "table", - } - sID, err := table.String() - r.NoError(err) - r.Equal("database_name|schema_name|table", sID) - - // Empty grant - table = &tableID{} - sID, err = table.String() - r.NoError(err) - r.Equal("||", sID) - - // Grant with extra delimiters - table = &tableID{ - DatabaseName: "database|name", - TableName: "table|name", - } - sID, err = table.String() - r.NoError(err) - newTable, err := tableIDFromString(sID) - r.NoError(err) - r.Equal("database|name", newTable.DatabaseName) - r.Equal("table|name", newTable.TableName) -} diff --git a/pkg/resources/tag.go b/pkg/resources/tag.go index 01359edf34..9f8276da52 100644 --- a/pkg/resources/tag.go +++ b/pkg/resources/tag.go @@ -95,33 +95,6 @@ type TagBuilder interface { ChangeTag(snowflake.TagValue) string } -func handleTagChanges(db *sql.DB, d *schema.ResourceData, builder TagBuilder) error { - if d.HasChange("tag") { - o, n := d.GetChange("tag") - removed, added, changed := getTags(o).diffs(getTags(n)) - for _, tA := range removed { - q := builder.UnsetTag(tA.toSnowflakeTagValue()) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error dropping tag on %v", d.Id()) - } - } - for _, tA := range added { - q := builder.AddTag(tA.toSnowflakeTagValue()) - - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error adding column on %v", d.Id()) - } - } - for _, tA := range changed { - q := builder.ChangeTag(tA.toSnowflakeTagValue()) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error changing property on %v", d.Id()) - } - } - } - return nil -} - // String() takes in a schemaID object and returns a pipe-delimited string: // DatabaseName|SchemaName|TagName. func (ti *TagID) String() (string, error) { diff --git a/pkg/sdk/tables.go b/pkg/sdk/tables.go index 7712b7a4ba..2e3b124786 100644 --- a/pkg/sdk/tables.go +++ b/pkg/sdk/tables.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "strings" ) var _ convertibleRow[Table] = new(tableDBRow) @@ -171,7 +172,7 @@ type ColumnMaskingPolicy struct { // OutOfLineConstraint is based on https://docs.snowflake.com/en/sql-reference/sql/create-table-constraint#out-of-line-unique-primary-foreign-key. type OutOfLineConstraint struct { - Name string `ddl:"parameter,no_equals" sql:"CONSTRAINT"` + Name *string `ddl:"parameter,no_equals" sql:"CONSTRAINT"` Type ColumnConstraintType `ddl:"keyword"` Columns []string `ddl:"keyword,parentheses"` ForeignKey *OutOfLineForeignKey `ddl:"keyword"` @@ -269,11 +270,12 @@ type TableColumnAddAction struct { InlineConstraint *TableColumnAddInlineConstraint `ddl:"keyword"` MaskingPolicy *ColumnMaskingPolicy `ddl:"keyword"` Tags []TagAssociation `ddl:"keyword,parentheses" sql:"TAG"` + Comment *string `ddl:"parameter,no_equals,single_quotes" sql:"COMMENT"` } type TableColumnAddInlineConstraint struct { NotNull *bool `ddl:"keyword" sql:"NOT NULL"` - Name string `ddl:"parameter,no_equals" sql:"CONSTRAINT"` + Name *string `ddl:"parameter,no_equals" sql:"CONSTRAINT"` Type ColumnConstraintType `ddl:"keyword"` ForeignKey *ColumnAddForeignKey `ddl:"keyword"` } @@ -370,15 +372,14 @@ type TableConstraintAlterAction struct { Unique *bool `ddl:"keyword" sql:"UNIQUE"` ForeignKey *bool `ddl:"keyword" sql:"FOREIGN KEY"` - Columns []string `ddl:"keyword,parentheses"` - // Optional - Enforced *bool `ddl:"keyword" sql:"ENFORCED"` - NotEnforced *bool `ddl:"keyword" sql:"NOT ENFORCED"` - Validate *bool `ddl:"keyword" sql:"VALIDATE"` - NoValidate *bool `ddl:"keyword" sql:"NOVALIDATE"` - Rely *bool `ddl:"keyword" sql:"RELY"` - NoRely *bool `ddl:"keyword" sql:"NORELY"` + Columns []string `ddl:"keyword,parentheses"` + Enforced *bool `ddl:"keyword" sql:"ENFORCED"` + NotEnforced *bool `ddl:"keyword" sql:"NOT ENFORCED"` + Validate *bool `ddl:"keyword" sql:"VALIDATE"` + NoValidate *bool `ddl:"keyword" sql:"NOVALIDATE"` + Rely *bool `ddl:"keyword" sql:"RELY"` + NoRely *bool `ddl:"keyword" sql:"NORELY"` } type TableConstraintDropAction struct { @@ -388,11 +389,10 @@ type TableConstraintDropAction struct { Unique *bool `ddl:"keyword" sql:"UNIQUE"` ForeignKey *bool `ddl:"keyword" sql:"FOREIGN KEY"` - Columns []string `ddl:"keyword,parentheses"` - // Optional - Cascade *bool `ddl:"keyword" sql:"CASCADE"` - Restrict *bool `ddl:"keyword" sql:"RESTRICT"` + Columns []string `ddl:"keyword,parentheses"` + Cascade *bool `ddl:"keyword" sql:"CASCADE"` + Restrict *bool `ddl:"keyword" sql:"RESTRICT"` } type TableUnsetTags struct { @@ -412,6 +412,7 @@ type TableExternalTableColumnAddAction struct { Name string `ddl:"keyword"` Type DataType `ddl:"keyword"` Expression []string `ddl:"parameter,no_equals,parentheses" sql:"AS"` + Comment *string `ddl:"parameter,no_equals,single_quotes" sql:"COMMENT"` } type TableExternalTableColumnRenameAction struct { @@ -553,6 +554,22 @@ type Table struct { Budget *string } +// GetClusterByKeys converts the SHOW TABLES result for ClusterBy and converts it to list of keys. +func (v *Table) GetClusterByKeys() []string { + if v.ClusterBy == "" { + return nil + } + + statementWithoutLinear := strings.TrimSuffix(strings.Replace(v.ClusterBy, "LINEAR(", "", 1), ")") + keysRaw := strings.Split(statementWithoutLinear, ",") + keysClean := make([]string, 0, len(keysRaw)) + for _, key := range keysRaw { + keysClean = append(keysClean, strings.TrimSpace(key)) + } + + return keysClean +} + func (row tableDBRow) convert() *Table { table := Table{ CreatedOn: row.CreatedOn, diff --git a/pkg/sdk/tables_dto.go b/pkg/sdk/tables_dto.go index 49fb03da58..504cab4fe5 100644 --- a/pkg/sdk/tables_dto.go +++ b/pkg/sdk/tables_dto.go @@ -132,7 +132,7 @@ type ColumnInlineConstraintRequest struct { } type OutOfLineConstraintRequest struct { - Name string // required + Name *string Type ColumnConstraintType // required Columns []string ForeignKey *OutOfLineForeignKeyRequest @@ -370,11 +370,12 @@ type TableColumnAddActionRequest struct { MaskingPolicy *ColumnMaskingPolicyRequest With *bool Tags []TagAssociation + Comment *string } type TableColumnAddInlineConstraintRequest struct { NotNull *bool - Name string + Name *string Type ColumnConstraintType ForeignKey *ColumnAddForeignKey } @@ -390,8 +391,7 @@ type TableColumnRenameActionRequest struct { } type TableColumnAlterActionRequest struct { - Column bool // required - Name string // required + Name string // required // One of DropDefault *bool @@ -447,8 +447,8 @@ type TableConstraintAlterActionRequest struct { Unique *bool ForeignKey *bool - Columns []string // required // Optional + Columns []string Enforced *bool NotEnforced *bool Validate *bool @@ -464,9 +464,8 @@ type TableConstraintDropActionRequest struct { Unique *bool ForeignKey *bool - Columns []string // required - // Optional + Columns []string Cascade *bool Restrict *bool } @@ -500,6 +499,7 @@ type TableExternalTableColumnAddActionRequest struct { Name string Type DataType Expression string + Comment *string } type TableExternalTableColumnRenameActionRequest struct { diff --git a/pkg/sdk/tables_dto_generated.go b/pkg/sdk/tables_dto_generated.go index 3573a85be8..17ec7218a7 100644 --- a/pkg/sdk/tables_dto_generated.go +++ b/pkg/sdk/tables_dto_generated.go @@ -425,15 +425,18 @@ func (s *ColumnInlineConstraintRequest) WithNoRely(noRely *bool) *ColumnInlineCo } func NewOutOfLineConstraintRequest( - name string, constraintType ColumnConstraintType, ) *OutOfLineConstraintRequest { s := OutOfLineConstraintRequest{} - s.Name = name s.Type = constraintType return &s } +func (s *OutOfLineConstraintRequest) WithName(name *string) *OutOfLineConstraintRequest { + s.Name = name + return s +} + func (s *OutOfLineConstraintRequest) WithColumns(columns []string) *OutOfLineConstraintRequest { s.Columns = columns return s @@ -1184,6 +1187,11 @@ func (s *TableColumnAddActionRequest) WithTags(tags []TagAssociation) *TableColu return s } +func (s *TableColumnAddActionRequest) WithComment(comment *string) *TableColumnAddActionRequest { + s.Comment = comment + return s +} + func NewTableColumnAddInlineConstraintRequest() *TableColumnAddInlineConstraintRequest { return &TableColumnAddInlineConstraintRequest{} } @@ -1193,7 +1201,7 @@ func (s *TableColumnAddInlineConstraintRequest) WithNotNull(notNull *bool) *Tabl return s } -func (s *TableColumnAddInlineConstraintRequest) WithName(name string) *TableColumnAddInlineConstraintRequest { +func (s *TableColumnAddInlineConstraintRequest) WithName(name *string) *TableColumnAddInlineConstraintRequest { s.Name = name return s } @@ -1233,11 +1241,9 @@ func NewTableColumnRenameActionRequest( } func NewTableColumnAlterActionRequest( - column bool, name string, ) *TableColumnAlterActionRequest { s := TableColumnAlterActionRequest{} - s.Column = column s.Name = name return &s } @@ -1369,10 +1375,8 @@ func (s *TableConstraintRenameActionRequest) WithNewName(newName string) *TableC return s } -func NewTableConstraintAlterActionRequest(columns []string) *TableConstraintAlterActionRequest { - return &TableConstraintAlterActionRequest{ - Columns: columns, - } +func NewTableConstraintAlterActionRequest() *TableConstraintAlterActionRequest { + return &TableConstraintAlterActionRequest{} } func (s *TableConstraintAlterActionRequest) WithConstraintName(constraintName *string) *TableConstraintAlterActionRequest { @@ -1395,6 +1399,11 @@ func (s *TableConstraintAlterActionRequest) WithForeignKey(foreignKey *bool) *Ta return s } +func (s *TableConstraintAlterActionRequest) WithColumns(columns []string) *TableConstraintAlterActionRequest { + s.Columns = columns + return s +} + func (s *TableConstraintAlterActionRequest) WithEnforced(enforced *bool) *TableConstraintAlterActionRequest { s.Enforced = enforced return s @@ -1425,10 +1434,8 @@ func (s *TableConstraintAlterActionRequest) WithNoRely(noRely *bool) *TableConst return s } -func NewTableConstraintDropActionRequest(columns []string) *TableConstraintDropActionRequest { - return &TableConstraintDropActionRequest{ - Columns: columns, - } +func NewTableConstraintDropActionRequest() *TableConstraintDropActionRequest { + return &TableConstraintDropActionRequest{} } func (s *TableConstraintDropActionRequest) WithConstraintName(constraintName *string) *TableConstraintDropActionRequest { @@ -1451,6 +1458,11 @@ func (s *TableConstraintDropActionRequest) WithForeignKey(foreignKey *bool) *Tab return s } +func (s *TableConstraintDropActionRequest) WithColumns(columns []string) *TableConstraintDropActionRequest { + s.Columns = columns + return s +} + func (s *TableConstraintDropActionRequest) WithCascade(cascade *bool) *TableConstraintDropActionRequest { s.Cascade = cascade return s @@ -1562,6 +1574,11 @@ func (s *TableExternalTableColumnAddActionRequest) WithExpression(expression str return s } +func (s *TableExternalTableColumnAddActionRequest) WithComment(comment *string) *TableExternalTableColumnAddActionRequest { + s.Comment = comment + return s +} + func NewTableExternalTableColumnRenameActionRequest() *TableExternalTableColumnRenameActionRequest { return &TableExternalTableColumnRenameActionRequest{} } diff --git a/pkg/sdk/tables_impl.go b/pkg/sdk/tables_impl.go index b3ca029285..5a34f0ac8e 100644 --- a/pkg/sdk/tables_impl.go +++ b/pkg/sdk/tables_impl.go @@ -254,6 +254,7 @@ func (r *TableExternalTableActionRequest) toOpts() *TableExternalTableAction { Name: r.Add.Name, Type: r.Add.Type, Expression: []string{r.Add.Expression}, + Comment: r.Add.Comment, }, } } @@ -392,6 +393,7 @@ func (r *TableColumnActionRequest) toOpts() *TableColumnAction { Type: r.Add.Type, DefaultValue: defaultValue, InlineConstraint: inlineConstraint, + Comment: r.Add.Comment, }, } } diff --git a/pkg/sdk/tables_test.go b/pkg/sdk/tables_test.go index 27574b9ffe..644c9594a7 100644 --- a/pkg/sdk/tables_test.go +++ b/pkg/sdk/tables_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/internal/random" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -406,7 +407,7 @@ func TestTableCreate(t *testing.T) { } require.NoError(t, err) outOfLineConstraint1 := OutOfLineConstraint{ - Name: "OUT_OF_LINE_CONSTRAINT", + Name: String("OUT_OF_LINE_CONSTRAINT"), Type: ColumnConstraintTypeForeignKey, Columns: []string{"COLUMN_1", "COLUMN_2"}, ForeignKey: &OutOfLineForeignKey{ @@ -475,7 +476,7 @@ func TestTableCreate(t *testing.T) { Comment: &tableComment, } assertOptsValidAndSQLEquals(t, opts, - `CREATE TABLE %s (%s %s CONSTRAINT INLINE_CONSTRAINT PRIMARY KEY NOT NULL COLLATE 'de' IDENTITY START 10 INCREMENT 1 ORDER MASKING POLICY %s USING (FOO, BAR) TAG ("db"."schema"."column_tag1" = 'v1', "db"."schema"."column_tag2" = 'v2') COMMENT '%s', CONSTRAINT OUT_OF_LINE_CONSTRAINT FOREIGN KEY (COLUMN_1, COLUMN_2) REFERENCES %s (COLUMN_3, COLUMN_4) MATCH FULL ON UPDATE SET NULL ON DELETE RESTRICT, CONSTRAINT UNIQUE (COLUMN_1) ENFORCED DEFERRABLE INITIALLY DEFERRED ENABLE RELY) CLUSTER BY (COLUMN_1, COLUMN_2) ENABLE_SCHEMA_EVOLUTION = true STAGE_FILE_FORMAT = (TYPE = CSV COMPRESSION = AUTO) STAGE_COPY_OPTIONS = (ON_ERROR = SKIP_FILE) DATA_RETENTION_TIME_IN_DAYS = 10 MAX_DATA_EXTENSION_TIME_IN_DAYS = 100 CHANGE_TRACKING = true DEFAULT_DDL_COLLATION = 'en' COPY GRANTS ROW ACCESS POLICY %s ON (COLUMN_1, COLUMN_2) TAG ("db"."schema"."table_tag1" = 'v1', "db"."schema"."table_tag2" = 'v2') COMMENT = '%s'`, + `CREATE TABLE %s (%s %s CONSTRAINT INLINE_CONSTRAINT PRIMARY KEY NOT NULL COLLATE 'de' IDENTITY START 10 INCREMENT 1 ORDER MASKING POLICY %s USING (FOO, BAR) TAG ("db"."schema"."column_tag1" = 'v1', "db"."schema"."column_tag2" = 'v2') COMMENT '%s', CONSTRAINT OUT_OF_LINE_CONSTRAINT FOREIGN KEY (COLUMN_1, COLUMN_2) REFERENCES %s (COLUMN_3, COLUMN_4) MATCH FULL ON UPDATE SET NULL ON DELETE RESTRICT, UNIQUE (COLUMN_1) ENFORCED DEFERRABLE INITIALLY DEFERRED ENABLE RELY) CLUSTER BY (COLUMN_1, COLUMN_2) ENABLE_SCHEMA_EVOLUTION = true STAGE_FILE_FORMAT = (TYPE = CSV COMPRESSION = AUTO) STAGE_COPY_OPTIONS = (ON_ERROR = SKIP_FILE) DATA_RETENTION_TIME_IN_DAYS = 10 MAX_DATA_EXTENSION_TIME_IN_DAYS = 100 CHANGE_TRACKING = true DEFAULT_DDL_COLLATION = 'en' COPY GRANTS ROW ACCESS POLICY %s ON (COLUMN_1, COLUMN_2) TAG ("db"."schema"."table_tag1" = 'v1', "db"."schema"."table_tag2" = 'v2') COMMENT = '%s'`, id.FullyQualifiedName(), columnName, columnType, @@ -813,17 +814,6 @@ func TestTableAlter(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("TableConstraintAlterAction", "ConstraintName", "PrimaryKey", "Unique", "ForeignKey", "Columns")) }) - t.Run("validation: constraint alter action - no columns", func(t *testing.T) { - opts := defaultOpts() - opts.ConstraintAction = &TableConstraintAction{ - Alter: &TableConstraintAlterAction{ - ConstraintName: String("constraint"), - Columns: []string{}, - }, - } - assertOptsInvalidJoinedErrors(t, opts, errNotSet("TableConstraintAlterAction", "Columns")) - }) - t.Run("validation: constraint alter action - two options present", func(t *testing.T) { opts := defaultOpts() opts.ConstraintAction = &TableConstraintAction{ @@ -843,17 +833,6 @@ func TestTableAlter(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("TableConstraintDropAction", "ConstraintName", "PrimaryKey", "Unique", "ForeignKey", "Columns")) }) - t.Run("validation: constraint drop action - no columns", func(t *testing.T) { - opts := defaultOpts() - opts.ConstraintAction = &TableConstraintAction{ - Drop: &TableConstraintDropAction{ - ConstraintName: String("constraint"), - Columns: []string{}, - }, - } - assertOptsInvalidJoinedErrors(t, opts, errNotSet("TableConstraintDropAction", "Columns")) - }) - t.Run("validation: constraint drop action - two options present", func(t *testing.T) { opts := defaultOpts() opts.ConstraintAction = &TableConstraintAction{ @@ -1191,7 +1170,7 @@ func TestTableAlter(t *testing.T) { t.Run("alter constraint: add", func(t *testing.T) { outOfLineConstraint := OutOfLineConstraint{ - Name: "OUT_OF_LINE_CONSTRAINT", + Name: String("OUT_OF_LINE_CONSTRAINT"), Type: ColumnConstraintTypeForeignKey, Columns: []string{"COLUMN_1", "COLUMN_2"}, ForeignKey: &OutOfLineForeignKey{ @@ -1571,3 +1550,29 @@ func TestTableDescribeStage(t *testing.T) { assertOptsValidAndSQLEquals(t, opts, `DESCRIBE TABLE %s TYPE = STAGE`, id.FullyQualifiedName()) }) } + +func TestTable_GetClusterByKeys(t *testing.T) { + t.Run("empty", func(t *testing.T) { + table := Table{ClusterBy: ""} + + assert.Nil(t, table.GetClusterByKeys()) + }) + + t.Run("one param", func(t *testing.T) { + table := Table{ClusterBy: "LINEAR(abc)"} + + assert.Equal(t, []string{"abc"}, table.GetClusterByKeys()) + }) + + t.Run("more params", func(t *testing.T) { + table := Table{ClusterBy: "LINEAR(abc,def)"} + + assert.Equal(t, []string{"abc", "def"}, table.GetClusterByKeys()) + }) + + t.Run("white space", func(t *testing.T) { + table := Table{ClusterBy: " LINEAR( abc , def )"} + + assert.Equal(t, []string{"abc", "def"}, table.GetClusterByKeys()) + }) +} diff --git a/pkg/sdk/tables_validations.go b/pkg/sdk/tables_validations.go index 6c6211bf9d..e1b2467753 100644 --- a/pkg/sdk/tables_validations.go +++ b/pkg/sdk/tables_validations.go @@ -226,9 +226,6 @@ func (opts *alterTableOptions) validate() error { ); !ok { errs = append(errs, errExactlyOneOf("TableConstraintAlterAction", "ConstraintName", "PrimaryKey", "Unique", "ForeignKey", "Columns")) } - if len(alterAction.Columns) == 0 { - errs = append(errs, errNotSet("TableConstraintAlterAction", "Columns")) - } } if dropAction := constraintAction.Drop; valueSet(dropAction) { if ok := exactlyOneValueSet( @@ -239,9 +236,6 @@ func (opts *alterTableOptions) validate() error { ); !ok { errs = append(errs, errExactlyOneOf("TableConstraintDropAction", "ConstraintName", "PrimaryKey", "Unique", "ForeignKey", "Columns")) } - if len(dropAction.Columns) == 0 { - errs = append(errs, errNotSet("TableConstraintDropAction", "Columns")) - } } if addAction := constraintAction.Add; valueSet(addAction) { if err := addAction.validate(); err != nil { diff --git a/pkg/sdk/testint/tables_integration_test.go b/pkg/sdk/testint/tables_integration_test.go index 2b4b850792..0228c3e732 100644 --- a/pkg/sdk/testint/tables_integration_test.go +++ b/pkg/sdk/testint/tables_integration_test.go @@ -124,7 +124,8 @@ func TestInt_Table(t *testing.T) { WithNotNull(sdk.Bool(true)), *sdk.NewTableColumnRequest("COLUMN_2", sdk.DataTypeNumber).WithDefaultValue(sdk.NewColumnDefaultValueRequest().WithIdentity(sdk.NewColumnIdentityRequest(1, 1))), } - outOfLineConstraint := sdk.NewOutOfLineConstraintRequest("OUT_OF_LINE_CONSTRAINT", sdk.ColumnConstraintTypeForeignKey). + outOfLineConstraint := sdk.NewOutOfLineConstraintRequest(sdk.ColumnConstraintTypeForeignKey). + WithName(sdk.String("OUT_OF_LINE_CONSTRAINT")). WithColumns([]string{"COLUMN_1"}). WithForeignKey(sdk.NewOutOfLineForeignKeyRequest(sdk.NewSchemaObjectIdentifier(database.Name, schema.Name, table2.Name), []string{"id"}). WithMatch(sdk.Pointer(sdk.FullMatchType)). @@ -476,7 +477,7 @@ func TestInt_Table(t *testing.T) { alterRequest := sdk.NewAlterTableRequest(id). WithColumnAction(sdk.NewTableColumnActionRequest(). - WithAdd(sdk.NewTableColumnAddActionRequest("COLUMN_3", sdk.DataTypeVARCHAR))) + WithAdd(sdk.NewTableColumnAddActionRequest("COLUMN_3", sdk.DataTypeVARCHAR).WithComment(sdk.String("some comment")))) err = client.Tables.Alter(ctx, alterRequest) require.NoError(t, err) @@ -670,7 +671,7 @@ func TestInt_Table(t *testing.T) { alterRequest := sdk.NewAlterTableRequest(id). WithConstraintAction(sdk.NewTableConstraintActionRequest(). - WithAdd(sdk.NewOutOfLineConstraintRequest("OUT_OF_LINE_CONSTRAINT", sdk.ColumnConstraintTypeForeignKey).WithColumns([]string{"COLUMN_1"}). + WithAdd(sdk.NewOutOfLineConstraintRequest(sdk.ColumnConstraintTypeForeignKey).WithName(sdk.String("OUT_OF_LINE_CONSTRAINT")).WithColumns([]string{"COLUMN_1"}). WithForeignKey(sdk.NewOutOfLineForeignKeyRequest(sdk.NewSchemaObjectIdentifier(database.Name, schema.Name, secondTableName), []string{"COLUMN_3"})))) err = client.Tables.Alter(ctx, alterRequest) require.NoError(t, err) @@ -685,7 +686,7 @@ func TestInt_Table(t *testing.T) { *sdk.NewTableColumnRequest("COLUMN_2", sdk.DataTypeVARCHAR), } oldConstraintName := "OUT_OF_LINE_CONSTRAINT" - outOfLineConstraint := sdk.NewOutOfLineConstraintRequest(oldConstraintName, sdk.ColumnConstraintTypePrimaryKey).WithColumns([]string{"COLUMN_1"}) + outOfLineConstraint := sdk.NewOutOfLineConstraintRequest(sdk.ColumnConstraintTypePrimaryKey).WithName(sdk.String(oldConstraintName)).WithColumns([]string{"COLUMN_1"}) err := client.Tables.Create(ctx, sdk.NewCreateTableRequest(id, columns).WithOutOfLineConstraint(*outOfLineConstraint)) require.NoError(t, err) @@ -703,7 +704,6 @@ func TestInt_Table(t *testing.T) { // TODO [SNOW-1007542]: check altered constraint t.Run("alter constraint: alter", func(t *testing.T) { - t.Skip("Test is failing: generated statement is not compiling but it is aligned with Snowflake docs https://docs.snowflake.com/en/sql-reference/sql/alter-table#syntax. Requires further investigation.") name := random.String() id := sdk.NewSchemaObjectIdentifier(database.Name, schema.Name, name) columns := []sdk.TableColumnRequest{ @@ -711,21 +711,20 @@ func TestInt_Table(t *testing.T) { *sdk.NewTableColumnRequest("COLUMN_2", sdk.DataTypeVARCHAR), } constraintName := "OUT_OF_LINE_CONSTRAINT" - outOfLineConstraint := sdk.NewOutOfLineConstraintRequest(constraintName, sdk.ColumnConstraintTypePrimaryKey).WithColumns([]string{"COLUMN_1"}) + outOfLineConstraint := sdk.NewOutOfLineConstraintRequest(sdk.ColumnConstraintTypePrimaryKey).WithName(sdk.String(constraintName)).WithColumns([]string{"COLUMN_1"}) err := client.Tables.Create(ctx, sdk.NewCreateTableRequest(id, columns).WithOutOfLineConstraint(*outOfLineConstraint)) require.NoError(t, err) t.Cleanup(cleanupTableProvider(id)) alterRequest := sdk.NewAlterTableRequest(id). - WithConstraintAction(sdk.NewTableConstraintActionRequest().WithAlter(sdk.NewTableConstraintAlterActionRequest([]string{"COLUMN_1"}).WithConstraintName(sdk.String(constraintName)).WithEnforced(sdk.Bool(true)))) + WithConstraintAction(sdk.NewTableConstraintActionRequest().WithAlter(sdk.NewTableConstraintAlterActionRequest().WithConstraintName(sdk.String(constraintName)).WithEnforced(sdk.Bool(true)))) err = client.Tables.Alter(ctx, alterRequest) require.NoError(t, err) }) // TODO [SNOW-1007542]: check dropped constraint - t.Run("alter constraint: drop", func(t *testing.T) { - t.Skip("Test is failing: generated statement is not compiling but it is aligned with Snowflake docs https://docs.snowflake.com/en/sql-reference/sql/alter-table#syntax. Requires further investigation.") + t.Run("alter constraint: drop constraint with name", func(t *testing.T) { name := random.String() id := sdk.NewSchemaObjectIdentifier(database.Name, schema.Name, name) columns := []sdk.TableColumnRequest{ @@ -733,19 +732,37 @@ func TestInt_Table(t *testing.T) { *sdk.NewTableColumnRequest("COLUMN_2", sdk.DataTypeVARCHAR), } constraintName := "OUT_OF_LINE_CONSTRAINT" - outOfLineConstraint := sdk.NewOutOfLineConstraintRequest(constraintName, sdk.ColumnConstraintTypePrimaryKey).WithColumns([]string{"COLUMN_1"}) + outOfLineConstraint := sdk.NewOutOfLineConstraintRequest(sdk.ColumnConstraintTypePrimaryKey).WithName(sdk.String(constraintName)).WithColumns([]string{"COLUMN_1"}) err := client.Tables.Create(ctx, sdk.NewCreateTableRequest(id, columns).WithOutOfLineConstraint(*outOfLineConstraint)) require.NoError(t, err) t.Cleanup(cleanupTableProvider(id)) alterRequest := sdk.NewAlterTableRequest(id). - WithConstraintAction(sdk.NewTableConstraintActionRequest().WithDrop(sdk.NewTableConstraintDropActionRequest([]string{"COLUMN_1"}).WithConstraintName(sdk.String(constraintName)))) + WithConstraintAction(sdk.NewTableConstraintActionRequest().WithDrop(sdk.NewTableConstraintDropActionRequest().WithConstraintName(sdk.String(constraintName)))) err = client.Tables.Alter(ctx, alterRequest) require.NoError(t, err) }) - t.Run("external table: add", func(t *testing.T) { + t.Run("alter constraint: drop primary key without constraint name", func(t *testing.T) { + name := random.String() + id := sdk.NewSchemaObjectIdentifier(database.Name, schema.Name, name) + columns := []sdk.TableColumnRequest{ + *sdk.NewTableColumnRequest("COLUMN_1", sdk.DataTypeVARCHAR), + } + outOfLineConstraint := sdk.NewOutOfLineConstraintRequest(sdk.ColumnConstraintTypePrimaryKey).WithColumns([]string{"COLUMN_1"}) + + err := client.Tables.Create(ctx, sdk.NewCreateTableRequest(id, columns).WithOutOfLineConstraint(*outOfLineConstraint)) + require.NoError(t, err) + t.Cleanup(cleanupTableProvider(id)) + + alterRequest := sdk.NewAlterTableRequest(id). + WithConstraintAction(sdk.NewTableConstraintActionRequest().WithDrop(sdk.NewTableConstraintDropActionRequest().WithPrimaryKey(sdk.Bool(true)))) + err = client.Tables.Alter(ctx, alterRequest) + require.NoError(t, err) + }) + + t.Run("external table: add column", func(t *testing.T) { name := random.String() id := sdk.NewSchemaObjectIdentifier(database.Name, schema.Name, name) columns := []sdk.TableColumnRequest{ @@ -758,7 +775,12 @@ func TestInt_Table(t *testing.T) { t.Cleanup(cleanupTableProvider(id)) alterRequest := sdk.NewAlterTableRequest(id). - WithExternalTableAction(sdk.NewTableExternalTableActionRequest().WithAdd(sdk.NewTableExternalTableColumnAddActionRequest().WithName("COLUMN_3").WithType(sdk.DataTypeNumber).WithExpression("1 + 1"))) + WithExternalTableAction(sdk.NewTableExternalTableActionRequest().WithAdd(sdk.NewTableExternalTableColumnAddActionRequest(). + WithName("COLUMN_3"). + WithType(sdk.DataTypeNumber). + WithExpression("1 + 1"). + WithComment(sdk.String("some comment")), + )) err = client.Tables.Alter(ctx, alterRequest) require.NoError(t, err) diff --git a/pkg/snowflake/table.go b/pkg/snowflake/table.go index 7aa6f4c855..58ba303894 100644 --- a/pkg/snowflake/table.go +++ b/pkg/snowflake/table.go @@ -1,420 +1,10 @@ package snowflake import ( - "database/sql" - "errors" "fmt" - "log" - "sort" - "strconv" - "strings" - - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" - - "github.com/jmoiron/sqlx" -) - -// PrimaryKey structure that represents a tables primary key. -type PrimaryKey struct { - name string - keys []string -} - -// WithName set the primary key name. -func (pk *PrimaryKey) WithName(name string) *PrimaryKey { - pk.name = name - return pk -} - -// WithKeys set the primary key keys. -func (pk *PrimaryKey) WithKeys(keys []string) *PrimaryKey { - pk.keys = keys - return pk -} - -type ColumnDefaultType int - -const ( - columnDefaultTypeConstant = iota - columnDefaultTypeSequence - columnDefaultTypeExpression ) -type ColumnDefault struct { - _type ColumnDefaultType - expression string -} - -type ColumnIdentity struct { - startNum int - stepNum int -} - -func (id *ColumnIdentity) WithStartNum(start int) *ColumnIdentity { - id.startNum = start - return id -} - -func (id *ColumnIdentity) WithStep(step int) *ColumnIdentity { - id.stepNum = step - return id -} - -func NewColumnDefaultWithConstant(constant string) *ColumnDefault { - return &ColumnDefault{ - _type: columnDefaultTypeConstant, - expression: constant, - } -} - -func NewColumnDefaultWithExpression(expression string) *ColumnDefault { - return &ColumnDefault{ - _type: columnDefaultTypeExpression, - expression: expression, - } -} - -func NewColumnDefaultWithSequence(sequence string) *ColumnDefault { - return &ColumnDefault{ - _type: columnDefaultTypeSequence, - expression: sequence, - } -} - -func (d *ColumnDefault) String(columnType string) string { - columnType = strings.ToUpper(columnType) - - switch { - case d._type == columnDefaultTypeExpression: - return d.expression - - case d._type == columnDefaultTypeSequence: - return fmt.Sprintf(`%v.NEXTVAL`, d.expression) - - case d._type == columnDefaultTypeConstant && (strings.Contains(columnType, "CHAR") || columnType == "STRING" || columnType == "TEXT"): - return EscapeSnowflakeString(d.expression) - - default: - return d.expression - } -} - -func (d *ColumnDefault) UnescapeConstantSnowflakeString(columnType string) string { - columnType = strings.ToUpper(columnType) - - if d._type == columnDefaultTypeConstant && (strings.Contains(columnType, "CHAR") || columnType == "STRING" || columnType == "TEXT") { - return UnescapeSnowflakeString(d.expression) - } - - return d.expression -} - -// Column structure that represents a table column. -type Column struct { - name string - _type string // type is reserved - nullable bool - _default *ColumnDefault // default is reserved - identity *ColumnIdentity - comment string // pointer as value is nullable - maskingPolicy string -} - -// WithName set the column name. -func (c *Column) WithName(name string) *Column { - c.name = name - return c -} - -// WithType set the column type. -func (c *Column) WithType(t string) *Column { - c._type = t - return c -} - -// WithNullable set if the column is nullable. -func (c *Column) WithNullable(nullable bool) *Column { - c.nullable = nullable - return c -} - -func (c *Column) WithDefault(cd *ColumnDefault) *Column { - c._default = cd - return c -} - -// WithComment set the column comment. -func (c *Column) WithComment(comment string) *Column { - c.comment = comment - return c -} - -func (c *Column) WithMaskingPolicy(maskingPolicy string) *Column { - c.maskingPolicy = maskingPolicy - return c -} - -func (c *Column) WithIdentity(id *ColumnIdentity) *Column { - c.identity = id - return c -} - -func (c *Column) getColumnDefinition(withInlineConstraints bool, withComment bool) string { - if c == nil { - return "" - } - var colDef strings.Builder - colDef.WriteString(fmt.Sprintf(`"%v" %v`, EscapeString(c.name), EscapeString(c._type))) - - if withInlineConstraints { - if !c.nullable { - colDef.WriteString(` NOT NULL`) - } - } - - if c._default != nil { - colDef.WriteString(fmt.Sprintf(` DEFAULT %v`, c._default.String(c._type))) - } - - if c.identity != nil { - colDef.WriteString(fmt.Sprintf(` IDENTITY(%v, %v)`, c.identity.startNum, c.identity.stepNum)) - } - - if strings.TrimSpace(c.maskingPolicy) != "" { - colDef.WriteString(fmt.Sprintf(` WITH MASKING POLICY %v`, EscapeString(c.maskingPolicy))) - } - - if withComment { - colDef.WriteString(fmt.Sprintf(` COMMENT '%v'`, EscapeString(c.comment))) - } - - return colDef.String() -} - -func FlattenTablePrimaryKey(pkds []PrimaryKeyDescription) []interface{} { - flattened := []interface{}{} - if len(pkds) == 0 { - return flattened - } - - sort.SliceStable(pkds, func(i, j int) bool { - num1, _ := strconv.Atoi(pkds[i].KeySequence.String) - num2, _ := strconv.Atoi(pkds[j].KeySequence.String) - return num1 < num2 - }) - // sort our keys on the key sequence - - flat := map[string]interface{}{} - keys := make([]string, 0, len(pkds)) - var name string - var nameSet bool - - for _, pk := range pkds { - // set as empty string, sys_constraint means it was an unnnamed constraint - if strings.Contains(pk.ConstraintName.String, "SYS_CONSTRAINT") && !nameSet { - name = "" - nameSet = true - } - if !nameSet { - name = pk.ConstraintName.String - nameSet = true - } - - keys = append(keys, pk.ColumnName.String) - } - - flat["name"] = name - flat["keys"] = keys - flattened = append(flattened, flat) - return flattened -} - -type Columns []Column - -// NewColumns generates columns from a table description. -func NewColumns(tds []TableDescription) Columns { - cs := []Column{} - for _, td := range tds { - if td.Kind.String != "COLUMN" { - continue - } - - cs = append(cs, Column{ - name: td.Name.String, - _type: td.Type.String, - nullable: td.IsNullable(), - _default: td.ColumnDefault(), - identity: td.ColumnIdentity(), - comment: td.Comment.String, - maskingPolicy: td.MaskingPolicy.String, - }) - } - return Columns(cs) -} - -func (c Columns) Flatten() []interface{} { - flattened := []interface{}{} - for _, col := range c { - flat := map[string]interface{}{} - flat["name"] = col.name - flat["type"] = col._type - flat["nullable"] = col.nullable - flat["comment"] = col.comment - flat["masking_policy"] = col.maskingPolicy - - if col._default != nil { - def := map[string]interface{}{} - switch col._default._type { - case columnDefaultTypeConstant: - def["constant"] = col._default.UnescapeConstantSnowflakeString(col._type) - case columnDefaultTypeExpression: - def["expression"] = col._default.expression - case columnDefaultTypeSequence: - def["sequence"] = col._default.expression - } - - flat["default"] = []interface{}{def} - } - - if col.identity != nil { - id := map[string]interface{}{} - id["start_num"] = col.identity.startNum - id["step_num"] = col.identity.stepNum - flat["identity"] = []interface{}{id} - } - flattened = append(flattened, flat) - } - return flattened -} - -func (c Columns) getColumnDefinitions(withInlineConstraints bool, withComments bool) string { - // TODO(el): verify Snowflake reflects column order back in desc table calls - columnDefinitions := []string{} - for _, column := range c { - columnDefinitions = append(columnDefinitions, column.getColumnDefinition(withInlineConstraints, withComments)) - } - - // NOTE: intentionally blank leading space - return fmt.Sprintf(" (%s)", strings.Join(columnDefinitions, ", ")) -} - -// TableBuilder abstracts the creation of SQL queries for a Snowflake schema. -type TableBuilder struct { - name string - db string - schema string - columns Columns - comment string - clusterBy []string - primaryKey PrimaryKey - dataRetentionTimeInDays *int - changeTracking bool - tags []TagValue -} - -// QualifiedName prepends the db and schema if set and escapes everything nicely. -func (tb *TableBuilder) QualifiedName() string { - var n strings.Builder - - if tb.db != "" && tb.schema != "" { - n.WriteString(fmt.Sprintf(`"%v"."%v".`, tb.db, tb.schema)) - } - - if tb.db != "" && tb.schema == "" { - n.WriteString(fmt.Sprintf(`"%v"..`, tb.db)) - } - - if tb.db == "" && tb.schema != "" { - n.WriteString(fmt.Sprintf(`"%v".`, tb.schema)) - } - - n.WriteString(fmt.Sprintf(`"%v"`, tb.name)) - - return n.String() -} - -// WithComment adds a comment to the TableBuilder. -func (tb *TableBuilder) WithComment(c string) *TableBuilder { - tb.comment = c - return tb -} - -// WithColumns sets the column definitions on the TableBuilder. -func (tb *TableBuilder) WithColumns(c Columns) *TableBuilder { - tb.columns = c - return tb -} - -// WithClustering adds cluster keys/expressions to TableBuilder. -func (tb *TableBuilder) WithClustering(c []string) *TableBuilder { - tb.clusterBy = c - return tb -} - -// WithPrimaryKey sets the primary key on the TableBuilder. -func (tb *TableBuilder) WithPrimaryKey(pk PrimaryKey) *TableBuilder { - tb.primaryKey = pk - return tb -} - -// WithDataRetentionTimeInDays sets the data retention time on the TableBuilder. -func (tb *TableBuilder) WithDataRetentionTimeInDays(days int) *TableBuilder { - tb.dataRetentionTimeInDays = sdk.Int(days) - return tb -} - -// WithChangeTracking sets the change tracking on the TableBuilder. -func (tb *TableBuilder) WithChangeTracking(changeTracking bool) *TableBuilder { - tb.changeTracking = changeTracking - return tb -} - -// WithTags sets the tags on the TableBuilder. -func (tb *TableBuilder) WithTags(tags []TagValue) *TableBuilder { - tb.tags = tags - return tb -} - -// AddTag returns the SQL query that will add a new tag to the table. -func (tb *TableBuilder) AddTag(tag TagValue) string { - return fmt.Sprintf(`ALTER TABLE %s SET TAG "%v"."%v"."%v" = "%v"`, tb.QualifiedName(), tag.Database, tag.Schema, tag.Name, tag.Value) -} - -// ChangeTag returns the SQL query that will alter a tag on the table. -func (tb *TableBuilder) ChangeTag(tag TagValue) string { - return fmt.Sprintf(`ALTER TABLE %s SET TAG "%v"."%v"."%v" = "%v"`, tb.QualifiedName(), tag.Database, tag.Schema, tag.Name, tag.Value) -} - -// UnsetTag returns the SQL query that will unset a tag on the table. -func (tb *TableBuilder) UnsetTag(tag TagValue) string { - return fmt.Sprintf(`ALTER TABLE %s UNSET TAG "%v"."%v"."%v"`, tb.QualifiedName(), tag.Database, tag.Schema, tag.Name) -} - -// Function to get clustering definition. -func (tb *TableBuilder) GetClusterKeyString() string { - return JoinStringList(tb.clusterBy, ", ") -} - -func (tb *TableBuilder) GetTagValueString() string { - var q strings.Builder - for _, v := range tb.tags { - fmt.Println(v) - if v.Schema != "" { - if v.Database != "" { - q.WriteString(fmt.Sprintf(`"%v".`, v.Database)) - } - q.WriteString(fmt.Sprintf(`"%v".`, v.Schema)) - } - q.WriteString(fmt.Sprintf(`"%v" = "%v", `, v.Name, v.Value)) - } - return strings.TrimSuffix(q.String(), ", ") -} - -func JoinStringList(instrings []string, delimiter string) string { - return fmt.Sprint(strings.Join(instrings, delimiter)) -} - -func quoteStringList(instrings []string) []string { +func QuoteStringList(instrings []string) []string { clean := make([]string, 0, len(instrings)) for _, word := range instrings { quoted := fmt.Sprintf(`"%s"`, word) @@ -422,351 +12,3 @@ func quoteStringList(instrings []string) []string { } return clean } - -func (tb *TableBuilder) getCreateStatementBody() string { - var q strings.Builder - - colDef := tb.columns.getColumnDefinitions(true, true) - - if len(tb.primaryKey.keys) > 0 { - colDef = strings.TrimSuffix(colDef, ")") // strip trailing - q.WriteString(colDef) - if tb.primaryKey.name != "" { - q.WriteString(fmt.Sprintf(` ,CONSTRAINT "%v" PRIMARY KEY(%v)`, tb.primaryKey.name, JoinStringList(quoteStringList(tb.primaryKey.keys), ","))) - } else { - q.WriteString(fmt.Sprintf(` ,PRIMARY KEY(%v)`, JoinStringList(quoteStringList(tb.primaryKey.keys), ","))) - } - - q.WriteString(")") // add closing - } else { - q.WriteString(colDef) - } - - return q.String() -} - -// function to take the literal snowflake cluster statement returned from SHOW TABLES and convert it to a list of keys. -func ClusterStatementToList(clusterStatement string) []string { - if clusterStatement == "" { - return nil - } - - cleanStatement := strings.TrimSuffix(strings.Replace(clusterStatement, "LINEAR(", "", 1), ")") - // remove cluster statement and trailing parenthesis - - spCleanStatement := strings.Split(cleanStatement, ",") - clean := make([]string, 0, len(spCleanStatement)) - for _, s := range spCleanStatement { - clean = append(clean, strings.TrimSpace(s)) - } - - return clean -} - -// Table returns a pointer to a Builder that abstracts the DDL operations for a table. -// -// Supported DDL operations are: -// - ALTER TABLE -// - DROP TABLE -// - SHOW TABLES -// -// [Snowflake Reference](https://docs.snowflake.com/en/sql-reference/ddl-table.html) -func NewTableBuilder(name, db, schema string) *TableBuilder { - return &TableBuilder{ - name: name, - db: db, - schema: schema, - } -} - -// Table returns a pointer to a Builder that abstracts the DDL operations for a table. -// -// Supported DDL operations are: -// - CREATE TABLE -// -// [Snowflake Reference](https://docs.snowflake.com/en/sql-reference/ddl-table.html) -func NewTableWithColumnDefinitionsBuilder(name, db, schema string, columns Columns) *TableBuilder { - return &TableBuilder{ - name: name, - db: db, - schema: schema, - columns: columns, - } -} - -// Create returns the SQL statement required to create a table. -func (tb *TableBuilder) Create() string { - q := strings.Builder{} - q.WriteString(fmt.Sprintf(`CREATE TABLE %v`, tb.QualifiedName())) - q.WriteString(tb.getCreateStatementBody()) - - if tb.comment != "" { - q.WriteString(fmt.Sprintf(` COMMENT = '%v'`, EscapeString(tb.comment))) - } - - if tb.clusterBy != nil { - // add optional clustering statement - q.WriteString(fmt.Sprintf(` CLUSTER BY LINEAR(%v)`, tb.GetClusterKeyString())) - } - - if tb.dataRetentionTimeInDays != nil { - q.WriteString(fmt.Sprintf(` DATA_RETENTION_TIME_IN_DAYS = %d`, *tb.dataRetentionTimeInDays)) - } - q.WriteString(fmt.Sprintf(` CHANGE_TRACKING = %t`, tb.changeTracking)) - - if tb.tags != nil { - q.WriteString(fmt.Sprintf(` WITH TAG (%v)`, tb.GetTagValueString())) - } - - return q.String() -} - -// ChangeClusterBy returns the SQL query to change cluastering on table. -func (tb *TableBuilder) ChangeClusterBy(cb string) string { - return fmt.Sprintf(`ALTER TABLE %v CLUSTER BY LINEAR(%v)`, tb.QualifiedName(), cb) -} - -// ChangeComment returns the SQL query that will update the comment on the table. -func (tb *TableBuilder) ChangeComment(c string) string { - return fmt.Sprintf(`ALTER TABLE %v SET COMMENT = '%v'`, tb.QualifiedName(), EscapeString(c)) -} - -// ChangeDataRetention returns the SQL query that will update the DATA_RETENTION_TIME_IN_DAYS on the table. -func (tb *TableBuilder) ChangeDataRetention(days int) string { - return fmt.Sprintf(`ALTER TABLE %v SET DATA_RETENTION_TIME_IN_DAYS = %d`, tb.QualifiedName(), days) -} - -// ChangeChangeTracking returns the SQL query that will update the CHANGE_TRACKING on the table. -func (tb *TableBuilder) ChangeChangeTracking(changeTracking bool) string { - return fmt.Sprintf(`ALTER TABLE %v SET CHANGE_TRACKING = %t`, tb.QualifiedName(), changeTracking) -} - -// AddColumn returns the SQL query that will add a new column to the table. -func (tb *TableBuilder) AddColumn(name string, dataType string, nullable bool, _default *ColumnDefault, identity *ColumnIdentity, comment string, maskingPolicy string) string { - col := Column{ - name: name, - _type: dataType, - nullable: nullable, - _default: _default, - identity: identity, - comment: comment, - maskingPolicy: maskingPolicy, - } - return fmt.Sprintf(`ALTER TABLE %s ADD COLUMN %s`, tb.QualifiedName(), col.getColumnDefinition(true, true)) -} - -// DropColumn returns the SQL query that will add a new column to the table. -func (tb *TableBuilder) DropColumn(name string) string { - return fmt.Sprintf(`ALTER TABLE %s DROP COLUMN "%s"`, tb.QualifiedName(), name) -} - -// ChangeColumnType returns the SQL query that will change the type of the named column to the given type. -func (tb *TableBuilder) ChangeColumnType(name string, dataType string) string { - col := Column{ - name: name, - _type: dataType, - } - - return fmt.Sprintf(`ALTER TABLE %s MODIFY COLUMN %s`, tb.QualifiedName(), col.getColumnDefinition(false, false)) -} - -func (tb *TableBuilder) ChangeColumnComment(name string, comment string) string { - return fmt.Sprintf(`ALTER TABLE %s MODIFY COLUMN "%v" COMMENT '%v'`, tb.QualifiedName(), EscapeString(name), EscapeString(comment)) -} - -func (tb *TableBuilder) ChangeColumnMaskingPolicy(name string, maskingPolicy string) string { - if strings.TrimSpace(maskingPolicy) == "" { - return fmt.Sprintf(`ALTER TABLE %s MODIFY COLUMN "%v" UNSET MASKING POLICY`, tb.QualifiedName(), EscapeString(name)) - } - return fmt.Sprintf(`ALTER TABLE %s MODIFY COLUMN "%v" SET MASKING POLICY %v`, tb.QualifiedName(), EscapeString(name), EscapeString(maskingPolicy)) -} - -func (tb *TableBuilder) DropColumnDefault(name string) string { - return fmt.Sprintf(`ALTER TABLE %s MODIFY COLUMN "%v" DROP DEFAULT`, tb.QualifiedName(), EscapeString(name)) -} - -// RemoveComment returns the SQL query that will remove the comment on the table. -func (tb *TableBuilder) RemoveComment() string { - return fmt.Sprintf(`ALTER TABLE %v UNSET COMMENT`, tb.QualifiedName()) -} - -// Return sql to set/unset null constraint on column. -func (tb *TableBuilder) ChangeNullConstraint(name string, nullable bool) string { - if nullable { - return fmt.Sprintf(`ALTER TABLE %s MODIFY COLUMN "%s" DROP NOT NULL`, tb.QualifiedName(), name) - } - return fmt.Sprintf(`ALTER TABLE %s MODIFY COLUMN "%s" SET NOT NULL`, tb.QualifiedName(), name) -} - -func (tb *TableBuilder) ChangePrimaryKey(newPk PrimaryKey) string { - tb.WithPrimaryKey(newPk) - pks := JoinStringList(quoteStringList(newPk.keys), ", ") - if tb.primaryKey.name != "" { - return fmt.Sprintf(`ALTER TABLE %s ADD CONSTRAINT "%v" PRIMARY KEY(%v)`, tb.QualifiedName(), tb.primaryKey.name, pks) - } - return fmt.Sprintf(`ALTER TABLE %s ADD PRIMARY KEY(%v)`, tb.QualifiedName(), pks) -} - -func (tb *TableBuilder) DropPrimaryKey() string { - return fmt.Sprintf(`ALTER TABLE %s DROP PRIMARY KEY`, tb.QualifiedName()) -} - -// RemoveClustering returns the SQL query that will remove data clustering from the table. -func (tb *TableBuilder) DropClustering() string { - return fmt.Sprintf(`ALTER TABLE %v DROP CLUSTERING KEY`, tb.QualifiedName()) -} - -// Drop returns the SQL query that will drop a table. -func (tb *TableBuilder) Drop() string { - return fmt.Sprintf(`DROP TABLE %v`, tb.QualifiedName()) -} - -// Show returns the SQL query that will show a table. -func (tb *TableBuilder) Show() string { - return fmt.Sprintf(`SHOW TABLES LIKE '%v' IN SCHEMA "%v"."%v"`, tb.name, tb.db, tb.schema) -} - -func (tb *TableBuilder) ShowColumns() string { - return fmt.Sprintf(`DESC TABLE %s`, tb.QualifiedName()) -} - -func (tb *TableBuilder) ShowPrimaryKeys() string { - return fmt.Sprintf(`SHOW PRIMARY KEYS IN TABLE %s`, tb.QualifiedName()) -} - -func (tb *TableBuilder) Rename(newName string) string { - oldName := tb.QualifiedName() - tb.name = newName - return fmt.Sprintf(`ALTER TABLE %s RENAME TO %s`, oldName, tb.QualifiedName()) -} - -type Table struct { - CreatedOn sql.NullString `db:"created_on"` - TableName sql.NullString `db:"name"` - DatabaseName sql.NullString `db:"database_name"` - SchemaName sql.NullString `db:"schema_name"` - Kind sql.NullString `db:"kind"` - Comment sql.NullString `db:"comment"` - ClusterBy sql.NullString `db:"cluster_by"` - Rows sql.NullString `db:"row"` - Bytes sql.NullString `db:"bytes"` - Owner sql.NullString `db:"owner"` - RetentionTime sql.NullInt32 `db:"retention_time"` - AutomaticClustering sql.NullString `db:"automatic_clustering"` - ChangeTracking sql.NullString `db:"change_tracking"` - IsExternal sql.NullString `db:"is_external"` -} - -func ScanTable(row *sqlx.Row) (*Table, error) { - t := &Table{} - e := row.StructScan(t) - return t, e -} - -type TableDescription struct { - Name sql.NullString `db:"name"` - Type sql.NullString `db:"type"` - Kind sql.NullString `db:"kind"` - Nullable sql.NullString `db:"null?"` - Default sql.NullString `db:"default"` - Comment sql.NullString `db:"comment"` - MaskingPolicy sql.NullString `db:"policy name"` -} - -func (td *TableDescription) IsNullable() bool { - return td.Nullable.String == "Y" -} - -func (td *TableDescription) ColumnDefault() *ColumnDefault { - if !td.Default.Valid { - return nil - } - - if strings.HasSuffix(td.Default.String, ".NEXTVAL") { - return NewColumnDefaultWithSequence(strings.TrimSuffix(td.Default.String, ".NEXTVAL")) - } - - if strings.Contains(td.Default.String, "(") && strings.Contains(td.Default.String, ")") { - return NewColumnDefaultWithExpression(td.Default.String) - } - - if strings.Contains(td.Type.String, "CHAR") || td.Type.String == "STRING" || td.Type.String == "TEXT" { - return NewColumnDefaultWithConstant(UnescapeSnowflakeString(td.Default.String)) - } - - if td.ColumnIdentity() != nil { - /* - Identity/autoincrement information is stored in the same column as default information. We want to handle the identity separate so will return nil - here if identity information is present. Default/identity are mutually exclusive - */ - return nil - } - - return NewColumnDefaultWithConstant(td.Default.String) -} - -func (td *TableDescription) ColumnIdentity() *ColumnIdentity { - // if autoincrement is used this is reflected back IDENTITY START 1 INCREMENT 1 - if !td.Default.Valid { - return nil - } - if strings.Contains(td.Default.String, "IDENTITY") { - split := strings.Split(td.Default.String, " ") - start, _ := strconv.Atoi(split[2]) - step, _ := strconv.Atoi(split[4]) - - return &ColumnIdentity{start, step} - } - return nil -} - -type PrimaryKeyDescription struct { - ColumnName sql.NullString `db:"column_name"` - KeySequence sql.NullString `db:"key_sequence"` - ConstraintName sql.NullString `db:"constraint_name"` -} - -func ScanTableDescription(rows *sqlx.Rows) ([]TableDescription, error) { - tds := []TableDescription{} - for rows.Next() { - td := TableDescription{} - err := rows.StructScan(&td) - if err != nil { - return nil, err - } - tds = append(tds, td) - } - return tds, rows.Err() -} - -func ScanPrimaryKeyDescription(rows *sqlx.Rows) ([]PrimaryKeyDescription, error) { - pkds := []PrimaryKeyDescription{} - for rows.Next() { - pk := PrimaryKeyDescription{} - err := rows.StructScan(&pk) - if err != nil { - return nil, err - } - pkds = append(pkds, pk) - } - return pkds, rows.Err() -} - -func ListTables(databaseName string, schemaName string, db *sql.DB) ([]Table, error) { - stmt := fmt.Sprintf(`SHOW TABLES IN SCHEMA "%s"."%v"`, databaseName, schemaName) - rows, err := Query(db, stmt) - if err != nil { - return nil, err - } - defer rows.Close() - - dbs := []Table{} - if err := sqlx.StructScan(rows, &dbs); err != nil { - if errors.Is(err, sql.ErrNoRows) { - log.Println("[DEBUG] no tables found") - return nil, nil - } - return nil, fmt.Errorf("unable to scan row for %s err = %w", stmt, err) - } - return dbs, nil -}