diff --git a/pkg/cmd/ollama/ollama.go b/pkg/cmd/ollama/ollama.go index d36934b4..aac0a72c 100644 --- a/pkg/cmd/ollama/ollama.go +++ b/pkg/cmd/ollama/ollama.go @@ -22,7 +22,7 @@ import ( var ( ollamaLong = "Start an AI/ML model workspace with specified model types" ollamaExample = ` - brev ollama --model llama2 + brev ollama --model llama3 brev ollama --model mistral7b ` modelTypes = []string{"llama2", "llama3", "mistral7b"} @@ -38,6 +38,7 @@ type OllamaStore interface { CreateWorkspace(organizationID string, options *store.CreateWorkspacesOptions) (*entity.Workspace, error) GetWorkspace(workspaceID string) (*entity.Workspace, error) BuildVerbContainer(workspaceID string, verbYaml string) (*store.BuildVerbRes, error) + ModifyPublicity(workspace *entity.Workspace, publicity bool) (*entity.Tunnel, error) } func validateModelType(modelType string) bool { @@ -120,6 +121,7 @@ func runOllamaWorkspace(t *terminal.Terminal, model string, ollamaStore OllamaSt w := &entity.Workspace{} workspaceCh := make(chan *entity.Workspace) + s := t.NewSpinner() go func() { w, err = ollamaStore.CreateWorkspace(org.ID, cwOptions) @@ -131,7 +133,6 @@ func runOllamaWorkspace(t *terminal.Terminal, model string, ollamaStore OllamaSt workspaceCh <- w }() - s := t.NewSpinner() s.Suffix = " Creating your workspace. Hang tight 🤙" s.Start() @@ -166,6 +167,7 @@ func runOllamaWorkspace(t *terminal.Terminal, model string, ollamaStore OllamaSt verbCh <- lf }() + s.Start() s.Suffix = " Building your verb container. Hang tight 🤙" lf = <-verbCh @@ -186,16 +188,25 @@ func runOllamaWorkspace(t *terminal.Terminal, model string, ollamaStore OllamaSt s.Stop() + link, err := getOllamaTunnelLink(w, ollamaStore) + if err != nil { + return breverrors.WrapAndTrace(err) + } + + _, err = makeTunnelPublic(w, ollamaStore) + if err != nil { + return breverrors.WrapAndTrace(err) + } + fmt.Print("\n") t.Vprint(t.Green("Your AI/ML workspace is ready!\n")) - displayConnectBreadCrumb(t, w) + displayOllamaConnectBreadCrumb(t, link) return nil } -func displayConnectBreadCrumb(t *terminal.Terminal, workspace *entity.Workspace) { - t.Vprintf(t.Green("Connect to the Ollama server:\n")) - t.Vprintf(t.Yellow(fmt.Sprintf("\tbrev open %s\t# brev open -> open workspace in VS Code\n", workspace.Name))) - t.Vprintf(t.Yellow(fmt.Sprintf("\tbrev shell %s\t# brev shell -> ssh into workspace (shortcut)\n", workspace.Name))) +func displayOllamaConnectBreadCrumb(t *terminal.Terminal, link string) { + t.Vprintf(t.Green("Query the Ollama API with the following command:\n")) + t.Vprintf(t.Yellow(fmt.Sprintf("curl %s/api/chat -d '{\n \"model\": \"llama3\",\n \"messages\": [\n {\n \"role\": \"user\",\n \"content\": \"why is the sky blue?\"\n }\n ]\n}'", link))) } func pollInstanceUntilVMReady(workspace *entity.Workspace, interval time.Duration, timeout time.Duration, ollamaStore OllamaStore) (bool, error) { @@ -203,9 +214,6 @@ func pollInstanceUntilVMReady(workspace *entity.Workspace, interval time.Duratio for elapsedTime < timeout { w, err := ollamaStore.GetWorkspace(workspace.ID) - fmt.Println(workspace.ID) - fmt.Println(w.ID) - fmt.Println(w.Status) if err != nil { return false, breverrors.WrapAndTrace(err) } else if w.Status == "RUNNING" { @@ -224,9 +232,6 @@ func pollInstanceUntilVerbContainerReady(workspace *entity.Workspace, interval t for elapsedTime < timeout { w, err := ollamaStore.GetWorkspace(workspace.ID) - fmt.Println(workspace.ID) - fmt.Println(w.ID) - fmt.Println(w.VerbBuildStatus) if err != nil { return false, breverrors.WrapAndTrace(err) } else if w.VerbBuildStatus == entity.Completed { @@ -238,8 +243,35 @@ func pollInstanceUntilVerbContainerReady(workspace *entity.Workspace, interval t return false, breverrors.New("timeout waiting for instance to start") } -func generateCloudflareAPIKeys(workspace *entity.Workspace, ollamaStore OllamaStore) (bool, error) { +func getOllamaTunnelLink(workspace *entity.Workspace, ollamaStore OllamaStore) (string, error) { + w, err := ollamaStore.GetWorkspace(workspace.ID) + if err != nil { + return "", breverrors.WrapAndTrace(err) + } + for _, v := range w.Tunnel.Applications { + if v.Port == 11434 { + return v.Hostname, nil + } + } + return "", breverrors.New("Could not find Ollama tunnel") } +// TODO: stubs for granular permissioning +// func generateCloudflareAPIKeys(workspace *entity.Workspace, ollamaStore OllamaStore) (bool, error) { +// return false, nil +// } + func makeTunnelPublic(workspace *entity.Workspace, ollamaStore OllamaStore) (bool, error) { + t, err := ollamaStore.ModifyPublicity(workspace, true) + if err != nil { + return false, breverrors.WrapAndTrace(err) + } + for _, v := range t.Applications { + if v.Port == 11434 { + if v.Policy.AllowEveryone { + return true, nil + } + } + } + return false, breverrors.New("Could not find Ollama tunnel") } diff --git a/pkg/entity/entity.go b/pkg/entity/entity.go index ee5b0ea3..45f9ab23 100644 --- a/pkg/entity/entity.go +++ b/pkg/entity/entity.go @@ -336,10 +336,10 @@ type Policy struct { } type Tunnel struct { - TunnelID string `json:"tunnelID"` - Applications []Application `json:"applications"` - TunnelSetupBash string `json:"tunnelSetupBash"` - TunnelStatus string `json:"tunnelStatus"` + TunnelID string `json:"tunnelID"` + Applications []CFApplication `json:"applications"` + TunnelSetupBash string `json:"tunnelSetupBash"` + TunnelStatus string `json:"tunnelStatus"` } // TODO Change this to Application. Theres an older application struct that should be removed diff --git a/pkg/store/workspace.go b/pkg/store/workspace.go index 7b1eae2b..2b7cc594 100644 --- a/pkg/store/workspace.go +++ b/pkg/store/workspace.go @@ -616,15 +616,9 @@ func (f FileStore) GetSetupScriptPath() string { return setupScriptPath } -// modifyPublicity: (wsId: string, args: ModifyApplicationPublicityRequest) => -// api.post( -// `/api/applications/modifypublicity/${wsId}`, -// args -// ) as Promise, - var ( modifyCloudflareAccessPattern = "api/applications/modifypublicity/%s" - modifyCloudflareAccessPath = fmt.Sprint(modifyCloudflareAccessPattern, fmt.Sprintf("{%s}", workspaceIDParamName)) + modifyCloudflareAccessPath = fmt.Sprintf(modifyCloudflareAccessPattern, fmt.Sprintf("{%s}", workspaceIDParamName)) ) type ModifyApplicationPublicityRequest struct {