diff --git a/README.md b/README.md index 3901161..bc7272d 100644 --- a/README.md +++ b/README.md @@ -152,14 +152,12 @@ $ tfupdate provider --help Usage: tfupdate provider [options] Arguments - PROVIDER_NAME A name of provider (e.g. aws, google, azurerm) + PROVIDER_NAME A name of provider (e.g. aws or integrations/github) PATH A path of file or directory to update Options: -v --version A new version constraint (default: latest) If the version is omitted, the latest version is automatically checked and set. - Getting the latest version automatically is supported only for official providers. - If you have an unofficial provider, use release latest command. -r --recursive Check a directory recursively (default: false) -i --ignore-path A regular expression for path to ignore If you want to ignore multiple directories, set the flag multiple times. diff --git a/command/provider.go b/command/provider.go index 525bfa9..31479c2 100644 --- a/command/provider.go +++ b/command/provider.go @@ -44,7 +44,13 @@ func (c *ProviderCommand) Run(args []string) int { v := c.version if v == "latest" { - source := fmt.Sprintf("terraform-providers/terraform-provider-%s", c.name) + source := "" + if strings.Contains(c.name, "/") { + namespace, name, _ := strings.Cut(c.name, "/") + source = fmt.Sprintf("%s/terraform-provider-%s", namespace, name) + } else { + source = fmt.Sprintf("hashicorp/terraform-provider-%s", c.name) + } r, err := newRelease("github", source) if err != nil { c.UI.Error(err.Error()) @@ -86,14 +92,12 @@ func (c *ProviderCommand) Help() string { Usage: tfupdate provider [options] Arguments - PROVIDER_NAME A name of provider (e.g. aws, google, azurerm) + PROVIDER_NAME A name of provider (e.g. aws or integrations/github) PATH A path of file or directory to update Options: -v --version A new version constraint (default: latest) If the version is omitted, the latest version is automatically checked and set. - Getting the latest version automatically is supported only for official providers. - If you have an unofficial provider, use release latest command. -r --recursive Check a directory recursively (default: false) -i --ignore-path A regular expression for path to ignore If you want to ignore multiple directories, set the flag multiple times. diff --git a/tfupdate/context.go b/tfupdate/context.go index 7d49bd9..f6a6dfb 100644 --- a/tfupdate/context.go +++ b/tfupdate/context.go @@ -181,3 +181,16 @@ func selectVersion(constraints []string) string { } return "" } + +// ResolveProviderShortNameFromSource is a helper function to resolve provider +// short names from the source address. +// If not found, return an empty string. +func (mc *ModuleContext) ResolveProviderShortNameFromSource(source string) string { + for k, v := range mc.requiredProviders { + if v.Source == source { + return k + } + } + + return "" +} diff --git a/tfupdate/context_test.go b/tfupdate/context_test.go index ab8f6a9..6aa5e6b 100644 --- a/tfupdate/context_test.go +++ b/tfupdate/context_test.go @@ -216,3 +216,97 @@ func TestSelectVersion(t *testing.T) { }) } } + +func TestModuleContextResolveProviderShortNameFromSource(t *testing.T) { + cases := []struct { + desc string + src string + source string + want string + }{ + { + desc: "simple", + src: ` +terraform { + required_providers { + github = { + source = "integrations/github" + version = "5.38.0" + } + } +} +`, + source: "integrations/github", + want: "github", + }, + { + desc: "multiple forks", + src: ` +terraform { + required_providers { + petoju = { + source = "petoju/mysql" + version = "3.0.41" + } + + winebarrel = { + source = "winebarrel/mysql" + version = "1.10.5" + } + } +} +`, + source: "winebarrel/mysql", + want: "winebarrel", + }, + { + desc: "not found", + src: ` +terraform { + required_providers { + petoju = { + source = "petoju/mysql" + version = "3.0.41" + } + + winebarrel = { + source = "winebarrel/mysql" + version = "1.10.5" + } + } +} +`, + source: "foo/mysql", + want: "", + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + fs := afero.NewMemMapFs() + dirname := "test" + err := fs.MkdirAll(dirname, os.ModePerm) + if err != nil { + t.Fatalf("failed to create dir: %s", err) + } + err = afero.WriteFile(fs, filepath.Join(dirname, "main.tf"), []byte(tc.src), 0644) + if err != nil { + t.Fatalf("failed to write file: %s", err) + } + + gc := &GlobalContext{ + fs: fs, + } + mc, err := NewModuleContext(dirname, gc) + if err != nil { + t.Fatalf("failed to new ModuleContext: %s", err) + } + + got := mc.ResolveProviderShortNameFromSource(tc.source) + + if got != tc.want { + t.Errorf("got: %s, want = %s", got, tc.want) + } + }) + } +} diff --git a/tfupdate/provider.go b/tfupdate/provider.go index 9b5d3c4..940a708 100644 --- a/tfupdate/provider.go +++ b/tfupdate/provider.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "path/filepath" + "strings" "github.com/hashicorp/hcl/v2" "github.com/hashicorp/hcl/v2/hclsyntax" @@ -36,29 +37,38 @@ func NewProviderUpdater(name string, version string) (Updater, error) { // Update updates the provider version constraint. // Note that this method will rewrite the AST passed as an argument. -func (u *ProviderUpdater) Update(_ context.Context, _ *ModuleContext, filename string, f *hclwrite.File) error { +func (u *ProviderUpdater) Update(_ context.Context, mc *ModuleContext, filename string, f *hclwrite.File) error { if filepath.Base(filename) == ".terraform.lock.hcl" { // skip a lock file. return nil } - if err := u.updateTerraformBlock(f); err != nil { + if err := u.updateTerraformBlock(mc, f); err != nil { return err } return u.updateProviderBlock(f) } -func (u *ProviderUpdater) updateTerraformBlock(f *hclwrite.File) error { +func (u *ProviderUpdater) updateTerraformBlock(mc *ModuleContext, f *hclwrite.File) error { for _, tf := range allMatchingBlocks(f.Body(), "terraform", []string{}) { p := tf.Body().FirstMatchingBlock("required_providers", []string{}) if p == nil { continue } + name := u.name + // If the name contains /, assume that a namespace is intended and check the source. + if strings.Contains(u.name, "/") { + name = mc.ResolveProviderShortNameFromSource(u.name) + if name == "" { + continue + } + } + // The hclwrite.Attribute doesn't have enough AST for object type to check. // Get the attribute as a native hcl.Attribute as a compromise. - hclAttr, err := getHCLNativeAttribute(p.Body(), u.name) + hclAttr, err := getHCLNativeAttribute(p.Body(), name) if err != nil { return err } @@ -72,7 +82,7 @@ func (u *ProviderUpdater) updateTerraformBlock(f *hclwrite.File) error { u.updateTerraformRequiredProvidersBlockAsString(p) } else { // Otherwise, it's an object syntax. - if err := u.updateTerraformRequiredProvidersBlockAsObject(p, hclAttr); err != nil { + if err := u.updateTerraformRequiredProvidersBlockAsObject(p, name, hclAttr); err != nil { return err } } @@ -82,7 +92,7 @@ func (u *ProviderUpdater) updateTerraformBlock(f *hclwrite.File) error { return nil } -func (u *ProviderUpdater) updateTerraformRequiredProvidersBlockAsObject(p *hclwrite.Block, hclAttr *hcl.Attribute) error { +func (u *ProviderUpdater) updateTerraformRequiredProvidersBlockAsObject(p *hclwrite.Block, name string, hclAttr *hcl.Attribute) error { // terraform { // required_providers { // aws = { @@ -115,7 +125,7 @@ func (u *ProviderUpdater) updateTerraformRequiredProvidersBlockAsObject(p *hclwr // tokens in order, updating the bytes directly. // It's apparently a fragile dirty hack, but I didn't come up with the better // way to do this. - attr := p.Body().GetAttribute(u.name) + attr := p.Body().GetAttribute(name) tokens := attr.Expr().BuildTokens(nil) i := 0 diff --git a/tfupdate/provider_test.go b/tfupdate/provider_test.go index 964aafd..b8d7ecc 100644 --- a/tfupdate/provider_test.go +++ b/tfupdate/provider_test.go @@ -2,11 +2,14 @@ package tfupdate import ( "context" + "os" + "path/filepath" "reflect" "testing" "github.com/hashicorp/hcl/v2" "github.com/hashicorp/hcl/v2/hclwrite" + "github.com/spf13/afero" ) func TestNewProviderUpdater(t *testing.T) { @@ -336,8 +339,8 @@ terraform { terraform { required_providers { aws = { - "version" = "2.65.0" - "source" = "hashicorp/aws" + version = "2.65.0" + source = "hashicorp/aws" } } } @@ -348,8 +351,8 @@ terraform { terraform { required_providers { aws = { - "version" = "2.66.0" - "source" = "hashicorp/aws" + version = "2.66.0" + source = "hashicorp/aws" } } } @@ -389,6 +392,68 @@ terraform { } } } +`, + ok: true, + }, + { + filename: "main.tf", + src: ` +terraform { + required_providers { + github = { + source = "integrations/github" + version = "5.38.0" + } + } +} +`, + name: "integrations/github", + version: "5.39.0", + want: ` +terraform { + required_providers { + github = { + source = "integrations/github" + version = "5.39.0" + } + } +} +`, + ok: true, + }, + { + filename: "main.tf", + src: ` +terraform { + required_providers { + petoju = { + source = "petoju/mysql" + version = "3.0.41" + } + + winebarrel = { + source = "winebarrel/mysql" + version = "1.10.5" + } + } +} +`, + name: "winebarrel/mysql", + version: "1.10.6", + want: ` +terraform { + required_providers { + petoju = { + source = "petoju/mysql" + version = "3.0.41" + } + + winebarrel = { + source = "winebarrel/mysql" + version = "1.10.6" + } + } +} `, ok: true, }, @@ -429,16 +494,40 @@ provider "registry.terraform.io/hashicorp/null" { } for _, tc := range cases { - u := &ProviderUpdater{ - name: tc.name, - version: tc.version, + fs := afero.NewMemMapFs() + dirname := "test" + err := fs.MkdirAll(dirname, os.ModePerm) + if err != nil { + t.Fatalf("failed to create dir: %s", err) + } + + err = afero.WriteFile(fs, filepath.Join(dirname, "main.tf"), []byte(tc.src), 0644) + if err != nil { + t.Fatalf("failed to write file: %s", err) } + + o := Option{ + updateType: "provider", + name: tc.name, + version: tc.version, + } + gc, err := NewGlobalContext(fs, o) + if err != nil { + t.Fatalf("failed to new global context: %s", err) + } + + mc, err := NewModuleContext(dirname, gc) + if err != nil { + t.Fatalf("failed to new module context: %s", err) + } + + u := gc.updater f, diags := hclwrite.ParseConfig([]byte(tc.src), tc.filename, hcl.Pos{Line: 1, Column: 1}) if diags.HasErrors() { t.Fatalf("unexpected diagnostics: %s", diags) } - err := u.Update(context.Background(), nil, tc.filename, f) + err = u.Update(context.Background(), mc, tc.filename, f) if tc.ok && err != nil { t.Errorf("Update() with src = %s, name = %s, version = %s returns unexpected err: %+v", tc.src, tc.name, tc.version, err) }