Skip to content

Commit

Permalink
Merge pull request #629 from ibuildthecloud/main
Browse files Browse the repository at this point in the history
chore: add location option to loading scripts
  • Loading branch information
ibuildthecloud authored Jul 13, 2024
2 parents 0c73f4b + 2676b35 commit 9ca6e93
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 16 deletions.
18 changes: 16 additions & 2 deletions pkg/loader/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,12 +373,20 @@ func ProgramFromSource(ctx context.Context, content, subToolName string, opts ..
}
opt := complete(opts...)

var locationPath, locationName string
if opt.Location != "" {
locationPath = path.Dir(opt.Location)
locationName = path.Base(opt.Location)
}

prg := types.Program{
ToolSet: types.ToolSet{},
}
tools, err := readTool(ctx, opt.Cache, &prg, &source{
Content: []byte(content),
Location: "inline",
Path: locationPath,
Name: locationName,
Location: opt.Location,
}, subToolName)
if err != nil {
return types.Program{}, err
Expand All @@ -388,12 +396,18 @@ func ProgramFromSource(ctx context.Context, content, subToolName string, opts ..
}

type Options struct {
Cache *cache.Client
Cache *cache.Client
Location string
}

func complete(opts ...Options) (result Options) {
for _, opt := range opts {
result.Cache = types.FirstSet(opt.Cache, result.Cache)
result.Location = types.FirstSet(opt.Location, result.Location)
}

if result.Location == "" {
result.Location = "inline"
}

return
Expand Down
38 changes: 25 additions & 13 deletions pkg/loader/url.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,20 @@ func loadURL(ctx context.Context, cache *cache.Client, base *source, name string
req.Header.Set("Authorization", "Bearer "+bearerToken)
}

data, err := getWithDefaults(req)
data, defaulted, err := getWithDefaults(req)
if err != nil {
return nil, false, fmt.Errorf("error loading %s: %v", url, err)
}

if defaulted != "" {
pathString = url
name = defaulted
if repo != nil {
repo.Path = path.Join(repo.Path, repo.Name)
repo.Name = defaulted
}
}

log.Debugf("opened %s", url)

result := &source{
Expand All @@ -137,31 +146,32 @@ func loadURL(ctx context.Context, cache *cache.Client, base *source, name string
return result, true, nil
}

func getWithDefaults(req *http.Request) ([]byte, error) {
func getWithDefaults(req *http.Request) ([]byte, string, error) {
originalPath := req.URL.Path

// First, try to get the original path as is. It might be an OpenAPI definition.
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
return nil, "", err
}
defer resp.Body.Close()

if resp.StatusCode == http.StatusOK {
if toolBytes, err := io.ReadAll(resp.Body); err == nil && isOpenAPI(toolBytes) != 0 {
return toolBytes, nil
}
toolBytes, err := io.ReadAll(resp.Body)
return toolBytes, "", err
}

base := path.Base(originalPath)
if strings.Contains(base, ".") {
return nil, "", fmt.Errorf("error loading %s: %s", req.URL.String(), resp.Status)
}

for i, def := range types.DefaultFiles {
base := path.Base(originalPath)
if !strings.Contains(base, ".") {
req.URL.Path = path.Join(originalPath, def)
}
req.URL.Path = path.Join(originalPath, def)

resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
return nil, "", err
}
defer resp.Body.Close()

Expand All @@ -170,11 +180,13 @@ func getWithDefaults(req *http.Request) ([]byte, error) {
}

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("error loading %s: %s", req.URL.String(), resp.Status)
return nil, "", fmt.Errorf("error loading %s: %s", req.URL.String(), resp.Status)
}

return io.ReadAll(resp.Body)
data, err := io.ReadAll(resp.Body)
return data, def, err
}

panic("unreachable")
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/sdkserver/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func (s *server) execHandler(w http.ResponseWriter, r *http.Request) {
logger.Debugf("executing tool: %+v", reqObject)
var (
def fmt.Stringer = &reqObject.ToolDefs
programLoader loaderFunc = loader.ProgramFromSource
programLoader = loaderWithLocation(loader.ProgramFromSource, reqObject.Location)
)
if reqObject.Content != "" {
def = &reqObject.content
Expand Down
8 changes: 8 additions & 0 deletions pkg/sdkserver/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ import (

type loaderFunc func(context.Context, string, string, ...loader.Options) (types.Program, error)

func loaderWithLocation(f loaderFunc, loc string) loaderFunc {
return func(ctx context.Context, s string, s2 string, options ...loader.Options) (types.Program, error) {
return f(ctx, s, s2, append(options, loader.Options{
Location: loc,
})...)
}
}

func (s *server) execAndStream(ctx context.Context, programLoader loaderFunc, logger mvl.Logger, w http.ResponseWriter, opts gptscript.Options, chatState, input, subTool string, toolDef fmt.Stringer) {
g, err := gptscript.New(ctx, s.gptscriptOpts, opts)
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions pkg/sdkserver/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ type toolOrFileRequest struct {
CredentialContext string `json:"credentialContext"`
CredentialOverrides []string `json:"credentialOverrides"`
Confirm bool `json:"confirm"`
Location string `json:"location,omitempty"`
}

type content struct {
Expand Down

0 comments on commit 9ca6e93

Please sign in to comment.