Skip to content

Commit

Permalink
workspaceID is no longer part of tool input
Browse files Browse the repository at this point in the history
Signed-off-by: Grant Linville <[email protected]>
  • Loading branch information
g-linville committed Nov 5, 2024
1 parent 5dab330 commit 0bb7bbf
Showing 1 changed file with 21 additions and 48 deletions.
69 changes: 21 additions & 48 deletions pkg/sdkserver/datasets.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,16 @@ import (

type datasetRequest struct {
Input string `json:"input"`
WorkspaceID string `json:"workspaceID"`
DatasetTool string `json:"datasetTool"`
Env []string `json:"env"`
}

func (r datasetRequest) validate() error {
if r.Input == "" {
func (r datasetRequest) validate(requireInput bool) error {
if requireInput && r.Input == "" {
return fmt.Errorf("input is required")
} else if r.WorkspaceID == "" {
return fmt.Errorf("workspaceID is required")
} else if len(r.Env) == 0 {
return fmt.Errorf("env is required")
}
Expand All @@ -27,9 +30,10 @@ func (r datasetRequest) validate() error {

func (r datasetRequest) opts(o gptscript.Options) gptscript.Options {
opts := gptscript.Options{
Cache: o.Cache,
Monitor: o.Monitor,
Runner: o.Runner,
Cache: o.Cache,
Monitor: o.Monitor,
Runner: o.Runner,
Workspace: r.WorkspaceID,
}
return opts
}
Expand All @@ -41,17 +45,6 @@ func (r datasetRequest) getToolRepo() string {
return "github.com/otto8-ai/datasets"
}

type listDatasetsArgs struct {
WorkspaceID string `json:"workspaceID"`
}

func (a listDatasetsArgs) validate() error {
if a.WorkspaceID == "" {
return fmt.Errorf("workspaceID is required")
}
return nil
}

func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) {
logger := gcontext.GetLogger(r.Context())

Expand All @@ -61,7 +54,7 @@ func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) {
return
}

if err := req.validate(); err != nil {
if err := req.validate(false); err != nil {
writeError(logger, w, http.StatusBadRequest, err)
return
}
Expand All @@ -72,17 +65,6 @@ func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) {
return
}

var args listDatasetsArgs
if err := json.Unmarshal([]byte(req.Input), &args); err != nil {
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to unmarshal input: %w", err))
return
}

if err := args.validate(); err != nil {
writeError(logger, w, http.StatusBadRequest, err)
return
}

prg, err := loader.Program(r.Context(), req.getToolRepo(), "List Datasets", loader.Options{
Cache: g.Cache,
})
Expand All @@ -102,9 +84,8 @@ func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) {
}

type addDatasetElementsArgs struct {
WorkspaceID string `json:"workspaceID"`
DatasetID string `json:"datasetID"`
Elements []struct {
DatasetID string `json:"datasetID"`
Elements []struct {
Name string `json:"name"`
Description string `json:"description"`
Contents string `json:"contents"`
Expand All @@ -113,9 +94,7 @@ type addDatasetElementsArgs struct {
}

func (a addDatasetElementsArgs) validate() error {
if a.WorkspaceID == "" {
return fmt.Errorf("workspaceID is required")
} else if len(a.Elements) == 0 {
if len(a.Elements) == 0 {
return fmt.Errorf("elements is required")
}
return nil
Expand All @@ -130,7 +109,7 @@ func (s *server) addDatasetElements(w http.ResponseWriter, r *http.Request) {
return
}

if err := req.validate(); err != nil {
if err := req.validate(true); err != nil {
writeError(logger, w, http.StatusBadRequest, err)
return
}
Expand Down Expand Up @@ -170,14 +149,11 @@ func (s *server) addDatasetElements(w http.ResponseWriter, r *http.Request) {
}

type listDatasetElementsArgs struct {
WorkspaceID string `json:"workspaceID"`
DatasetID string `json:"datasetID"`
DatasetID string `json:"datasetID"`
}

func (a listDatasetElementsArgs) validate() error {
if a.WorkspaceID == "" {
return fmt.Errorf("workspaceID is required")
} else if a.DatasetID == "" {
if a.DatasetID == "" {
return fmt.Errorf("datasetID is required")
}
return nil
Expand All @@ -192,7 +168,7 @@ func (s *server) listDatasetElements(w http.ResponseWriter, r *http.Request) {
return
}

if err := req.validate(); err != nil {
if err := req.validate(true); err != nil {
writeError(logger, w, http.StatusBadRequest, err)
return
}
Expand Down Expand Up @@ -232,15 +208,12 @@ func (s *server) listDatasetElements(w http.ResponseWriter, r *http.Request) {
}

type getDatasetElementArgs struct {
WorkspaceID string `json:"workspaceID"`
DatasetID string `json:"datasetID"`
Name string `json:"name"`
DatasetID string `json:"datasetID"`
Name string `json:"name"`
}

func (a getDatasetElementArgs) validate() error {
if a.WorkspaceID == "" {
return fmt.Errorf("workspaceID is required")
} else if a.DatasetID == "" {
if a.DatasetID == "" {
return fmt.Errorf("datasetID is required")
} else if a.Name == "" {
return fmt.Errorf("name is required")
Expand All @@ -257,7 +230,7 @@ func (s *server) getDatasetElement(w http.ResponseWriter, r *http.Request) {
return
}

if err := req.validate(); err != nil {
if err := req.validate(true); err != nil {
writeError(logger, w, http.StatusBadRequest, err)
return
}
Expand Down

0 comments on commit 0bb7bbf

Please sign in to comment.