Skip to content

Commit

Permalink
Add support for provider namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
minamijoyo committed Oct 8, 2023
1 parent b2896e0 commit 53895e5
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 22 deletions.
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,12 @@ $ tfupdate provider --help
Usage: tfupdate provider [options] <PROVIDER_NAME> <PATH>
Arguments
PROVIDER_NAME A name of provider (e.g. aws, google, azurerm)
PROVIDER_NAME A name of provider (e.g. aws or integrations/github)
PATH A path of file or directory to update
Options:
-v --version A new version constraint (default: latest)
If the version is omitted, the latest version is automatically checked and set.
Getting the latest version automatically is supported only for official providers.
If you have an unofficial provider, use release latest command.
-r --recursive Check a directory recursively (default: false)
-i --ignore-path A regular expression for path to ignore
If you want to ignore multiple directories, set the flag multiple times.
Expand Down
12 changes: 8 additions & 4 deletions command/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,13 @@ func (c *ProviderCommand) Run(args []string) int {

v := c.version
if v == "latest" {
source := fmt.Sprintf("terraform-providers/terraform-provider-%s", c.name)
source := ""
if strings.Contains(c.name, "/") {
namespace, name, _ := strings.Cut(c.name, "/")
source = fmt.Sprintf("%s/terraform-provider-%s", namespace, name)
} else {
source = fmt.Sprintf("hashicorp/terraform-provider-%s", c.name)
}
r, err := newRelease("github", source)
if err != nil {
c.UI.Error(err.Error())
Expand Down Expand Up @@ -86,14 +92,12 @@ func (c *ProviderCommand) Help() string {
Usage: tfupdate provider [options] <PROVIDER_NAME> <PATH>
Arguments
PROVIDER_NAME A name of provider (e.g. aws, google, azurerm)
PROVIDER_NAME A name of provider (e.g. aws or integrations/github)
PATH A path of file or directory to update
Options:
-v --version A new version constraint (default: latest)
If the version is omitted, the latest version is automatically checked and set.
Getting the latest version automatically is supported only for official providers.
If you have an unofficial provider, use release latest command.
-r --recursive Check a directory recursively (default: false)
-i --ignore-path A regular expression for path to ignore
If you want to ignore multiple directories, set the flag multiple times.
Expand Down
13 changes: 13 additions & 0 deletions tfupdate/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,16 @@ func selectVersion(constraints []string) string {
}
return ""
}

// ResolveProviderShortNameFromSource is a helper function to resolve provider
// short names from the source address.
// If not found, return an empty string.
func (mc *ModuleContext) ResolveProviderShortNameFromSource(source string) string {
for k, v := range mc.requiredProviders {
if v.Source == source {
return k
}
}

return ""
}
94 changes: 94 additions & 0 deletions tfupdate/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,97 @@ func TestSelectVersion(t *testing.T) {
})
}
}

func TestModuleContextResolveProviderShortNameFromSource(t *testing.T) {
cases := []struct {
desc string
src string
source string
want string
}{
{
desc: "simple",
src: `
terraform {
required_providers {
github = {
source = "integrations/github"
version = "5.38.0"
}
}
}
`,
source: "integrations/github",
want: "github",
},
{
desc: "multiple forks",
src: `
terraform {
required_providers {
petoju = {
source = "petoju/mysql"
version = "3.0.41"
}
winebarrel = {
source = "winebarrel/mysql"
version = "1.10.5"
}
}
}
`,
source: "winebarrel/mysql",
want: "winebarrel",
},
{
desc: "not found",
src: `
terraform {
required_providers {
petoju = {
source = "petoju/mysql"
version = "3.0.41"
}
winebarrel = {
source = "winebarrel/mysql"
version = "1.10.5"
}
}
}
`,
source: "foo/mysql",
want: "",
},
}

for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
fs := afero.NewMemMapFs()
dirname := "test"
err := fs.MkdirAll(dirname, os.ModePerm)
if err != nil {
t.Fatalf("failed to create dir: %s", err)
}
err = afero.WriteFile(fs, filepath.Join(dirname, "main.tf"), []byte(tc.src), 0644)
if err != nil {
t.Fatalf("failed to write file: %s", err)
}

gc := &GlobalContext{
fs: fs,
}
mc, err := NewModuleContext(dirname, gc)
if err != nil {
t.Fatalf("failed to new ModuleContext: %s", err)
}

got := mc.ResolveProviderShortNameFromSource(tc.source)

if got != tc.want {
t.Errorf("got: %s, want = %s", got, tc.want)
}
})
}
}
24 changes: 17 additions & 7 deletions tfupdate/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"path/filepath"
"strings"

"github.com/hashicorp/hcl/v2"
"github.com/hashicorp/hcl/v2/hclsyntax"
Expand Down Expand Up @@ -36,29 +37,38 @@ func NewProviderUpdater(name string, version string) (Updater, error) {

// Update updates the provider version constraint.
// Note that this method will rewrite the AST passed as an argument.
func (u *ProviderUpdater) Update(_ context.Context, _ *ModuleContext, filename string, f *hclwrite.File) error {
func (u *ProviderUpdater) Update(_ context.Context, mc *ModuleContext, filename string, f *hclwrite.File) error {
if filepath.Base(filename) == ".terraform.lock.hcl" {
// skip a lock file.
return nil
}

if err := u.updateTerraformBlock(f); err != nil {
if err := u.updateTerraformBlock(mc, f); err != nil {
return err
}

return u.updateProviderBlock(f)
}

func (u *ProviderUpdater) updateTerraformBlock(f *hclwrite.File) error {
func (u *ProviderUpdater) updateTerraformBlock(mc *ModuleContext, f *hclwrite.File) error {
for _, tf := range allMatchingBlocks(f.Body(), "terraform", []string{}) {
p := tf.Body().FirstMatchingBlock("required_providers", []string{})
if p == nil {
continue
}

name := u.name
// If the name contains /, assume that a namespace is intended and check the source.
if strings.Contains(u.name, "/") {
name = mc.ResolveProviderShortNameFromSource(u.name)
if name == "" {
continue
}
}

// The hclwrite.Attribute doesn't have enough AST for object type to check.
// Get the attribute as a native hcl.Attribute as a compromise.
hclAttr, err := getHCLNativeAttribute(p.Body(), u.name)
hclAttr, err := getHCLNativeAttribute(p.Body(), name)
if err != nil {
return err
}
Expand All @@ -72,7 +82,7 @@ func (u *ProviderUpdater) updateTerraformBlock(f *hclwrite.File) error {
u.updateTerraformRequiredProvidersBlockAsString(p)
} else {
// Otherwise, it's an object syntax.
if err := u.updateTerraformRequiredProvidersBlockAsObject(p, hclAttr); err != nil {
if err := u.updateTerraformRequiredProvidersBlockAsObject(p, name, hclAttr); err != nil {
return err
}
}
Expand All @@ -82,7 +92,7 @@ func (u *ProviderUpdater) updateTerraformBlock(f *hclwrite.File) error {
return nil
}

func (u *ProviderUpdater) updateTerraformRequiredProvidersBlockAsObject(p *hclwrite.Block, hclAttr *hcl.Attribute) error {
func (u *ProviderUpdater) updateTerraformRequiredProvidersBlockAsObject(p *hclwrite.Block, name string, hclAttr *hcl.Attribute) error {
// terraform {
// required_providers {
// aws = {
Expand Down Expand Up @@ -115,7 +125,7 @@ func (u *ProviderUpdater) updateTerraformRequiredProvidersBlockAsObject(p *hclwr
// tokens in order, updating the bytes directly.
// It's apparently a fragile dirty hack, but I didn't come up with the better
// way to do this.
attr := p.Body().GetAttribute(u.name)
attr := p.Body().GetAttribute(name)
tokens := attr.Expr().BuildTokens(nil)

i := 0
Expand Down
105 changes: 97 additions & 8 deletions tfupdate/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ package tfupdate

import (
"context"
"os"
"path/filepath"
"reflect"
"testing"

"github.com/hashicorp/hcl/v2"
"github.com/hashicorp/hcl/v2/hclwrite"
"github.com/spf13/afero"
)

func TestNewProviderUpdater(t *testing.T) {
Expand Down Expand Up @@ -336,8 +339,8 @@ terraform {
terraform {
required_providers {
aws = {
"version" = "2.65.0"
"source" = "hashicorp/aws"
version = "2.65.0"
source = "hashicorp/aws"
}
}
}
Expand All @@ -348,8 +351,8 @@ terraform {
terraform {
required_providers {
aws = {
"version" = "2.66.0"
"source" = "hashicorp/aws"
version = "2.66.0"
source = "hashicorp/aws"
}
}
}
Expand Down Expand Up @@ -389,6 +392,68 @@ terraform {
}
}
}
`,
ok: true,
},
{
filename: "main.tf",
src: `
terraform {
required_providers {
github = {
source = "integrations/github"
version = "5.38.0"
}
}
}
`,
name: "integrations/github",
version: "5.39.0",
want: `
terraform {
required_providers {
github = {
source = "integrations/github"
version = "5.39.0"
}
}
}
`,
ok: true,
},
{
filename: "main.tf",
src: `
terraform {
required_providers {
petoju = {
source = "petoju/mysql"
version = "3.0.41"
}
winebarrel = {
source = "winebarrel/mysql"
version = "1.10.5"
}
}
}
`,
name: "winebarrel/mysql",
version: "1.10.6",
want: `
terraform {
required_providers {
petoju = {
source = "petoju/mysql"
version = "3.0.41"
}
winebarrel = {
source = "winebarrel/mysql"
version = "1.10.6"
}
}
}
`,
ok: true,
},
Expand Down Expand Up @@ -429,16 +494,40 @@ provider "registry.terraform.io/hashicorp/null" {
}

for _, tc := range cases {
u := &ProviderUpdater{
name: tc.name,
version: tc.version,
fs := afero.NewMemMapFs()
dirname := "test"
err := fs.MkdirAll(dirname, os.ModePerm)
if err != nil {
t.Fatalf("failed to create dir: %s", err)
}

err = afero.WriteFile(fs, filepath.Join(dirname, "main.tf"), []byte(tc.src), 0644)
if err != nil {
t.Fatalf("failed to write file: %s", err)
}

o := Option{
updateType: "provider",
name: tc.name,
version: tc.version,
}
gc, err := NewGlobalContext(fs, o)
if err != nil {
t.Fatalf("failed to new global context: %s", err)
}

mc, err := NewModuleContext(dirname, gc)
if err != nil {
t.Fatalf("failed to new module context: %s", err)
}

u := gc.updater
f, diags := hclwrite.ParseConfig([]byte(tc.src), tc.filename, hcl.Pos{Line: 1, Column: 1})
if diags.HasErrors() {
t.Fatalf("unexpected diagnostics: %s", diags)
}

err := u.Update(context.Background(), nil, tc.filename, f)
err = u.Update(context.Background(), mc, tc.filename, f)
if tc.ok && err != nil {
t.Errorf("Update() with src = %s, name = %s, version = %s returns unexpected err: %+v", tc.src, tc.name, tc.version, err)
}
Expand Down

0 comments on commit 53895e5

Please sign in to comment.