diff --git a/internal/acceptance/model_serving_test.go b/internal/acceptance/model_serving_test.go index 59ff216ffe..2204d57fdb 100644 --- a/internal/acceptance/model_serving_test.go +++ b/internal/acceptance/model_serving_test.go @@ -97,7 +97,7 @@ func TestUcAccModelServingProvisionedThroughput(t *testing.T) { config { served_entities{ name = "pt_model" - entity_name = "system.ai.mistral_7b_instruct_v0_1" + entity_name = "system.ai.mistral_7b_instruct_v0_2" entity_version = "1" min_provisioned_throughput = 0 max_provisioned_throughput = 970 @@ -111,6 +111,133 @@ func TestUcAccModelServingProvisionedThroughput(t *testing.T) { } } `, name), + }, step{ + Template: fmt.Sprintf(` + resource "databricks_model_serving" "endpoint" { + name = "%s" + config { + served_entities{ + name = "pt_model" + entity_name = "system.ai.mistral_7b_instruct_v0_2" + entity_version = "1" + min_provisioned_throughput = 970 + max_provisioned_throughput = 1940 + } + traffic_config { + routes { + served_model_name = "pt_model" + traffic_percentage = 100 + } + } + } + } + `, name), + }, step{ + Template: fmt.Sprintf(` + resource "databricks_model_serving" "endpoint" { + name = "%s" + config { + served_entities{ + name = "pt_model" + entity_name = "system.ai.mistral_7b_instruct_v0_2" + entity_version = "1" + min_provisioned_throughput = 0 + max_provisioned_throughput = 1940 + } + traffic_config { + routes { + served_model_name = "pt_model" + traffic_percentage = 100 + } + } + } + } + `, name), }, ) } + +func TestAccModelServingExternalModel(t *testing.T) { + loadWorkspaceEnv(t) + if isGcp(t) { + skipf(t)("not available on GCP") + } + + name := fmt.Sprintf("terraform-test-model-serving-em-%s", + acctest.RandStringFromCharSet(5, acctest.CharSetAlphaNum)) + scope_name := fmt.Sprintf("terraform-test-secret-scope-%s", + acctest.RandStringFromCharSet(5, acctest.CharSetAlphaNum)) + workspaceLevel(t, step{ + Template: fmt.Sprintf(` + resource "databricks_secret_scope" "scope" { + name = "%s" + } + + resource "databricks_secret" "key" { + key = "api_key" + string_value = "fake-secret" + scope = databricks_secret_scope.scope.id + } + + resource "databricks_model_serving" "endpoint" { + name = "%s" + config { + served_entities { + name = "prod_model" + external_model { + provider = "anthropic" + name = "claude-2.0" + task = "llm/v1/chat" + anthropic_config { + anthropic_api_key = databricks_secret.key.config_reference + } + } + } + traffic_config { + routes { + served_model_name = "prod_model" + traffic_percentage = 100 + } + } + } + } + `, scope_name, name), + }, + step{ + Template: fmt.Sprintf(` + resource "databricks_secret_scope" "scope" { + name = "%s" + } + + resource "databricks_secret" "key" { + key = "api_key" + string_value = "fake-secret" + scope = databricks_secret_scope.scope.id + } + + resource "databricks_model_serving" "endpoint" { + name = "%s" + config { + served_entities { + name = "prod_model" + external_model { + provider = "openai" + name = "gpt-4o" + task = "llm/v1/chat" + openai_config { + openai_api_key = databricks_secret.key.config_reference + } + } + } + traffic_config { + routes { + served_model_name = "prod_model" + traffic_percentage = 100 + } + } + } + } + `, scope_name, name), + }, + ) +} diff --git a/serving/resource_model_serving.go b/serving/resource_model_serving.go index b0abbf70bc..183910de42 100644 --- a/serving/resource_model_serving.go +++ b/serving/resource_model_serving.go @@ -2,10 +2,7 @@ package serving import ( "context" - "fmt" "log" - "slices" - "strings" "time" "github.com/databricks/databricks-sdk-go/retries" @@ -24,33 +21,24 @@ func ResourceModelServing() common.Resource { m["name"].ForceNew = true common.MustSchemaPath(m, "config", "served_models").ConflictsWith = []string{"config.served_entities"} common.MustSchemaPath(m, "config", "served_entities").ConflictsWith = []string{"config.served_models"} + + common.MustSchemaPath(m, "config", "traffic_config").Computed = true + common.MustSchemaPath(m, "config", "auto_capture_config", "table_name_prefix").Computed = true + common.MustSchemaPath(m, "config", "auto_capture_config", "enabled").Computed = true + common.MustSchemaPath(m, "config", "auto_capture_config", "catalog_name").ForceNew = true + common.MustSchemaPath(m, "config", "auto_capture_config", "schema_name").ForceNew = true + common.MustSchemaPath(m, "config", "auto_capture_config", "table_name_prefix").ForceNew = true + + common.MustSchemaPath(m, "config", "served_models", "name").Computed = true + common.MustSchemaPath(m, "config", "served_models", "workload_type").Computed = true common.MustSchemaPath(m, "config", "served_models", "scale_to_zero_enabled").Required = false common.MustSchemaPath(m, "config", "served_models", "scale_to_zero_enabled").Optional = true common.MustSchemaPath(m, "config", "served_models", "scale_to_zero_enabled").Default = true - common.MustSchemaPath(m, "config", "served_models", "name").Computed = true - common.MustSchemaPath(m, "config", "served_models", "workload_type").Default = "CPU" - // TODO: `config.served_models.workload_type` should be a `Optional+Computed` field. Also consider this for other similar fields. - // In this scenario, if a workspace does not have GPU serving, specifying `workload_type` = 'CPU' will get empty response from API. - common.MustSchemaPath(m, "config", "served_models", "workload_type").DiffSuppressFunc = func(k, old, new string, d *schema.ResourceData) bool { - return old == "" && new == "CPU" - } - common.MustSchemaPath(m, "config", "traffic_config").Computed = true common.MustSchemaPath(m, "config", "served_models").Deprecated = "Please use 'config.served_entities' instead of 'config.served_models'." - common.MustSchemaPath(m, "config", "served_entities", "scale_to_zero_enabled").Required = false - common.MustSchemaPath(m, "config", "served_entities", "scale_to_zero_enabled").Optional = true - common.MustSchemaPath(m, "config", "served_entities", "scale_to_zero_enabled").Default = false common.MustSchemaPath(m, "config", "served_entities", "name").Computed = true - common.MustSchemaPath(m, "config", "served_entities", "workload_size").Optional = true common.MustSchemaPath(m, "config", "served_entities", "workload_size").Computed = true - common.MustSchemaPath(m, "config", "served_entities", "workload_type").Optional = true common.MustSchemaPath(m, "config", "served_entities", "workload_type").Computed = true - common.MustSchemaPath(m, "config", "served_entities", "workload_type").DiffSuppressFunc = func(k, old, new string, d *schema.ResourceData) bool { - return old == "" && new == "CPU" - } - common.MustSchemaPath(m, "config", "auto_capture_config", "catalog_name").ForceNew = true - common.MustSchemaPath(m, "config", "auto_capture_config", "schema_name").ForceNew = true - common.MustSchemaPath(m, "config", "auto_capture_config", "table_name_prefix").ForceNew = true m["serving_endpoint_id"] = &schema.Schema{ Computed: true, @@ -60,17 +48,6 @@ func ResourceModelServing() common.Resource { }) return common.Resource{ - CustomizeDiff: func(ctx context.Context, d *schema.ResourceDiff) error { - old, new := d.GetChange("config.0.auto_capture_config.0.enabled") - if old != nil && old == false && new == true { - d.ForceNew("config.0.auto_capture_config.0.enabled") - } - err := validateExternalModelConfig(d) - if err != nil { - return err - } - return nil - }, Create: func(ctx context.Context, d *schema.ResourceData, c *common.DatabricksClient) error { w, err := c.WorkspaceClient() if err != nil { @@ -78,9 +55,6 @@ func ResourceModelServing() common.Resource { } var e serving.CreateServingEndpoint common.DataToStructPointer(d, s, &e) - for i := range e.Config.ServedEntities { - e.Config.ServedEntities[i].ForceSendFields = append(e.Config.ServedEntities[i].ForceSendFields, "ScaleToZeroEnabled", "MinProvisionedThroughput") - } wait, err := w.ServingEndpoints.Create(ctx, e) if err != nil { return err @@ -133,9 +107,6 @@ func ResourceModelServing() common.Resource { } var e serving.CreateServingEndpoint common.DataToStructPointer(d, s, &e) - for i := range e.Config.ServedEntities { - e.Config.ServedEntities[i].ForceSendFields = append(e.Config.ServedEntities[i].ForceSendFields, "ScaleToZeroEnabled") - } e.Config.Name = e.Name _, err = w.ServingEndpoints.UpdateConfigAndWait(ctx, e.Config, retries.Timeout[serving.ServingEndpointDetailed](d.Timeout(schema.TimeoutUpdate))) return err @@ -156,34 +127,3 @@ func ResourceModelServing() common.Resource { }, } } - -func validateExternalModelConfig(d *schema.ResourceDiff) error { - _, e := d.GetOk("config.0.served_entities.0.external_model") - provider, p := d.GetOk("config.0.served_entities.0.external_model.0.provider") - - if !e || !p { - return nil - } - - name := strings.ReplaceAll(provider.(string), "-", "_") - config := d.Get(fmt.Sprintf("config.0.served_entities.0.external_model.0.%s_config", name)).([]interface{}) - - if len(config) == 0 { - return fmt.Errorf("external_model provider is set to \"%s\" but \"%s_config\" block is missing", name, name) - } - - if configBlock, ok := d.Get("config.0.served_entities.0.external_model.0").(map[string]interface{}); ok { - var found []string - for key, value := range configBlock { - if strings.HasSuffix(key, "_config") && len(value.([]interface{})) > 0 { - found = append(found, key) - } - } - slices.Sort(found) - if len(found) > 1 { - msg := strings.Join(found, ", ") - return fmt.Errorf("only one external_model config block is allowed. Found: %s", msg) - } - } - return nil -} diff --git a/serving/resource_model_serving_test.go b/serving/resource_model_serving_test.go index fc494b2f35..c53d17c49f 100644 --- a/serving/resource_model_serving_test.go +++ b/serving/resource_model_serving_test.go @@ -632,57 +632,3 @@ func TestModelServingDelete_Error(t *testing.T) { ID: "test-endpoint", }.ExpectError(t, "Internal error happened") } - -func TestModelServingExternalModelNoConfig(t *testing.T) { - qa.ResourceFixture{ - Resource: ResourceModelServing(), - HCL: ` - name = "test-endpoint" - config { - served_entities { - name = "prod_model" - entity_name = "ads1" - entity_version = "2" - external_model { - name = "prod_external_model" - provider = "ai21labs" - task = "llm/v1/embeddings" - } - workload_size = "Small" - scale_to_zero_enabled = true - } - } - `, - Create: true, - }.ExpectError(t, "external_model provider is set to \"ai21labs\" but \"ai21labs_config\" block is missing") -} - -func TestModelServingExternalModelMultipleConfig(t *testing.T) { - qa.ResourceFixture{ - Resource: ResourceModelServing(), - HCL: ` - name = "test-endpoint" - config { - served_entities { - name = "prod_model" - entity_name = "ads1" - entity_version = "2" - external_model { - name = "prod_external_model" - provider = "ai21labs" - task = "llm/v1/embeddings" - ai21labs_config { - ai21labs_api_key = "{{secrets/databricks/ai21labs_api_key}}" - } - anthropic_config { - anthropic_api_key = "{{secrets/databricks/anthropic_api_key}}" - } - } - workload_size = "Small" - scale_to_zero_enabled = true - } - } - `, - Create: true, - }.ExpectError(t, "only one external_model config block is allowed. Found: ai21labs_config, anthropic_config") -}