From 76042bd382685a8f536a7e48d666ae7f5a8229d1 Mon Sep 17 00:00:00 2001 From: Grant Linville Date: Tue, 30 Jul 2024 14:41:31 -0400 Subject: [PATCH] feat: better handling of OpenAPI tools (#667) Signed-off-by: Grant Linville --- go.mod | 3 + go.sum | 7 + pkg/engine/openapi.go | 498 +++++------------- pkg/engine/openapi_test.go | 31 +- pkg/loader/loader.go | 124 ++--- pkg/loader/loader_test.go | 7 +- pkg/loader/openapi.go | 147 +++++- pkg/loader/openapi_test.go | 39 ++ .../testdata/openapi/TestOpenAPIv2.golden | 6 +- .../openapi/TestOpenAPIv2Revamp.golden | 116 ++++ .../testdata/openapi/TestOpenAPIv3.golden | 6 +- .../TestOpenAPIv3NoOperationIDs.golden | 6 +- .../TestOpenAPIv3NoOperationIDsRevamp.golden | 116 ++++ .../openapi/TestOpenAPIv3Revamp.golden | 116 ++++ pkg/openapi/getschema.go | 285 ++++++++++ pkg/openapi/list.go | 68 +++ pkg/openapi/load.go | 121 +++++ pkg/openapi/run.go | 451 ++++++++++++++++ pkg/openapi/security.go | 56 ++ 19 files changed, 1704 insertions(+), 499 deletions(-) create mode 100644 pkg/loader/testdata/openapi/TestOpenAPIv2Revamp.golden create mode 100644 pkg/loader/testdata/openapi/TestOpenAPIv3NoOperationIDsRevamp.golden create mode 100644 pkg/loader/testdata/openapi/TestOpenAPIv3Revamp.golden create mode 100644 pkg/openapi/getschema.go create mode 100644 pkg/openapi/list.go create mode 100644 pkg/openapi/load.go create mode 100644 pkg/openapi/run.go create mode 100644 pkg/openapi/security.go diff --git a/go.mod b/go.mod index 9feaef4f..de545abd 100644 --- a/go.mod +++ b/go.mod @@ -30,6 +30,7 @@ require ( github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.8.4 github.com/tidwall/gjson v1.17.1 + github.com/xeipuuv/gojsonschema v1.2.0 golang.org/x/exp v0.0.0-20240103183307-be819d1f06fc golang.org/x/sync v0.7.0 golang.org/x/term v0.20.0 @@ -101,6 +102,8 @@ require ( github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/ulikunitz/xz v0.5.10 // indirect + github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect + github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/yuin/goldmark v1.5.4 // indirect github.com/yuin/goldmark-emoji v1.0.2 // indirect diff --git a/go.sum b/go.sum index 9c288064..5a6ce6cf 100644 --- a/go.sum +++ b/go.sum @@ -317,6 +317,7 @@ github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf h1:pvbZ0lM0XWPBqUKqFU8cmavspvIl9nulOYwdy6IFRRo= github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf/go.mod h1:RJID2RhlZKId02nZ62WenDCkgHFerpIOmW0iT7GKmXM= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -336,6 +337,12 @@ github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95 github.com/ulikunitz/xz v0.5.6/go.mod h1:2bypXElzHzzJZwzH67Y6wb67pO62Rzfn7BSiF4ABRW8= github.com/ulikunitz/xz v0.5.10 h1:t92gobL9l3HE202wg3rlk19F6X+JOxl9BBrCCMYEYd8= github.com/ulikunitz/xz v0.5.10/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= +github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f h1:J9EGpcZtP0E/raorCMxlFGSTBrsSlaDGf3jU/qvAE2c= +github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= +github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= +github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778/go.mod h1:2MuV+tbUrU1zIOPMxZ5EncGwgmMJsa+9ucAQZXxsObs= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= diff --git a/pkg/engine/openapi.go b/pkg/engine/openapi.go index 2e338ca4..0bd5f599 100644 --- a/pkg/engine/openapi.go +++ b/pkg/engine/openapi.go @@ -8,83 +8,148 @@ import ( "mime/multipart" "net/http" "net/url" + "os" "strings" "github.com/gptscript-ai/gptscript/pkg/env" + "github.com/gptscript-ai/gptscript/pkg/openapi" "github.com/gptscript-ai/gptscript/pkg/types" "github.com/tidwall/gjson" - "golang.org/x/exp/maps" ) -var ( - SupportedMIMETypes = []string{"application/json", "text/plain", "multipart/form-data"} - SupportedSecurityTypes = []string{"apiKey", "http"} -) +func (e *Engine) runOpenAPIRevamp(tool types.Tool, input string) (*Return, error) { + envMap := make(map[string]string, len(e.Env)) + for _, env := range e.Env { + k, v, _ := strings.Cut(env, "=") + envMap[k] = v + } -type Parameter struct { - Name string `json:"name"` - Style string `json:"style"` - Explode *bool `json:"explode"` -} + _, inst, _ := strings.Cut(tool.Instructions, types.OpenAPIPrefix+" ") + args := strings.Fields(inst) -// A SecurityInfo represents a security scheme in OpenAPI. -type SecurityInfo struct { - Name string `json:"name"` // name as defined in the security schemes - Type string `json:"type"` // http or apiKey - Scheme string `json:"scheme"` // bearer or basic, for type==http - APIKeyName string `json:"apiKeyName"` // name of the API key, for type==apiKey - In string `json:"in"` // header, query, or cookie, for type==apiKey -} + if len(args) != 3 { + return nil, fmt.Errorf("expected 3 arguments to %s", types.OpenAPIPrefix) + } -func (i SecurityInfo) GetCredentialToolStrings(hostname string) []string { - vars := i.getCredentialNamesAndEnvVars(hostname) - var tools []string - - for cred, v := range vars { - field := "value" - switch i.Type { - case "apiKey": - field = i.APIKeyName - case "http": - if i.Scheme == "bearer" { - field = "bearer token" - } else { - if strings.Contains(v, "PASSWORD") { - field = "password" - } else { - field = "username" - } + command := args[0] + source := args[1] + filter := args[2] + + var res *Return + switch command { + case openapi.ListTool: + t, err := openapi.Load(source) + if err != nil { + return nil, fmt.Errorf("failed to load OpenAPI file %s: %w", source, err) + } + + opList, err := openapi.List(t, filter) + if err != nil { + return nil, fmt.Errorf("failed to list operations: %w", err) + } + + opListJSON, err := json.MarshalIndent(opList, "", " ") + if err != nil { + return nil, fmt.Errorf("failed to marshal operation list: %w", err) + } + + res = &Return{ + Result: ptr(string(opListJSON)), + } + case openapi.GetSchemaTool: + operation := gjson.Get(input, "operation").String() + + if filter != "" && filter != openapi.NoFilter { + match, err := openapi.MatchFilters(strings.Split(filter, "|"), operation) + if err != nil { + return nil, err + } else if !match { + // Report to the LLM that the operation was not found + return &Return{ + Result: ptr(fmt.Sprintf("operation %s not found", operation)), + }, nil } } - tools = append(tools, fmt.Sprintf("github.com/gptscript-ai/credential as %s with %s as env and %q as message and %q as field", - cred, v, "Please provide a value for the "+v+" environment variable", field)) - } - return tools -} + t, err := openapi.Load(source) + if err != nil { + return nil, fmt.Errorf("failed to load OpenAPI file %s: %w", source, err) + } -func (i SecurityInfo) getCredentialNamesAndEnvVars(hostname string) map[string]string { - if i.Type == "http" && i.Scheme == "basic" { - return map[string]string{ - hostname + i.Name + "Username": "GPTSCRIPT_" + env.ToEnvLike(hostname) + "_" + env.ToEnvLike(i.Name) + "_USERNAME", - hostname + i.Name + "Password": "GPTSCRIPT_" + env.ToEnvLike(hostname) + "_" + env.ToEnvLike(i.Name) + "_PASSWORD", + var defaultHost string + if strings.HasPrefix(source, "http://") || strings.HasPrefix(source, "https://") { + u, err := url.Parse(source) + if err != nil { + return nil, fmt.Errorf("failed to parse server URL %s: %w", source, err) + } + defaultHost = u.Scheme + "://" + u.Hostname() + } + + schema, _, found, err := openapi.GetSchema(operation, defaultHost, t) + if err != nil { + return nil, fmt.Errorf("failed to get schema: %w", err) + } + if !found { + // Report to the LLM that the operation was not found + return &Return{ + Result: ptr(fmt.Sprintf("operation %s not found", operation)), + }, nil + } + + schemaJSON, err := json.MarshalIndent(schema, "", " ") + if err != nil { + return nil, fmt.Errorf("failed to marshal schema: %w", err) + } + + res = &Return{ + Result: ptr(string(schemaJSON)), + } + case openapi.RunTool: + operation := gjson.Get(input, "operation").String() + args := gjson.Get(input, "args").String() + + if filter != "" && filter != openapi.NoFilter { + match, err := openapi.MatchFilters(strings.Split(filter, "|"), operation) + if err != nil { + return nil, err + } else if !match { + // Report to the LLM that the operation was not found + return &Return{ + Result: ptr(fmt.Sprintf("operation %s not found", operation)), + }, nil + } + } + + t, err := openapi.Load(source) + if err != nil { + return nil, fmt.Errorf("failed to load OpenAPI file %s: %w", source, err) + } + + var defaultHost string + if strings.HasPrefix(source, "http://") || strings.HasPrefix(source, "https://") { + u, err := url.Parse(source) + if err != nil { + return nil, fmt.Errorf("failed to parse server URL %s: %w", source, err) + } + defaultHost = u.Scheme + "://" + u.Hostname() + } + + result, found, err := openapi.Run(operation, defaultHost, args, t, e.Env) + if err != nil { + return nil, fmt.Errorf("failed to run operation %s: %w", operation, err) + } else if !found { + // Report to the LLM that the operation was not found + return &Return{ + Result: ptr(fmt.Sprintf("operation %s not found", operation)), + }, nil + } + + res = &Return{ + Result: ptr(result), } } - return map[string]string{ - hostname + i.Name: "GPTSCRIPT_" + env.ToEnvLike(hostname) + "_" + env.ToEnvLike(i.Name), - } -} -type OpenAPIInstructions struct { - Server string `json:"server"` - Path string `json:"path"` - Method string `json:"method"` - BodyContentMIME string `json:"bodyContentMIME"` - SecurityInfos [][]SecurityInfo `json:"apiKeyInfos"` - QueryParameters []Parameter `json:"queryParameters"` - PathParameters []Parameter `json:"pathParameters"` - HeaderParameters []Parameter `json:"headerParameters"` - CookieParameters []Parameter `json:"cookieParameters"` + return res, nil } // runOpenAPI runs a tool that was generated from an OpenAPI definition. @@ -92,6 +157,10 @@ type OpenAPIInstructions struct { // The tools Instructions field will be in the format "#!sys.openapi '{Instructions JSON}'", // where {Instructions JSON} is a JSON string of type OpenAPIInstructions. func (e *Engine) runOpenAPI(tool types.Tool, input string) (*Return, error) { + if os.Getenv("GPTSCRIPT_OPENAPI_REVAMP") == "true" { + return e.runOpenAPIRevamp(tool, input) + } + envMap := map[string]string{} for _, env := range e.Env { @@ -100,7 +169,7 @@ func (e *Engine) runOpenAPI(tool types.Tool, input string) (*Return, error) { } // Extract the instructions from the tool to determine server, path, method, etc. - var instructions OpenAPIInstructions + var instructions openapi.OperationInfo _, inst, _ := strings.Cut(tool.Instructions, types.OpenAPIPrefix+" ") inst = strings.TrimPrefix(inst, "'") inst = strings.TrimSuffix(inst, "'") @@ -109,7 +178,7 @@ func (e *Engine) runOpenAPI(tool types.Tool, input string) (*Return, error) { } // Handle path parameters - instructions.Path = handlePathParameters(instructions.Path, instructions.PathParameters, input) + instructions.Path = openapi.HandlePathParameters(instructions.Path, instructions.PathParams, input) // Parse the URL path, err := url.JoinPath(instructions.Server, instructions.Path) @@ -131,7 +200,7 @@ func (e *Engine) runOpenAPI(tool types.Tool, input string) (*Return, error) { // Check for authentication (only if using HTTPS or localhost) if u.Scheme == "https" || u.Hostname() == "localhost" || u.Hostname() == "127.0.0.1" { if len(instructions.SecurityInfos) > 0 { - if err := handleAuths(req, envMap, instructions.SecurityInfos); err != nil { + if err := openapi.HandleAuths(req, envMap, instructions.SecurityInfos); err != nil { return nil, fmt.Errorf("error setting up authentication: %w", err) } } @@ -145,11 +214,11 @@ func (e *Engine) runOpenAPI(tool types.Tool, input string) (*Return, error) { } // Handle query parameters - req.URL.RawQuery = handleQueryParameters(req.URL.Query(), instructions.QueryParameters, input).Encode() + req.URL.RawQuery = openapi.HandleQueryParameters(req.URL.Query(), instructions.QueryParams, input).Encode() // Handle header and cookie parameters - handleHeaderParameters(req, instructions.HeaderParameters, input) - handleCookieParameters(req, instructions.CookieParameters, input) + openapi.HandleHeaderParameters(req, instructions.HeaderParams, input) + openapi.HandleCookieParameters(req, instructions.CookieParams, input) // Handle request body if instructions.BodyContentMIME != "" { @@ -217,299 +286,6 @@ func (e *Engine) runOpenAPI(tool types.Tool, input string) (*Return, error) { }, nil } -// handleAuths will set up the request with the necessary authentication information. -// A set of sets of SecurityInfo is passed in, where each represents a possible set of security options. -func handleAuths(req *http.Request, envMap map[string]string, infoSets [][]SecurityInfo) error { - var missingVariables [][]string - - // We need to find a set of infos where we have all the needed environment variables. - for _, infoSet := range infoSets { - var missing []string // Keep track of any missing environment variables - for _, info := range infoSet { - vars := info.getCredentialNamesAndEnvVars(req.URL.Hostname()) - - for _, envName := range vars { - if _, ok := envMap[envName]; !ok { - missing = append(missing, envName) - } - } - } - if len(missing) > 0 { - missingVariables = append(missingVariables, missing) - continue - } - - // We're using this info set, because no environment variables were missing. - // Set up the request as needed. - for _, info := range infoSet { - envNames := maps.Values(info.getCredentialNamesAndEnvVars(req.URL.Hostname())) - switch info.Type { - case "apiKey": - switch info.In { - case "header": - req.Header.Set(info.APIKeyName, envMap[envNames[0]]) - case "query": - v := url.Values{} - v.Add(info.APIKeyName, envMap[envNames[0]]) - req.URL.RawQuery = v.Encode() - case "cookie": - req.AddCookie(&http.Cookie{ - Name: info.APIKeyName, - Value: envMap[envNames[0]], - }) - } - case "http": - switch info.Scheme { - case "bearer": - req.Header.Set("Authorization", "Bearer "+envMap[envNames[0]]) - case "basic": - req.SetBasicAuth(envMap[envNames[0]], envMap[envNames[1]]) - } - } - } - return nil - } - - return fmt.Errorf("did not find the needed environment variables for any of the security options. "+ - "At least one of these sets of environment variables must be provided: %v", missingVariables) -} - -// handleQueryParameters extracts each query parameter from the input JSON and adds it to the URL query. -func handleQueryParameters(q url.Values, params []Parameter, input string) url.Values { - for _, param := range params { - res := gjson.Get(input, param.Name) - if res.Exists() { - // If it's an array or object, handle the serialization style - if res.IsArray() { - switch param.Style { - case "form", "": // form is the default style for query parameters - if param.Explode == nil || *param.Explode { // default is to explode - for _, item := range res.Array() { - q.Add(param.Name, item.String()) - } - } else { - var strs []string - for _, item := range res.Array() { - strs = append(strs, item.String()) - } - q.Add(param.Name, strings.Join(strs, ",")) - } - case "spaceDelimited": - if param.Explode == nil || *param.Explode { - for _, item := range res.Array() { - q.Add(param.Name, item.String()) - } - } else { - var strs []string - for _, item := range res.Array() { - strs = append(strs, item.String()) - } - q.Add(param.Name, strings.Join(strs, " ")) - } - case "pipeDelimited": - if param.Explode == nil || *param.Explode { - for _, item := range res.Array() { - q.Add(param.Name, item.String()) - } - } else { - var strs []string - for _, item := range res.Array() { - strs = append(strs, item.String()) - } - q.Add(param.Name, strings.Join(strs, "|")) - } - } - } else if res.IsObject() { - switch param.Style { - case "form", "": // form is the default style for query parameters - if param.Explode == nil || *param.Explode { // default is to explode - for k, v := range res.Map() { - q.Add(k, v.String()) - } - } else { - var strs []string - for k, v := range res.Map() { - strs = append(strs, k, v.String()) - } - q.Add(param.Name, strings.Join(strs, ",")) - } - case "deepObject": - for k, v := range res.Map() { - q.Add(param.Name+"["+k+"]", v.String()) - } - } - } else { - q.Add(param.Name, res.String()) - } - } - } - return q -} - -// handlePathParameters extracts each path parameter from the input JSON and replaces its placeholder in the URL path. -func handlePathParameters(path string, params []Parameter, input string) string { - for _, param := range params { - res := gjson.Get(input, param.Name) - if res.Exists() { - // If it's an array or object, handle the serialization style - if res.IsArray() { - switch param.Style { - case "simple", "": // simple is the default style for path parameters - // simple looks the same regardless of whether explode is true - strs := make([]string, len(res.Array())) - for i, item := range res.Array() { - strs[i] = item.String() - } - path = strings.Replace(path, "{"+param.Name+"}", strings.Join(strs, ","), 1) - case "label": - strs := make([]string, len(res.Array())) - for i, item := range res.Array() { - strs[i] = item.String() - } - - if param.Explode == nil || !*param.Explode { // default is to not explode - path = strings.Replace(path, "{"+param.Name+"}", "."+strings.Join(strs, ","), 1) - } else { - path = strings.Replace(path, "{"+param.Name+"}", "."+strings.Join(strs, "."), 1) - } - case "matrix": - strs := make([]string, len(res.Array())) - for i, item := range res.Array() { - strs[i] = item.String() - } - - if param.Explode == nil || !*param.Explode { // default is to not explode - path = strings.Replace(path, "{"+param.Name+"}", ";"+param.Name+"="+strings.Join(strs, ","), 1) - } else { - s := "" - for _, str := range strs { - s += ";" + param.Name + "=" + str - } - path = strings.Replace(path, "{"+param.Name+"}", s, 1) - } - } - } else if res.IsObject() { - switch param.Style { - case "simple", "": - if param.Explode == nil || !*param.Explode { // default is to not explode - var strs []string - for k, v := range res.Map() { - strs = append(strs, k, v.String()) - } - path = strings.Replace(path, "{"+param.Name+"}", strings.Join(strs, ","), 1) - } else { - var strs []string - for k, v := range res.Map() { - strs = append(strs, k+"="+v.String()) - } - path = strings.Replace(path, "{"+param.Name+"}", strings.Join(strs, ","), 1) - } - case "label": - if param.Explode == nil || !*param.Explode { // default is to not explode - var strs []string - for k, v := range res.Map() { - strs = append(strs, k, v.String()) - } - path = strings.Replace(path, "{"+param.Name+"}", "."+strings.Join(strs, ","), 1) - } else { - s := "" - for k, v := range res.Map() { - s += "." + k + "=" + v.String() - } - path = strings.Replace(path, "{"+param.Name+"}", s, 1) - } - case "matrix": - if param.Explode == nil || !*param.Explode { // default is to not explode - var strs []string - for k, v := range res.Map() { - strs = append(strs, k, v.String()) - } - path = strings.Replace(path, "{"+param.Name+"}", ";"+param.Name+"="+strings.Join(strs, ","), 1) - } else { - s := "" - for k, v := range res.Map() { - s += ";" + k + "=" + v.String() - } - path = strings.Replace(path, "{"+param.Name+"}", s, 1) - } - } - } else { - // Serialization is handled slightly differently even for basic types. - // Explode doesn't do anything though. - switch param.Style { - case "simple", "": - path = strings.Replace(path, "{"+param.Name+"}", res.String(), 1) - case "label": - path = strings.Replace(path, "{"+param.Name+"}", "."+res.String(), 1) - case "matrix": - path = strings.Replace(path, "{"+param.Name+"}", ";"+param.Name+"="+res.String(), 1) - } - } - } - } - return path -} - -// handleHeaderParameters extracts each header parameter from the input JSON and adds it to the request headers. -func handleHeaderParameters(req *http.Request, params []Parameter, input string) { - for _, param := range params { - res := gjson.Get(input, param.Name) - if res.Exists() { - if res.IsArray() { - strs := make([]string, len(res.Array())) - for i, item := range res.Array() { - strs[i] = item.String() - } - req.Header.Add(param.Name, strings.Join(strs, ",")) - } else if res.IsObject() { - // Handle explosion - var strs []string - if param.Explode == nil || !*param.Explode { // default is to not explode - for k, v := range res.Map() { - strs = append(strs, k, v.String()) - } - } else { - for k, v := range res.Map() { - strs = append(strs, k+"="+v.String()) - } - } - req.Header.Add(param.Name, strings.Join(strs, ",")) - } else { // basic type - req.Header.Add(param.Name, res.String()) - } - } - } -} - -// handleCookieParameters extracts each cookie parameter from the input JSON and adds it to the request cookies. -func handleCookieParameters(req *http.Request, params []Parameter, input string) { - for _, param := range params { - res := gjson.Get(input, param.Name) - if res.Exists() { - if res.IsArray() { - strs := make([]string, len(res.Array())) - for i, item := range res.Array() { - strs[i] = item.String() - } - req.AddCookie(&http.Cookie{ - Name: param.Name, - Value: strings.Join(strs, ","), - }) - } else if res.IsObject() { - var strs []string - for k, v := range res.Map() { - strs = append(strs, k, v.String()) - } - req.AddCookie(&http.Cookie{ - Name: param.Name, - Value: strings.Join(strs, ","), - }) - } else { // basic type - req.AddCookie(&http.Cookie{ - Name: param.Name, - Value: res.String(), - }) - } - } - } +func ptr[T any](t T) *T { + return &t } diff --git a/pkg/engine/openapi_test.go b/pkg/engine/openapi_test.go index df1e00fc..9fd5d34e 100644 --- a/pkg/engine/openapi_test.go +++ b/pkg/engine/openapi_test.go @@ -5,6 +5,7 @@ import ( "net/url" "testing" + "github.com/gptscript-ai/gptscript/pkg/openapi" "github.com/stretchr/testify/require" ) @@ -89,7 +90,7 @@ func TestPathParameterSerialization(t *testing.T) { t.Run(test.name, func(t *testing.T) { path := path params := getParameters(test.style, test.explode) - path = handlePathParameters(path, params, string(inputStr)) + path = openapi.HandlePathParameters(path, params, string(inputStr)) require.Contains(t, test.expectedPaths, path) }) } @@ -111,13 +112,13 @@ func TestQueryParameterSerialization(t *testing.T) { tests := []struct { name string input string - param Parameter + param openapi.Parameter expectedQueries []string // We use multiple expected queries due to randomness in map iteration }{ { name: "value", input: string(inputStr), - param: Parameter{ + param: openapi.Parameter{ Name: "v", }, expectedQueries: []string{"v=42"}, @@ -125,7 +126,7 @@ func TestQueryParameterSerialization(t *testing.T) { { name: "array form + explode", input: string(inputStr), - param: Parameter{ + param: openapi.Parameter{ Name: "a", Style: "form", Explode: boolPointer(true), @@ -135,7 +136,7 @@ func TestQueryParameterSerialization(t *testing.T) { { name: "array form + no explode", input: string(inputStr), - param: Parameter{ + param: openapi.Parameter{ Name: "a", Style: "form", Explode: boolPointer(false), @@ -145,7 +146,7 @@ func TestQueryParameterSerialization(t *testing.T) { { name: "array spaceDelimited + explode", input: string(inputStr), - param: Parameter{ + param: openapi.Parameter{ Name: "a", Style: "spaceDelimited", Explode: boolPointer(true), @@ -155,7 +156,7 @@ func TestQueryParameterSerialization(t *testing.T) { { name: "array spaceDelimited + no explode", input: string(inputStr), - param: Parameter{ + param: openapi.Parameter{ Name: "a", Style: "spaceDelimited", Explode: boolPointer(false), @@ -165,7 +166,7 @@ func TestQueryParameterSerialization(t *testing.T) { { name: "array pipeDelimited + explode", input: string(inputStr), - param: Parameter{ + param: openapi.Parameter{ Name: "a", Style: "pipeDelimited", Explode: boolPointer(true), @@ -175,7 +176,7 @@ func TestQueryParameterSerialization(t *testing.T) { { name: "array pipeDelimited + no explode", input: string(inputStr), - param: Parameter{ + param: openapi.Parameter{ Name: "a", Style: "pipeDelimited", Explode: boolPointer(false), @@ -185,7 +186,7 @@ func TestQueryParameterSerialization(t *testing.T) { { name: "object form + explode", input: string(inputStr), - param: Parameter{ + param: openapi.Parameter{ Name: "o", Style: "form", Explode: boolPointer(true), @@ -198,7 +199,7 @@ func TestQueryParameterSerialization(t *testing.T) { { name: "object form + no explode", input: string(inputStr), - param: Parameter{ + param: openapi.Parameter{ Name: "o", Style: "form", Explode: boolPointer(false), @@ -211,7 +212,7 @@ func TestQueryParameterSerialization(t *testing.T) { { name: "object deepObject", input: string(inputStr), - param: Parameter{ + param: openapi.Parameter{ Name: "o", Style: "deepObject", }, @@ -224,14 +225,14 @@ func TestQueryParameterSerialization(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - q := handleQueryParameters(url.Values{}, []Parameter{test.param}, test.input) + q := openapi.HandleQueryParameters(url.Values{}, []openapi.Parameter{test.param}, test.input) require.Contains(t, test.expectedQueries, q.Encode()) }) } } -func getParameters(style string, explode bool) []Parameter { - return []Parameter{ +func getParameters(style string, explode bool) []openapi.Parameter { + return []openapi.Parameter{ { Name: "v", Style: style, diff --git a/pkg/loader/loader.go b/pkg/loader/loader.go index d7634058..3d2ae8ed 100644 --- a/pkg/loader/loader.go +++ b/pkg/loader/loader.go @@ -8,26 +8,23 @@ import ( "fmt" "io" "io/fs" + "os" "path" "path/filepath" - "strconv" "strings" "time" "unicode/utf8" - "github.com/getkin/kin-openapi/openapi2" - "github.com/getkin/kin-openapi/openapi2conv" "github.com/getkin/kin-openapi/openapi3" "github.com/gptscript-ai/gptscript/internal" "github.com/gptscript-ai/gptscript/pkg/assemble" "github.com/gptscript-ai/gptscript/pkg/builtin" "github.com/gptscript-ai/gptscript/pkg/cache" "github.com/gptscript-ai/gptscript/pkg/hash" + "github.com/gptscript-ai/gptscript/pkg/openapi" "github.com/gptscript-ai/gptscript/pkg/parser" "github.com/gptscript-ai/gptscript/pkg/system" "github.com/gptscript-ai/gptscript/pkg/types" - "gopkg.in/yaml.v3" - kyaml "sigs.k8s.io/yaml" ) const CacheTimeout = time.Hour @@ -157,33 +154,8 @@ func loadOpenAPI(prg *types.Program, data []byte) *openapi3.T { prg.OpenAPICache = map[string]any{} } - switch isOpenAPI(data) { - case 2: - // Convert OpenAPI v2 to v3 - jsondata := data - if !json.Valid(data) { - jsondata, err = kyaml.YAMLToJSON(data) - if err != nil { - return nil - } - } - - doc := &openapi2.T{} - if err := doc.UnmarshalJSON(jsondata); err != nil { - return nil - } - - openAPIDocument, err = openapi2conv.ToV3(doc) - if err != nil { - return nil - } - case 3: - // Use OpenAPI v3 as is - openAPIDocument, err = openapi3.NewLoader().LoadFromData(data) - if err != nil { - return nil - } - default: + openAPIDocument, err = openapi.LoadFromBytes(data) + if err != nil { return nil } @@ -202,14 +174,18 @@ func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base return []types.Tool{tool}, nil } - var tools []types.Tool + var ( + tools []types.Tool + isOpenAPI bool + ) if openAPIDocument := loadOpenAPI(prg, data); openAPIDocument != nil { + isOpenAPI = true var err error if base.Remote { - tools, err = getOpenAPITools(openAPIDocument, base.Location) + tools, err = getOpenAPITools(openAPIDocument, base.Location, base.Location, targetToolName) } else { - tools, err = getOpenAPITools(openAPIDocument, "") + tools, err = getOpenAPITools(openAPIDocument, "", base.Name, targetToolName) } if err != nil { return nil, fmt.Errorf("error parsing OpenAPI definition: %w", err) @@ -257,10 +233,6 @@ func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base // Probably a better way to come up with an ID tool.ID = tool.Source.Location + ":" + tool.Name - if i == 0 && targetToolName == "" { - targetTools = append(targetTools, tool) - } - if i != 0 && tool.Parameters.Name == "" { return nil, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, fmt.Errorf("only the first tool in a file can have no name")) } @@ -273,16 +245,35 @@ func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base return nil, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, fmt.Errorf("only the first tool in a file can have global tools")) } - if targetToolName != "" && tool.Parameters.Name != "" { - if strings.EqualFold(tool.Parameters.Name, targetToolName) { + // Determine targetTools + if isOpenAPI && os.Getenv("GPTSCRIPT_OPENAPI_REVAMP") == "true" { + targetTools = append(targetTools, tool) + } else { + if i == 0 && targetToolName == "" { targetTools = append(targetTools, tool) - } else if strings.Contains(targetToolName, "*") { - match, err := filepath.Match(strings.ToLower(targetToolName), strings.ToLower(tool.Parameters.Name)) - if err != nil { - return nil, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, err) - } - if match { + } + + if targetToolName != "" && tool.Parameters.Name != "" { + if strings.EqualFold(tool.Parameters.Name, targetToolName) { targetTools = append(targetTools, tool) + } else if strings.Contains(targetToolName, "*") { + var patterns []string + if strings.Contains(targetToolName, "|") { + patterns = strings.Split(targetToolName, "|") + } else { + patterns = []string{targetToolName} + } + + for _, pattern := range patterns { + match, err := filepath.Match(strings.ToLower(pattern), strings.ToLower(tool.Parameters.Name)) + if err != nil { + return nil, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, err) + } + if match { + targetTools = append(targetTools, tool) + break + } + } } } } @@ -491,42 +482,3 @@ func input(ctx context.Context, cache *cache.Client, base *source, name string) return nil, fmt.Errorf("can not load tools path=%s name=%s", base.Path, name) } - -// isOpenAPI checks if the data is an OpenAPI definition and returns the version if it is. -func isOpenAPI(data []byte) int { - var fragment struct { - Paths map[string]any `json:"paths,omitempty"` - Swagger string `json:"swagger,omitempty"` - OpenAPI string `json:"openapi,omitempty"` - } - - if err := json.Unmarshal(data, &fragment); err != nil { - if err := yaml.Unmarshal(data, &fragment); err != nil { - return 0 - } - } - if len(fragment.Paths) == 0 { - return 0 - } - - if v, _, _ := strings.Cut(fragment.OpenAPI, "."); v != "" { - ver, err := strconv.Atoi(v) - if err != nil { - log.Debugf("invalid OpenAPI version: openapi=%q", fragment.OpenAPI) - return 0 - } - return ver - } - - if v, _, _ := strings.Cut(fragment.Swagger, "."); v != "" { - ver, err := strconv.Atoi(v) - if err != nil { - log.Debugf("invalid Swagger version: swagger=%q", fragment.Swagger) - return 0 - } - return ver - } - - log.Debugf("no OpenAPI version found in input data: openapi=%q, swagger=%q", fragment.OpenAPI, fragment.Swagger) - return 0 -} diff --git a/pkg/loader/loader_test.go b/pkg/loader/loader_test.go index d70c45f5..7c480034 100644 --- a/pkg/loader/loader_test.go +++ b/pkg/loader/loader_test.go @@ -10,6 +10,7 @@ import ( "path/filepath" "testing" + "github.com/gptscript-ai/gptscript/pkg/openapi" "github.com/hexops/autogold/v2" "github.com/stretchr/testify/require" ) @@ -53,17 +54,17 @@ Stuff func TestIsOpenAPI(t *testing.T) { datav2, err := os.ReadFile("testdata/openapi_v2.yaml") require.NoError(t, err) - v := isOpenAPI(datav2) + v := openapi.IsOpenAPI(datav2) require.Equal(t, 2, v, "(yaml) expected openapi v2") datav2, err = os.ReadFile("testdata/openapi_v2.json") require.NoError(t, err) - v = isOpenAPI(datav2) + v = openapi.IsOpenAPI(datav2) require.Equal(t, 2, v, "(json) expected openapi v2") datav3, err := os.ReadFile("testdata/openapi_v3.yaml") require.NoError(t, err) - v = isOpenAPI(datav3) + v = openapi.IsOpenAPI(datav3) require.Equal(t, 3, v, "(json) expected openapi v3") } diff --git a/pkg/loader/openapi.go b/pkg/loader/openapi.go index 45254c9d..bc469a4e 100644 --- a/pkg/loader/openapi.go +++ b/pkg/loader/openapi.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "net/url" + "os" "regexp" "slices" "sort" @@ -11,7 +12,7 @@ import ( "time" "github.com/getkin/kin-openapi/openapi3" - "github.com/gptscript-ai/gptscript/pkg/engine" + "github.com/gptscript-ai/gptscript/pkg/openapi" "github.com/gptscript-ai/gptscript/pkg/types" ) @@ -20,8 +21,12 @@ var toolNameRegex = regexp.MustCompile(`[^a-zA-Z0-9_-]+`) // getOpenAPITools parses an OpenAPI definition and generates a set of tools from it. // Each operation will become a tool definition. // The tool's Instructions will be in the format "#!sys.openapi '{JSON Instructions}'", -// where the JSON Instructions are a JSON-serialized engine.OpenAPIInstructions struct. -func getOpenAPITools(t *openapi3.T, defaultHost string) ([]types.Tool, error) { +// where the JSON Instructions are a JSON-serialized openapi.OperationInfo struct. +func getOpenAPITools(t *openapi3.T, defaultHost, source, targetToolName string) ([]types.Tool, error) { + if os.Getenv("GPTSCRIPT_OPENAPI_REVAMP") == "true" { + return getOpenAPIToolsRevamp(t, source, targetToolName) + } + if log.IsDebug() { start := time.Now() defer func() { @@ -51,7 +56,7 @@ func getOpenAPITools(t *openapi3.T, defaultHost string) ([]types.Tool, error) { for _, item := range t.Security { current := map[string]struct{}{} for name := range item { - if scheme, ok := t.Components.SecuritySchemes[name]; ok && slices.Contains(engine.SupportedSecurityTypes, scheme.Value.Type) { + if scheme, ok := t.Components.SecuritySchemes[name]; ok && slices.Contains(openapi.GetSupportedSecurityTypes(), scheme.Value.Type) { current[name] = struct{}{} } } @@ -134,10 +139,10 @@ func getOpenAPITools(t *openapi3.T, defaultHost string) ([]types.Tool, error) { // - C // D auths []map[string]struct{} - queryParameters []engine.Parameter - pathParameters []engine.Parameter - headerParameters []engine.Parameter - cookieParameters []engine.Parameter + queryParameters []openapi.Parameter + pathParameters []openapi.Parameter + headerParameters []openapi.Parameter + cookieParameters []openapi.Parameter bodyMIME string ) tool := types.Tool{ @@ -177,7 +182,7 @@ func getOpenAPITools(t *openapi3.T, defaultHost string) ([]types.Tool, error) { } // Add the parameter to the appropriate list for the tool's instructions - p := engine.Parameter{ + p := openapi.Parameter{ Name: param.Value.Name, Style: param.Value.Style, Explode: param.Value.Explode, @@ -199,7 +204,7 @@ func getOpenAPITools(t *openapi3.T, defaultHost string) ([]types.Tool, error) { for mime, content := range operation.RequestBody.Value.Content { // Each MIME type needs to be handled individually, so we // keep a list of the ones we support. - if !slices.Contains(engine.SupportedMIMETypes, mime) { + if !slices.Contains(openapi.GetSupportedMIMETypes(), mime) { continue } bodyMIME = mime @@ -250,18 +255,18 @@ func getOpenAPITools(t *openapi3.T, defaultHost string) ([]types.Tool, error) { } // For each set of auths, turn them into SecurityInfos, and drop ones that contain unsupported types. - var infos [][]engine.SecurityInfo + var infos [][]openapi.SecurityInfo outer: for _, auth := range auths { - var current []engine.SecurityInfo + var current []openapi.SecurityInfo for name := range auth { if scheme, ok := t.Components.SecuritySchemes[name]; ok { - if !slices.Contains(engine.SupportedSecurityTypes, scheme.Value.Type) { + if !slices.Contains(openapi.GetSupportedSecurityTypes(), scheme.Value.Type) { // There is an unsupported type in this auth, so move on to the next one. continue outer } - current = append(current, engine.SecurityInfo{ + current = append(current, openapi.SecurityInfo{ Type: scheme.Value.Type, Name: name, In: scheme.Value.In, @@ -324,17 +329,17 @@ func getOpenAPITools(t *openapi3.T, defaultHost string) ([]types.Tool, error) { return tools, nil } -func instructionString(server, method, path, bodyMIME string, queryParameters, pathParameters, headerParameters, cookieParameters []engine.Parameter, infos [][]engine.SecurityInfo) (string, error) { - inst := engine.OpenAPIInstructions{ - Server: server, - Path: path, - Method: method, - BodyContentMIME: bodyMIME, - SecurityInfos: infos, - QueryParameters: queryParameters, - PathParameters: pathParameters, - HeaderParameters: headerParameters, - CookieParameters: cookieParameters, +func instructionString(server, method, path, bodyMIME string, queryParameters, pathParameters, headerParameters, cookieParameters []openapi.Parameter, infos [][]openapi.SecurityInfo) (string, error) { + inst := openapi.OperationInfo{ + Server: server, + Path: path, + Method: method, + BodyContentMIME: bodyMIME, + SecurityInfos: infos, + QueryParams: queryParameters, + PathParams: pathParameters, + HeaderParams: headerParameters, + CookieParams: cookieParameters, } instBytes, err := json.Marshal(inst) if err != nil { @@ -362,3 +367,95 @@ func parseServer(server *openapi3.Server) (string, error) { } return s, nil } + +func getOpenAPIToolsRevamp(t *openapi3.T, source, targetToolName string) ([]types.Tool, error) { + if t == nil { + return nil, fmt.Errorf("OpenAPI spec is nil") + } else if t.Info == nil { + return nil, fmt.Errorf("OpenAPI spec is missing info field") + } + + if targetToolName == "" { + targetToolName = openapi.NoFilter + } + + list := types.Tool{ + ToolDef: types.ToolDef{ + Parameters: types.Parameters{ + Name: types.ToolNormalizer("list-operations-" + t.Info.Title), + Description: fmt.Sprintf("List available operations for %s. Each of these operations is an OpenAPI operation. Run this tool before you do anything else.", t.Info.Title), + }, + Instructions: fmt.Sprintf("%s %s %s %s", types.OpenAPIPrefix, openapi.ListTool, source, targetToolName), + }, + Source: types.ToolSource{ + LineNo: 0, + }, + } + + getSchema := types.Tool{ + ToolDef: types.ToolDef{ + Parameters: types.Parameters{ + Name: types.ToolNormalizer("get-schema-" + t.Info.Title), + Description: fmt.Sprintf("Get the JSONSchema for the arguments for an operation for %s. You must do this before you run the operation.", t.Info.Title), + Arguments: &openapi3.Schema{ + Type: &openapi3.Types{openapi3.TypeObject}, + Properties: openapi3.Schemas{ + "operation": { + Value: &openapi3.Schema{ + Type: &openapi3.Types{openapi3.TypeString}, + Title: "operation", + Description: "the name of the operation to get the schema for", + Required: []string{"operation"}, + }, + }, + }, + }, + }, + Instructions: fmt.Sprintf("%s %s %s %s", types.OpenAPIPrefix, openapi.GetSchemaTool, source, targetToolName), + }, + Source: types.ToolSource{ + LineNo: 1, + }, + } + + run := types.Tool{ + ToolDef: types.ToolDef{ + Parameters: types.Parameters{ + Name: types.ToolNormalizer("run-operation-" + t.Info.Title), + Description: fmt.Sprintf("Run an operation for %s. You MUST call %s for the operation before you use this tool.", t.Info.Title, openapi.GetSchemaTool), + Arguments: &openapi3.Schema{ + Type: &openapi3.Types{openapi3.TypeObject}, + Properties: openapi3.Schemas{ + "operation": { + Value: &openapi3.Schema{ + Type: &openapi3.Types{openapi3.TypeString}, + Title: "operation", + Description: "the name of the operation to run", + Required: []string{"operation"}, + }, + }, + "args": { + Value: &openapi3.Schema{ + Type: &openapi3.Types{openapi3.TypeString}, + Title: "args", + Description: "the JSON string containing arguments; must match the JSONSchema for the operation", + Required: []string{"args"}, + }, + }, + }, + }, + }, + Instructions: fmt.Sprintf("%s %s %s %s", types.OpenAPIPrefix, openapi.RunTool, source, targetToolName), + }, + } + + exportTool := types.Tool{ + ToolDef: types.ToolDef{ + Parameters: types.Parameters{ + Export: []string{list.Parameters.Name, getSchema.Parameters.Name, run.Parameters.Name}, + }, + }, + } + + return []types.Tool{exportTool, list, getSchema, run}, nil +} diff --git a/pkg/loader/openapi_test.go b/pkg/loader/openapi_test.go index d00ffcca..1a7eaa76 100644 --- a/pkg/loader/openapi_test.go +++ b/pkg/loader/openapi_test.go @@ -86,3 +86,42 @@ func TestOpenAPIv2(t *testing.T) { autogold.ExpectFile(t, prgv2.ToolSet, autogold.Dir("testdata/openapi")) } + +func TestOpenAPIv3Revamp(t *testing.T) { + os.Setenv("GPTSCRIPT_OPENAPI_REVAMP", "true") + prgv3 := types.Program{ + ToolSet: types.ToolSet{}, + } + datav3, err := os.ReadFile("testdata/openapi_v3.yaml") + require.NoError(t, err) + _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "") + require.NoError(t, err) + + autogold.ExpectFile(t, prgv3.ToolSet, autogold.Dir("testdata/openapi")) +} + +func TestOpenAPIv3NoOperationIDsRevamp(t *testing.T) { + os.Setenv("GPTSCRIPT_OPENAPI_REVAMP", "true") + prgv3 := types.Program{ + ToolSet: types.ToolSet{}, + } + datav3, err := os.ReadFile("testdata/openapi_v3_no_operation_ids.yaml") + require.NoError(t, err) + _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "") + require.NoError(t, err) + + autogold.ExpectFile(t, prgv3.ToolSet, autogold.Dir("testdata/openapi")) +} + +func TestOpenAPIv2Revamp(t *testing.T) { + os.Setenv("GPTSCRIPT_OPENAPI_REVAMP", "true") + prgv2 := types.Program{ + ToolSet: types.ToolSet{}, + } + datav2, err := os.ReadFile("testdata/openapi_v2.yaml") + require.NoError(t, err) + _, err = readTool(context.Background(), nil, &prgv2, &source{Content: datav2}, "") + require.NoError(t, err) + + autogold.ExpectFile(t, prgv2.ToolSet, autogold.Dir("testdata/openapi")) +} diff --git a/pkg/loader/testdata/openapi/TestOpenAPIv2.golden b/pkg/loader/testdata/openapi/TestOpenAPIv2.golden index 90dd1967..39b0b2c1 100644 --- a/pkg/loader/testdata/openapi/TestOpenAPIv2.golden +++ b/pkg/loader/testdata/openapi/TestOpenAPIv2.golden @@ -38,7 +38,7 @@ types.ToolSet{ Description: "Create a pet", ModelName: "gpt-4o", }, - Instructions: `#!sys.openapi '{"server":"http://petstore.swagger.io/v1","path":"/pets","method":"POST","bodyContentMIME":"","apiKeyInfos":null,"queryParameters":null,"pathParameters":null,"headerParameters":null,"cookieParameters":null}'`, + Instructions: `#!sys.openapi '{"server":"http://petstore.swagger.io/v1","path":"/pets","method":"POST","bodyContentMIME":"","securityInfos":null,"queryParameters":null,"pathParameters":null,"headerParameters":null,"cookieParameters":null}'`, }, ID: ":createPets", ToolMapping: map[string][]types.ToolReference{}, @@ -68,7 +68,7 @@ types.ToolSet{ }}}, }, }, - Instructions: `#!sys.openapi '{"server":"http://petstore.swagger.io/v1","path":"/pets","method":"GET","bodyContentMIME":"","apiKeyInfos":null,"queryParameters":[{"name":"limit","style":"","explode":null}],"pathParameters":null,"headerParameters":null,"cookieParameters":null}'`, + Instructions: `#!sys.openapi '{"server":"http://petstore.swagger.io/v1","path":"/pets","method":"GET","bodyContentMIME":"","securityInfos":null,"queryParameters":[{"name":"limit","style":"","explode":null}],"pathParameters":null,"headerParameters":null,"cookieParameters":null}'`, }, ID: ":listPets", ToolMapping: map[string][]types.ToolReference{}, @@ -95,7 +95,7 @@ types.ToolSet{ }}}, }, }, - Instructions: `#!sys.openapi '{"server":"http://petstore.swagger.io/v1","path":"/pets/{petId}","method":"GET","bodyContentMIME":"","apiKeyInfos":null,"queryParameters":null,"pathParameters":[{"name":"petId","style":"","explode":null}],"headerParameters":null,"cookieParameters":null}'`, + Instructions: `#!sys.openapi '{"server":"http://petstore.swagger.io/v1","path":"/pets/{petId}","method":"GET","bodyContentMIME":"","securityInfos":null,"queryParameters":null,"pathParameters":[{"name":"petId","style":"","explode":null}],"headerParameters":null,"cookieParameters":null}'`, }, ID: ":showPetById", ToolMapping: map[string][]types.ToolReference{}, diff --git a/pkg/loader/testdata/openapi/TestOpenAPIv2Revamp.golden b/pkg/loader/testdata/openapi/TestOpenAPIv2Revamp.golden new file mode 100644 index 00000000..ebe68cc2 --- /dev/null +++ b/pkg/loader/testdata/openapi/TestOpenAPIv2Revamp.golden @@ -0,0 +1,116 @@ +types.ToolSet{ + ":": types.Tool{ + ToolDef: types.ToolDef{Parameters: types.Parameters{ + ModelName: "gpt-4o", + Export: []string{ + "listOperationsSwaggerPetstore", + "getSchemaSwaggerPetstore", + "runOperationSwaggerPetstore", + }, + }}, + ID: ":", + ToolMapping: map[string][]types.ToolReference{ + "getSchemaSwaggerPetstore": {{ + Reference: "getSchemaSwaggerPetstore", + ToolID: ":getSchemaSwaggerPetstore", + }}, + "listOperationsSwaggerPetstore": {{ + Reference: "listOperationsSwaggerPetstore", + ToolID: ":listOperationsSwaggerPetstore", + }}, + "runOperationSwaggerPetstore": {{ + Reference: "runOperationSwaggerPetstore", + ToolID: ":runOperationSwaggerPetstore", + }}, + }, + LocalTools: map[string]string{ + "": ":", + "getschemaswaggerpetstore": ":getSchemaSwaggerPetstore", + "listoperationsswaggerpetstore": ":listOperationsSwaggerPetstore", + "runoperationswaggerpetstore": ":runOperationSwaggerPetstore", + }, + }, + ":getSchemaSwaggerPetstore": types.Tool{ + ToolDef: types.ToolDef{ + Parameters: types.Parameters{ + Name: "getSchemaSwaggerPetstore", + Description: "Get the JSONSchema for the arguments for an operation for Swagger Petstore. You must do this before you run the operation.", + ModelName: "gpt-4o", + Arguments: &openapi3.Schema{ + Type: &openapi3.Types{ + "object", + }, + Properties: openapi3.Schemas{"operation": &openapi3.SchemaRef{Value: &openapi3.Schema{ + Type: &openapi3.Types{"string"}, + Title: "operation", + Description: "the name of the operation to get the schema for", + Required: []string{"operation"}, + }}}, + }, + }, + Instructions: "#!sys.openapi get-schema ", + }, + ID: ":getSchemaSwaggerPetstore", + ToolMapping: map[string][]types.ToolReference{}, + LocalTools: map[string]string{ + "": ":", + "getschemaswaggerpetstore": ":getSchemaSwaggerPetstore", + "listoperationsswaggerpetstore": ":listOperationsSwaggerPetstore", + "runoperationswaggerpetstore": ":runOperationSwaggerPetstore", + }, + Source: types.ToolSource{LineNo: 1}, + }, + ":listOperationsSwaggerPetstore": types.Tool{ + ToolDef: types.ToolDef{ + Parameters: types.Parameters{ + Name: "listOperationsSwaggerPetstore", + Description: "List available operations for Swagger Petstore. Each of these operations is an OpenAPI operation. Run this tool before you do anything else.", + ModelName: "gpt-4o", + }, + Instructions: "#!sys.openapi list ", + }, + ID: ":listOperationsSwaggerPetstore", + ToolMapping: map[string][]types.ToolReference{}, + LocalTools: map[string]string{ + "": ":", + "getschemaswaggerpetstore": ":getSchemaSwaggerPetstore", + "listoperationsswaggerpetstore": ":listOperationsSwaggerPetstore", + "runoperationswaggerpetstore": ":runOperationSwaggerPetstore", + }, + }, + ":runOperationSwaggerPetstore": types.Tool{ + ToolDef: types.ToolDef{ + Parameters: types.Parameters{ + Name: "runOperationSwaggerPetstore", + Description: "Run an operation for Swagger Petstore. You MUST call get-schema for the operation before you use this tool.", + ModelName: "gpt-4o", + Arguments: &openapi3.Schema{ + Type: &openapi3.Types{"object"}, + Properties: openapi3.Schemas{ + "args": &openapi3.SchemaRef{Value: &openapi3.Schema{ + Type: &openapi3.Types{"string"}, + Title: "args", + Description: "the JSON string containing arguments; must match the JSONSchema for the operation", + Required: []string{"args"}, + }}, + "operation": &openapi3.SchemaRef{Value: &openapi3.Schema{ + Type: &openapi3.Types{"string"}, + Title: "operation", + Description: "the name of the operation to run", + Required: []string{"operation"}, + }}, + }, + }, + }, + Instructions: "#!sys.openapi run ", + }, + ID: ":runOperationSwaggerPetstore", + ToolMapping: map[string][]types.ToolReference{}, + LocalTools: map[string]string{ + "": ":", + "getschemaswaggerpetstore": ":getSchemaSwaggerPetstore", + "listoperationsswaggerpetstore": ":listOperationsSwaggerPetstore", + "runoperationswaggerpetstore": ":runOperationSwaggerPetstore", + }, + }, +} diff --git a/pkg/loader/testdata/openapi/TestOpenAPIv3.golden b/pkg/loader/testdata/openapi/TestOpenAPIv3.golden index 72ccafae..37ac2fe2 100644 --- a/pkg/loader/testdata/openapi/TestOpenAPIv3.golden +++ b/pkg/loader/testdata/openapi/TestOpenAPIv3.golden @@ -63,7 +63,7 @@ types.ToolSet{ }}}, }, }, - Instructions: `#!sys.openapi '{"server":"http://petstore.swagger.io/v1","path":"/pets","method":"POST","bodyContentMIME":"application/json","apiKeyInfos":null,"queryParameters":null,"pathParameters":null,"headerParameters":null,"cookieParameters":null}'`, + Instructions: `#!sys.openapi '{"server":"http://petstore.swagger.io/v1","path":"/pets","method":"POST","bodyContentMIME":"application/json","securityInfos":null,"queryParameters":null,"pathParameters":null,"headerParameters":null,"cookieParameters":null}'`, }, ID: ":createPets", ToolMapping: map[string][]types.ToolReference{}, @@ -92,7 +92,7 @@ types.ToolSet{ }}}, }, }, - Instructions: `#!sys.openapi '{"server":"http://petstore.swagger.io/v1","path":"/pets","method":"GET","bodyContentMIME":"","apiKeyInfos":null,"queryParameters":[{"name":"limit","style":"","explode":null}],"pathParameters":null,"headerParameters":null,"cookieParameters":null}'`, + Instructions: `#!sys.openapi '{"server":"http://petstore.swagger.io/v1","path":"/pets","method":"GET","bodyContentMIME":"","securityInfos":null,"queryParameters":[{"name":"limit","style":"","explode":null}],"pathParameters":null,"headerParameters":null,"cookieParameters":null}'`, }, ID: ":listPets", ToolMapping: map[string][]types.ToolReference{}, @@ -119,7 +119,7 @@ types.ToolSet{ }}}, }, }, - Instructions: `#!sys.openapi '{"server":"http://petstore.swagger.io/v1","path":"/pets/{petId}","method":"GET","bodyContentMIME":"","apiKeyInfos":null,"queryParameters":null,"pathParameters":[{"name":"petId","style":"","explode":null}],"headerParameters":null,"cookieParameters":null}'`, + Instructions: `#!sys.openapi '{"server":"http://petstore.swagger.io/v1","path":"/pets/{petId}","method":"GET","bodyContentMIME":"","securityInfos":null,"queryParameters":null,"pathParameters":[{"name":"petId","style":"","explode":null}],"headerParameters":null,"cookieParameters":null}'`, }, ID: ":showPetById", ToolMapping: map[string][]types.ToolReference{}, diff --git a/pkg/loader/testdata/openapi/TestOpenAPIv3NoOperationIDs.golden b/pkg/loader/testdata/openapi/TestOpenAPIv3NoOperationIDs.golden index 3bcfd9e5..e950e19c 100644 --- a/pkg/loader/testdata/openapi/TestOpenAPIv3NoOperationIDs.golden +++ b/pkg/loader/testdata/openapi/TestOpenAPIv3NoOperationIDs.golden @@ -50,7 +50,7 @@ types.ToolSet{ }}}, }, }, - Instructions: `#!sys.openapi '{"server":"http://petstore.swagger.io/v1","path":"/pets","method":"GET","bodyContentMIME":"","apiKeyInfos":null,"queryParameters":[{"name":"limit","style":"","explode":null}],"pathParameters":null,"headerParameters":null,"cookieParameters":null}'`, + Instructions: `#!sys.openapi '{"server":"http://petstore.swagger.io/v1","path":"/pets","method":"GET","bodyContentMIME":"","securityInfos":null,"queryParameters":[{"name":"limit","style":"","explode":null}],"pathParameters":null,"headerParameters":null,"cookieParameters":null}'`, }, ID: ":get_pets", ToolMapping: map[string][]types.ToolReference{}, @@ -77,7 +77,7 @@ types.ToolSet{ }}}, }, }, - Instructions: `#!sys.openapi '{"server":"http://petstore.swagger.io/v1","path":"/pets/{petId}","method":"GET","bodyContentMIME":"","apiKeyInfos":null,"queryParameters":null,"pathParameters":[{"name":"petId","style":"","explode":null}],"headerParameters":null,"cookieParameters":null}'`, + Instructions: `#!sys.openapi '{"server":"http://petstore.swagger.io/v1","path":"/pets/{petId}","method":"GET","bodyContentMIME":"","securityInfos":null,"queryParameters":null,"pathParameters":[{"name":"petId","style":"","explode":null}],"headerParameters":null,"cookieParameters":null}'`, }, ID: ":get_pets_petId", ToolMapping: map[string][]types.ToolReference{}, @@ -119,7 +119,7 @@ types.ToolSet{ }}}, }, }, - Instructions: `#!sys.openapi '{"server":"http://petstore.swagger.io/v1","path":"/pets","method":"POST","bodyContentMIME":"application/json","apiKeyInfos":null,"queryParameters":null,"pathParameters":null,"headerParameters":null,"cookieParameters":null}'`, + Instructions: `#!sys.openapi '{"server":"http://petstore.swagger.io/v1","path":"/pets","method":"POST","bodyContentMIME":"application/json","securityInfos":null,"queryParameters":null,"pathParameters":null,"headerParameters":null,"cookieParameters":null}'`, }, ID: ":post_pets", ToolMapping: map[string][]types.ToolReference{}, diff --git a/pkg/loader/testdata/openapi/TestOpenAPIv3NoOperationIDsRevamp.golden b/pkg/loader/testdata/openapi/TestOpenAPIv3NoOperationIDsRevamp.golden new file mode 100644 index 00000000..ebe68cc2 --- /dev/null +++ b/pkg/loader/testdata/openapi/TestOpenAPIv3NoOperationIDsRevamp.golden @@ -0,0 +1,116 @@ +types.ToolSet{ + ":": types.Tool{ + ToolDef: types.ToolDef{Parameters: types.Parameters{ + ModelName: "gpt-4o", + Export: []string{ + "listOperationsSwaggerPetstore", + "getSchemaSwaggerPetstore", + "runOperationSwaggerPetstore", + }, + }}, + ID: ":", + ToolMapping: map[string][]types.ToolReference{ + "getSchemaSwaggerPetstore": {{ + Reference: "getSchemaSwaggerPetstore", + ToolID: ":getSchemaSwaggerPetstore", + }}, + "listOperationsSwaggerPetstore": {{ + Reference: "listOperationsSwaggerPetstore", + ToolID: ":listOperationsSwaggerPetstore", + }}, + "runOperationSwaggerPetstore": {{ + Reference: "runOperationSwaggerPetstore", + ToolID: ":runOperationSwaggerPetstore", + }}, + }, + LocalTools: map[string]string{ + "": ":", + "getschemaswaggerpetstore": ":getSchemaSwaggerPetstore", + "listoperationsswaggerpetstore": ":listOperationsSwaggerPetstore", + "runoperationswaggerpetstore": ":runOperationSwaggerPetstore", + }, + }, + ":getSchemaSwaggerPetstore": types.Tool{ + ToolDef: types.ToolDef{ + Parameters: types.Parameters{ + Name: "getSchemaSwaggerPetstore", + Description: "Get the JSONSchema for the arguments for an operation for Swagger Petstore. You must do this before you run the operation.", + ModelName: "gpt-4o", + Arguments: &openapi3.Schema{ + Type: &openapi3.Types{ + "object", + }, + Properties: openapi3.Schemas{"operation": &openapi3.SchemaRef{Value: &openapi3.Schema{ + Type: &openapi3.Types{"string"}, + Title: "operation", + Description: "the name of the operation to get the schema for", + Required: []string{"operation"}, + }}}, + }, + }, + Instructions: "#!sys.openapi get-schema ", + }, + ID: ":getSchemaSwaggerPetstore", + ToolMapping: map[string][]types.ToolReference{}, + LocalTools: map[string]string{ + "": ":", + "getschemaswaggerpetstore": ":getSchemaSwaggerPetstore", + "listoperationsswaggerpetstore": ":listOperationsSwaggerPetstore", + "runoperationswaggerpetstore": ":runOperationSwaggerPetstore", + }, + Source: types.ToolSource{LineNo: 1}, + }, + ":listOperationsSwaggerPetstore": types.Tool{ + ToolDef: types.ToolDef{ + Parameters: types.Parameters{ + Name: "listOperationsSwaggerPetstore", + Description: "List available operations for Swagger Petstore. Each of these operations is an OpenAPI operation. Run this tool before you do anything else.", + ModelName: "gpt-4o", + }, + Instructions: "#!sys.openapi list ", + }, + ID: ":listOperationsSwaggerPetstore", + ToolMapping: map[string][]types.ToolReference{}, + LocalTools: map[string]string{ + "": ":", + "getschemaswaggerpetstore": ":getSchemaSwaggerPetstore", + "listoperationsswaggerpetstore": ":listOperationsSwaggerPetstore", + "runoperationswaggerpetstore": ":runOperationSwaggerPetstore", + }, + }, + ":runOperationSwaggerPetstore": types.Tool{ + ToolDef: types.ToolDef{ + Parameters: types.Parameters{ + Name: "runOperationSwaggerPetstore", + Description: "Run an operation for Swagger Petstore. You MUST call get-schema for the operation before you use this tool.", + ModelName: "gpt-4o", + Arguments: &openapi3.Schema{ + Type: &openapi3.Types{"object"}, + Properties: openapi3.Schemas{ + "args": &openapi3.SchemaRef{Value: &openapi3.Schema{ + Type: &openapi3.Types{"string"}, + Title: "args", + Description: "the JSON string containing arguments; must match the JSONSchema for the operation", + Required: []string{"args"}, + }}, + "operation": &openapi3.SchemaRef{Value: &openapi3.Schema{ + Type: &openapi3.Types{"string"}, + Title: "operation", + Description: "the name of the operation to run", + Required: []string{"operation"}, + }}, + }, + }, + }, + Instructions: "#!sys.openapi run ", + }, + ID: ":runOperationSwaggerPetstore", + ToolMapping: map[string][]types.ToolReference{}, + LocalTools: map[string]string{ + "": ":", + "getschemaswaggerpetstore": ":getSchemaSwaggerPetstore", + "listoperationsswaggerpetstore": ":listOperationsSwaggerPetstore", + "runoperationswaggerpetstore": ":runOperationSwaggerPetstore", + }, + }, +} diff --git a/pkg/loader/testdata/openapi/TestOpenAPIv3Revamp.golden b/pkg/loader/testdata/openapi/TestOpenAPIv3Revamp.golden new file mode 100644 index 00000000..ebe68cc2 --- /dev/null +++ b/pkg/loader/testdata/openapi/TestOpenAPIv3Revamp.golden @@ -0,0 +1,116 @@ +types.ToolSet{ + ":": types.Tool{ + ToolDef: types.ToolDef{Parameters: types.Parameters{ + ModelName: "gpt-4o", + Export: []string{ + "listOperationsSwaggerPetstore", + "getSchemaSwaggerPetstore", + "runOperationSwaggerPetstore", + }, + }}, + ID: ":", + ToolMapping: map[string][]types.ToolReference{ + "getSchemaSwaggerPetstore": {{ + Reference: "getSchemaSwaggerPetstore", + ToolID: ":getSchemaSwaggerPetstore", + }}, + "listOperationsSwaggerPetstore": {{ + Reference: "listOperationsSwaggerPetstore", + ToolID: ":listOperationsSwaggerPetstore", + }}, + "runOperationSwaggerPetstore": {{ + Reference: "runOperationSwaggerPetstore", + ToolID: ":runOperationSwaggerPetstore", + }}, + }, + LocalTools: map[string]string{ + "": ":", + "getschemaswaggerpetstore": ":getSchemaSwaggerPetstore", + "listoperationsswaggerpetstore": ":listOperationsSwaggerPetstore", + "runoperationswaggerpetstore": ":runOperationSwaggerPetstore", + }, + }, + ":getSchemaSwaggerPetstore": types.Tool{ + ToolDef: types.ToolDef{ + Parameters: types.Parameters{ + Name: "getSchemaSwaggerPetstore", + Description: "Get the JSONSchema for the arguments for an operation for Swagger Petstore. You must do this before you run the operation.", + ModelName: "gpt-4o", + Arguments: &openapi3.Schema{ + Type: &openapi3.Types{ + "object", + }, + Properties: openapi3.Schemas{"operation": &openapi3.SchemaRef{Value: &openapi3.Schema{ + Type: &openapi3.Types{"string"}, + Title: "operation", + Description: "the name of the operation to get the schema for", + Required: []string{"operation"}, + }}}, + }, + }, + Instructions: "#!sys.openapi get-schema ", + }, + ID: ":getSchemaSwaggerPetstore", + ToolMapping: map[string][]types.ToolReference{}, + LocalTools: map[string]string{ + "": ":", + "getschemaswaggerpetstore": ":getSchemaSwaggerPetstore", + "listoperationsswaggerpetstore": ":listOperationsSwaggerPetstore", + "runoperationswaggerpetstore": ":runOperationSwaggerPetstore", + }, + Source: types.ToolSource{LineNo: 1}, + }, + ":listOperationsSwaggerPetstore": types.Tool{ + ToolDef: types.ToolDef{ + Parameters: types.Parameters{ + Name: "listOperationsSwaggerPetstore", + Description: "List available operations for Swagger Petstore. Each of these operations is an OpenAPI operation. Run this tool before you do anything else.", + ModelName: "gpt-4o", + }, + Instructions: "#!sys.openapi list ", + }, + ID: ":listOperationsSwaggerPetstore", + ToolMapping: map[string][]types.ToolReference{}, + LocalTools: map[string]string{ + "": ":", + "getschemaswaggerpetstore": ":getSchemaSwaggerPetstore", + "listoperationsswaggerpetstore": ":listOperationsSwaggerPetstore", + "runoperationswaggerpetstore": ":runOperationSwaggerPetstore", + }, + }, + ":runOperationSwaggerPetstore": types.Tool{ + ToolDef: types.ToolDef{ + Parameters: types.Parameters{ + Name: "runOperationSwaggerPetstore", + Description: "Run an operation for Swagger Petstore. You MUST call get-schema for the operation before you use this tool.", + ModelName: "gpt-4o", + Arguments: &openapi3.Schema{ + Type: &openapi3.Types{"object"}, + Properties: openapi3.Schemas{ + "args": &openapi3.SchemaRef{Value: &openapi3.Schema{ + Type: &openapi3.Types{"string"}, + Title: "args", + Description: "the JSON string containing arguments; must match the JSONSchema for the operation", + Required: []string{"args"}, + }}, + "operation": &openapi3.SchemaRef{Value: &openapi3.Schema{ + Type: &openapi3.Types{"string"}, + Title: "operation", + Description: "the name of the operation to run", + Required: []string{"operation"}, + }}, + }, + }, + }, + Instructions: "#!sys.openapi run ", + }, + ID: ":runOperationSwaggerPetstore", + ToolMapping: map[string][]types.ToolReference{}, + LocalTools: map[string]string{ + "": ":", + "getschemaswaggerpetstore": ":getSchemaSwaggerPetstore", + "listoperationsswaggerpetstore": ":listOperationsSwaggerPetstore", + "runoperationswaggerpetstore": ":runOperationSwaggerPetstore", + }, + }, +} diff --git a/pkg/openapi/getschema.go b/pkg/openapi/getschema.go new file mode 100644 index 00000000..3550afcf --- /dev/null +++ b/pkg/openapi/getschema.go @@ -0,0 +1,285 @@ +package openapi + +import ( + "encoding/json" + "fmt" + "slices" + "strings" + + "github.com/getkin/kin-openapi/openapi3" +) + +type Parameter struct { + Name string `json:"name"` + Style string `json:"style"` + Explode *bool `json:"explode"` +} + +type OperationInfo struct { + Server string `json:"server"` + Path string `json:"path"` + Method string `json:"method"` + BodyContentMIME string `json:"bodyContentMIME"` + SecurityInfos [][]SecurityInfo `json:"securityInfos"` + QueryParams []Parameter `json:"queryParameters"` + PathParams []Parameter `json:"pathParameters"` + HeaderParams []Parameter `json:"headerParameters"` + CookieParams []Parameter `json:"cookieParameters"` +} + +var ( + supportedMIMETypes = []string{"application/json", "application/x-www-form-urlencoded", "multipart/form-data"} + supportedSecurityTypes = []string{"apiKey", "http"} +) + +const GetSchemaTool = "get-schema" + +func GetSupportedMIMETypes() []string { + return supportedMIMETypes +} + +func GetSupportedSecurityTypes() []string { + return supportedSecurityTypes +} + +// GetSchema returns the JSONSchema and OperationInfo for a particular OpenAPI operation. +// Return values in order: JSONSchema (string), OperationInfo, found (bool), error. +func GetSchema(operationID, defaultHost string, t *openapi3.T) (string, OperationInfo, bool, error) { + arguments := &openapi3.Schema{ + Type: &openapi3.Types{"object"}, + Properties: openapi3.Schemas{}, + Required: []string{}, + } + + info := OperationInfo{} + + // Determine the default server. + var ( + defaultServer = defaultHost + err error + ) + if len(t.Servers) > 0 { + defaultServer, err = parseServer(t.Servers[0]) + if err != nil { + return "", OperationInfo{}, false, err + } + } + + var globalSecurity []map[string]struct{} + if t.Security != nil { + for _, item := range t.Security { + current := map[string]struct{}{} + for name := range item { + if scheme, ok := t.Components.SecuritySchemes[name]; ok && slices.Contains(supportedSecurityTypes, scheme.Value.Type) { + current[name] = struct{}{} + } + } + if len(current) > 0 { + globalSecurity = append(globalSecurity, current) + } + } + } + + for path, pathItem := range t.Paths.Map() { + // Handle path-level server override, if one exists. + pathServer := defaultServer + if pathItem.Servers != nil && len(pathItem.Servers) > 0 { + pathServer, err = parseServer(pathItem.Servers[0]) + if err != nil { + return "", OperationInfo{}, false, err + } + } + + for method, operation := range pathItem.Operations() { + if operation.OperationID == operationID { + // Handle operation-level server override, if one exists. + operationServer := pathServer + if operation.Servers != nil && len(*operation.Servers) > 0 { + operationServer, err = parseServer((*operation.Servers)[0]) + if err != nil { + return "", OperationInfo{}, false, err + } + } + + info.Server = operationServer + info.Path = path + info.Method = method + + // We found our operation. Now we need to process it and build the arguments. + // Handle query, path, header, and cookie parameters first. + for _, param := range append(operation.Parameters, pathItem.Parameters...) { + removeRefs(param.Value.Schema) + arg := param.Value.Schema.Value + + if arg.Description == "" { + arg.Description = param.Value.Description + } + + // Store the arg + arguments.Properties[param.Value.Name] = &openapi3.SchemaRef{Value: arg} + + // Check whether it is required + if param.Value.Required { + arguments.Required = append(arguments.Required, param.Value.Name) + } + + // Save the parameter to the correct set of params. + p := Parameter{ + Name: param.Value.Name, + Style: param.Value.Style, + Explode: param.Value.Explode, + } + switch param.Value.In { + case "query": + info.QueryParams = append(info.QueryParams, p) + case "path": + info.PathParams = append(info.PathParams, p) + case "header": + info.HeaderParams = append(info.HeaderParams, p) + case "cookie": + info.CookieParams = append(info.CookieParams, p) + } + } + + // Next, handle the request body, if one exists. + if operation.RequestBody != nil { + for mime, content := range operation.RequestBody.Value.Content { + // Each MIME type needs to be handled individually, so we keep a list of the ones we support. + if !slices.Contains(supportedMIMETypes, mime) { + continue + } + info.BodyContentMIME = mime + + removeRefs(content.Schema) + + arg := content.Schema.Value + if arg.Description == "" { + arg.Description = content.Schema.Value.Description + } + + // Read Only cannot be sent in the request body, so we remove it + for key, property := range arg.Properties { + if property.Value.ReadOnly { + delete(arg.Properties, key) + } + } + + // Unfortunately, the request body doesn't contain any good descriptor for it, + // so we just use "requestBodyContent" as the name of the arg. + arguments.Properties["requestBodyContent"] = &openapi3.SchemaRef{Value: arg} + arguments.Required = append(arguments.Required, "requestBodyContent") + break + } + + if info.BodyContentMIME == "" { + return "", OperationInfo{}, false, fmt.Errorf("no supported MIME type found for request body in operation %s", operationID) + } + } + + // See if there is any auth defined for this operation + var ( + noAuth bool + auths []map[string]struct{} + ) + if operation.Security != nil { + if len(*operation.Security) == 0 { + noAuth = true + } + for _, req := range *operation.Security { + current := map[string]struct{}{} + for name := range req { + current[name] = struct{}{} + } + if len(current) > 0 { + auths = append(auths, current) + } + } + } + + // Use the global security if it was not overridden for this operation + if !noAuth && len(auths) == 0 { + auths = append(auths, globalSecurity...) + } + + // For each set of auths, turn them into SecurityInfos, and drop ones that contain unsupported types. + outer: + for _, auth := range auths { + var current []SecurityInfo + for name := range auth { + if scheme, ok := t.Components.SecuritySchemes[name]; ok { + if !slices.Contains(supportedSecurityTypes, scheme.Value.Type) { + // There is an unsupported type in this auth, so move on to the next one. + continue outer + } + + current = append(current, SecurityInfo{ + Type: scheme.Value.Type, + Name: name, + In: scheme.Value.In, + Scheme: scheme.Value.Scheme, + APIKeyName: scheme.Value.Name, + }) + } + } + + if len(current) > 0 { + info.SecurityInfos = append(info.SecurityInfos, current) + } + } + + argumentsJSON, err := json.MarshalIndent(arguments, "", " ") + if err != nil { + return "", OperationInfo{}, false, err + } + return string(argumentsJSON), info, true, nil + } + } + } + + return "", OperationInfo{}, false, nil +} + +func parseServer(server *openapi3.Server) (string, error) { + s := server.URL + for name, variable := range server.Variables { + if variable == nil { + continue + } + + if variable.Default != "" { + s = strings.Replace(s, "{"+name+"}", variable.Default, 1) + } else if len(variable.Enum) > 0 { + s = strings.Replace(s, "{"+name+"}", variable.Enum[0], 1) + } + } + + if !strings.HasPrefix(s, "http") { + return "", fmt.Errorf("invalid server URL: %s (must use HTTP or HTTPS; relative URLs not supported)", s) + } + return s, nil +} + +func removeRefs(r *openapi3.SchemaRef) { + if r == nil { + return + } + + r.Ref = "" + r.Value.Discriminator = nil // Discriminators are not very useful and can junk up the schema. + + for i := range r.Value.OneOf { + removeRefs(r.Value.OneOf[i]) + } + for i := range r.Value.AnyOf { + removeRefs(r.Value.AnyOf[i]) + } + for i := range r.Value.AllOf { + removeRefs(r.Value.AllOf[i]) + } + removeRefs(r.Value.Not) + removeRefs(r.Value.Items) + + for i := range r.Value.Properties { + removeRefs(r.Value.Properties[i]) + } +} diff --git a/pkg/openapi/list.go b/pkg/openapi/list.go new file mode 100644 index 00000000..857c7014 --- /dev/null +++ b/pkg/openapi/list.go @@ -0,0 +1,68 @@ +package openapi + +import ( + "path/filepath" + "strings" + + "github.com/getkin/kin-openapi/openapi3" +) + +type OperationList struct { + Operations map[string]Operation `json:"operations"` +} + +type Operation struct { + Description string `json:"description,omitempty"` + Summary string `json:"summary,omitempty"` +} + +const ( + ListTool = "list" + NoFilter = "" +) + +func List(t *openapi3.T, filter string) (OperationList, error) { + operations := make(map[string]Operation) + for _, pathItem := range t.Paths.Map() { + for _, operation := range pathItem.Operations() { + var ( + match bool + err error + ) + if filter != "" && filter != NoFilter { + if strings.Contains(filter, "*") { + match, err = MatchFilters(strings.Split(filter, "|"), operation.OperationID) + if err != nil { + return OperationList{}, err + } + } else { + match = operation.OperationID == filter + } + } else { + match = true + } + + if match { + operations[operation.OperationID] = Operation{ + Description: operation.Description, + Summary: operation.Summary, + } + } + } + } + + return OperationList{Operations: operations}, nil +} + +func MatchFilters(filters []string, operationID string) (bool, error) { + for _, filter := range filters { + match, err := filepath.Match(filter, operationID) + if err != nil { + return false, err + } + if match { + return true, nil + } + } + return false, nil +} diff --git a/pkg/openapi/load.go b/pkg/openapi/load.go new file mode 100644 index 00000000..0ff82fdb --- /dev/null +++ b/pkg/openapi/load.go @@ -0,0 +1,121 @@ +package openapi + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "strconv" + "strings" + + "github.com/getkin/kin-openapi/openapi2" + "github.com/getkin/kin-openapi/openapi2conv" + "github.com/getkin/kin-openapi/openapi3" + "gopkg.in/yaml.v3" + kyaml "sigs.k8s.io/yaml" +) + +func Load(source string) (*openapi3.T, error) { + if strings.HasPrefix(source, "http://") || strings.HasPrefix(source, "https://") { + return loadFromURL(source) + } + return loadFromFile(source) +} + +func loadFromURL(source string) (*openapi3.T, error) { + resp, err := http.DefaultClient.Get(source) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + contents, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + return LoadFromBytes(contents) +} + +func loadFromFile(source string) (*openapi3.T, error) { + contents, err := os.ReadFile(source) + if err != nil { + return nil, err + } + + return LoadFromBytes(contents) +} + +func LoadFromBytes(content []byte) (*openapi3.T, error) { + var ( + openAPIDocument *openapi3.T + err error + ) + + switch IsOpenAPI(content) { + case 2: + // Convert OpenAPI v2 to v3 + if !json.Valid(content) { + content, err = kyaml.YAMLToJSON(content) + if err != nil { + return nil, err + } + } + + doc := &openapi2.T{} + if err := doc.UnmarshalJSON(content); err != nil { + return nil, fmt.Errorf("failed to unmarshal OpenAPI v2 document: %w", err) + } + + openAPIDocument, err = openapi2conv.ToV3(doc) + if err != nil { + return nil, fmt.Errorf("failed to convert OpenAPI v2 to v3: %w", err) + } + case 3: + openAPIDocument, err = openapi3.NewLoader().LoadFromData(content) + if err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("unsupported OpenAPI version") + } + + return openAPIDocument, nil +} + +// IsOpenAPI checks if the data is an OpenAPI definition and returns the version if it is. +func IsOpenAPI(data []byte) int { + var fragment struct { + Paths map[string]any `json:"paths,omitempty"` + Swagger string `json:"swagger,omitempty"` + OpenAPI string `json:"openapi,omitempty"` + } + + if err := json.Unmarshal(data, &fragment); err != nil { + if err := yaml.Unmarshal(data, &fragment); err != nil { + return 0 + } + } + if len(fragment.Paths) == 0 { + return 0 + } + + if v, _, _ := strings.Cut(fragment.OpenAPI, "."); v != "" { + ver, err := strconv.Atoi(v) + if err != nil { + return 0 + } + return ver + } + + if v, _, _ := strings.Cut(fragment.Swagger, "."); v != "" { + ver, err := strconv.Atoi(v) + if err != nil { + return 0 + } + return ver + } + + return 0 +} diff --git a/pkg/openapi/run.go b/pkg/openapi/run.go new file mode 100644 index 00000000..17199851 --- /dev/null +++ b/pkg/openapi/run.go @@ -0,0 +1,451 @@ +package openapi + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "net/url" + "os" + "strings" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/gptscript-ai/gptscript/pkg/env" + "github.com/tidwall/gjson" + "github.com/xeipuuv/gojsonschema" + "golang.org/x/exp/maps" +) + +const RunTool = "run" + +func Run(operationID, defaultHost, args string, t *openapi3.T, envs []string) (string, bool, error) { + envMap := make(map[string]string, len(envs)) + for _, e := range envs { + k, v, _ := strings.Cut(e, "=") + envMap[k] = v + } + + if args == "" { + args = "{}" + } + schemaJSON, opInfo, found, err := GetSchema(operationID, defaultHost, t) + if err != nil || !found { + return "", false, err + } + + // Validate args against the schema. + validationResult, err := gojsonschema.Validate(gojsonschema.NewStringLoader(schemaJSON), gojsonschema.NewStringLoader(args)) + if err != nil { + return "", false, err + } + + if !validationResult.Valid() { + return "", false, fmt.Errorf("invalid arguments for operation %s: %s", operationID, validationResult.Errors()) + } + + // Construct and execute the HTTP request. + + // Handle path parameters. + opInfo.Path = HandlePathParameters(opInfo.Path, opInfo.PathParams, args) + + // Parse the URL + path, err := url.JoinPath(opInfo.Server, opInfo.Path) + if err != nil { + return "", false, fmt.Errorf("failed to join server and path: %w", err) + } + + u, err := url.Parse(path) + if err != nil { + return "", false, fmt.Errorf("failed to parse server URL %s: %w", opInfo.Server+opInfo.Path, err) + } + + // Set up the request + req, err := http.NewRequest(opInfo.Method, u.String(), nil) + if err != nil { + return "", false, fmt.Errorf("failed to create request: %w", err) + } + + // Check for authentication (only if using HTTPS or localhost) + if u.Scheme == "https" || u.Hostname() == "localhost" || u.Hostname() == "127.0.0.1" { + if len(opInfo.SecurityInfos) > 0 { + if err := HandleAuths(req, envMap, opInfo.SecurityInfos); err != nil { + return "", false, fmt.Errorf("error setting up authentication: %w", err) + } + } + + // If there is a bearer token set for the whole server, and no Authorization header has been defined, use it. + if token, ok := envMap["GPTSCRIPT_"+env.ToEnvLike(u.Hostname())+"_BEARER_TOKEN"]; ok { + if req.Header.Get("Authorization") == "" { + req.Header.Set("Authorization", "Bearer "+token) + } + } + } else { + fmt.Fprintf(os.Stderr, "no auth") + } + + // Handle query parameters + req.URL.RawQuery = HandleQueryParameters(req.URL.Query(), opInfo.QueryParams, args).Encode() + + // Handle header and cookie parameters + HandleHeaderParameters(req, opInfo.HeaderParams, args) + HandleCookieParameters(req, opInfo.CookieParams, args) + + // Handle request body + if opInfo.BodyContentMIME != "" { + res := gjson.Get(args, "requestBodyContent") + var body bytes.Buffer + switch opInfo.BodyContentMIME { + case "application/json": + var reqBody any = struct{}{} + if res.Exists() { + reqBody = res.Value() + } + if err := json.NewEncoder(&body).Encode(reqBody); err != nil { + return "", false, fmt.Errorf("failed to encode JSON: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + case "text/plain": + reqBody := "" + if res.Exists() { + reqBody = res.String() + } + body.WriteString(reqBody) + + req.Header.Set("Content-Type", "text/plain") + + case "multipart/form-data": + multiPartWriter := multipart.NewWriter(&body) + req.Header.Set("Content-Type", multiPartWriter.FormDataContentType()) + if res.Exists() && res.IsObject() { + for k, v := range res.Map() { + if err := multiPartWriter.WriteField(k, v.String()); err != nil { + return "", false, fmt.Errorf("failed to write multipart field: %w", err) + } + } + } else { + return "", false, fmt.Errorf("multipart/form-data requires an object as the requestBodyContent") + } + if err := multiPartWriter.Close(); err != nil { + return "", false, fmt.Errorf("failed to close multipart writer: %w", err) + } + + default: + return "", false, fmt.Errorf("unsupported MIME type: %s", opInfo.BodyContentMIME) + } + req.Body = io.NopCloser(&body) + } + + // Make the request + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", false, fmt.Errorf("failed to make request: %w", err) + } + defer resp.Body.Close() + + result, err := io.ReadAll(resp.Body) + if err != nil { + return "", false, fmt.Errorf("failed to read response: %w", err) + } + + return string(result), true, nil +} + +// HandleAuths will set up the request with the necessary authentication information. +// A set of sets of SecurityInfo is passed in, where each represents a possible set of security options. +func HandleAuths(req *http.Request, envMap map[string]string, infoSets [][]SecurityInfo) error { + var missingVariables [][]string + + // We need to find a set of infos where we have all the needed environment variables. + for _, infoSet := range infoSets { + var missing []string // Keep track of any missing environment variables + for _, info := range infoSet { + vars := info.getCredentialNamesAndEnvVars(req.URL.Hostname()) + + for _, envName := range vars { + if _, ok := envMap[envName]; !ok { + missing = append(missing, envName) + } + } + } + if len(missing) > 0 { + missingVariables = append(missingVariables, missing) + continue + } + + // We're using this info set, because no environment variables were missing. + // Set up the request as needed. + for _, info := range infoSet { + envNames := maps.Values(info.getCredentialNamesAndEnvVars(req.URL.Hostname())) + switch info.Type { + case "apiKey": + switch info.In { + case "header": + req.Header.Set(info.APIKeyName, envMap[envNames[0]]) + case "query": + v := url.Values{} + v.Add(info.APIKeyName, envMap[envNames[0]]) + req.URL.RawQuery = v.Encode() + case "cookie": + req.AddCookie(&http.Cookie{ + Name: info.APIKeyName, + Value: envMap[envNames[0]], + }) + } + case "http": + switch info.Scheme { + case "bearer": + req.Header.Set("Authorization", "Bearer "+envMap[envNames[0]]) + case "basic": + req.SetBasicAuth(envMap[envNames[0]], envMap[envNames[1]]) + } + } + } + return nil + } + + return fmt.Errorf("did not find the needed environment variables for any of the security options. "+ + "At least one of these sets of environment variables must be provided: %v", missingVariables) +} + +// HandlePathParameters extracts each path parameter from the input JSON and replaces its placeholder in the URL path. +func HandlePathParameters(path string, params []Parameter, input string) string { + for _, param := range params { + res := gjson.Get(input, param.Name) + if res.Exists() { + // If it's an array or object, handle the serialization style + if res.IsArray() { + switch param.Style { + case "simple", "": // simple is the default style for path parameters + // simple looks the same regardless of whether explode is true + strs := make([]string, len(res.Array())) + for i, item := range res.Array() { + strs[i] = item.String() + } + path = strings.Replace(path, "{"+param.Name+"}", strings.Join(strs, ","), 1) + case "label": + strs := make([]string, len(res.Array())) + for i, item := range res.Array() { + strs[i] = item.String() + } + + if param.Explode == nil || !*param.Explode { // default is to not explode + path = strings.Replace(path, "{"+param.Name+"}", "."+strings.Join(strs, ","), 1) + } else { + path = strings.Replace(path, "{"+param.Name+"}", "."+strings.Join(strs, "."), 1) + } + case "matrix": + strs := make([]string, len(res.Array())) + for i, item := range res.Array() { + strs[i] = item.String() + } + + if param.Explode == nil || !*param.Explode { // default is to not explode + path = strings.Replace(path, "{"+param.Name+"}", ";"+param.Name+"="+strings.Join(strs, ","), 1) + } else { + s := "" + for _, str := range strs { + s += ";" + param.Name + "=" + str + } + path = strings.Replace(path, "{"+param.Name+"}", s, 1) + } + } + } else if res.IsObject() { + switch param.Style { + case "simple", "": + if param.Explode == nil || !*param.Explode { // default is to not explode + var strs []string + for k, v := range res.Map() { + strs = append(strs, k, v.String()) + } + path = strings.Replace(path, "{"+param.Name+"}", strings.Join(strs, ","), 1) + } else { + var strs []string + for k, v := range res.Map() { + strs = append(strs, k+"="+v.String()) + } + path = strings.Replace(path, "{"+param.Name+"}", strings.Join(strs, ","), 1) + } + case "label": + if param.Explode == nil || !*param.Explode { // default is to not explode + var strs []string + for k, v := range res.Map() { + strs = append(strs, k, v.String()) + } + path = strings.Replace(path, "{"+param.Name+"}", "."+strings.Join(strs, ","), 1) + } else { + s := "" + for k, v := range res.Map() { + s += "." + k + "=" + v.String() + } + path = strings.Replace(path, "{"+param.Name+"}", s, 1) + } + case "matrix": + if param.Explode == nil || !*param.Explode { // default is to not explode + var strs []string + for k, v := range res.Map() { + strs = append(strs, k, v.String()) + } + path = strings.Replace(path, "{"+param.Name+"}", ";"+param.Name+"="+strings.Join(strs, ","), 1) + } else { + s := "" + for k, v := range res.Map() { + s += ";" + k + "=" + v.String() + } + path = strings.Replace(path, "{"+param.Name+"}", s, 1) + } + } + } else { + // Serialization is handled slightly differently even for basic types. + // Explode doesn't do anything though. + switch param.Style { + case "simple", "": + path = strings.Replace(path, "{"+param.Name+"}", res.String(), 1) + case "label": + path = strings.Replace(path, "{"+param.Name+"}", "."+res.String(), 1) + case "matrix": + path = strings.Replace(path, "{"+param.Name+"}", ";"+param.Name+"="+res.String(), 1) + } + } + } + } + return path +} + +// HandleQueryParameters extracts each query parameter from the input JSON and adds it to the URL query. +func HandleQueryParameters(q url.Values, params []Parameter, input string) url.Values { + for _, param := range params { + res := gjson.Get(input, param.Name) + if res.Exists() { + // If it's an array or object, handle the serialization style + if res.IsArray() { + switch param.Style { + case "form", "": // form is the default style for query parameters + if param.Explode == nil || *param.Explode { // default is to explode + for _, item := range res.Array() { + q.Add(param.Name, item.String()) + } + } else { + var strs []string + for _, item := range res.Array() { + strs = append(strs, item.String()) + } + q.Add(param.Name, strings.Join(strs, ",")) + } + case "spaceDelimited": + if param.Explode == nil || *param.Explode { + for _, item := range res.Array() { + q.Add(param.Name, item.String()) + } + } else { + var strs []string + for _, item := range res.Array() { + strs = append(strs, item.String()) + } + q.Add(param.Name, strings.Join(strs, " ")) + } + case "pipeDelimited": + if param.Explode == nil || *param.Explode { + for _, item := range res.Array() { + q.Add(param.Name, item.String()) + } + } else { + var strs []string + for _, item := range res.Array() { + strs = append(strs, item.String()) + } + q.Add(param.Name, strings.Join(strs, "|")) + } + } + } else if res.IsObject() { + switch param.Style { + case "form", "": // form is the default style for query parameters + if param.Explode == nil || *param.Explode { // default is to explode + for k, v := range res.Map() { + q.Add(k, v.String()) + } + } else { + var strs []string + for k, v := range res.Map() { + strs = append(strs, k, v.String()) + } + q.Add(param.Name, strings.Join(strs, ",")) + } + case "deepObject": + for k, v := range res.Map() { + q.Add(param.Name+"["+k+"]", v.String()) + } + } + } else { + q.Add(param.Name, res.String()) + } + } + } + return q +} + +// HandleHeaderParameters extracts each header parameter from the input JSON and adds it to the request headers. +func HandleHeaderParameters(req *http.Request, params []Parameter, input string) { + for _, param := range params { + res := gjson.Get(input, param.Name) + if res.Exists() { + if res.IsArray() { + strs := make([]string, len(res.Array())) + for i, item := range res.Array() { + strs[i] = item.String() + } + req.Header.Add(param.Name, strings.Join(strs, ",")) + } else if res.IsObject() { + // Handle explosion + var strs []string + if param.Explode == nil || !*param.Explode { // default is to not explode + for k, v := range res.Map() { + strs = append(strs, k, v.String()) + } + } else { + for k, v := range res.Map() { + strs = append(strs, k+"="+v.String()) + } + } + req.Header.Add(param.Name, strings.Join(strs, ",")) + } else { // basic type + req.Header.Add(param.Name, res.String()) + } + } + } +} + +// HandleCookieParameters extracts each cookie parameter from the input JSON and adds it to the request cookies. +func HandleCookieParameters(req *http.Request, params []Parameter, input string) { + for _, param := range params { + res := gjson.Get(input, param.Name) + if res.Exists() { + if res.IsArray() { + strs := make([]string, len(res.Array())) + for i, item := range res.Array() { + strs[i] = item.String() + } + req.AddCookie(&http.Cookie{ + Name: param.Name, + Value: strings.Join(strs, ","), + }) + } else if res.IsObject() { + var strs []string + for k, v := range res.Map() { + strs = append(strs, k, v.String()) + } + req.AddCookie(&http.Cookie{ + Name: param.Name, + Value: strings.Join(strs, ","), + }) + } else { // basic type + req.AddCookie(&http.Cookie{ + Name: param.Name, + Value: res.String(), + }) + } + } + } +} diff --git a/pkg/openapi/security.go b/pkg/openapi/security.go new file mode 100644 index 00000000..dd4521fc --- /dev/null +++ b/pkg/openapi/security.go @@ -0,0 +1,56 @@ +package openapi + +import ( + "fmt" + "strings" + + "github.com/gptscript-ai/gptscript/pkg/env" +) + +// A SecurityInfo represents a security scheme in OpenAPI. +type SecurityInfo struct { + Name string `json:"name"` // name as defined in the security schemes + Type string `json:"type"` // http or apiKey + Scheme string `json:"scheme"` // bearer or basic, for type==http + APIKeyName string `json:"apiKeyName"` // name of the API key, for type==apiKey + In string `json:"in"` // header, query, or cookie, for type==apiKey +} + +func (i SecurityInfo) GetCredentialToolStrings(hostname string) []string { + vars := i.getCredentialNamesAndEnvVars(hostname) + var tools []string + + for cred, v := range vars { + field := "value" + switch i.Type { + case "apiKey": + field = i.APIKeyName + case "http": + if i.Scheme == "bearer" { + field = "bearer token" + } else { + if strings.Contains(v, "PASSWORD") { + field = "password" + } else { + field = "username" + } + } + } + + tools = append(tools, fmt.Sprintf("github.com/gptscript-ai/credential as %s with %s as env and %q as message and %q as field", + cred, v, "Please provide a value for the "+v+" environment variable", field)) + } + return tools +} + +func (i SecurityInfo) getCredentialNamesAndEnvVars(hostname string) map[string]string { + if i.Type == "http" && i.Scheme == "basic" { + return map[string]string{ + hostname + i.Name + "Username": "GPTSCRIPT_" + env.ToEnvLike(hostname) + "_" + env.ToEnvLike(i.Name) + "_USERNAME", + hostname + i.Name + "Password": "GPTSCRIPT_" + env.ToEnvLike(hostname) + "_" + env.ToEnvLike(i.Name) + "_PASSWORD", + } + } + return map[string]string{ + hostname + i.Name: "GPTSCRIPT_" + env.ToEnvLike(hostname) + "_" + env.ToEnvLike(i.Name), + } +}