Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test Case Provider Switch #4135

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 2 additions & 16 deletions internal/providers/pluginfw/pluginfw.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,6 @@ import (
"github.com/databricks/terraform-provider-databricks/commands"
"github.com/databricks/terraform-provider-databricks/common"
providercommon "github.com/databricks/terraform-provider-databricks/internal/providers/common"
"github.com/databricks/terraform-provider-databricks/internal/providers/pluginfw/resources/cluster"
"github.com/databricks/terraform-provider-databricks/internal/providers/pluginfw/resources/library"
"github.com/databricks/terraform-provider-databricks/internal/providers/pluginfw/resources/notificationdestinations"
"github.com/databricks/terraform-provider-databricks/internal/providers/pluginfw/resources/qualitymonitor"
"github.com/databricks/terraform-provider-databricks/internal/providers/pluginfw/resources/registered_model"
"github.com/databricks/terraform-provider-databricks/internal/providers/pluginfw/resources/volume"

"github.com/hashicorp/terraform-plugin-framework/datasource"
"github.com/hashicorp/terraform-plugin-framework/diag"
Expand All @@ -44,19 +38,11 @@ type DatabricksProviderPluginFramework struct {
var _ provider.Provider = (*DatabricksProviderPluginFramework)(nil)

func (p *DatabricksProviderPluginFramework) Resources(ctx context.Context) []func() resource.Resource {
return []func() resource.Resource{
qualitymonitor.ResourceQualityMonitor,
library.ResourceLibrary,
}
return getPluginFrameworkResourcesToRegister(ctx)
}

func (p *DatabricksProviderPluginFramework) DataSources(ctx context.Context) []func() datasource.DataSource {
return []func() datasource.DataSource{
cluster.DataSourceCluster,
volume.DataSourceVolumes,
registered_model.DataSourceRegisteredModel,
notificationdestinations.DataSourceNotificationDestinations,
}
return getPluginFrameworkDataSourcesToRegister(ctx)
}

func (p *DatabricksProviderPluginFramework) Schema(ctx context.Context, req provider.SchemaRequest, resp *provider.SchemaResponse) {
Expand Down
159 changes: 159 additions & 0 deletions internal/providers/pluginfw/pluginfw_rollout_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
package pluginfw

// This file contains all of the utils for controlling the plugin framework rollout.
// For migrated resources and data sources, we can add them to the two maps below to have them registered with the plugin framework.
// Users can manually specify resources and data sources to use SDK V2 instead of the plugin framework by setting the USE_SDK_V2_RESOURCES and USE_SDK_V2_DATA_SOURCES environment variables.
//
// Example: USE_SDK_V2_RESOURCES="databricks_library" would force the library resource to use SDK V2 instead of the plugin framework.

import (
"context"
"os"
"strings"

"github.com/databricks/terraform-provider-databricks/internal/providers/pluginfw/resources/library"
"github.com/databricks/terraform-provider-databricks/internal/providers/pluginfw/resources/notificationdestinations"
"github.com/databricks/terraform-provider-databricks/internal/providers/pluginfw/resources/qualitymonitor"
"github.com/databricks/terraform-provider-databricks/internal/providers/pluginfw/resources/registered_model"
"github.com/databricks/terraform-provider-databricks/internal/providers/pluginfw/resources/volume"
"github.com/hashicorp/terraform-plugin-framework/datasource"
"github.com/hashicorp/terraform-plugin-framework/resource"
)

// List of resources that have been migrated from SDK V2 to plugin framework
var migratedResources = []func() resource.Resource{
qualitymonitor.ResourceQualityMonitor,
library.ResourceLibrary,
}

// List of data sources that have been migrated from SDK V2 to plugin framework
var migratedDataSources = []func() datasource.DataSource{
// TODO: Add DataSourceCluster back in after fixing unit tests.
// cluster.DataSourceCluster,
volume.DataSourceVolumes,
}

// List of resources that have been onboarded to the plugin framework - not migrated from sdkv2.
var onboardedResources = []func() resource.Resource{
// TODO Add resources here
}

// List of data sources that have been onboarded to the plugin framework - not migrated from sdkv2.
var onboardedDataSources = []func() datasource.DataSource{
registered_model.DataSourceRegisteredModel,
notificationdestinations.DataSourceNotificationDestinations,
}

// GetUseSdkV2DataSources is a helper function to get name of resources that should use SDK V2 instead of plugin framework
func getUseSdkV2Resources(ctx context.Context) []string {
useSdkV2 := os.Getenv("USE_SDK_V2_RESOURCES")
useSdkV2Ctx := ctx.Value("USE_SDK_V2_RESOURCES")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

key should be a separate type

combinedNames := ""
if useSdkV2 != "" && useSdkV2Ctx != "" {
combinedNames = useSdkV2 + "," + useSdkV2Ctx.(string)
} else {
combinedNames = useSdkV2 + useSdkV2Ctx.(string)
}
return strings.Split(combinedNames, ",")
}

// GetUseSdkV2DataSources is a helper function to get name of data sources that should use SDK V2 instead of plugin framework
func getUseSdkV2DataSources(ctx context.Context) []string {
useSdkV2 := os.Getenv("USE_SDK_V2_DATA_SOURCES")
useSdkV2Ctx := ctx.Value("USE_SDK_V2_DATA_SOURCES")
combinedNames := ""
if useSdkV2 != "" && useSdkV2Ctx != "" {
combinedNames = useSdkV2 + "," + useSdkV2Ctx.(string)
} else {
combinedNames = useSdkV2 + useSdkV2Ctx.(string)
}
return strings.Split(combinedNames, ",")
}

// Helper function to check if a resource should use be in SDK V2 instead of plugin framework
func shouldUseSdkV2Resource(ctx context.Context, resourceName string) bool {
useSdkV2Resources := getUseSdkV2Resources(ctx)
for _, sdkV2Resource := range useSdkV2Resources {
if resourceName == sdkV2Resource {
return true
}
}
return false
}

// Helper function to check if a data source should use be in SDK V2 instead of plugin framework
func shouldUseSdkV2DataSource(ctx context.Context, dataSourceName string) bool {
sdkV2DataSources := getUseSdkV2DataSources(ctx)
for _, sdkV2DataSource := range sdkV2DataSources {
if dataSourceName == sdkV2DataSource {
return true
}
}
return false
}

// getPluginFrameworkResourcesToRegister is a helper function to get the list of resources that are migrated away from sdkv2 to plugin framework
func getPluginFrameworkResourcesToRegister(ctx context.Context) []func() resource.Resource {
var resources []func() resource.Resource

// Loop through the map and add resources if they're not specifically marked to use the SDK V2
for _, resourceFunc := range migratedResources {
name := getResourceName(resourceFunc)
if !shouldUseSdkV2Resource(ctx, name) {
resources = append(resources, resourceFunc)
}
}

return append(resources, onboardedResources...)
}

// getPluginFrameworkDataSourcesToRegister is a helper function to get the list of data sources that are migrated away from sdkv2 to plugin framework
func getPluginFrameworkDataSourcesToRegister(ctx context.Context) []func() datasource.DataSource {
var dataSources []func() datasource.DataSource

// Loop through the map and add data sources if they're not specifically marked to use the SDK V2
for _, dataSourceFunc := range migratedDataSources {
name := getDataSourceName(dataSourceFunc)
if !shouldUseSdkV2DataSource(ctx, name) {
dataSources = append(dataSources, dataSourceFunc)
}
}

return append(dataSources, onboardedDataSources...)
}

func getResourceName(resourceFunc func() resource.Resource) string {
resp := resource.MetadataResponse{}
resourceFunc().Metadata(context.Background(), resource.MetadataRequest{ProviderTypeName: "databricks"}, &resp)
return resp.TypeName
}

func getDataSourceName(dataSourceFunc func() datasource.DataSource) string {
resp := datasource.MetadataResponse{}
dataSourceFunc().Metadata(context.Background(), datasource.MetadataRequest{ProviderTypeName: "databricks"}, &resp)
return resp.TypeName
}

// GetSdkV2ResourcesToRemove is a helper function to get the list of resources that are migrated away from sdkv2 to plugin framework
func GetSdkV2ResourcesToRemove(ctx context.Context) []string {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the context passed through for getting the names

resourcesToRemove := []string{}
for _, resourceFunc := range migratedResources {
name := getResourceName(resourceFunc)
if !shouldUseSdkV2Resource(ctx, name) {
resourcesToRemove = append(resourcesToRemove, name)
}
}
return resourcesToRemove
}

// GetSdkV2DataSourcesToRemove is a helper function to get the list of data sources that are migrated away from sdkv2 to plugin framework
func GetSdkV2DataSourcesToRemove(ctx context.Context) []string {
dataSourcesToRemove := []string{}
for _, dataSourceFunc := range migratedDataSources {
name := getDataSourceName(dataSourceFunc)
if !shouldUseSdkV2DataSource(ctx, name) {
dataSourcesToRemove = append(dataSourcesToRemove, name)
}
}
return dataSourcesToRemove
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type ClusterInfo struct {
}

func (d *ClusterDataSource) Metadata(ctx context.Context, req datasource.MetadataRequest, resp *datasource.MetadataResponse) {
resp.TypeName = pluginfwcommon.GetDatabricksStagingName(dataSourceName)
resp.TypeName = pluginfwcommon.GetDatabricksProductionName(dataSourceName)
}

func (d *ClusterDataSource) Schema(ctx context.Context, req datasource.SchemaRequest, resp *datasource.SchemaResponse) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
)

const dataClusterTemplateById = `
data "databricks_cluster_pluginframework" "by_id" {
data "databricks_cluster "by_id" {
cluster_id = "{env.TEST_DEFAULT_CLUSTER_ID}"
}
`
Expand All @@ -21,8 +21,8 @@ func TestAccDataSourceClusterByID(t *testing.T) {
func TestAccDataSourceClusterByName(t *testing.T) {
acceptance.WorkspaceLevel(t, acceptance.Step{
Template: dataClusterTemplateById + `
data "databricks_cluster_pluginframework" "by_name" {
cluster_name = data.databricks_cluster_pluginframework.by_id.cluster_name
data "databricks_cluster" "by_name" {
cluster_name = data.databricks_cluster.by_id.cluster_name
}`,
})
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ type LibraryResource struct {
}

func (r *LibraryResource) Metadata(ctx context.Context, req resource.MetadataRequest, resp *resource.MetadataResponse) {
resp.TypeName = pluginfwcommon.GetDatabricksStagingName(resourceName)
resp.TypeName = pluginfwcommon.GetDatabricksProductionName(resourceName)
}

func (r *LibraryResource) Schema(ctx context.Context, req resource.SchemaRequest, resp *resource.SchemaResponse) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func TestAccLibraryCreationPluginFramework(t *testing.T) {
"ResourceClass" = "SingleNode"
}
}
resource "databricks_library_pluginframework" "new_library" {
resource "databricks_library" "new_library" {
cluster_id = databricks_cluster.this.id
pypi {
repo = "https://pypi.org/dummy"
Expand Down Expand Up @@ -54,7 +54,7 @@ func TestAccLibraryUpdatePluginFramework(t *testing.T) {
"ResourceClass" = "SingleNode"
}
}
resource "databricks_library_pluginframework" "new_library" {
resource "databricks_library" "new_library" {
cluster_id = databricks_cluster.this.id
pypi {
repo = "https://pypi.org/simple"
Expand All @@ -80,7 +80,7 @@ func TestAccLibraryUpdatePluginFramework(t *testing.T) {
"ResourceClass" = "SingleNode"
}
}
resource "databricks_library_pluginframework" "new_library" {
resource "databricks_library" "new_library" {
cluster_id = databricks_cluster.this.id
pypi {
package = "networkx"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ type QualityMonitorResource struct {
}

func (r *QualityMonitorResource) Metadata(ctx context.Context, req resource.MetadataRequest, resp *resource.MetadataResponse) {
resp.TypeName = pluginfwcommon.GetDatabricksStagingName(resourceName)
resp.TypeName = pluginfwcommon.GetDatabricksProductionName(resourceName)
}

func (r *QualityMonitorResource) Schema(ctx context.Context, req resource.SchemaRequest, resp *resource.SchemaResponse) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func TestUcAccQualityMonitor(t *testing.T) {
acceptance.UnityWorkspaceLevel(t, acceptance.Step{
Template: commonPartQualityMonitoring + `

resource "databricks_quality_monitor_pluginframework" "testMonitorInference" {
resource "databricks_quality_monitor" "testMonitorInference" {
table_name = databricks_sql_table.myInferenceTable.id
assets_dir = "/Shared/provider-test/databricks_quality_monitoring/${databricks_sql_table.myInferenceTable.name}"
output_schema_name = databricks_schema.things.id
Expand All @@ -81,7 +81,7 @@ func TestUcAccQualityMonitor(t *testing.T) {
}
}

resource "databricks_quality_monitor_pluginframework" "testMonitorTimeseries" {
resource "databricks_quality_monitor" "testMonitorTimeseries" {
table_name = databricks_sql_table.myTimeseries.id
assets_dir = "/Shared/provider-test/databricks_quality_monitoring/${databricks_sql_table.myTimeseries.name}"
output_schema_name = databricks_schema.things.id
Expand All @@ -104,7 +104,7 @@ func TestUcAccQualityMonitor(t *testing.T) {
}
}

resource "databricks_quality_monitor_pluginframework" "testMonitorSnapshot" {
resource "databricks_quality_monitor" "testMonitorSnapshot" {
table_name = databricks_sql_table.mySnapshot.id
assets_dir = "/Shared/provider-test/databricks_quality_monitoring/${databricks_sql_table.myTimeseries.name}"
output_schema_name = databricks_schema.things.id
Expand All @@ -121,7 +121,7 @@ func TestUcAccUpdateQualityMonitor(t *testing.T) {
}
acceptance.UnityWorkspaceLevel(t, acceptance.Step{
Template: commonPartQualityMonitoring + `
resource "databricks_quality_monitor_pluginframework" "testMonitorInference" {
resource "databricks_quality_monitor" "testMonitorInference" {
table_name = databricks_sql_table.myInferenceTable.id
assets_dir = "/Shared/provider-test/databricks_quality_monitoring/${databricks_sql_table.myInferenceTable.name}"
output_schema_name = databricks_schema.things.id
Expand All @@ -136,7 +136,7 @@ func TestUcAccUpdateQualityMonitor(t *testing.T) {
`,
}, acceptance.Step{
Template: commonPartQualityMonitoring + `
resource "databricks_quality_monitor_pluginframework" "testMonitorInference" {
resource "databricks_quality_monitor" "testMonitorInference" {
table_name = databricks_sql_table.myInferenceTable.id
assets_dir = "/Shared/provider-test/databricks_quality_monitoring/${databricks_sql_table.myInferenceTable.name}"
output_schema_name = databricks_schema.things.id
Expand All @@ -160,7 +160,7 @@ func TestUcAccQualityMonitorImportPluginFramework(t *testing.T) {
acceptance.Step{
Template: commonPartQualityMonitoring + `

resource "databricks_quality_monitor_pluginframework" "testMonitorInference" {
resource "databricks_quality_monitor" "testMonitorInference" {
table_name = databricks_sql_table.myInferenceTable.id
assets_dir = "/Shared/provider-test/databricks_quality_monitoring/${databricks_sql_table.myInferenceTable.name}"
output_schema_name = databricks_schema.things.id
Expand All @@ -176,8 +176,8 @@ func TestUcAccQualityMonitorImportPluginFramework(t *testing.T) {
},
acceptance.Step{
ImportState: true,
ResourceName: "databricks_quality_monitor_pluginframework.testMonitorInference",
ImportStateIdFunc: acceptance.BuildImportStateIdFunc("databricks_quality_monitor_pluginframework.testMonitorInference", "table_name"),
ResourceName: "databricks_quality_monitor.testMonitorInference",
ImportStateIdFunc: acceptance.BuildImportStateIdFunc("databricks_quality_monitor.testMonitorInference", "table_name"),
ImportStateVerify: true,
ImportStateVerifyIdentifierAttribute: "table_name",
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type VolumesList struct {
}

func (d *VolumesDataSource) Metadata(ctx context.Context, req datasource.MetadataRequest, resp *datasource.MetadataResponse) {
resp.TypeName = pluginfwcommon.GetDatabricksStagingName(dataSourceName)
resp.TypeName = pluginfwcommon.GetDatabricksProductionName(dataSourceName)
}

func (d *VolumesDataSource) Schema(ctx context.Context, req datasource.SchemaRequest, resp *datasource.SchemaResponse) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ import (

func checkDataSourceVolumesPopulated(t *testing.T) func(s *terraform.State) error {
return func(s *terraform.State) error {
_, ok := s.Modules[0].Resources["data.databricks_volumes_pluginframework.this"]
require.True(t, ok, "data.databricks_volumes_pluginframework.this has to be there")
_, ok := s.Modules[0].Resources["data.databricks_volumes.this"]
require.True(t, ok, "data.databricks_volumes.this has to be there")
num_volumes, _ := strconv.Atoi(s.Modules[0].Outputs["volumes"].Value.(string))
assert.GreaterOrEqual(t, num_volumes, 1)
return nil
Expand Down Expand Up @@ -45,13 +45,13 @@ func TestUcAccDataSourceVolumes(t *testing.T) {
schema_name = databricks_schema.things.name
volume_type = "MANAGED"
}
data "databricks_volumes_pluginframework" "this" {
data "databricks_volumes" "this" {
catalog_name = databricks_catalog.sandbox.name
schema_name = databricks_schema.things.name
depends_on = [ databricks_volume.this ]
}
output "volumes" {
value = length(data.databricks_volumes_pluginframework.this.ids)
value = length(data.databricks_volumes.this.ids)
}
`,
Check: checkDataSourceVolumesPopulated(t),
Expand Down
4 changes: 2 additions & 2 deletions internal/providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,15 @@ func GetProviderServer(ctx context.Context, options ...ServerOption) (tfprotov6.
}
sdkPluginProvider := serverOptions.sdkV2Provider
if sdkPluginProvider == nil {
sdkPluginProvider = sdkv2.DatabricksProvider()
sdkPluginProvider = sdkv2.DatabricksProvider(ctx)
}
pluginFrameworkProvider := serverOptions.pluginFrameworkProvider
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can provide own options

if pluginFrameworkProvider == nil {
pluginFrameworkProvider = pluginfw.GetDatabricksProviderPluginFramework()
}

upgradedSdkPluginProvider, err := tf5to6server.UpgradeServer(
context.Background(),
ctx,
sdkPluginProvider.GRPCProvider,
)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/providers/providers_test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func (pf providerFixture) configureProviderAndReturnClient_SDKv2(t *testing.T) (
for k, v := range pf.env {
t.Setenv(k, v)
}
p := sdkv2.DatabricksProvider()
p := sdkv2.DatabricksProvider(context.Background())
ctx := context.Background()
diags := p.Configure(ctx, terraform.NewResourceConfigRaw(pf.rawConfigSDKv2()))
if len(diags) > 0 {
Expand Down
Loading
Loading