Skip to content

Commit

Permalink
fix: Fix replication for database resource (#2524)
Browse files Browse the repository at this point in the history
- Move replication config in create to the right place
- Add show replication databases to the SDK
- Add `FullyQualifiedName` to `ExternalIdentifier`

References: #2021
  • Loading branch information
sfc-gh-asawicki authored Feb 20, 2024
1 parent 61883f3 commit 767fbce
Show file tree
Hide file tree
Showing 11 changed files with 392 additions and 49 deletions.
42 changes: 22 additions & 20 deletions pkg/resources/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,26 +115,6 @@ func CreateDatabase(d *schema.ResourceData, meta interface{}) error {
return fmt.Errorf("error creating database %v: %w", name, err)
}
d.SetId(name)
if v, ok := d.GetOk("replication_configuration"); ok {
replicationConfiguration := v.([]interface{})[0].(map[string]interface{})
accounts := replicationConfiguration["accounts"].([]interface{})
accountIDs := make([]sdk.AccountIdentifier, len(accounts))
for i, account := range accounts {
accountIDs[i] = sdk.NewAccountIdentifierFromAccountLocator(account.(string))
}
opts := &sdk.AlterDatabaseReplicationOptions{
EnableReplication: &sdk.EnableReplication{
ToAccounts: accountIDs,
},
}
if ignoreEditionCheck, ok := replicationConfiguration["ignore_edition_check"]; ok {
opts.EnableReplication.IgnoreEditionCheck = sdk.Bool(ignoreEditionCheck.(bool))
}
err := client.Databases.AlterReplication(ctx, id, opts)
if err != nil {
return fmt.Errorf("error enabling replication for database %v: %w", name, err)
}
}
return ReadDatabase(d, meta)
}
// Is it a Secondary Database?
Expand Down Expand Up @@ -177,6 +157,28 @@ func CreateDatabase(d *schema.ResourceData, meta interface{}) error {
return fmt.Errorf("error creating database %v: %w", name, err)
}
d.SetId(name)

if v, ok := d.GetOk("replication_configuration"); ok {
replicationConfiguration := v.([]interface{})[0].(map[string]interface{})
accounts := replicationConfiguration["accounts"].([]interface{})
accountIDs := make([]sdk.AccountIdentifier, len(accounts))
for i, account := range accounts {
accountIDs[i] = sdk.NewAccountIdentifierFromAccountLocator(account.(string))
}
opts := &sdk.AlterDatabaseReplicationOptions{
EnableReplication: &sdk.EnableReplication{
ToAccounts: accountIDs,
},
}
if ignoreEditionCheck, ok := replicationConfiguration["ignore_edition_check"]; ok {
opts.EnableReplication.IgnoreEditionCheck = sdk.Bool(ignoreEditionCheck.(bool))
}
err := client.Databases.AlterReplication(ctx, id, opts)
if err != nil {
return fmt.Errorf("error enabling replication for database %v: %w", name, err)
}
}

return ReadDatabase(d, meta)
}

Expand Down
87 changes: 72 additions & 15 deletions pkg/resources/database_acceptance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ package resources_test
import (
"context"
"fmt"
"os"
"strings"
"testing"

acc "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance"

"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk"
"github.com/hashicorp/terraform-plugin-testing/config"
"github.com/hashicorp/terraform-plugin-testing/helper/acctest"
Expand All @@ -19,14 +19,14 @@ import (
)

func TestAcc_DatabaseWithUnderscore(t *testing.T) {
if _, ok := os.LookupEnv("SKIP_DATABASE_TESTS"); ok {
t.Skip("Skipping TestAcc_DatabaseWithUnderscore")
}

prefix := "_" + strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha))
resource.ParallelTest(t, resource.TestCase{
Providers: acc.TestAccProviders(),
PreCheck: func() { acc.TestAccPreCheck(t) },

resource.Test(t, resource.TestCase{
ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories,
PreCheck: func() { acc.TestAccPreCheck(t) },
TerraformVersionChecks: []tfversion.TerraformVersionCheck{
tfversion.RequireAbove(tfversion.Version1_5_0),
},
CheckDestroy: nil,
Steps: []resource.TestStep{
{
Expand All @@ -42,18 +42,17 @@ func TestAcc_DatabaseWithUnderscore(t *testing.T) {
}

func TestAcc_Database(t *testing.T) {
if _, ok := os.LookupEnv("SKIP_DATABASE_TESTS"); ok {
t.Skip("Skipping TestAcc_Database")
}

prefix := "tst-terraform" + strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha))
prefix2 := "tst-terraform" + strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha))

secondaryAccountName := getSecondaryAccount(t)

resource.ParallelTest(t, resource.TestCase{
Providers: acc.TestAccProviders(),
PreCheck: func() { acc.TestAccPreCheck(t) },
resource.Test(t, resource.TestCase{
ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories,
PreCheck: func() { acc.TestAccPreCheck(t) },
TerraformVersionChecks: []tfversion.TerraformVersionCheck{
tfversion.RequireAbove(tfversion.Version1_5_0),
},
CheckDestroy: nil,
Steps: []resource.TestStep{
{
Expand Down Expand Up @@ -150,6 +149,34 @@ func TestAcc_DatabaseRemovedOutsideOfTerraform(t *testing.T) {
})
}

// proves https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2021
func TestAcc_Database_issue2021(t *testing.T) {
name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha))

secondaryAccountName := getSecondaryAccount(t)

resource.Test(t, resource.TestCase{
ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories,
PreCheck: func() { acc.TestAccPreCheck(t) },
TerraformVersionChecks: []tfversion.TerraformVersionCheck{
tfversion.RequireAbove(tfversion.Version1_5_0),
},
CheckDestroy: nil,
Steps: []resource.TestStep{
{
Config: dbConfigWithReplication(name, secondaryAccountName),
Check: resource.ComposeTestCheckFunc(
resource.TestCheckResourceAttr("snowflake_database.db", "name", name),
resource.TestCheckResourceAttr("snowflake_database.db", "replication_configuration.#", "1"),
resource.TestCheckResourceAttr("snowflake_database.db", "replication_configuration.0.accounts.#", "1"),
resource.TestCheckResourceAttr("snowflake_database.db", "replication_configuration.0.accounts.0", secondaryAccountName),
testAccCheckIfDatabaseIsReplicated(t, name),
),
},
},
})
}

func dbConfig(prefix string) string {
s := `
resource "snowflake_database" "db" {
Expand Down Expand Up @@ -235,3 +262,33 @@ func testAccCheckDatabaseExistence(t *testing.T, id string, shouldExist bool) fu
return nil
}
}

func testAccCheckIfDatabaseIsReplicated(t *testing.T, id string) func(state *terraform.State) error {
t.Helper()
return func(state *terraform.State) error {
client, err := sdk.NewDefaultClient()
if err != nil {
return err
}

ctx := context.Background()
replicationDatabases, err := client.ReplicationFunctions.ShowReplicationDatabases(ctx, nil)
if err != nil {
return err
}

var exists bool
for _, o := range replicationDatabases {
if o.Name == id {
exists = true
break
}
}

if !exists {
return fmt.Errorf("database %s should be replicated", id)
}

return nil
}
}
6 changes: 3 additions & 3 deletions pkg/sdk/databases_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func TestDatabasesCreateShared(t *testing.T) {
name: databaseID,
fromShare: NewExternalObjectIdentifier(NewAccountIdentifierFromAccountLocator("account1"), NewAccountObjectIdentifier("db1")),
}
assertOptsValidAndSQLEquals(t, opts, `CREATE DATABASE "db1" FROM SHARE account1."db1"`)
assertOptsValidAndSQLEquals(t, opts, `CREATE DATABASE "db1" FROM SHARE "account1"."db1"`)
})

t.Run("with comment", func(t *testing.T) {
Expand All @@ -55,7 +55,7 @@ func TestDatabasesCreateShared(t *testing.T) {
fromShare: NewExternalObjectIdentifier(NewAccountIdentifierFromAccountLocator("account1"), NewAccountObjectIdentifier("db1")),
Comment: String("comment"),
}
assertOptsValidAndSQLEquals(t, opts, `CREATE DATABASE "db1" FROM SHARE account1."db1" COMMENT = 'comment'`)
assertOptsValidAndSQLEquals(t, opts, `CREATE DATABASE "db1" FROM SHARE "account1"."db1" COMMENT = 'comment'`)
})
}

Expand All @@ -65,7 +65,7 @@ func TestDatabasesCreateSecondary(t *testing.T) {
primaryDatabase: NewExternalObjectIdentifier(NewAccountIdentifierFromAccountLocator("account1"), NewAccountObjectIdentifier("db1")),
DataRetentionTimeInDays: Int(1),
}
assertOptsValidAndSQLEquals(t, opts, `CREATE DATABASE "db1" AS REPLICA OF account1."db1" DATA_RETENTION_TIME_IN_DAYS = 1`)
assertOptsValidAndSQLEquals(t, opts, `CREATE DATABASE "db1" AS REPLICA OF "account1"."db1" DATA_RETENTION_TIME_IN_DAYS = 1`)
}

func TestDatabasesDrop(t *testing.T) {
Expand Down
6 changes: 3 additions & 3 deletions pkg/sdk/failover_groups_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestFailoverGroupsCreate(t *testing.T) {
IgnoreEditionCheck: Bool(true),
ReplicationSchedule: String("10 MINUTE"),
}
assertOptsValidAndSQLEquals(t, opts, `CREATE FAILOVER GROUP IF NOT EXISTS "fg1" OBJECT_TYPES = SHARES, DATABASES ALLOWED_DATABASES = "db1" ALLOWED_SHARES = "share1" ALLOWED_ACCOUNTS = "MY_ORG.MY_ACCOUNT" IGNORE EDITION CHECK REPLICATION_SCHEDULE = '10 MINUTE'`)
assertOptsValidAndSQLEquals(t, opts, `CREATE FAILOVER GROUP IF NOT EXISTS "fg1" OBJECT_TYPES = SHARES, DATABASES ALLOWED_DATABASES = "db1" ALLOWED_SHARES = "share1" ALLOWED_ACCOUNTS = "MY_ORG"."MY_ACCOUNT" IGNORE EDITION CHECK REPLICATION_SCHEDULE = '10 MINUTE'`)
})

t.Run("minimal", func(t *testing.T) {
Expand All @@ -39,7 +39,7 @@ func TestFailoverGroupsCreate(t *testing.T) {
NewAccountIdentifier("MY_ORG", "MY_ACCOUNT"),
},
}
assertOptsValidAndSQLEquals(t, opts, `CREATE FAILOVER GROUP IF NOT EXISTS "fg1" OBJECT_TYPES = ROLES ALLOWED_ACCOUNTS = "MY_ORG.MY_ACCOUNT"`)
assertOptsValidAndSQLEquals(t, opts, `CREATE FAILOVER GROUP IF NOT EXISTS "fg1" OBJECT_TYPES = ROLES ALLOWED_ACCOUNTS = "MY_ORG"."MY_ACCOUNT"`)
})
}

Expand All @@ -49,7 +49,7 @@ func TestCreateSecondaryReplicationGroup(t *testing.T) {
name: NewAccountObjectIdentifier("fg1"),
primaryFailoverGroup: NewExternalObjectIdentifierFromFullyQualifiedName("myorg.myaccount.fg1"),
}
assertOptsValidAndSQLEquals(t, opts, `CREATE FAILOVER GROUP IF NOT EXISTS "fg1" AS REPLICA OF myorg.myaccount."fg1"`)
assertOptsValidAndSQLEquals(t, opts, `CREATE FAILOVER GROUP IF NOT EXISTS "fg1" AS REPLICA OF "myorg"."myaccount"."fg1"`)
}

func TestFailoverGroupAlterSource(t *testing.T) {
Expand Down
9 changes: 8 additions & 1 deletion pkg/sdk/identifier_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (i ExternalObjectIdentifier) Name() string {
}

func (i ExternalObjectIdentifier) FullyQualifiedName() string {
return fmt.Sprintf(`%v.%v`, i.accountIdentifier.Name(), i.objectIdentifier.FullyQualifiedName())
return fmt.Sprintf(`%v.%v`, i.accountIdentifier.FullyQualifiedName(), i.objectIdentifier.FullyQualifiedName())
}

type AccountIdentifier struct {
Expand Down Expand Up @@ -116,6 +116,13 @@ func (i AccountIdentifier) Name() string {
return i.accountLocator
}

func (i AccountIdentifier) FullyQualifiedName() string {
if i.organizationName != "" && i.accountName != "" {
return fmt.Sprintf(`"%s"."%s"`, i.organizationName, i.accountName)
}
return fmt.Sprintf(`"%s"`, i.accountLocator)
}

type AccountObjectIdentifier struct {
name string
}
Expand Down
4 changes: 4 additions & 0 deletions pkg/sdk/random.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ func RandomSchemaObjectIdentifier() SchemaObjectIdentifier {
return NewSchemaObjectIdentifier(random.StringN(12), random.StringN(12), random.StringN(12))
}

func RandomExternalObjectIdentifier() ExternalObjectIdentifier {
return NewExternalObjectIdentifier(NewAccountIdentifierFromAccountLocator(random.StringN(12)), RandomAccountObjectIdentifier())
}

func RandomDatabaseObjectIdentifier() DatabaseObjectIdentifier {
return NewDatabaseObjectIdentifier(random.StringN(12), random.StringN(12))
}
Expand Down
99 changes: 95 additions & 4 deletions pkg/sdk/replication_functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,26 @@ package sdk

import (
"context"
"database/sql"
"errors"
"time"
)

var _ validatable = new(ShowRegionsOptions)
var _ ReplicationFunctions = (*replicationFunctions)(nil)

var (
_ validatable = new(ShowRegionsOptions)
_ validatable = new(ShowReplicationDatabasesOptions)
)

var _ convertibleRow[ReplicationDatabase] = new(replicationDatabaseRow)

type ReplicationFunctions interface {
ShowReplicationAccounts(ctx context.Context) ([]*ReplicationAccount, error)
// todo: ShowReplicationDatabases(ctx context.Context, opts *ShowReplicationDatabasesOptions) ([]*ReplicationDatabase, error)
ShowReplicationDatabases(ctx context.Context, opts *ShowReplicationDatabasesOptions) ([]ReplicationDatabase, error)
ShowRegions(ctx context.Context, opts *ShowRegionsOptions) ([]*Region, error)
}

var _ ReplicationFunctions = (*replicationFunctions)(nil)

type replicationFunctions struct {
client *Client
}
Expand Down Expand Up @@ -54,6 +60,91 @@ func (c *replicationFunctions) ShowReplicationAccounts(ctx context.Context) ([]*
return replicationAccounts, nil
}

type replicationDatabaseRow struct {
RegionGroup sql.NullString `db:"region_group"`
SnowflakeRegion string `db:"snowflake_region"`
CreatedOn string `db:"created_on"`
AccountName string `db:"account_name"`
Name string `db:"name"`
Comment sql.NullString `db:"comment"`
IsPrimary bool `db:"is_primary"`
PrimaryDatabase string `db:"primary"`
ReplicationAllowedToAccounts sql.NullString `db:"replication_allowed_to_accounts"`
FailoverAllowedToAccounts sql.NullString `db:"failover_allowed_to_accounts"`
OrganizationName string `db:"organization_name"`
AccountLocator string `db:"account_locator"`
}

type ReplicationDatabase struct {
RegionGroup string
SnowflakeRegion string
CreatedOn string
AccountName string
Name string
Comment string
IsPrimary bool
PrimaryDatabase string
ReplicationAllowedToAccounts string
FailoverAllowedToAccounts string
OrganizationName string
AccountLocator string
}

func (row replicationDatabaseRow) convert() *ReplicationDatabase {
db := &ReplicationDatabase{
SnowflakeRegion: row.SnowflakeRegion,
CreatedOn: row.CreatedOn,
AccountName: row.AccountName,
Name: row.Name,
IsPrimary: row.IsPrimary,
PrimaryDatabase: row.PrimaryDatabase,
OrganizationName: row.OrganizationName,
AccountLocator: row.AccountLocator,
}
if row.RegionGroup.Valid {
db.RegionGroup = row.RegionGroup.String
}
if row.Comment.Valid {
db.Comment = row.Comment.String
}
if row.ReplicationAllowedToAccounts.Valid {
db.ReplicationAllowedToAccounts = row.ReplicationAllowedToAccounts.String
}
if row.FailoverAllowedToAccounts.Valid {
db.FailoverAllowedToAccounts = row.FailoverAllowedToAccounts.String
}
return db
}

// ShowReplicationDatabasesOptions is based on https://docs.snowflake.com/en/sql-reference/sql/show-replication-databases.
type ShowReplicationDatabasesOptions struct {
show bool `ddl:"static" sql:"SHOW"`
replicationDatabases bool `ddl:"static" sql:"REPLICATION DATABASES"`
Like *Like `ddl:"keyword" sql:"LIKE"`
WithPrimary *ExternalObjectIdentifier `ddl:"identifier" sql:"WITH PRIMARY"`
}

func (opts *ShowReplicationDatabasesOptions) validate() error {
if opts == nil {
return ErrNilOptions
}
var errs []error
if opts.WithPrimary != nil && !ValidObjectIdentifier(opts.WithPrimary) {
errs = append(errs, ErrInvalidObjectIdentifier)
}
return JoinErrors(errs...)
}

func (c *replicationFunctions) ShowReplicationDatabases(ctx context.Context, opts *ShowReplicationDatabasesOptions) ([]ReplicationDatabase, error) {
opts = createIfNil(opts)
dbRows, err := validateAndQuery[replicationDatabaseRow](c.client, ctx, opts)
if err != nil {
return nil, err
}
resultList := convertRows[replicationDatabaseRow, ReplicationDatabase](dbRows)
return resultList, nil
}

type CloudType string

const (
Expand Down
Loading

0 comments on commit 767fbce

Please sign in to comment.