From 21f069a3232a0039792f30445e25b5e88f09ab25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Cie=C5=9Blak?= Date: Wed, 27 Sep 2023 16:00:05 +0200 Subject: [PATCH] chore: Change return type of show (#2045) --- pkg/sdk/accounts.go | 28 ++++----------- pkg/sdk/alerts.go | 35 +++++------------- pkg/sdk/alerts_integration_test.go | 6 ++-- pkg/sdk/databases.go | 30 ++++++---------- pkg/sdk/failover_groups.go | 30 ++++------------ pkg/sdk/file_format.go | 36 +++++++------------ pkg/sdk/file_format_integration_test.go | 10 +++--- pkg/sdk/helper_test.go | 6 ++-- pkg/sdk/masking_policy.go | 29 ++++----------- pkg/sdk/masking_policy_integration_test.go | 6 ++-- pkg/sdk/password_policy.go | 21 +++++------ pkg/sdk/password_policy_integration_test.go | 6 ++-- pkg/sdk/resource_monitors.go | 18 +++++----- pkg/sdk/resource_monitors_integration_test.go | 8 ++--- pkg/sdk/session_policies.go | 12 +++---- pkg/sdk/shares.go | 29 +++++---------- pkg/sdk/shares_integration_test.go | 2 +- pkg/sdk/users.go | 29 ++++----------- pkg/sdk/users_integration_test.go | 8 ++--- pkg/sdk/warehouses.go | 28 ++++----------- 20 files changed, 123 insertions(+), 254 deletions(-) diff --git a/pkg/sdk/accounts.go b/pkg/sdk/accounts.go index f67bc7e89d..5871938408 100644 --- a/pkg/sdk/accounts.go +++ b/pkg/sdk/accounts.go @@ -19,7 +19,7 @@ type Accounts interface { // Alter modifies an existing account Alter(ctx context.Context, opts *AlterAccountOptions) error // Show returns a list of accounts. - Show(ctx context.Context, opts *ShowAccountOptions) ([]*Account, error) + Show(ctx context.Context, opts *ShowAccountOptions) ([]Account, error) // ShowByID returns an account by id ShowByID(ctx context.Context, id AccountObjectIdentifier) (*Account, error) } @@ -345,7 +345,7 @@ type accountDBRow struct { IsOrgAdmin bool `db:"is_org_admin"` } -func (row accountDBRow) toAccount() *Account { +func (row accountDBRow) convert() *Account { acc := &Account{ OrganizationName: row.OrganizationName, AccountName: row.AccountName, @@ -376,27 +376,13 @@ func (row accountDBRow) toAccount() *Account { return acc } -func (c *accounts) Show(ctx context.Context, opts *ShowAccountOptions) ([]*Account, error) { - if opts == nil { - opts = &ShowAccountOptions{} - } - if err := opts.validate(); err != nil { - return nil, err - } - sql, err := structToSQL(opts) +func (c *accounts) Show(ctx context.Context, opts *ShowAccountOptions) ([]Account, error) { + opts = createIfNil(opts) + dbRows, err := validateAndQuery[accountDBRow](c.client, ctx, opts) if err != nil { return nil, err } - dest := []accountDBRow{} - err = c.client.query(ctx, &dest, sql) - if err != nil { - return nil, err - } - resultList := make([]*Account, len(dest)) - for i, row := range dest { - resultList[i] = row.toAccount() - } - + resultList := convertRows[accountDBRow, Account](dbRows) return resultList, nil } @@ -412,7 +398,7 @@ func (c *accounts) ShowByID(ctx context.Context, id AccountObjectIdentifier) (*A for _, account := range accounts { if account.AccountName == id.Name() || account.AccountLocator == id.Name() { - return account, nil + return &account, nil } } return nil, errObjectNotExistOrAuthorized diff --git a/pkg/sdk/alerts.go b/pkg/sdk/alerts.go index c8a69bdab6..c8937204dd 100644 --- a/pkg/sdk/alerts.go +++ b/pkg/sdk/alerts.go @@ -25,7 +25,7 @@ type Alerts interface { // Drop removes an alert. Drop(ctx context.Context, id SchemaObjectIdentifier) error // Show returns a list of alerts - Show(ctx context.Context, opts *ShowAlertOptions) ([]*Alert, error) + Show(ctx context.Context, opts *ShowAlertOptions) ([]Alert, error) // ShowByID returns an alert by ID ShowByID(ctx context.Context, id SchemaObjectIdentifier) (*Alert, error) // Describe returns the details of an alert. @@ -252,7 +252,7 @@ type alertDBRow struct { Action string `db:"action"` } -func (row alertDBRow) toAlert() (*Alert, error) { +func (row alertDBRow) convert() *Alert { return &Alert{ CreatedOn: row.CreatedOn, Name: row.Name, @@ -265,39 +265,20 @@ func (row alertDBRow) toAlert() (*Alert, error) { State: AlertState(row.State), Condition: row.Condition, Action: row.Action, - }, nil + } } func (opts *ShowAlertOptions) validate() error { return nil } -func (v *alerts) Show(ctx context.Context, opts *ShowAlertOptions) ([]*Alert, error) { - if opts == nil { - opts = &ShowAlertOptions{} - } - if err := opts.validate(); err != nil { - return nil, err - } - sql, err := structToSQL(opts) +func (v *alerts) Show(ctx context.Context, opts *ShowAlertOptions) ([]Alert, error) { + opts = createIfNil(opts) + dbRows, err := validateAndQuery[alertDBRow](v.client, ctx, opts) if err != nil { return nil, err } - dest := []alertDBRow{} - - err = v.client.query(ctx, &dest, sql) - if err != nil { - return nil, err - } - resultList := make([]*Alert, len(dest)) - for i, row := range dest { - alert, err := row.toAlert() - if err != nil { - return nil, err - } - resultList[i] = alert - } - + resultList := convertRows[alertDBRow, Alert](dbRows) return resultList, nil } @@ -316,7 +297,7 @@ func (v *alerts) ShowByID(ctx context.Context, id SchemaObjectIdentifier) (*Aler for _, alert := range alerts { if alert.ID().name == id.Name() { - return alert, nil + return &alert, nil } } return nil, errObjectNotExistOrAuthorized diff --git a/pkg/sdk/alerts_integration_test.go b/pkg/sdk/alerts_integration_test.go index 50af32a4e9..fca3720703 100644 --- a/pkg/sdk/alerts_integration_test.go +++ b/pkg/sdk/alerts_integration_test.go @@ -42,8 +42,8 @@ func TestInt_AlertsShow(t *testing.T) { } alerts, err := client.Alerts.Show(ctx, showOptions) require.NoError(t, err) - assert.Contains(t, alerts, alertTest) - assert.Contains(t, alerts, alert2Test) + assert.Contains(t, alerts, *alertTest) + assert.Contains(t, alerts, *alert2Test) assert.Equal(t, 2, len(alerts)) }) @@ -58,7 +58,7 @@ func TestInt_AlertsShow(t *testing.T) { } alerts, err := client.Alerts.Show(ctx, showOptions) require.NoError(t, err) - assert.Contains(t, alerts, alertTest) + assert.Contains(t, alerts, *alertTest) assert.Equal(t, 1, len(alerts)) }) diff --git a/pkg/sdk/databases.go b/pkg/sdk/databases.go index b60bdd90ef..d504df159a 100644 --- a/pkg/sdk/databases.go +++ b/pkg/sdk/databases.go @@ -40,7 +40,7 @@ type Databases interface { // Undrop restores the most recent version of a dropped database Undrop(ctx context.Context, id AccountObjectIdentifier) error // Show returns a list of databases. - Show(ctx context.Context, opts *ShowDatabasesOptions) ([]*Database, error) + Show(ctx context.Context, opts *ShowDatabasesOptions) ([]Database, error) // ShowByID returns a database by ID ShowByID(ctx context.Context, id AccountObjectIdentifier) (*Database, error) // Describe returns the details of a database. @@ -94,8 +94,8 @@ type databaseRow struct { Kind sql.NullString `db:"kind"` } -func (row *databaseRow) toDatabase() *Database { - database := Database{ +func (row databaseRow) convert() *Database { + database := &Database{ CreatedOn: row.CreatedOn, Name: row.Name, } @@ -141,7 +141,7 @@ func (row *databaseRow) toDatabase() *Database { if row.Kind.Valid { database.Kind = row.Kind.String } - return &database + return database } type CreateDatabaseOptions struct { @@ -542,24 +542,14 @@ func (opts *ShowDatabasesOptions) validate() error { return nil } -func (v *databases) Show(ctx context.Context, opts *ShowDatabasesOptions) ([]*Database, error) { - if opts == nil { - opts = &ShowDatabasesOptions{} - } - if err := opts.validate(); err != nil { - return nil, err - } - sql, err := structToSQL(opts) +func (v *databases) Show(ctx context.Context, opts *ShowDatabasesOptions) ([]Database, error) { + opts = createIfNil(opts) + dbRows, err := validateAndQuery[databaseRow](v.client, ctx, opts) if err != nil { return nil, err } - var rows []databaseRow - err = v.client.query(ctx, &rows, sql) - databases := make([]*Database, len(rows)) - for i, row := range rows { - databases[i] = row.toDatabase() - } - return databases, err + resultList := convertRows[databaseRow, Database](dbRows) + return resultList, nil } func (v *databases) ShowByID(ctx context.Context, id AccountObjectIdentifier) (*Database, error) { @@ -573,7 +563,7 @@ func (v *databases) ShowByID(ctx context.Context, id AccountObjectIdentifier) (* } for _, database := range databases { if database.ID() == id { - return database, nil + return &database, nil } } return nil, errObjectNotExistOrAuthorized diff --git a/pkg/sdk/failover_groups.go b/pkg/sdk/failover_groups.go index fdd2227570..7e6750d601 100644 --- a/pkg/sdk/failover_groups.go +++ b/pkg/sdk/failover_groups.go @@ -40,7 +40,7 @@ type FailoverGroups interface { // Drop removes a failover group. Drop(ctx context.Context, id AccountObjectIdentifier, opts *DropFailoverGroupOptions) error // Show returns a list of failover groups. - Show(ctx context.Context, opts *ShowFailoverGroupOptions) ([]*FailoverGroup, error) + Show(ctx context.Context, opts *ShowFailoverGroupOptions) ([]FailoverGroup, error) // ShowByID returns a failover group by ID ShowByID(ctx context.Context, id AccountObjectIdentifier) (*FailoverGroup, error) // ShowDatabases returns a list of databases in a failover group. @@ -393,7 +393,7 @@ type failoverGroupDBRow struct { Owner sql.NullString `db:"owner"` } -func (row failoverGroupDBRow) toFailoverGroup() *FailoverGroup { +func (row failoverGroupDBRow) convert() *FailoverGroup { ots := strings.Split(row.ObjectTypes, ",") pluralObjectTypes := make([]PluralObjectType, 0, len(ots)) for _, ot := range ots { @@ -458,29 +458,13 @@ func (row failoverGroupDBRow) toFailoverGroup() *FailoverGroup { } } -// List all the failover groups by pattern. -func (v *failoverGroups) Show(ctx context.Context, opts *ShowFailoverGroupOptions) ([]*FailoverGroup, error) { - if opts == nil { - opts = &ShowFailoverGroupOptions{} - } - if err := opts.validate(); err != nil { - return nil, err - } - sql, err := structToSQL(opts) - if err != nil { - return nil, err - } - dest := []failoverGroupDBRow{} - - err = v.client.query(ctx, &dest, sql) +func (v *failoverGroups) Show(ctx context.Context, opts *ShowFailoverGroupOptions) ([]FailoverGroup, error) { + opts = createIfNil(opts) + dbRows, err := validateAndQuery[failoverGroupDBRow](v.client, ctx, opts) if err != nil { return nil, err } - resultList := make([]*FailoverGroup, len(dest)) - for i, row := range dest { - resultList[i] = row.toFailoverGroup() - } - + resultList := convertRows[failoverGroupDBRow, FailoverGroup](dbRows) return resultList, nil } @@ -495,7 +479,7 @@ func (v *failoverGroups) ShowByID(ctx context.Context, id AccountObjectIdentifie } for _, failoverGroup := range failoverGroups { if failoverGroup.ID() == id && failoverGroup.AccountLocator == currentAccount { - return failoverGroup, nil + return &failoverGroup, nil } } return nil, errObjectNotExistOrAuthorized diff --git a/pkg/sdk/file_format.go b/pkg/sdk/file_format.go index cf38b082cb..4153079780 100644 --- a/pkg/sdk/file_format.go +++ b/pkg/sdk/file_format.go @@ -28,7 +28,7 @@ type FileFormats interface { // Drop removes a FileFormat. Drop(ctx context.Context, id SchemaObjectIdentifier, opts *DropFileFormatOptions) error // Show returns a list of fileFormats. - Show(ctx context.Context, opts *ShowFileFormatsOptions) ([]*FileFormat, error) + Show(ctx context.Context, opts *ShowFileFormatsOptions) ([]FileFormat, error) // ShowByID returns a FileFormat by ID ShowByID(ctx context.Context, id SchemaObjectIdentifier) (*FileFormat, error) // Describe returns the details of a FileFormat. @@ -114,7 +114,7 @@ type showFileFormatsOptionsResult struct { DisableAutoConvert bool `json:"DISABLE_AUTO_CONVERT"` } -func (row *FileFormatRow) toFileFormat() *FileFormat { +func (row FileFormatRow) convert() *FileFormat { inputOptions := showFileFormatsOptionsResult{} err := json.Unmarshal([]byte(row.FormatOptions), &inputOptions) if err != nil { @@ -132,9 +132,9 @@ func (row *FileFormatRow) toFileFormat() *FileFormat { Options: FileFormatTypeOptions{}, } - newNullIf := []NullString{} - for _, s := range inputOptions.NullIf { - newNullIf = append(newNullIf, NullString{s}) + newNullIf := make([]NullString, len(inputOptions.NullIf)) + for i, s := range inputOptions.NullIf { + newNullIf[i] = NullString{s} } switch ff.Type { @@ -638,24 +638,14 @@ func (opts *ShowFileFormatsOptions) validate() error { return nil } -func (v *fileFormats) Show(ctx context.Context, opts *ShowFileFormatsOptions) ([]*FileFormat, error) { - if opts == nil { - opts = &ShowFileFormatsOptions{} - } - if err := opts.validate(); err != nil { - return nil, err - } - sql, err := structToSQL(opts) +func (v *fileFormats) Show(ctx context.Context, opts *ShowFileFormatsOptions) ([]FileFormat, error) { + opts = createIfNil(opts) + dbRows, err := validateAndQuery[FileFormatRow](v.client, ctx, opts) if err != nil { return nil, err } - var rows []FileFormatRow - err = v.client.query(ctx, &rows, sql) - fileFormats := make([]*FileFormat, len(rows)) - for i, row := range rows { - fileFormats[i] = row.toFileFormat() - } - return fileFormats, err + resultList := convertRows[FileFormatRow, FileFormat](dbRows) + return resultList, nil } func (v *fileFormats) ShowByID(ctx context.Context, id SchemaObjectIdentifier) (*FileFormat, error) { @@ -670,9 +660,9 @@ func (v *fileFormats) ShowByID(ctx context.Context, id SchemaObjectIdentifier) ( if err != nil { return nil, err } - for _, FileFormat := range fileFormats { - if reflect.DeepEqual(FileFormat.ID(), id) { - return FileFormat, nil + for _, f := range fileFormats { + if reflect.DeepEqual(f.ID(), id) { + return &f, nil } } return nil, errObjectNotExistOrAuthorized diff --git a/pkg/sdk/file_format_integration_test.go b/pkg/sdk/file_format_integration_test.go index 1b94d39cf3..5b0c7ba8d4 100644 --- a/pkg/sdk/file_format_integration_test.go +++ b/pkg/sdk/file_format_integration_test.go @@ -466,8 +466,8 @@ func TestInt_FileFormatsShow(t *testing.T) { fileFormats, err := client.FileFormats.Show(ctx, nil) require.NoError(t, err) assert.LessOrEqual(t, 2, len(fileFormats)) - assert.Contains(t, fileFormats, fileFormatTest) - assert.Contains(t, fileFormats, fileFormatTest2) + assert.Contains(t, fileFormats, *fileFormatTest) + assert.Contains(t, fileFormats, *fileFormatTest2) }) t.Run("LIKE", func(t *testing.T) { @@ -478,7 +478,7 @@ func TestInt_FileFormatsShow(t *testing.T) { }) require.NoError(t, err) assert.LessOrEqual(t, 1, len(fileFormats)) - assert.Contains(t, fileFormats, fileFormatTest) + assert.Contains(t, fileFormats, *fileFormatTest) }) t.Run("IN", func(t *testing.T) { @@ -489,8 +489,8 @@ func TestInt_FileFormatsShow(t *testing.T) { }) require.NoError(t, err) assert.LessOrEqual(t, 2, len(fileFormats)) - assert.Contains(t, fileFormats, fileFormatTest) - assert.Contains(t, fileFormats, fileFormatTest2) + assert.Contains(t, fileFormats, *fileFormatTest) + assert.Contains(t, fileFormats, *fileFormatTest2) }) } diff --git a/pkg/sdk/helper_test.go b/pkg/sdk/helper_test.go index f30c7a27ee..92213dd636 100644 --- a/pkg/sdk/helper_test.go +++ b/pkg/sdk/helper_test.go @@ -434,7 +434,7 @@ func createPasswordPolicyWithOptions(t *testing.T, client *Client, database *Dat passwordPolicyList, err := client.PasswordPolicies.Show(ctx, showOptions) require.NoError(t, err) require.Equal(t, 1, len(passwordPolicyList)) - return passwordPolicyList[0], func() { + return &passwordPolicyList[0], func() { err := client.PasswordPolicies.Drop(ctx, id, nil) require.NoError(t, err) if schemaCleanup != nil { @@ -478,7 +478,7 @@ func createMaskingPolicyWithOptions(t *testing.T, client *Client, database *Data maskingPolicyList, err := client.MaskingPolicies.Show(ctx, showOptions) require.NoError(t, err) require.Equal(t, 1, len(maskingPolicyList)) - return maskingPolicyList[0], func() { + return &maskingPolicyList[0], func() { err := client.MaskingPolicies.Drop(ctx, id) require.NoError(t, err) if schemaCleanup != nil { @@ -583,7 +583,7 @@ func createAlertWithOptions(t *testing.T, client *Client, database *Database, sc alertList, err := client.Alerts.Show(ctx, showOptions) require.NoError(t, err) require.Equal(t, 1, len(alertList)) - return alertList[0], func() { + return &alertList[0], func() { err := client.Alerts.Drop(ctx, id) require.NoError(t, err) if schemaCleanup != nil { diff --git a/pkg/sdk/masking_policy.go b/pkg/sdk/masking_policy.go index c55da1571c..491f74f618 100644 --- a/pkg/sdk/masking_policy.go +++ b/pkg/sdk/masking_policy.go @@ -31,7 +31,7 @@ type MaskingPolicies interface { // Drop removes a masking policy. Drop(ctx context.Context, id SchemaObjectIdentifier) error // Show returns a list of masking policies. - Show(ctx context.Context, opts *ShowMaskingPolicyOptions) ([]*MaskingPolicy, error) + Show(ctx context.Context, opts *ShowMaskingPolicyOptions) ([]MaskingPolicy, error) // ShowByID returns a masking policy by ID ShowByID(ctx context.Context, id SchemaObjectIdentifier) (*MaskingPolicy, error) // Describe returns the details of a masking policy. @@ -246,7 +246,7 @@ type maskingPolicyDBRow struct { Options string `db:"options"` } -func (row maskingPolicyDBRow) toMaskingPolicy() *MaskingPolicy { +func (row maskingPolicyDBRow) convert() *MaskingPolicy { exemptOtherPolicies, err := jsonparser.GetBoolean([]byte(row.Options), "EXEMPT_OTHER_POLICIES") if err != nil { exemptOtherPolicies = false @@ -264,28 +264,13 @@ func (row maskingPolicyDBRow) toMaskingPolicy() *MaskingPolicy { } // List all the masking policies by pattern. -func (v *maskingPolicies) Show(ctx context.Context, opts *ShowMaskingPolicyOptions) ([]*MaskingPolicy, error) { - if opts == nil { - opts = &ShowMaskingPolicyOptions{} - } - if err := opts.validate(); err != nil { - return nil, err - } - sql, err := structToSQL(opts) +func (v *maskingPolicies) Show(ctx context.Context, opts *ShowMaskingPolicyOptions) ([]MaskingPolicy, error) { + opts = createIfNil(opts) + dbRows, err := validateAndQuery[maskingPolicyDBRow](v.client, ctx, opts) if err != nil { return nil, err } - dest := []maskingPolicyDBRow{} - - err = v.client.query(ctx, &dest, sql) - if err != nil { - return nil, err - } - resultList := make([]*MaskingPolicy, len(dest)) - for i, row := range dest { - resultList[i] = row.toMaskingPolicy() - } - + resultList := convertRows[maskingPolicyDBRow, MaskingPolicy](dbRows) return resultList, nil } @@ -304,7 +289,7 @@ func (v *maskingPolicies) ShowByID(ctx context.Context, id SchemaObjectIdentifie for _, maskingPolicy := range maskingPolicies { if maskingPolicy.ID().name == id.Name() { - return maskingPolicy, nil + return &maskingPolicy, nil } } return nil, errObjectNotExistOrAuthorized diff --git a/pkg/sdk/masking_policy_integration_test.go b/pkg/sdk/masking_policy_integration_test.go index cc7387b4de..482131b636 100644 --- a/pkg/sdk/masking_policy_integration_test.go +++ b/pkg/sdk/masking_policy_integration_test.go @@ -43,8 +43,8 @@ func TestInt_MaskingPoliciesShow(t *testing.T) { } maskingPolicies, err := client.MaskingPolicies.Show(ctx, showOptions) require.NoError(t, err) - assert.Contains(t, maskingPolicies, maskingPolicyTest) - assert.Contains(t, maskingPolicies, maskingPolicy2Test) + assert.Contains(t, maskingPolicies, *maskingPolicyTest) + assert.Contains(t, maskingPolicies, *maskingPolicy2Test) assert.Equal(t, 2, len(maskingPolicies)) }) @@ -59,7 +59,7 @@ func TestInt_MaskingPoliciesShow(t *testing.T) { } maskingPolicies, err := client.MaskingPolicies.Show(ctx, showOptions) require.NoError(t, err) - assert.Contains(t, maskingPolicies, maskingPolicyTest) + assert.Contains(t, maskingPolicies, *maskingPolicyTest) assert.Equal(t, 1, len(maskingPolicies)) }) diff --git a/pkg/sdk/password_policy.go b/pkg/sdk/password_policy.go index 4e589a7e06..c89ce2e555 100644 --- a/pkg/sdk/password_policy.go +++ b/pkg/sdk/password_policy.go @@ -28,7 +28,7 @@ type PasswordPolicies interface { // Drop removes a password policy. Drop(ctx context.Context, id SchemaObjectIdentifier, opts *DropPasswordPolicyOptions) error // Show returns a list of password policies. - Show(ctx context.Context, opts *PasswordPolicyShowOptions) ([]*PasswordPolicy, error) + Show(ctx context.Context, opts *PasswordPolicyShowOptions) ([]PasswordPolicy, error) // ShowByID returns a password policy by ID. ShowByID(ctx context.Context, id SchemaObjectIdentifier) (*PasswordPolicy, error) // Describe returns the details of a password policy. @@ -287,8 +287,8 @@ type passwordPolicyDBRow struct { Options string `db:"options"` } -func (row passwordPolicyDBRow) toPasswordPolicy() *PasswordPolicy { - return &PasswordPolicy{ +func (row passwordPolicyDBRow) convert() PasswordPolicy { + return PasswordPolicy{ CreatedOn: row.CreatedOn, Name: row.Name, DatabaseName: row.DatabaseName, @@ -300,10 +300,8 @@ func (row passwordPolicyDBRow) toPasswordPolicy() *PasswordPolicy { } // List all the password policies by pattern. -func (v *passwordPolicies) Show(ctx context.Context, opts *PasswordPolicyShowOptions) ([]*PasswordPolicy, error) { - if opts == nil { - opts = &PasswordPolicyShowOptions{} - } +func (v *passwordPolicies) Show(ctx context.Context, opts *PasswordPolicyShowOptions) ([]PasswordPolicy, error) { + opts = createIfNil(opts) if err := opts.validate(); err != nil { return nil, err } @@ -311,15 +309,14 @@ func (v *passwordPolicies) Show(ctx context.Context, opts *PasswordPolicyShowOpt if err != nil { return nil, err } - dest := []passwordPolicyDBRow{} - + var dest []passwordPolicyDBRow err = v.client.query(ctx, &dest, sql) if err != nil { return nil, err } - resultList := make([]*PasswordPolicy, len(dest)) + resultList := make([]PasswordPolicy, len(dest)) for i, row := range dest { - resultList[i] = row.toPasswordPolicy() + resultList[i] = row.convert() } return resultList, nil @@ -340,7 +337,7 @@ func (v *passwordPolicies) ShowByID(ctx context.Context, id SchemaObjectIdentifi for _, passwordPolicy := range passwordPolicies { if passwordPolicy.ID().name == id.Name() { - return passwordPolicy, nil + return &passwordPolicy, nil } } return nil, errObjectNotExistOrAuthorized diff --git a/pkg/sdk/password_policy_integration_test.go b/pkg/sdk/password_policy_integration_test.go index 82b4a322ef..d28bb950fe 100644 --- a/pkg/sdk/password_policy_integration_test.go +++ b/pkg/sdk/password_policy_integration_test.go @@ -37,8 +37,8 @@ func TestInt_PasswordPoliciesShow(t *testing.T) { } passwordPolicies, err := client.PasswordPolicies.Show(ctx, showOptions) require.NoError(t, err) - assert.Contains(t, passwordPolicies, passwordPolicyTest) - assert.Contains(t, passwordPolicies, passwordPolicy2Test) + assert.Contains(t, passwordPolicies, *passwordPolicyTest) + assert.Contains(t, passwordPolicies, *passwordPolicy2Test) assert.Equal(t, 2, len(passwordPolicies)) }) @@ -53,7 +53,7 @@ func TestInt_PasswordPoliciesShow(t *testing.T) { } passwordPolicies, err := client.PasswordPolicies.Show(ctx, showOptions) require.NoError(t, err) - assert.Contains(t, passwordPolicies, passwordPolicyTest) + assert.Contains(t, passwordPolicies, *passwordPolicyTest) assert.Equal(t, 1, len(passwordPolicies)) }) diff --git a/pkg/sdk/resource_monitors.go b/pkg/sdk/resource_monitors.go index c0e7c13a61..4105522fad 100644 --- a/pkg/sdk/resource_monitors.go +++ b/pkg/sdk/resource_monitors.go @@ -24,7 +24,7 @@ type ResourceMonitors interface { // Drop removes a resource monitor. Drop(ctx context.Context, id AccountObjectIdentifier) error // Show returns a list of resource monitor. - Show(ctx context.Context, opts *ShowResourceMonitorOptions) ([]*ResourceMonitor, error) + Show(ctx context.Context, opts *ShowResourceMonitorOptions) ([]ResourceMonitor, error) // ShowByID returns a resource monitor by ID ShowByID(ctx context.Context, id AccountObjectIdentifier) (*ResourceMonitor, error) } @@ -68,7 +68,7 @@ type resourceMonitorRow struct { NotifyUsers sql.NullString `db:"notify_users"` } -func (row *resourceMonitorRow) toResourceMonitor() (*ResourceMonitor, error) { +func (row *resourceMonitorRow) convert() (*ResourceMonitor, error) { resourceMonitor := &ResourceMonitor{ Name: row.Name, } @@ -364,10 +364,8 @@ func (opts *ShowResourceMonitorOptions) validate() error { return nil } -func (v *resourceMonitors) Show(ctx context.Context, opts *ShowResourceMonitorOptions) ([]*ResourceMonitor, error) { - if opts == nil { - opts = &ShowResourceMonitorOptions{} - } +func (v *resourceMonitors) Show(ctx context.Context, opts *ShowResourceMonitorOptions) ([]ResourceMonitor, error) { + opts = createIfNil(opts) if err := opts.validate(); err != nil { return nil, err } @@ -380,13 +378,13 @@ func (v *resourceMonitors) Show(ctx context.Context, opts *ShowResourceMonitorOp if err != nil { return nil, err } - resourceMonitors := make([]*ResourceMonitor, 0, len(rows)) + resourceMonitors := make([]ResourceMonitor, 0, len(rows)) for _, row := range rows { - resourceMonitor, err := row.toResourceMonitor() + resourceMonitor, err := row.convert() if err != nil { return nil, err } - resourceMonitors = append(resourceMonitors, resourceMonitor) + resourceMonitors = append(resourceMonitors, *resourceMonitor) } return resourceMonitors, nil } @@ -402,7 +400,7 @@ func (v *resourceMonitors) ShowByID(ctx context.Context, id AccountObjectIdentif } for _, resourceMonitor := range resourceMonitors { if resourceMonitor.Name == id.Name() { - return resourceMonitor, nil + return &resourceMonitor, nil } } return nil, errObjectNotExistOrAuthorized diff --git a/pkg/sdk/resource_monitors_integration_test.go b/pkg/sdk/resource_monitors_integration_test.go index c174c84218..fc138fa5d9 100644 --- a/pkg/sdk/resource_monitors_integration_test.go +++ b/pkg/sdk/resource_monitors_integration_test.go @@ -23,7 +23,7 @@ func TestInt_ResourceMonitorsShow(t *testing.T) { } resourceMonitors, err := client.ResourceMonitors.Show(ctx, showOptions) require.NoError(t, err) - assert.Contains(t, resourceMonitors, resourceMonitorTest) + assert.Contains(t, resourceMonitors, *resourceMonitorTest) assert.Equal(t, 1, len(resourceMonitors)) }) @@ -175,7 +175,7 @@ func TestInt_ResourceMonitorAlter(t *testing.T) { }) require.NoError(t, err) assert.Equal(t, 1, len(resourceMonitors)) - resourceMonitor = resourceMonitors[0] + resourceMonitor = &resourceMonitors[0] var newNotifyTriggers []TriggerDefinition for _, threshold := range resourceMonitor.NotifyTriggers { newNotifyTriggers = append(newNotifyTriggers, TriggerDefinition{Threshold: threshold, TriggerAction: TriggerActionNotify}) @@ -205,7 +205,7 @@ func TestInt_ResourceMonitorAlter(t *testing.T) { }) require.NoError(t, err) assert.Equal(t, 1, len(resourceMonitors)) - resourceMonitor = resourceMonitors[0] + resourceMonitor = &resourceMonitors[0] assert.Equal(t, creditQuota, int(resourceMonitor.CreditQuota)) }) t.Run("when changing scheduling info", func(t *testing.T) { @@ -232,7 +232,7 @@ func TestInt_ResourceMonitorAlter(t *testing.T) { }) require.NoError(t, err) assert.Equal(t, 1, len(resourceMonitors)) - resourceMonitor = resourceMonitors[0] + resourceMonitor = &resourceMonitors[0] assert.Equal(t, *frequency, resourceMonitor.Frequency) startTime, err := ParseTimestampWithOffset(resourceMonitor.StartTime) require.NoError(t, err) diff --git a/pkg/sdk/session_policies.go b/pkg/sdk/session_policies.go index 5843573775..88f6ed2f80 100644 --- a/pkg/sdk/session_policies.go +++ b/pkg/sdk/session_policies.go @@ -19,7 +19,7 @@ type SessionPolicies interface { // Drop removes a session policy. Drop(ctx context.Context, id SchemaObjectIdentifier, opts *DropSessionPolicyOptions) error // Show returns a list of session policy. - Show(ctx context.Context) ([]*SessionPolicy, error) + Show(ctx context.Context) ([]SessionPolicy, error) // ShowByID returns a session policy by ID ShowByID(ctx context.Context, id SchemaObjectIdentifier) (*SessionPolicy, error) // Describe returns the details of a session policy. @@ -44,7 +44,7 @@ type sessionPolicyRow struct { SchemaName string `db:"schema_name"` } -func (row *sessionPolicyRow) toSessionPolicy() *SessionPolicy { +func (row *sessionPolicyRow) convert() *SessionPolicy { return &SessionPolicy{ Name: row.Name, DatabaseName: row.DatabaseName, @@ -142,7 +142,7 @@ func (opts *sessionPolicyShowOptions) validate() error { return nil } -func (v *sessionPolicies) Show(ctx context.Context) ([]*SessionPolicy, error) { +func (v *sessionPolicies) Show(ctx context.Context) ([]SessionPolicy, error) { opts := &sessionPolicyShowOptions{} if err := opts.validate(); err != nil { return nil, err @@ -156,9 +156,9 @@ func (v *sessionPolicies) Show(ctx context.Context) ([]*SessionPolicy, error) { if err != nil { return nil, err } - sessionPolicies := make([]*SessionPolicy, 0, len(rows)) + sessionPolicies := make([]SessionPolicy, 0, len(rows)) for _, row := range rows { - sessionPolicies = append(sessionPolicies, row.toSessionPolicy()) + sessionPolicies = append(sessionPolicies, *row.convert()) } return sessionPolicies, nil } @@ -170,7 +170,7 @@ func (v *sessionPolicies) ShowByID(ctx context.Context, id SchemaObjectIdentifie } for _, sessionPolicy := range sessionPolicies { if sessionPolicy.Name == id.Name() { - return sessionPolicy, nil + return &sessionPolicy, nil } } return nil, errObjectNotExistOrAuthorized diff --git a/pkg/sdk/shares.go b/pkg/sdk/shares.go index 2aad557c28..c56525f412 100644 --- a/pkg/sdk/shares.go +++ b/pkg/sdk/shares.go @@ -23,7 +23,7 @@ type Shares interface { // Drop removes a share. Drop(ctx context.Context, id AccountObjectIdentifier) error // Show returns a list of shares. - Show(ctx context.Context, opts *ShowShareOptions) ([]*Share, error) + Show(ctx context.Context, opts *ShowShareOptions) ([]Share, error) // ShowByID returns a share by ID. ShowByID(ctx context.Context, id AccountObjectIdentifier) (*Share, error) // Describe returns the details of an outbound share. @@ -78,7 +78,7 @@ type shareRow struct { Comment string `db:"comment"` } -func (r *shareRow) toShare() *Share { +func (r shareRow) convert() *Share { toAccounts := strings.Split(r.To, ",") var to []AccountIdentifier if len(toAccounts) != 0 { @@ -281,27 +281,14 @@ func (opts *ShowShareOptions) validate() error { return nil } -func (s *shares) Show(ctx context.Context, opts *ShowShareOptions) ([]*Share, error) { - if opts == nil { - opts = &ShowShareOptions{} - } - if err := opts.validate(); err != nil { - return nil, err - } - sql, err := structToSQL(opts) +func (s *shares) Show(ctx context.Context, opts *ShowShareOptions) ([]Share, error) { + opts = createIfNil(opts) + dbRows, err := validateAndQuery[shareRow](s.client, ctx, opts) if err != nil { return nil, err } - var rows []*shareRow - err = s.client.query(ctx, &rows, sql) - if err != nil { - return nil, err - } - shares := make([]*Share, 0, len(rows)) - for _, row := range rows { - shares = append(shares, row.toShare()) - } - return shares, nil + resultList := convertRows[shareRow, Share](dbRows) + return resultList, nil } func (s *shares) ShowByID(ctx context.Context, id AccountObjectIdentifier) (*Share, error) { @@ -315,7 +302,7 @@ func (s *shares) ShowByID(ctx context.Context, id AccountObjectIdentifier) (*Sha } for _, share := range shares { if share.Name.Name() == id.Name() { - return share, nil + return &share, nil } } return nil, errObjectNotExistOrAuthorized diff --git a/pkg/sdk/shares_integration_test.go b/pkg/sdk/shares_integration_test.go index b6599e0e17..6988e10028 100644 --- a/pkg/sdk/shares_integration_test.go +++ b/pkg/sdk/shares_integration_test.go @@ -32,7 +32,7 @@ func TestInt_SharesShow(t *testing.T) { shares, err := client.Shares.Show(ctx, showOptions) require.NoError(t, err) assert.Equal(t, 1, len(shares)) - assert.Contains(t, shares, shareTest) + assert.Contains(t, shares, *shareTest) }) t.Run("when searching a non-existent share", func(t *testing.T) { diff --git a/pkg/sdk/users.go b/pkg/sdk/users.go index d5b2e40a22..b09d50f5e8 100644 --- a/pkg/sdk/users.go +++ b/pkg/sdk/users.go @@ -26,7 +26,7 @@ type Users interface { // Describe returns the details of a user. Describe(ctx context.Context, id AccountObjectIdentifier) (*UserDetails, error) // Show returns a list of users. - Show(ctx context.Context, opts *ShowUserOptions) ([]*User, error) + Show(ctx context.Context, opts *ShowUserOptions) ([]User, error) // ShowByID returns a user by ID ShowByID(ctx context.Context, id AccountObjectIdentifier) (*User, error) } @@ -94,7 +94,7 @@ type userDBRow struct { HasRsaPublicKey bool `db:"has_rsa_public_key"` } -func (row userDBRow) toUser() *User { +func (row userDBRow) convert() *User { user := &User{ Name: row.Name, CreatedOn: row.CreatedOn, @@ -580,28 +580,13 @@ func (input *ShowUserOptions) validate() error { return nil } -func (v *users) Show(ctx context.Context, opts *ShowUserOptions) ([]*User, error) { - if opts == nil { - opts = &ShowUserOptions{} - } - if err := opts.validate(); err != nil { - return nil, err - } - sql, err := structToSQL(opts) +func (v *users) Show(ctx context.Context, opts *ShowUserOptions) ([]User, error) { + opts = createIfNil(opts) + dbRows, err := validateAndQuery[userDBRow](v.client, ctx, opts) if err != nil { return nil, err } - var dest []userDBRow - - err = v.client.query(ctx, &dest, sql) - if err != nil { - return nil, err - } - resultList := make([]*User, len(dest)) - for i, row := range dest { - resultList[i] = row.toUser() - } - + resultList := convertRows[userDBRow, User](dbRows) return resultList, nil } @@ -617,7 +602,7 @@ func (v *users) ShowByID(ctx context.Context, id AccountObjectIdentifier) (*User for _, user := range users { if user.ID().name == id.Name() { - return user, nil + return &user, nil } } return nil, errObjectNotExistOrAuthorized diff --git a/pkg/sdk/users_integration_test.go b/pkg/sdk/users_integration_test.go index 0c950b6579..90a49c7816 100644 --- a/pkg/sdk/users_integration_test.go +++ b/pkg/sdk/users_integration_test.go @@ -27,7 +27,7 @@ func TestInt_UsersShow(t *testing.T) { } users, err := client.Users.Show(ctx, showOptions) require.NoError(t, err) - assert.Contains(t, users, userTest) + assert.Contains(t, users, *userTest) assert.Equal(t, 1, len(users)) }) @@ -37,8 +37,8 @@ func TestInt_UsersShow(t *testing.T) { } users, err := client.Users.Show(ctx, showOptions) require.NoError(t, err) - assert.Contains(t, users, userTest) - assert.Contains(t, users, userTest2) + assert.Contains(t, users, *userTest) + assert.Contains(t, users, *userTest2) assert.Equal(t, 2, len(users)) }) t.Run("with starts with, limit and from options", func(t *testing.T) { @@ -50,7 +50,7 @@ func TestInt_UsersShow(t *testing.T) { users, err := client.Users.Show(ctx, showOptions) require.NoError(t, err) - assert.Contains(t, users, userTest) + assert.Contains(t, users, *userTest) assert.Equal(t, 1, len(users)) }) diff --git a/pkg/sdk/warehouses.go b/pkg/sdk/warehouses.go index 904fcc63f8..84a9fc6bc4 100644 --- a/pkg/sdk/warehouses.go +++ b/pkg/sdk/warehouses.go @@ -25,7 +25,7 @@ type Warehouses interface { // Drop removes a warehouse. Drop(ctx context.Context, id AccountObjectIdentifier, opts *DropWarehouseOptions) error // Show returns a list of warehouses. - Show(ctx context.Context, opts *ShowWarehouseOptions) ([]*Warehouse, error) + Show(ctx context.Context, opts *ShowWarehouseOptions) ([]Warehouse, error) // ShowByID returns a warehouse by ID ShowByID(ctx context.Context, id AccountObjectIdentifier) (*Warehouse, error) // Describe returns the details of a warehouse. @@ -411,7 +411,7 @@ type warehouseDBRow struct { ScalingPolicy string `db:"scaling_policy"` } -func (row warehouseDBRow) toWarehouse() *Warehouse { +func (row warehouseDBRow) convert() *Warehouse { wh := &Warehouse{ Name: row.Name, State: WarehouseState(row.State), @@ -453,27 +453,13 @@ func (row warehouseDBRow) toWarehouse() *Warehouse { return wh } -func (c *warehouses) Show(ctx context.Context, opts *ShowWarehouseOptions) ([]*Warehouse, error) { - if opts == nil { - opts = &ShowWarehouseOptions{} - } - if err := opts.validate(); err != nil { - return nil, err - } - sql, err := structToSQL(opts) +func (c *warehouses) Show(ctx context.Context, opts *ShowWarehouseOptions) ([]Warehouse, error) { + opts = createIfNil(opts) + dbRows, err := validateAndQuery[warehouseDBRow](c.client, ctx, opts) if err != nil { return nil, err } - dest := []warehouseDBRow{} - err = c.client.query(ctx, &dest, sql) - if err != nil { - return nil, err - } - resultList := make([]*Warehouse, len(dest)) - for i, row := range dest { - resultList[i] = row.toWarehouse() - } - + resultList := convertRows[warehouseDBRow, Warehouse](dbRows) return resultList, nil } @@ -489,7 +475,7 @@ func (c *warehouses) ShowByID(ctx context.Context, id AccountObjectIdentifier) ( for _, warehouse := range warehouses { if warehouse.ID().name == id.Name() { - return warehouse, nil + return &warehouse, nil } } return nil, errObjectNotExistOrAuthorized