From 2676b35eff90ea05599c828e10bd9b4f7720668a Mon Sep 17 00:00:00 2001 From: Darren Shepherd Date: Fri, 12 Jul 2024 22:21:47 -0700 Subject: [PATCH] bug: fix relative references when defaulting files from dirs --- pkg/loader/url.go | 38 +++++++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/pkg/loader/url.go b/pkg/loader/url.go index bc4d5c9f..2035469e 100644 --- a/pkg/loader/url.go +++ b/pkg/loader/url.go @@ -111,11 +111,20 @@ func loadURL(ctx context.Context, cache *cache.Client, base *source, name string req.Header.Set("Authorization", "Bearer "+bearerToken) } - data, err := getWithDefaults(req) + data, defaulted, err := getWithDefaults(req) if err != nil { return nil, false, fmt.Errorf("error loading %s: %v", url, err) } + if defaulted != "" { + pathString = url + name = defaulted + if repo != nil { + repo.Path = path.Join(repo.Path, repo.Name) + repo.Name = defaulted + } + } + log.Debugf("opened %s", url) result := &source{ @@ -137,31 +146,32 @@ func loadURL(ctx context.Context, cache *cache.Client, base *source, name string return result, true, nil } -func getWithDefaults(req *http.Request) ([]byte, error) { +func getWithDefaults(req *http.Request) ([]byte, string, error) { originalPath := req.URL.Path // First, try to get the original path as is. It might be an OpenAPI definition. resp, err := http.DefaultClient.Do(req) if err != nil { - return nil, err + return nil, "", err } defer resp.Body.Close() if resp.StatusCode == http.StatusOK { - if toolBytes, err := io.ReadAll(resp.Body); err == nil && isOpenAPI(toolBytes) != 0 { - return toolBytes, nil - } + toolBytes, err := io.ReadAll(resp.Body) + return toolBytes, "", err + } + + base := path.Base(originalPath) + if strings.Contains(base, ".") { + return nil, "", fmt.Errorf("error loading %s: %s", req.URL.String(), resp.Status) } for i, def := range types.DefaultFiles { - base := path.Base(originalPath) - if !strings.Contains(base, ".") { - req.URL.Path = path.Join(originalPath, def) - } + req.URL.Path = path.Join(originalPath, def) resp, err := http.DefaultClient.Do(req) if err != nil { - return nil, err + return nil, "", err } defer resp.Body.Close() @@ -170,11 +180,13 @@ func getWithDefaults(req *http.Request) ([]byte, error) { } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("error loading %s: %s", req.URL.String(), resp.Status) + return nil, "", fmt.Errorf("error loading %s: %s", req.URL.String(), resp.Status) } - return io.ReadAll(resp.Body) + data, err := io.ReadAll(resp.Body) + return data, def, err } + panic("unreachable") }