Skip to content

Commit

Permalink
chore: sdkserver: update dataset methods for the rewrite
Browse files Browse the repository at this point in the history
Signed-off-by: Grant Linville <[email protected]>
  • Loading branch information
g-linville committed Nov 4, 2024
1 parent 4ce687f commit 6515d93
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 162 deletions.
209 changes: 49 additions & 160 deletions pkg/sdkserver/datasets.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,13 @@ import (
)

type datasetRequest struct {
Input string `json:"input"`
WorkspaceID string `json:"workspaceID"`
DatasetToolRepo string `json:"datasetToolRepo"`
Env []string `json:"env"`
Input string `json:"input"`
DatasetTool string `json:"datasetTool"`
Env []string `json:"env"`
}

func (r datasetRequest) validate(requireInput bool) error {
if r.WorkspaceID == "" {
return fmt.Errorf("workspaceID is required")
} else if requireInput && r.Input == "" {
func (r datasetRequest) validate() error {
if r.Input == "" {
return fmt.Errorf("input is required")
} else if len(r.Env) == 0 {
return fmt.Errorf("env is required")
Expand All @@ -30,72 +27,32 @@ func (r datasetRequest) validate(requireInput bool) error {

func (r datasetRequest) opts(o gptscript.Options) gptscript.Options {
opts := gptscript.Options{
Cache: o.Cache,
Monitor: o.Monitor,
Runner: o.Runner,
Workspace: r.WorkspaceID,
Cache: o.Cache,
Monitor: o.Monitor,
Runner: o.Runner,
}
return opts
}

func (r datasetRequest) getToolRepo() string {
if r.DatasetToolRepo != "" {
return r.DatasetToolRepo
if r.DatasetTool != "" {
return r.DatasetTool
}
return "github.com/otto8-ai/datasets"
return "github.com/g-linville/datasets@rewrite-as-daemon"
}

func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) {
logger := gcontext.GetLogger(r.Context())

var req datasetRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %w", err))
return
}

if err := req.validate(false); err != nil {
writeError(logger, w, http.StatusBadRequest, err)
return
}

g, err := gptscript.New(r.Context(), req.opts(s.gptscriptOpts))
if err != nil {
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to initialize gptscript: %w", err))
return
}

prg, err := loader.Program(r.Context(), req.getToolRepo(), "List Datasets", loader.Options{
Cache: g.Cache,
})

if err != nil {
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err))
return
}

result, err := g.Run(r.Context(), prg, req.Env, req.Input)
if err != nil {
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err))
return
}

writeResponse(logger, w, map[string]any{"stdout": result})
}

type createDatasetArgs struct {
Name string `json:"datasetName"`
Description string `json:"datasetDescription"`
type listDatasetsArgs struct {
WorkspaceID string `json:"workspaceID"`
}

func (a createDatasetArgs) validate() error {
if a.Name == "" {
return fmt.Errorf("datasetName is required")
func (a listDatasetsArgs) validate() error {
if a.WorkspaceID == "" {
return fmt.Errorf("workspaceID is required")
}
return nil
}

func (s *server) createDataset(w http.ResponseWriter, r *http.Request) {
func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) {
logger := gcontext.GetLogger(r.Context())

var req datasetRequest
Expand All @@ -104,7 +61,7 @@ func (s *server) createDataset(w http.ResponseWriter, r *http.Request) {
return
}

if err := req.validate(true); err != nil {
if err := req.validate(); err != nil {
writeError(logger, w, http.StatusBadRequest, err)
return
}
Expand All @@ -115,7 +72,7 @@ func (s *server) createDataset(w http.ResponseWriter, r *http.Request) {
return
}

var args createDatasetArgs
var args listDatasetsArgs
if err := json.Unmarshal([]byte(req.Input), &args); err != nil {
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to unmarshal input: %w", err))
return
Expand All @@ -126,7 +83,7 @@ func (s *server) createDataset(w http.ResponseWriter, r *http.Request) {
return
}

prg, err := loader.Program(r.Context(), req.getToolRepo(), "Create Dataset", loader.Options{
prg, err := loader.Program(r.Context(), req.getToolRepo(), "List Datasets", loader.Options{
Cache: g.Cache,
})

Expand All @@ -144,88 +101,21 @@ func (s *server) createDataset(w http.ResponseWriter, r *http.Request) {
writeResponse(logger, w, map[string]any{"stdout": result})
}

type addDatasetElementArgs struct {
DatasetID string `json:"datasetID"`
ElementName string `json:"elementName"`
ElementDescription string `json:"elementDescription"`
ElementContent string `json:"elementContent"`
}

func (a addDatasetElementArgs) validate() error {
if a.DatasetID == "" {
return fmt.Errorf("datasetID is required")
}
if a.ElementName == "" {
return fmt.Errorf("elementName is required")
}
if a.ElementContent == "" {
return fmt.Errorf("elementContent is required")
}
return nil
}

func (s *server) addDatasetElement(w http.ResponseWriter, r *http.Request) {
logger := gcontext.GetLogger(r.Context())

var req datasetRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %w", err))
return
}

if err := req.validate(true); err != nil {
writeError(logger, w, http.StatusBadRequest, err)
return
}

g, err := gptscript.New(r.Context(), req.opts(s.gptscriptOpts))
if err != nil {
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to initialize gptscript: %w", err))
return
}

var args addDatasetElementArgs
if err := json.Unmarshal([]byte(req.Input), &args); err != nil {
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to unmarshal input: %w", err))
return
}

if err := args.validate(); err != nil {
writeError(logger, w, http.StatusBadRequest, err)
return
}

prg, err := loader.Program(r.Context(), req.getToolRepo(), "Add Element", loader.Options{
Cache: g.Cache,
})
if err != nil {
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err))
return
}

result, err := g.Run(r.Context(), prg, req.Env, req.Input)
if err != nil {
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err))
return
}

writeResponse(logger, w, map[string]any{"stdout": result})
}

type addDatasetElementsArgs struct {
DatasetID string `json:"datasetID"`
Elements []struct {
Name string `json:"name"`
Description string `json:"description"`
Contents string `json:"contents"`
}
WorkspaceID string `json:"workspaceID"`
DatasetID string `json:"datasetID"`
Elements []struct {
Name string `json:"name"`
Description string `json:"description"`
Contents string `json:"contents"`
BinaryContents []byte `json:"binaryContents"`
} `json:"elements"`
}

func (a addDatasetElementsArgs) validate() error {
if a.DatasetID == "" {
return fmt.Errorf("datasetID is required")
}
if len(a.Elements) == 0 {
if a.WorkspaceID == "" {
return fmt.Errorf("workspaceID is required")
} else if len(a.Elements) == 0 {
return fmt.Errorf("elements is required")
}
return nil
Expand All @@ -240,7 +130,7 @@ func (s *server) addDatasetElements(w http.ResponseWriter, r *http.Request) {
return
}

if err := req.validate(true); err != nil {
if err := req.validate(); err != nil {
writeError(logger, w, http.StatusBadRequest, err)
return
}
Expand Down Expand Up @@ -270,13 +160,7 @@ func (s *server) addDatasetElements(w http.ResponseWriter, r *http.Request) {
return
}

elementsJSON, err := json.Marshal(args.Elements)
if err != nil {
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to marshal elements: %w", err))
return
}

result, err := g.Run(r.Context(), prg, req.Env, fmt.Sprintf(`{"datasetID":%q, "elements":%q}`, args.DatasetID, string(elementsJSON)))
result, err := g.Run(r.Context(), prg, req.Env, req.Input)
if err != nil {
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err))
return
Expand All @@ -286,11 +170,14 @@ func (s *server) addDatasetElements(w http.ResponseWriter, r *http.Request) {
}

type listDatasetElementsArgs struct {
DatasetID string `json:"datasetID"`
WorkspaceID string `json:"workspaceID"`
DatasetID string `json:"datasetID"`
}

func (a listDatasetElementsArgs) validate() error {
if a.DatasetID == "" {
if a.WorkspaceID == "" {
return fmt.Errorf("workspaceID is required")
} else if a.DatasetID == "" {
return fmt.Errorf("datasetID is required")
}
return nil
Expand All @@ -305,7 +192,7 @@ func (s *server) listDatasetElements(w http.ResponseWriter, r *http.Request) {
return
}

if err := req.validate(true); err != nil {
if err := req.validate(); err != nil {
writeError(logger, w, http.StatusBadRequest, err)
return
}
Expand Down Expand Up @@ -345,16 +232,18 @@ func (s *server) listDatasetElements(w http.ResponseWriter, r *http.Request) {
}

type getDatasetElementArgs struct {
DatasetID string `json:"datasetID"`
Element string `json:"element"`
WorkspaceID string `json:"workspaceID"`
DatasetID string `json:"datasetID"`
Name string `json:"name"`
}

func (a getDatasetElementArgs) validate() error {
if a.DatasetID == "" {
if a.WorkspaceID == "" {
return fmt.Errorf("workspaceID is required")
} else if a.DatasetID == "" {
return fmt.Errorf("datasetID is required")
}
if a.Element == "" {
return fmt.Errorf("element is required")
} else if a.Name == "" {
return fmt.Errorf("name is required")
}
return nil
}
Expand All @@ -368,7 +257,7 @@ func (s *server) getDatasetElement(w http.ResponseWriter, r *http.Request) {
return
}

if err := req.validate(true); err != nil {
if err := req.validate(); err != nil {
writeError(logger, w, http.StatusBadRequest, err)
return
}
Expand All @@ -390,7 +279,7 @@ func (s *server) getDatasetElement(w http.ResponseWriter, r *http.Request) {
return
}

prg, err := loader.Program(r.Context(), req.getToolRepo(), "Get Element SDK", loader.Options{
prg, err := loader.Program(r.Context(), req.getToolRepo(), "Get Element", loader.Options{
Cache: g.Cache,
})
if err != nil {
Expand Down
2 changes: 0 additions & 2 deletions pkg/sdkserver/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,8 @@ func (s *server) addRoutes(mux *http.ServeMux) {
mux.HandleFunc("POST /credentials/delete", s.deleteCredential)

mux.HandleFunc("POST /datasets", s.listDatasets)
mux.HandleFunc("POST /datasets/create", s.createDataset)
mux.HandleFunc("POST /datasets/list-elements", s.listDatasetElements)
mux.HandleFunc("POST /datasets/get-element", s.getDatasetElement)
mux.HandleFunc("POST /datasets/add-element", s.addDatasetElement)
mux.HandleFunc("POST /datasets/add-elements", s.addDatasetElements)

mux.HandleFunc("POST /workspaces/create", s.createWorkspace)
Expand Down

0 comments on commit 6515d93

Please sign in to comment.