diff --git a/pkg/download/download.go b/pkg/download/download.go index 90f28ef..3a4bef5 100644 --- a/pkg/download/download.go +++ b/pkg/download/download.go @@ -28,6 +28,8 @@ import ( type RequestOption = func(*http.Request) +type ResponseChecker = func(*http.Response) error + func ApplyURLTransformer(urlTransformer URLTransformer, baseURLs ...string) ([]string, error) { transformedURLs := make([]string, len(baseURLs)) for index, baseURL := range baseURLs { @@ -42,7 +44,7 @@ func ApplyURLTransformer(urlTransformer URLTransformer, baseURLs ...string) ([]s return transformedURLs, nil } -func Bytes(ctx context.Context, url string, display func(string), requestOptions ...RequestOption) ([]byte, error) { +func Bytes(ctx context.Context, url string, display func(string), checker ResponseChecker, requestOptions ...RequestOption) ([]byte, error) { display("Downloading " + url) request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) @@ -60,11 +62,15 @@ func Bytes(ctx context.Context, url string, display func(string), requestOptions } defer response.Body.Close() + if err = checker(response); err != nil { + return nil, err + } + return io.ReadAll(response.Body) } -func JSON(ctx context.Context, url string, display func(string), requestOptions ...RequestOption) (any, error) { - data, err := Bytes(ctx, url, display, requestOptions...) +func JSON(ctx context.Context, url string, display func(string), checker ResponseChecker, requestOptions ...RequestOption) (any, error) { + data, err := Bytes(ctx, url, display, checker, requestOptions...) if err != nil { return nil, err } @@ -103,3 +109,7 @@ func WithBasicAuth(username string, password string) RequestOption { func NoTransform(value string) (string, error) { return value, nil } + +func NoCheck(*http.Response) error { + return nil +} diff --git a/pkg/github/github.go b/pkg/github/github.go index bba96ba..1ab6862 100644 --- a/pkg/github/github.go +++ b/pkg/github/github.go @@ -20,13 +20,13 @@ package github import ( "context" - "encoding/json" "errors" "net/http" "net/url" "strconv" "github.com/tofuutils/tenv/v3/pkg/apimsg" + "github.com/tofuutils/tenv/v3/pkg/download" versionfinder "github.com/tofuutils/tenv/v3/versionmanager/semantic/finder" ) @@ -113,28 +113,13 @@ func ListReleases(ctx context.Context, githubReleaseURL string, githubToken stri } func apiGetRequest(ctx context.Context, callURL string, authorizationHeader string) (any, error) { - resp, err := downloadWithHeaders(ctx, callURL, func(request *http.Request) { + return download.JSON(ctx, callURL, download.NoDisplay, checkRateLimit, func(request *http.Request) { request.Header.Set("Accept", "application/vnd.github+json") if authorizationHeader != "" { request.Header.Set("Authorization", authorizationHeader) } request.Header.Set("X-GitHub-Api-Version", "2022-11-28") //nolint }) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - if err := checkRateLimit(resp); err != nil { - return nil, err - } - - var result any - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return nil, apimsg.ErrReturn - } - - return result, nil } func checkRateLimit(resp *http.Response) error { @@ -146,23 +131,6 @@ func checkRateLimit(resp *http.Response) error { return nil } -func downloadWithHeaders(ctx context.Context, url string, modifyRequest func(*http.Request)) (*http.Response, error) { - client := &http.Client{} - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) - if err != nil { - return nil, err - } - if modifyRequest != nil { - modifyRequest(req) - } - resp, err := client.Do(req) - if err != nil { - return nil, err - } - - return resp, nil -} - func buildAuthorizationHeader(token string) string { if token == "" { return "" diff --git a/pkg/htmlquery/html.go b/pkg/htmlquery/html.go index dd595be..e67a0dd 100644 --- a/pkg/htmlquery/html.go +++ b/pkg/htmlquery/html.go @@ -29,7 +29,7 @@ import ( ) func Request(ctx context.Context, callURL string, selector string, extractor func(*goquery.Selection) string, ro ...download.RequestOption) ([]string, error) { - data, err := download.Bytes(ctx, callURL, download.NoDisplay, ro...) + data, err := download.Bytes(ctx, callURL, download.NoDisplay, download.NoCheck, ro...) if err != nil { return nil, err } diff --git a/versionmanager/retriever/atmos/atmosretriever.go b/versionmanager/retriever/atmos/atmosretriever.go index cce829a..e4963a2 100644 --- a/versionmanager/retriever/atmos/atmosretriever.go +++ b/versionmanager/retriever/atmos/atmosretriever.go @@ -97,12 +97,12 @@ func (r AtmosRetriever) Install(ctx context.Context, versionStr string, targetPa } requestOptions := config.GetBasicAuthOption(r.conf.Getenv, config.AtmosRemoteUserEnvName, config.AtmosRemotePassEnvName) - data, err := download.Bytes(ctx, assetURLs[0], r.conf.Displayer.Display, requestOptions...) + data, err := download.Bytes(ctx, assetURLs[0], r.conf.Displayer.Display, download.NoCheck, requestOptions...) if err != nil { return err } - dataSums, err := download.Bytes(ctx, assetURLs[1], r.conf.Displayer.Display, requestOptions...) + dataSums, err := download.Bytes(ctx, assetURLs[1], r.conf.Displayer.Display, download.NoCheck, requestOptions...) if err != nil { return err } diff --git a/versionmanager/retriever/terraform/terraformretriever.go b/versionmanager/retriever/terraform/terraformretriever.go index 7107924..5373bca 100644 --- a/versionmanager/retriever/terraform/terraformretriever.go +++ b/versionmanager/retriever/terraform/terraformretriever.go @@ -95,7 +95,7 @@ func (r TerraformRetriever) Install(ctx context.Context, version string, targetP r.conf.Displayer.Display(apimsg.MsgFetchRelease + versionURL) - value, err := download.JSON(ctx, versionURL, download.NoDisplay, requestOptions...) + value, err := download.JSON(ctx, versionURL, download.NoDisplay, download.NoCheck, requestOptions...) if err != nil { return err } @@ -124,7 +124,7 @@ func (r TerraformRetriever) Install(ctx context.Context, version string, targetP return err } - data, err := download.Bytes(ctx, assetURLs[0], r.conf.Displayer.Display, requestOptions...) + data, err := download.Bytes(ctx, assetURLs[0], r.conf.Displayer.Display, download.NoCheck, requestOptions...) if err != nil { return err } @@ -162,7 +162,7 @@ func (r TerraformRetriever) ListVersions(ctx context.Context) ([]string, error) r.conf.Displayer.Display(apimsg.MsgFetchAllReleases + releasesURL) - value, err := download.JSON(ctx, releasesURL, download.NoDisplay, requestOptions...) + value, err := download.JSON(ctx, releasesURL, download.NoDisplay, download.NoCheck, requestOptions...) if err != nil { return nil, err } @@ -174,7 +174,7 @@ func (r TerraformRetriever) ListVersions(ctx context.Context) ([]string, error) } func (r TerraformRetriever) checkSumAndSig(ctx context.Context, fileName string, data []byte, downloadSumsURL string, downloadSumsSigURL string, options []download.RequestOption) error { - dataSums, err := download.Bytes(ctx, downloadSumsURL, r.conf.Displayer.Display, options...) + dataSums, err := download.Bytes(ctx, downloadSumsURL, r.conf.Displayer.Display, download.NoCheck, options...) if err != nil { return err } @@ -187,14 +187,14 @@ func (r TerraformRetriever) checkSumAndSig(ctx context.Context, fileName string, return nil } - dataSumsSig, err := download.Bytes(ctx, downloadSumsSigURL, r.conf.Displayer.Display, options...) + dataSumsSig, err := download.Bytes(ctx, downloadSumsSigURL, r.conf.Displayer.Display, download.NoCheck, options...) if err != nil { return err } var dataPublicKey []byte if r.conf.TfKeyPath == "" { - dataPublicKey, err = download.Bytes(ctx, publicKeyURL, r.conf.Displayer.Display) + dataPublicKey, err = download.Bytes(ctx, publicKeyURL, r.conf.Displayer.Display, download.NoCheck) } else { dataPublicKey, err = os.ReadFile(r.conf.TfKeyPath) } diff --git a/versionmanager/retriever/terragrunt/terragruntretriever.go b/versionmanager/retriever/terragrunt/terragruntretriever.go index 42efa36..4088655 100644 --- a/versionmanager/retriever/terragrunt/terragruntretriever.go +++ b/versionmanager/retriever/terragrunt/terragruntretriever.go @@ -94,12 +94,12 @@ func (r TerragruntRetriever) Install(ctx context.Context, versionStr string, tar } requestOptions := config.GetBasicAuthOption(r.conf.Getenv, config.TgRemoteUserEnvName, config.TgRemotePassEnvName) - data, err := download.Bytes(ctx, assetURLs[0], r.conf.Displayer.Display, requestOptions...) + data, err := download.Bytes(ctx, assetURLs[0], r.conf.Displayer.Display, download.NoCheck, requestOptions...) if err != nil { return err } - dataSums, err := download.Bytes(ctx, assetURLs[1], r.conf.Displayer.Display, requestOptions...) + dataSums, err := download.Bytes(ctx, assetURLs[1], r.conf.Displayer.Display, download.NoCheck, requestOptions...) if err != nil { return err } diff --git a/versionmanager/retriever/tofu/tofuretriever.go b/versionmanager/retriever/tofu/tofuretriever.go index bef4091..9683179 100644 --- a/versionmanager/retriever/tofu/tofuretriever.go +++ b/versionmanager/retriever/tofu/tofuretriever.go @@ -131,7 +131,7 @@ func (r TofuRetriever) Install(ctx context.Context, versionStr string, targetPat } requestOptions := config.GetBasicAuthOption(r.conf.Getenv, config.TofuRemoteUserEnvName, config.TofuRemotePassEnvName) - data, err := download.Bytes(ctx, assetURLs[0], r.conf.Displayer.Display, requestOptions...) + data, err := download.Bytes(ctx, assetURLs[0], r.conf.Displayer.Display, download.NoCheck, requestOptions...) if err != nil { return err } @@ -173,7 +173,7 @@ func (r TofuRetriever) ListVersions(ctx context.Context) ([]string, error) { r.conf.Displayer.Display(apimsg.MsgFetchAllReleases + listURL) - value, err := download.JSON(ctx, listURL, download.NoDisplay, requestOptions...) + value, err := download.JSON(ctx, listURL, download.NoDisplay, download.NoCheck, requestOptions...) if err != nil { return nil, err } @@ -185,7 +185,7 @@ func (r TofuRetriever) ListVersions(ctx context.Context) ([]string, error) { } func (r TofuRetriever) checkSumAndSig(ctx context.Context, version *version.Version, stable bool, data []byte, fileName string, assetURLs []string, options []download.RequestOption) error { - dataSums, err := download.Bytes(ctx, assetURLs[1], r.conf.Displayer.Display, options...) + dataSums, err := download.Bytes(ctx, assetURLs[1], r.conf.Displayer.Display, download.NoCheck, options...) if err != nil { return err } @@ -198,12 +198,12 @@ func (r TofuRetriever) checkSumAndSig(ctx context.Context, version *version.Vers return nil } - dataSumsSig, err := download.Bytes(ctx, assetURLs[3], r.conf.Displayer.Display, options...) + dataSumsSig, err := download.Bytes(ctx, assetURLs[3], r.conf.Displayer.Display, download.NoCheck, options...) if err != nil { return err } - dataSumsCert, err := download.Bytes(ctx, assetURLs[2], r.conf.Displayer.Display, options...) + dataSumsCert, err := download.Bytes(ctx, assetURLs[2], r.conf.Displayer.Display, download.NoCheck, options...) if err != nil { return err } @@ -222,14 +222,14 @@ func (r TofuRetriever) checkSumAndSig(ctx context.Context, version *version.Vers r.conf.Displayer.Display("cosign executable not found, fallback to pgp check") - dataSumsSig, err = download.Bytes(ctx, assetURLs[4], r.conf.Displayer.Display, options...) + dataSumsSig, err = download.Bytes(ctx, assetURLs[4], r.conf.Displayer.Display, download.NoCheck, options...) if err != nil { return err } var dataPublicKey []byte if r.conf.TofuKeyPath == "" { - dataPublicKey, err = download.Bytes(ctx, publicKeyURL, r.conf.Displayer.Display) + dataPublicKey, err = download.Bytes(ctx, publicKeyURL, r.conf.Displayer.Display, download.NoCheck) } else { dataPublicKey, err = os.ReadFile(r.conf.TofuKeyPath) }