Skip to content

Commit

Permalink
Add implementation for procedures
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-asawicki committed Dec 12, 2024
1 parent ced1cef commit 47e09a4
Show file tree
Hide file tree
Showing 9 changed files with 511 additions and 102 deletions.
8 changes: 6 additions & 2 deletions pkg/resources/function_commons.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ var (
"is_secure",
"arguments",
"return_type",
"null_input_behavior",
"return_results_behavior",
"comment",
"function_definition",
Expand All @@ -100,6 +99,7 @@ var (
javaFunctionSchemaDefinition = functionSchemaDef{
additionalArguments: []string{
"runtime_version",
"null_input_behavior",
"imports",
"packages",
"handler",
Expand All @@ -117,14 +117,17 @@ var (
targetPathDescription: "The TARGET_PATH clause specifies the location to which Snowflake should write the compiled code (JAR file) after compiling the source code specified in the `function_definition`. If this clause is included, the user should manually remove the JAR file when it is no longer needed (typically when the Java UDF is dropped). If this clause is omitted, Snowflake re-compiles the source code each time the code is needed. The JAR file is not stored permanently, and the user does not need to clean up the JAR file. Snowflake returns an error if the TARGET_PATH matches an existing file; you cannot use TARGET_PATH to overwrite an existing file.",
}
javascriptFunctionSchemaDefinition = functionSchemaDef{
additionalArguments: []string{},
additionalArguments: []string{
"null_input_behavior",
},
functionDefinitionDescription: functionDefinitionTemplate("JavaScript", "https://docs.snowflake.com/en/developer-guide/udf/javascript/udf-javascript-introduction"),
functionDefinitionRequired: true,
}
pythonFunctionSchemaDefinition = functionSchemaDef{
additionalArguments: []string{
"is_aggregate",
"runtime_version",
"null_input_behavior",
"imports",
"packages",
"handler",
Expand All @@ -141,6 +144,7 @@ var (
scalaFunctionSchemaDefinition = functionSchemaDef{
additionalArguments: []string{
"runtime_version",
"null_input_behavior",
"imports",
"packages",
"handler",
Expand Down
4 changes: 2 additions & 2 deletions pkg/resources/function_javascript.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ func CreateContextFunctionJavascript(ctx context.Context, d *schema.ResourceData
if err != nil {
return diag.FromErr(err)
}
handler := d.Get("handler").(string)
functionDefinition := d.Get("function_definition").(string)

argumentDataTypes := collections.Map(argumentRequests, func(r sdk.FunctionArgumentRequest) datatypes.DataType { return r.ArgDataType })
id := sdk.NewSchemaObjectIdentifierWithArgumentsNormalized(database, sc, name, argumentDataTypes...)
request := sdk.NewCreateForJavascriptFunctionRequestDefinitionWrapped(id.SchemaObjectId(), *returns, handler).
request := sdk.NewCreateForJavascriptFunctionRequestDefinitionWrapped(id.SchemaObjectId(), *returns, functionDefinition).
WithArguments(argumentRequests)

errs := errors.Join(
Expand Down
4 changes: 2 additions & 2 deletions pkg/resources/function_sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ func CreateContextFunctionSql(ctx context.Context, d *schema.ResourceData, meta
if err != nil {
return diag.FromErr(err)
}
handler := d.Get("handler").(string)
functionDefinition := d.Get("function_definition").(string)

argumentDataTypes := collections.Map(argumentRequests, func(r sdk.FunctionArgumentRequest) datatypes.DataType { return r.ArgDataType })
id := sdk.NewSchemaObjectIdentifierWithArgumentsNormalized(database, sc, name, argumentDataTypes...)
request := sdk.NewCreateForSQLFunctionRequestDefinitionWrapped(id.SchemaObjectId(), *returns, handler).
request := sdk.NewCreateForSQLFunctionRequestDefinitionWrapped(id.SchemaObjectId(), *returns, functionDefinition).
WithArguments(argumentRequests)

errs := errors.Join(
Expand Down
99 changes: 99 additions & 0 deletions pkg/resources/procedure_commons.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import (
"errors"
"fmt"
"log"
"reflect"
"slices"

"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers"
"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider"
"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/schemas"
"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk"
Expand Down Expand Up @@ -401,6 +403,83 @@ func DeleteProcedure(ctx context.Context, d *schema.ResourceData, meta any) diag
return nil
}

func UpdateProcedure(language string, readFunc func(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics) func(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics {
return func(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics {
client := meta.(*provider.Context).Client
id, err := sdk.ParseSchemaObjectIdentifierWithArguments(d.Id())
if err != nil {
return diag.FromErr(err)
}

if d.HasChange("name") {
newId := sdk.NewSchemaObjectIdentifierWithArgumentsInSchema(id.SchemaId(), d.Get("name").(string), id.ArgumentDataTypes()...)

err := client.Procedures.Alter(ctx, sdk.NewAlterProcedureRequest(id).WithRenameTo(newId.SchemaObjectId()))
if err != nil {
return diag.FromErr(fmt.Errorf("error renaming procedure %v err = %w", d.Id(), err))
}

d.SetId(helpers.EncodeResourceIdentifier(newId))
id = newId
}

// Batch SET operations and UNSET operations
setRequest := sdk.NewProcedureSetRequest()
unsetRequest := sdk.NewProcedureUnsetRequest()

_ = stringAttributeUpdate(d, "comment", &setRequest.Comment, &unsetRequest.Comment)

switch language {
case "JAVA", "SCALA", "PYTHON":
err = errors.Join(
func() error {
if d.HasChange("secrets") {
return setSecretsInBuilder(d, func(references []sdk.SecretReference) *sdk.ProcedureSetRequest {
return setRequest.WithSecretsList(sdk.SecretsListRequest{SecretsList: references})
})
}
return nil
}(),
func() error {
if d.HasChange("external_access_integrations") {
return setExternalAccessIntegrationsInBuilder(d, func(references []sdk.AccountObjectIdentifier) any {
if len(references) == 0 {
return unsetRequest.WithExternalAccessIntegrations(true)
} else {
return setRequest.WithExternalAccessIntegrations(references)
}
})
}
return nil
}(),
)
if err != nil {
return diag.FromErr(err)
}
}

if updateParamDiags := handleProcedureParametersUpdate(d, setRequest, unsetRequest); len(updateParamDiags) > 0 {
return updateParamDiags
}

// Apply SET and UNSET changes
if !reflect.DeepEqual(*setRequest, *sdk.NewProcedureSetRequest()) {
err := client.Procedures.Alter(ctx, sdk.NewAlterProcedureRequest(id).WithSet(*setRequest))
if err != nil {
return diag.FromErr(err)
}
}
if !reflect.DeepEqual(*unsetRequest, *sdk.NewProcedureUnsetRequest()) {
err := client.Procedures.Alter(ctx, sdk.NewAlterProcedureRequest(id).WithUnset(*unsetRequest))
if err != nil {
return diag.FromErr(err)
}
}

return readFunc(ctx, d, meta)
}
}

func queryAllProcedureDetailsCommon(ctx context.Context, d *schema.ResourceData, client *sdk.Client, id sdk.SchemaObjectIdentifierWithArguments) (*allProcedureDetailsCommon, diag.Diagnostics) {
procedureDetails, err := client.Procedures.DescribeDetails(ctx, id)
if err != nil {
Expand Down Expand Up @@ -526,6 +605,26 @@ func parseProcedureReturnsCommon(d *schema.ResourceData) (*sdk.ProcedureReturnsR
return returns, nil
}

func parseProcedureSqlReturns(d *schema.ResourceData) (*sdk.ProcedureSQLReturnsRequest, error) {
returnTypeRaw := d.Get("return_type").(string)
dataType, err := datatypes.ParseDataType(returnTypeRaw)
if err != nil {
return nil, err
}
returns := sdk.NewProcedureSQLReturnsRequest()
switch v := dataType.(type) {
case *datatypes.TableDataType:
var cr []sdk.ProcedureColumnRequest
for _, c := range v.Columns() {
cr = append(cr, *sdk.NewProcedureColumnRequest(c.ColumnName(), c.ColumnType()))
}
returns.WithTable(*sdk.NewProcedureReturnsTableRequest().WithColumns(cr))
default:
returns.WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(dataType))
}
return returns, nil
}

func setProcedureImportsInBuilder[T any](d *schema.ResourceData, setImports func([]sdk.ProcedureImportRequest) T) error {
imports, err := parseProcedureImportsCommon(d)
if err != nil {
Expand Down
73 changes: 1 addition & 72 deletions pkg/resources/procedure_java.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func ProcedureJava() *schema.Resource {
return &schema.Resource{
CreateContext: TrackingCreateWrapper(resources.ProcedureJava, CreateContextProcedureJava),
ReadContext: TrackingReadWrapper(resources.ProcedureJava, ReadContextProcedureJava),
UpdateContext: TrackingUpdateWrapper(resources.ProcedureJava, UpdateContextProcedureJava),
UpdateContext: TrackingUpdateWrapper(resources.ProcedureJava, UpdateProcedure("JAVA", ReadContextProcedureJava)),
DeleteContext: TrackingDeleteWrapper(resources.ProcedureJava, DeleteProcedure),
Description: "Resource used to manage java procedure objects. For more information, check [procedure documentation](https://docs.snowflake.com/en/sql-reference/sql/create-procedure).",

Expand Down Expand Up @@ -151,74 +151,3 @@ func ReadContextProcedureJava(ctx context.Context, d *schema.ResourceData, meta

return nil
}

func UpdateContextProcedureJava(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics {
client := meta.(*provider.Context).Client
id, err := sdk.ParseSchemaObjectIdentifierWithArguments(d.Id())
if err != nil {
return diag.FromErr(err)
}

if d.HasChange("name") {
newId := sdk.NewSchemaObjectIdentifierWithArgumentsInSchema(id.SchemaId(), d.Get("name").(string), id.ArgumentDataTypes()...)

err := client.Procedures.Alter(ctx, sdk.NewAlterProcedureRequest(id).WithRenameTo(newId.SchemaObjectId()))
if err != nil {
return diag.FromErr(fmt.Errorf("error renaming procedure %v err = %w", d.Id(), err))
}

d.SetId(helpers.EncodeResourceIdentifier(newId))
id = newId
}

// Batch SET operations and UNSET operations
setRequest := sdk.NewProcedureSetRequest()
unsetRequest := sdk.NewProcedureUnsetRequest()

err = errors.Join(
stringAttributeUpdate(d, "comment", &setRequest.Comment, &unsetRequest.Comment),
func() error {
if d.HasChange("secrets") {
return setSecretsInBuilder(d, func(references []sdk.SecretReference) *sdk.ProcedureSetRequest {
return setRequest.WithSecretsList(sdk.SecretsListRequest{SecretsList: references})
})
}
return nil
}(),
func() error {
if d.HasChange("external_access_integrations") {
return setExternalAccessIntegrationsInBuilder(d, func(references []sdk.AccountObjectIdentifier) any {
if len(references) == 0 {
return unsetRequest.WithExternalAccessIntegrations(true)
} else {
return setRequest.WithExternalAccessIntegrations(references)
}
})
}
return nil
}(),
)
if err != nil {
return diag.FromErr(err)
}

if updateParamDiags := handleProcedureParametersUpdate(d, setRequest, unsetRequest); len(updateParamDiags) > 0 {
return updateParamDiags
}

// Apply SET and UNSET changes
if !reflect.DeepEqual(*setRequest, *sdk.NewProcedureSetRequest()) {
err := client.Procedures.Alter(ctx, sdk.NewAlterProcedureRequest(id).WithSet(*setRequest))
if err != nil {
return diag.FromErr(err)
}
}
if !reflect.DeepEqual(*unsetRequest, *sdk.NewProcedureUnsetRequest()) {
err := client.Procedures.Alter(ctx, sdk.NewAlterProcedureRequest(id).WithUnset(*unsetRequest))
if err != nil {
return diag.FromErr(err)
}
}

return ReadContextProcedureJava(ctx, d, meta)
}
Loading

0 comments on commit 47e09a4

Please sign in to comment.