Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jmichalak committed Nov 18, 2024
1 parent 7736e0a commit 6f06e25
Show file tree
Hide file tree
Showing 26 changed files with 407 additions and 272 deletions.
8 changes: 8 additions & 0 deletions pkg/acceptance/helpers/tag_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ func (c *TagClient) CreateWithRequest(t *testing.T, req *sdk.CreateTagRequest) (
return tag, c.DropTagFunc(t, req.GetName())
}

func (c *TagClient) Unset(t *testing.T, objectType sdk.ObjectType, id sdk.ObjectIdentifier, unsetTags []sdk.ObjectIdentifier) {
t.Helper()
ctx := context.Background()

err := c.client().Unset(ctx, sdk.NewUnsetTagRequest(objectType, id).WithUnsetTags(unsetTags))
require.NoError(t, err)
}

func (c *TagClient) DropTagFunc(t *testing.T, id sdk.SchemaObjectIdentifier) func() {
t.Helper()
ctx := context.Background()
Expand Down
13 changes: 13 additions & 0 deletions pkg/resources/helper_expansion.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@ func expandStringList(configured []interface{}) []string {
return vs
}

func expandStringListWithMapping[T any](configured []any, mapping func(string) (T, error)) ([]T, error) {
stringList := expandStringList(configured)
vs := make([]T, 0, len(configured))
for _, v := range stringList {
val, err := mapping(v)
if err != nil {
return nil, err
}
vs = append(vs, val)
}
return vs, nil
}

func expandStringListAllowEmpty(configured []interface{}) []string {
// Allow empty values during expansion
vs := make([]string, 0, len(configured))
Expand Down
203 changes: 91 additions & 112 deletions pkg/resources/tag_association.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import (
"context"
"fmt"
"log"
"slices"
"strings"
"time"

"github.com/hashicorp/terraform-plugin-sdk/v2/diag"
Expand All @@ -27,32 +25,14 @@ var tagAssociationSchema = map[string]*schema.Schema{
Deprecated: "Use `object_identifier` instead",
},
"object_identifier": {
Type: schema.TypeList,
Type: schema.TypeSet,
MinItems: 1,
Required: true,
Description: "Specifies the object identifier for the tag association.",
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"name": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
Description: "Name of the object to associate the tag with.",
},
"database": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
Description: "Name of the database that the object was created in.",
},
"schema": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
Description: "Name of the schema that the object was created in.",
},
},
Description: "Specifies the object identifiers for the tag association.",
Elem: &schema.Schema{
Type: schema.TypeString,
},
DiffSuppressFunc: NormalizeAndCompareIdentifiersInSet("object_identifier"),
},
"object_type": {
Type: schema.TypeString,
Expand All @@ -62,10 +42,11 @@ var tagAssociationSchema = map[string]*schema.Schema{
ForceNew: true,
},
"tag_id": {
Type: schema.TypeString,
Required: true,
Description: "Specifies the identifier for the tag. Note: format must follow: \"databaseName\".\"schemaName\".\"tagName\" or \"databaseName.schemaName.tagName\" or \"databaseName|schemaName.tagName\" (snowflake_tag.tag.id)",
ForceNew: true,
Type: schema.TypeString,
Required: true,
Description: "Specifies the identifier for the tag.",
ForceNew: true,
DiffSuppressFunc: suppressIdentifierQuoting,
},
"tag_value": {
Type: schema.TypeString,
Expand All @@ -88,6 +69,7 @@ func TagAssociation() *schema.Resource {
ReadContext: ReadContextTagAssociation,
UpdateContext: UpdateContextTagAssociation,
DeleteContext: DeleteContextTagAssociation,
Description: "Resource used to manage tag associations. For more information, check [object tagging documentation](https://docs.snowflake.com/en/user-guide/object-tagging).",

Schema: tagAssociationSchema,
Importer: &schema.ResourceImporter{
Expand All @@ -99,66 +81,39 @@ func TagAssociation() *schema.Resource {
}
}

func TagIdentifierAndObjectIdentifier(d *schema.ResourceData) (sdk.SchemaObjectIdentifier, []sdk.ObjectIdentifier, sdk.ObjectType) {
func TagIdentifierAndObjectIdentifier(d *schema.ResourceData) (sdk.SchemaObjectIdentifier, []sdk.ObjectIdentifier, sdk.ObjectType, error) {
tag := d.Get("tag_id").(string)
tagId, err := sdk.ParseSchemaObjectIdentifier(tag)
if err != nil {
return sdk.SchemaObjectIdentifier{}, nil, "", fmt.Errorf("invalid tag id: %w", err)
}

objectType := sdk.ObjectType(d.Get("object_type").(string))

tagDatabase, tagSchema, tagName := ParseFullyQualifiedObjectID(tag)
tid := sdk.NewSchemaObjectIdentifier(tagDatabase, tagSchema, tagName)

var identifiers []sdk.ObjectIdentifier
for _, item := range d.Get("object_identifier").([]interface{}) {
m := item.(map[string]interface{})
name := strings.Trim(m["name"].(string), `"`)
var databaseName, schemaName string
if v, ok := m["database"]; ok {
databaseName = strings.Trim(v.(string), `"`)
if databaseName == "" && slices.Contains(sdk.TagAssociationTagObjectTypeIsSchemaObjectType, objectType) {
databaseName = tagDatabase
}
}
if v, ok := m["schema"]; ok {
schemaName = strings.Trim(v.(string), `"`)
if schemaName == "" && slices.Contains(sdk.TagAssociationTagObjectTypeIsSchemaObjectType, objectType) {
schemaName = tagSchema
}
}
switch {
case databaseName != "" && schemaName != "":
if objectType == sdk.ObjectTypeColumn {
fields := strings.Split(name, ".")
if len(fields) > 1 {
tableName := strings.ReplaceAll(fields[0], `"`, "")
var parts []string
for i := 1; i < len(fields); i++ {
parts = append(parts, strings.ReplaceAll(fields[i], `"`, ""))
}
columnName := strings.Join(parts, ".")
identifiers = append(identifiers, sdk.NewTableColumnIdentifier(databaseName, schemaName, tableName, columnName))
} else {
identifiers = append(identifiers, sdk.NewSchemaObjectIdentifier(databaseName, schemaName, name))
}
} else {
identifiers = append(identifiers, sdk.NewSchemaObjectIdentifier(databaseName, schemaName, name))
}
case databaseName != "":
identifiers = append(identifiers, sdk.NewDatabaseObjectIdentifier(databaseName, name))
default:
identifiers = append(identifiers, sdk.NewAccountObjectIdentifier(name))
idsRaw := expandStringList(d.Get("object_identifier").(*schema.Set).List())
ids := make([]sdk.ObjectIdentifier, len(idsRaw))
for i, idRaw := range idsRaw {
id, err := sdk.ParseObjectIdentifierString(idRaw)
if err != nil {
return sdk.SchemaObjectIdentifier{}, nil, "", fmt.Errorf("invalid object id: %w", err)
}
ids[i] = id
}
return tid, identifiers, objectType
return tagId, ids, objectType, nil
}

func CreateContextTagAssociation(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
func CreateContextTagAssociation(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics {
client := meta.(*provider.Context).Client
tagValue := d.Get("tag_value").(string)

tid, ids, ot := TagIdentifierAndObjectIdentifier(d)
tagId, ids, objectType, err := TagIdentifierAndObjectIdentifier(d)
if err != nil {
return diag.FromErr(err)
}
for _, oid := range ids {
request := sdk.NewSetTagRequest(ot, oid).WithSetTags([]sdk.TagAssociation{
request := sdk.NewSetTagRequest(objectType, oid).WithSetTags([]sdk.TagAssociation{
{
Name: tid,
Name: tagId,
Value: tagValue,
},
})
Expand All @@ -169,12 +124,12 @@ func CreateContextTagAssociation(ctx context.Context, d *schema.ResourceData, me
if !skipValidate {
log.Println("[DEBUG] validating tag creation")
if err := retry.RetryContext(ctx, d.Timeout(schema.TimeoutCreate)-time.Minute, func() *retry.RetryError {
tag, err := client.SystemFunctions.GetTag(ctx, tid, oid, ot)
tag, err := client.SystemFunctions.GetTag(ctx, tagId, oid, objectType)
if err != nil {
return retry.NonRetryableError(fmt.Errorf("error getting tag: %w", err))
}
// if length of response is zero, tag association was not found. retry
if len(tag) == 0 {
if tag == nil {
return retry.RetryableError(fmt.Errorf("expected tag association to be created but not yet created"))
}
return nil
Expand All @@ -183,63 +138,87 @@ func CreateContextTagAssociation(ctx context.Context, d *schema.ResourceData, me
}
}
}
d.SetId(helpers.EncodeSnowflakeID(tid.DatabaseName(), tid.SchemaName(), tid.Name()))
d.SetId(helpers.EncodeSnowflakeID(tagId.FullyQualifiedName(), tagValue, string(objectType)))
return ReadContextTagAssociation(ctx, d, meta)
}

func ReadContextTagAssociation(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
diags := diag.Diagnostics{}
func ReadContextTagAssociation(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics {
client := meta.(*provider.Context).Client
tagValue := d.Get("tag_value").(string)

tid, ids, ot := TagIdentifierAndObjectIdentifier(d)
tagId, ids, objectType, err := TagIdentifierAndObjectIdentifier(d)
if err != nil {
return diag.FromErr(err)
}
var correctObjectIds []string
for _, oid := range ids {
tagValue, err := client.SystemFunctions.GetTag(ctx, tid, oid, ot)
objectTagValue, err := client.SystemFunctions.GetTag(ctx, tagId, oid, objectType)
if err != nil {
return diag.FromErr(err)
}
if err := d.Set("tag_value", tagValue); err != nil {
return diag.FromErr(err)
if objectTagValue != nil && *objectTagValue == tagValue {
correctObjectIds = append(correctObjectIds, oid.FullyQualifiedName())
}
}
return diags
if err := d.Set("object_identifier", correctObjectIds); err != nil {
return diag.FromErr(err)
}
return nil
}

func UpdateContextTagAssociation(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
func UpdateContextTagAssociation(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics {
client := meta.(*provider.Context).Client
tid, ids, ot := TagIdentifierAndObjectIdentifier(d)
for _, oid := range ids {
if d.HasChange("skip_validation") {
o, n := d.GetChange("skip_validation")
log.Printf("[DEBUG] skip_validation changed from %v to %v", o, n)
tagId, _, objectType, err := TagIdentifierAndObjectIdentifier(d)
if err != nil {
return diag.FromErr(err)
}
tagValue := d.Get("tag_value").(string)
if d.HasChange("object_identifier") {
o, n := d.GetChange("object_identifier")

oldAllowedValues, err := expandStringListWithMapping(o.(*schema.Set).List(), sdk.ParseObjectIdentifierString)
if err != nil {
return diag.FromErr(err)
}
if d.HasChange("tag_value") {
tagValue, ok := d.GetOk("tag_value")
if ok {
request := sdk.NewSetTagRequest(ot, oid).WithSetTags([]sdk.TagAssociation{
{
Name: tid,
Value: tagValue.(string),
},
})
if err := client.Tags.Set(ctx, request); err != nil {
return diag.FromErr(err)
}
} else {
request := sdk.NewUnsetTagRequest(ot, oid).WithUnsetTags([]sdk.ObjectIdentifier{tid})
if err := client.Tags.Unset(ctx, request); err != nil {
return diag.FromErr(err)
}
newAllowedValues, err := expandStringListWithMapping(n.(*schema.Set).List(), sdk.ParseObjectIdentifierString)
if err != nil {
return diag.FromErr(err)
}

addedids, removedids := ListDiff(oldAllowedValues, newAllowedValues)

for _, id := range addedids {
request := sdk.NewSetTagRequest(objectType, id).WithSetTags([]sdk.TagAssociation{
{
Name: tagId,
Value: tagValue,
},
})
if err := client.Tags.Set(ctx, request); err != nil {
return diag.FromErr(err)
}
}

for _, id := range removedids {
request := sdk.NewUnsetTagRequest(objectType, id).WithUnsetTags([]sdk.ObjectIdentifier{tagId}).WithIfExists(true)
if err := client.Tags.Unset(ctx, request); err != nil {
return diag.FromErr(err)
}
}

}

return ReadContextTagAssociation(ctx, d, meta)
}

func DeleteContextTagAssociation(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
client := meta.(*provider.Context).Client
tid, ids, ot := TagIdentifierAndObjectIdentifier(d)
tid, ids, ot, err := TagIdentifierAndObjectIdentifier(d)
if err != nil {
return diag.FromErr(err)
}
for _, oid := range ids {
request := sdk.NewUnsetTagRequest(ot, oid).WithUnsetTags([]sdk.ObjectIdentifier{tid})
request := sdk.NewUnsetTagRequest(ot, oid).WithUnsetTags([]sdk.ObjectIdentifier{tid}).WithIfExists(true)
if err := client.Tags.Unset(ctx, request); err != nil {
return diag.FromErr(err)
}
Expand Down
Loading

0 comments on commit 6f06e25

Please sign in to comment.