Skip to content

Commit

Permalink
Adding registry url discovery for Terragrunt Provider Cache (#3299)
Browse files Browse the repository at this point in the history
* chore: registry discovery url

* chore: default urls for registry without well-known endpoint

* chore: building url if registry download url is a filename

* Update cli/provider_cache.go

Co-authored-by: Yousif Akbar <[email protected]>

---------

Co-authored-by: Yousif Akbar <[email protected]>
  • Loading branch information
levkohimins and yhakbar authored Jul 26, 2024
1 parent d7423b8 commit e951157
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 43 deletions.
31 changes: 25 additions & 6 deletions cli/provider_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package cli

import (
"context"
liberrors "errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"regexp"
Expand Down Expand Up @@ -100,6 +102,7 @@ func InitProviderCacheServer(opts *options.TerragruntOptions) (*ProviderCache, e
var (
providerHandlers []handlers.ProviderHandler
excludeAddrs []string
directIsdefined bool
)

for _, registryName := range opts.ProviderCacheRegistryNames {
Expand All @@ -114,10 +117,15 @@ func InitProviderCacheServer(opts *options.TerragruntOptions) (*ProviderCache, e
providerHandlers = append(providerHandlers, handlers.NewProviderNetworkMirrorHandler(providerService, cacheProviderHTTPStatusCode, method))
case *cliconfig.ProviderInstallationDirect:
providerHandlers = append(providerHandlers, handlers.NewProviderDirectHandler(providerService, cacheProviderHTTPStatusCode, method))
directIsdefined = true
}
method.AppendExclude(excludeAddrs)
}
providerHandlers = append(providerHandlers, handlers.NewProviderDirectHandler(providerService, cacheProviderHTTPStatusCode, new(cliconfig.ProviderInstallationDirect)))

if !directIsdefined {
// In a case if none of direct provider installation methods `cliCfg.ProviderInstallation.Methods` are specified.
providerHandlers = append(providerHandlers, handlers.NewProviderDirectHandler(providerService, cacheProviderHTTPStatusCode, new(cliconfig.ProviderInstallationDirect)))
}

cache := cache.NewServer(
cache.WithHostname(opts.ProviderCacheHostname),
Expand Down Expand Up @@ -160,7 +168,7 @@ func (cache *ProviderCache) TerraformCommandHook(ctx context.Context, opts *opti
}

// Create terraform cli config file that enables provider caching and does not use provider cache dir
if err := cache.createLocalCLIConfig(opts, cliConfigFilename, cacheRequestID); err != nil {
if err := cache.createLocalCLIConfig(ctx, opts, cliConfigFilename, cacheRequestID); err != nil {
return nil, err
}

Expand All @@ -184,7 +192,7 @@ func (cache *ProviderCache) TerraformCommandHook(ctx context.Context, opts *opti
}

// Create terraform cli config file that uses provider cache dir
if err := cache.createLocalCLIConfig(opts, cliConfigFilename, ""); err != nil {
if err := cache.createLocalCLIConfig(ctx, opts, cliConfigFilename, ""); err != nil {
return nil, err
}

Expand Down Expand Up @@ -227,7 +235,7 @@ func (cache *ProviderCache) TerraformCommandHook(ctx context.Context, opts *opti
// It creates two types of configuration depending on the `cacheRequestID` variable set.
// 1. If `cacheRequestID` is set, `terraform init` does _not_ use the provider cache directory, the cache server creates a cache for requested providers and returns HTTP status 423. Since for each module we create the CLI config, using `cacheRequestID` we have the opportunity later retrieve from the cache server exactly those cached providers that were requested by `terraform init` using this configuration.
// 2. If `cacheRequestID` is empty, 'terraform init` uses provider cache directory, the cache server acts as a proxy.
func (cache *ProviderCache) createLocalCLIConfig(opts *options.TerragruntOptions, filename string, cacheRequestID string) error {
func (cache *ProviderCache) createLocalCLIConfig(ctx context.Context, opts *options.TerragruntOptions, filename string, cacheRequestID string) error {
cfg := cache.cliCfg.Clone()
cfg.PluginCacheDir = ""

Expand All @@ -236,10 +244,21 @@ func (cache *ProviderCache) createLocalCLIConfig(opts *options.TerragruntOptions
for _, registryName := range opts.ProviderCacheRegistryNames {
providerInstallationIncludes = append(providerInstallationIncludes, fmt.Sprintf("%s/*/*", registryName))

urls, err := DiscoveryURL(ctx, registryName)
if err != nil {
if !liberrors.As(err, &NotFoundWellKnownURL{}) {
return err
}
urls = DefaultRegistryURLs
opts.Logger.Debugf("Unable to discover %q registry URLs, reason: %q, use default URLs: %s", registryName, err, urls)
} else {
opts.Logger.Debugf("Discovered %q registry URLs: %s", registryName, urls)
}

cfg.AddHost(registryName, map[string]string{
"providers.v1": fmt.Sprintf("%s/%s/%s/", cache.ProviderController.URL(), cacheRequestID, registryName),
"providers.v1": fmt.Sprintf("%s/%s/%s/%s/", cache.ProviderController.URL(), cacheRequestID, url.PathEscape(urls.ProvidersV1), registryName),
// Since Terragrunt Provider Cache only caches providers, we need to route module requests to the original registry.
"modules.v1": fmt.Sprintf("https://%s/v1/modules", registryName),
"modules.v1": fmt.Sprintf("https://%s%s", registryName, urls.ModulesV1),
})
}

Expand Down
22 changes: 11 additions & 11 deletions cli/provider_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"regexp"
Expand Down Expand Up @@ -40,14 +41,13 @@ func TestProviderCache(t *testing.T) {

token := fmt.Sprintf("%s:%s", apiKeyAuth, uuid.New().String())

providerCacheDir, err := os.MkdirTemp("", "*")
require.NoError(t, err)

pluginCacheDir, err := os.MkdirTemp("", "*")
require.NoError(t, err)
providerCacheDir := t.TempDir()
pluginCacheDir := t.TempDir()

opts := []cache.Option{cache.WithToken(token)}

registryPrefix := url.PathEscape("/v1/providers/")

testCases := []struct {
opts []cache.Option
urlPath string
Expand All @@ -63,36 +63,36 @@ func TestProviderCache(t *testing.T) {
},
{
opts: append(opts, cache.WithToken("")),
urlPath: "/v1/providers/cache/registry.terraform.io/hashicorp/aws/versions",
urlPath: "/v1/providers/cache/" + registryPrefix + "/registry.terraform.io/hashicorp/aws/versions",
expectedStatusCode: http.StatusUnauthorized,
},
{
opts: opts,
urlPath: "/v1/providers/cache/registry.terraform.io/hashicorp/aws/versions",
urlPath: "/v1/providers/cache/" + registryPrefix + "/registry.terraform.io/hashicorp/aws/versions",
expectedStatusCode: http.StatusOK,
expectedBodyReg: regexp.MustCompile(regexp.QuoteMeta(`"version":"5.36.0","protocols":["5.0"],"platforms"`)),
},
{
opts: opts,
urlPath: "/v1/providers/cache/registry.terraform.io/hashicorp/aws/5.36.0/download/darwin/arm64",
urlPath: "/v1/providers/cache/" + registryPrefix + "/registry.terraform.io/hashicorp/aws/5.36.0/download/darwin/arm64",
expectedStatusCode: http.StatusLocked,
expectedCachePath: "registry.terraform.io/hashicorp/aws/5.36.0/darwin_arm64/terraform-provider-aws_v5.36.0_x5",
},
{
opts: opts,
urlPath: "/v1/providers/cache/registry.terraform.io/hashicorp/template/2.2.0/download/linux/amd64",
urlPath: "/v1/providers/cache/" + registryPrefix + "/registry.terraform.io/hashicorp/template/2.2.0/download/linux/amd64",
expectedStatusCode: http.StatusLocked,
expectedCachePath: "registry.terraform.io/hashicorp/template/2.2.0/linux_amd64/terraform-provider-template_v2.2.0_x4",
},
{
opts: opts,
urlPath: fmt.Sprintf("/v1/providers/cache/registry.terraform.io/hashicorp/template/1234.5678.9/download/%s/%s", runtime.GOOS, runtime.GOARCH),
urlPath: fmt.Sprintf("/v1/providers/cache/%s/registry.terraform.io/hashicorp/template/1234.5678.9/download/%s/%s", registryPrefix, runtime.GOOS, runtime.GOARCH),
expectedStatusCode: http.StatusLocked,
expectedCachePath: createFakeProvider(t, pluginCacheDir, fmt.Sprintf("registry.terraform.io/hashicorp/template/1234.5678.9/%s_%s/terraform-provider-template_1234.5678.9_x5", runtime.GOOS, runtime.GOARCH)),
},
{
opts: opts,
urlPath: "/v1/providers//registry.terraform.io/hashicorp/aws/5.36.0/download/darwin/arm64",
urlPath: "/v1/providers//" + registryPrefix + "/registry.terraform.io/hashicorp/aws/5.36.0/download/darwin/arm64",
expectedStatusCode: http.StatusOK,
expectedBodyReg: regexp.MustCompile(`\{.*` + regexp.QuoteMeta(`"download_url":"http://127.0.0.1:`) + `\d+` + regexp.QuoteMeta(`/downloads/releases.hashicorp.com/terraform-provider-aws/5.36.0/terraform-provider-aws_5.36.0_darwin_arm64.zip"`) + `.*\}`),
},
Expand Down
75 changes: 75 additions & 0 deletions cli/registry_urls.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package cli

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"

"github.com/gruntwork-io/go-commons/errors"
)

const (
// well-known address for discovery URLs
wellKnownURL = ".well-known/terraform.json"
)

var (
DefaultRegistryURLs = &RegistryURLs{
ModulesV1: "/v1/modules",
ProvidersV1: "/v1/providers",
}
)

type RegistryURLs struct {
ModulesV1 string `json:"modules.v1"`
ProvidersV1 string `json:"providers.v1"`
}

func (urls *RegistryURLs) String() string {
b, _ := json.Marshal(urls) //nolint:errcheck
return string(b)
}

func DiscoveryURL(ctx context.Context, registryName string) (*RegistryURLs, error) {
url := fmt.Sprintf("https://%s/%s", registryName, wellKnownURL)

req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, errors.WithStackTrace(err)
}

resp, err := (&http.Client{}).Do(req)
if err != nil {
return nil, errors.WithStackTrace(err)
}
defer resp.Body.Close() //nolint:errcheck

switch resp.StatusCode {
case http.StatusNotFound:
return nil, errors.WithStackTrace(NotFoundWellKnownURL{wellKnownURL})
case http.StatusOK:
default:
return nil, fmt.Errorf("%s returned %s", url, resp.Status)
}

content, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errors.WithStackTrace(err)
}

urls := new(RegistryURLs)
if err := json.Unmarshal(content, urls); err != nil {
return nil, errors.WithStackTrace(err)
}
return urls, nil
}

type NotFoundWellKnownURL struct {
url string
}

func (err NotFoundWellKnownURL) Error() string {
return fmt.Sprintf("%s not found", err.url)
}
43 changes: 29 additions & 14 deletions terraform/cache/controllers/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package controllers

import (
"net/http"
"net/url"

"github.com/gruntwork-io/terragrunt/terraform/cache/handlers"
"github.com/gruntwork-io/terragrunt/terraform/cache/models"
Expand Down Expand Up @@ -44,24 +45,31 @@ func (controller *ProviderController) Register(router *router.Router) {

// Get All Versions for a Single Provider
// https://developer.hashicorp.com/terraform/cloud-docs/api-docs/private-registry/provider-versions-platforms#get-all-versions-for-a-single-provider
controller.GET("/:cache_request_id/:registry_name/:namespace/:name/versions", controller.getVersionsAction)
controller.GET("/:cache_request_id/:registry_prefix/:registry_name/:namespace/:name/versions", controller.getVersionsAction)

// Get a Platform
// https://developer.hashicorp.com/terraform/cloud-docs/api-docs/private-registry/provider-versions-platforms#get-a-platform
controller.GET("/:cache_request_id/:registry_name/:namespace/:name/:version/download/:os/:arch", controller.getPlatformsAction)
controller.GET("/:cache_request_id/:registry_prefix/:registry_name/:namespace/:name/:version/download/:os/:arch", controller.getPlatformsAction)
}

func (controller *ProviderController) getVersionsAction(ctx echo.Context) error {
var (
registryName = ctx.Param("registry_name")
namespace = ctx.Param("namespace")
name = ctx.Param("name")
registryPrefix = ctx.Param("registry_prefix")
registryName = ctx.Param("registry_name")
namespace = ctx.Param("namespace")
name = ctx.Param("name")
)

registryPrefix, err := url.QueryUnescape(registryPrefix)
if err != nil {
return err
}

provider := &models.Provider{
RegistryName: registryName,
Namespace: namespace,
Name: name,
RegistryPrefix: registryPrefix,
RegistryName: registryName,
Namespace: namespace,
Name: name,
}

for _, handler := range controller.ProviderHandlers {
Expand All @@ -76,6 +84,7 @@ func (controller *ProviderController) getVersionsAction(ctx echo.Context) error

func (controller *ProviderController) getPlatformsAction(ctx echo.Context) (er error) {
var (
registryPrefix = ctx.Param("registry_prefix")
registryName = ctx.Param("registry_name")
namespace = ctx.Param("namespace")
name = ctx.Param("name")
Expand All @@ -85,13 +94,19 @@ func (controller *ProviderController) getPlatformsAction(ctx echo.Context) (er e
cacheRequestID = ctx.Param("cache_request_id")
)

registryPrefix, err := url.QueryUnescape(registryPrefix)
if err != nil {
return err
}

provider := &models.Provider{
RegistryName: registryName,
Namespace: namespace,
Name: name,
Version: version,
OS: os,
Arch: arch,
RegistryPrefix: registryPrefix,
RegistryName: registryName,
Namespace: namespace,
Name: name,
Version: version,
OS: os,
Arch: arch,
}

for _, handler := range controller.ProviderHandlers {
Expand Down
19 changes: 15 additions & 4 deletions terraform/cache/handlers/provider_direct.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ import (
"net/http"
"net/url"
"path"
"path/filepath"
"strconv"
"strings"

"github.com/gruntwork-io/terragrunt/pkg/log"
"github.com/gruntwork-io/terragrunt/terraform/cache/helpers"
Expand Down Expand Up @@ -42,7 +44,7 @@ type ProviderDirectHandler struct {
cacheProviderHTTPStatusCode int
}

func NewProviderDirectHandler(providerService *services.ProviderService, cacheProviderHTTPStatusCode int, method *cliconfig.ProviderInstallationDirect) ProviderHandler {
func NewProviderDirectHandler(providerService *services.ProviderService, cacheProviderHTTPStatusCode int, method *cliconfig.ProviderInstallationDirect) *ProviderDirectHandler {
return &ProviderDirectHandler{
CommonProviderHandler: NewCommonProviderHandler(method.Include, method.Exclude),
ReverseProxy: &ReverseProxy{},
Expand All @@ -61,7 +63,7 @@ func (handler *ProviderDirectHandler) GetVersions(ctx echo.Context, provider *mo
reqURL := &url.URL{
Scheme: "https",
Host: provider.RegistryName,
Path: path.Join("/v1/providers", provider.Namespace, provider.Name, "versions"),
Path: path.Join(provider.RegistryPrefix, provider.Namespace, provider.Name, "versions"),
}

return handler.ReverseProxy.NewRequest(ctx, reqURL)
Expand Down Expand Up @@ -101,12 +103,21 @@ func (handler *ProviderDirectHandler) Download(ctx echo.Context, provider *model
}
}

// check if the URL contains http scheme, it may just be a filename and we need to build the URL
if !strings.Contains(provider.DownloadURL, "://") {
downloadURL := &url.URL{
Scheme: "https",
Host: provider.RegistryName,
Path: filepath.Join(provider.RegistryPrefix, provider.RegistryName, provider.Namespace, provider.Name, provider.DownloadURL),
}
return handler.ReverseProxy.NewRequest(ctx, downloadURL)
}

downloadURL, err := url.Parse(provider.DownloadURL)
if err != nil {
return err
}
return handler.ReverseProxy.NewRequest(ctx, downloadURL)

}

// platformURL returns the URL used to query the all platforms for a single version.
Expand All @@ -115,7 +126,7 @@ func (handler *ProviderDirectHandler) platformURL(provider *models.Provider) *ur
return &url.URL{
Scheme: "https",
Host: provider.RegistryName,
Path: path.Join("/v1/providers", provider.Namespace, provider.Name, provider.Version, "download", provider.OS, provider.Arch),
Path: path.Join(provider.RegistryPrefix, provider.Namespace, provider.Name, provider.Version, "download", provider.OS, provider.Arch),
}
}

Expand Down
8 changes: 7 additions & 1 deletion terraform/cache/handlers/provider_filesystem_mirror.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net/http"
"os"
"path/filepath"
"strings"

"github.com/gruntwork-io/go-commons/errors"
"github.com/gruntwork-io/terragrunt/terraform/cache/models"
Expand Down Expand Up @@ -83,9 +84,14 @@ func (handler *ProviderFilesystemMirrorHandler) GetPlatform(ctx echo.Context, pr
}

if archive, ok := mirrorData.Archives[provider.Platform()]; ok {
// check if the URL contains http scheme, it may just be a filename and we need to build the URL
if !strings.Contains(archive.URL, "://") {
archive.URL = filepath.Join(handler.filesystemMirrorPath, provider.RegistryName, provider.Namespace, provider.Name, archive.URL)
}

provider.ResponseBody = &models.ResponseBody{
Filename: filepath.Base(archive.URL),
DownloadURL: filepath.Join(handler.filesystemMirrorPath, provider.RegistryName, provider.Namespace, provider.Name, archive.URL),
DownloadURL: archive.URL,
}
} else {
return ctx.NoContent(http.StatusNotFound)
Expand Down
Loading

0 comments on commit e951157

Please sign in to comment.