Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ollama: Enable specification of GPU types for Ollama instance types #193

Merged
merged 2 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions pkg/cmd/ollama/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/brevdev/brev-cli/pkg/collections"
"github.com/brevdev/brev-cli/pkg/config"
"github.com/brevdev/brev-cli/pkg/entity"
"github.com/brevdev/brev-cli/pkg/instancetypes"
"github.com/brevdev/brev-cli/pkg/store"
"github.com/brevdev/brev-cli/pkg/terminal"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -69,6 +70,7 @@ func validateModelType(input string) (bool, error) {

func NewCmdOllama(t *terminal.Terminal, ollamaStore OllamaStore) *cobra.Command {
var model string
var gpu string

cmd := &cobra.Command{
Use: "ollama",
Expand All @@ -91,10 +93,20 @@ func NewCmdOllama(t *terminal.Terminal, ollamaStore OllamaStore) *cobra.Command
if !isValid {
return fmt.Errorf("invalid model type: %s", model)
}
if gpu != "" {
isValid := instancetypes.ValidateInstanceType(gpu)
if !isValid {
err := fmt.Errorf("invalid GPU instance type: %s, see https://brev.dev/docs/reference/gpu for a list of valid GPU instance types", gpu)
return breverrors.WrapAndTrace(err)
}
}

// Start the network call in a goroutine
res := collections.Async(func() (any, error) {
err := runOllamaWorkspace(t, model, ollamaStore)
err := runOllamaWorkspace(t, RunOptions{
Model: model,
GPUType: gpu,
}, ollamaStore)
if err != nil {
return nil, breverrors.WrapAndTrace(err)
}
Expand All @@ -114,10 +126,16 @@ func NewCmdOllama(t *terminal.Terminal, ollamaStore OllamaStore) *cobra.Command
},
}
cmd.Flags().StringVarP(&model, "model", "m", "", "AI/ML model type (e.g., llama2, llama3, mistral7b)")
cmd.Flags().StringVarP(&gpu, "gpu", "g", "g5.xlarge", "GPU instance type. See https://brev.dev/docs/reference/gpu for details")
return cmd
}

func runOllamaWorkspace(t *terminal.Terminal, model string, ollamaStore OllamaStore) error { //nolint:funlen, gocyclo // todo
type RunOptions struct {
Model string
GPUType string
}

func runOllamaWorkspace(t *terminal.Terminal, opts RunOptions, ollamaStore OllamaStore) error { //nolint:funlen, gocyclo // todo
_, err := ollamaStore.GetCurrentUser()
if err != nil {
return breverrors.WrapAndTrace(err)
Expand All @@ -128,14 +146,13 @@ func runOllamaWorkspace(t *terminal.Terminal, model string, ollamaStore OllamaSt
return breverrors.WrapAndTrace(err)
}

// Placeholder for instance type, to be updated later
instanceType := "g4dn.xlarge"
instanceType := opts.GPUType
clusterID := config.GlobalConfig.GetDefaultClusterID()
uuid := uuid.New().String()
instanceName := fmt.Sprintf("ollama-%s", uuid)
cwOptions := store.NewCreateWorkspacesOptions(clusterID, instanceName).WithInstanceType(instanceType)

hello.TypeItToMeUnskippable27(fmt.Sprintf("Creating Ollama server %s with model %s in org %s\n", t.Green(cwOptions.Name), t.Green(model), t.Green(org.ID)))
hello.TypeItToMeUnskippable27(fmt.Sprintf("Creating Ollama server %s with model %s in org %s\n", t.Green(cwOptions.Name), t.Green(opts.Model), t.Green(org.ID)))

s := t.NewSpinner()

Expand Down Expand Up @@ -223,17 +240,17 @@ func runOllamaWorkspace(t *terminal.Terminal, model string, ollamaStore OllamaSt
s.Suffix = "Pulling the %s model, just a bit more! 🏄"

// shell in and run ollama pull:
if err := runSSHExec(instanceName, []string{"ollama", "pull", model}, false); err != nil {
if err := runSSHExec(instanceName, []string{"ollama", "pull", opts.Model}, false); err != nil {
return breverrors.WrapAndTrace(err)
}
if err := runSSHExec(instanceName, []string{"ollama", "run", model, "hello world"}, true); err != nil {
if err := runSSHExec(instanceName, []string{"ollama", "run", opts.Model, "hello world"}, true); err != nil {
return breverrors.WrapAndTrace(err)
}
s.Stop()

fmt.Print("\n")
t.Vprint(t.Green("Ollama is ready to go!\n"))
displayOllamaConnectBreadCrumb(t, link, model)
displayOllamaConnectBreadCrumb(t, link, opts.Model)
return nil
}

Expand Down
19 changes: 3 additions & 16 deletions pkg/cmd/start/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/brevdev/brev-cli/pkg/config"
"github.com/brevdev/brev-cli/pkg/entity"
"github.com/brevdev/brev-cli/pkg/featureflag"
"github.com/brevdev/brev-cli/pkg/instancetypes"
"github.com/brevdev/brev-cli/pkg/mergeshells"
"github.com/brevdev/brev-cli/pkg/store"
"github.com/brevdev/brev-cli/pkg/terminal"
Expand All @@ -29,11 +30,6 @@ var (
brev start <git url>
brev start <git url> --org myFancyOrg
`
instanceTypes = []string{
"p4d.24xlarge", "p3.2xlarge", "p3.8xlarge", "p3.16xlarge", "p3dn.24xlarge", "p2.xlarge", "p2.8xlarge", "p2.16xlarge", "g5.xlarge", "g5.2xlarge", "g5.4xlarge", "g5.8xlarge", "g5.16xlarge", "g5.12xlarge", "g5.24xlarge", "g5.48xlarge", "g5g.xlarge", "g5g.2xlarge", "g5g.4xlarge", "g5g.8xlarge", "g5g.16xlarge", "g5g.metal", "g4dn.xlarge", "g4dn.2xlarge", "g4dn.4xlarge", "g4dn.8xlarge", "g4dn.16xlarge", "g4dn.12xlarge", "g4dn.metal", "g4ad.xlarge", "g4ad.2xlarge", "g4ad.4xlarge", "g4ad.8xlarge", "g4ad.16xlarge", "g3s.xlarge", "g3.4xlarge", "g3.8xlarge", "g3.16xlarge",
"g5g.xlarge", "g5g.2xlarge", "g5g.4xlarge", "g5g.8xlarge", "g5g.16xlarge", "g5g.metal", "g4dn.xlarge", "g4dn.2xlarge", "g4dn.4xlarge", "g4dn.8xlarge", "g4dn.16xlarge", "g4dn.12xlarge", "g4dn.metal",
"n1-standard-1:nvidia-tesla-t4:1", "n1-highcpu-2:nvidia-tesla-t4:1", "n1-standard-2:nvidia-tesla-t4:1", "n1-highmem-2:nvidia-tesla-t4:1", "n1-highcpu-4:nvidia-tesla-t4:1", "n1-standard-4:nvidia-tesla-t4:1", "n1-highmem-4:nvidia-tesla-t4:1", "n1-highcpu-8:nvidia-tesla-t4:1", "n1-standard-1:nvidia-tesla-t4:2", "n1-highcpu-2:nvidia-tesla-t4:2", "n1-standard-8:nvidia-tesla-t4:1", "n1-standard-2:nvidia-tesla-t4:2", "n1-highmem-2:nvidia-tesla-t4:2", "n1-highcpu-4:nvidia-tesla-t4:2", "n1-highmem-8:nvidia-tesla-t4:1", "n1-standard-4:nvidia-tesla-t4:2", "n1-highmem-4:nvidia-tesla-t4:2", "n1-highcpu-16:nvidia-tesla-t4:1", "n1-highcpu-8:nvidia-tesla-t4:2", "n1-standard-8:nvidia-tesla-t4:2", "n1-standard-16:nvidia-tesla-t4:1", "n1-highmem-8:nvidia-tesla-t4:2", "n1-highcpu-16:nvidia-tesla-t4:2", "n1-standard-1:nvidia-tesla-t4:4", "n1-highcpu-2:nvidia-tesla-t4:4", "n1-highmem-16:nvidia-tesla-t4:1", "n1-standard-2:nvidia-tesla-t4:4", "n1-highmem-2:nvidia-tesla-t4:4", "n1-highcpu-4:nvidia-tesla-t4:4", "n1-standard-16:nvidia-tesla-t4:2", "n1-standard-4:nvidia-tesla-t4:4", "n1-highmem-4:nvidia-tesla-t4:4", "n1-highcpu-32:nvidia-tesla-t4:1", "n1-highcpu-8:nvidia-tesla-t4:4", "n1-highmem-16:nvidia-tesla-t4:2", "n1-standard-8:nvidia-tesla-t4:4", "n1-highmem-8:nvidia-tesla-t4:4", "n1-highcpu-32:nvidia-tesla-t4:2", "n1-highcpu-16:nvidia-tesla-t4:4", "n1-standard-32:nvidia-tesla-t4:1", "n1-standard-16:nvidia-tesla-t4:4", "n1-standard-32:nvidia-tesla-t4:2", "n1-highmem-16:nvidia-tesla-t4:4", "n1-highmem-32:nvidia-tesla-t4:1", "n1-highcpu-32:nvidia-tesla-t4:4", "n1-highmem-32:nvidia-tesla-t4:2", "n1-highcpu-64:nvidia-tesla-t4:1", "n1-standard-32:nvidia-tesla-t4:4", "n1-highcpu-64:nvidia-tesla-t4:2", "n1-highmem-32:nvidia-tesla-t4:4", "n1-standard-64:nvidia-tesla-t4:1", "n1-highcpu-64:nvidia-tesla-t4:4", "n1-standard-64:nvidia-tesla-t4:2", "n1-highcpu-96:nvidia-tesla-t4:1", "n1-highcpu-96:nvidia-tesla-t4:2", "n1-highmem-64:nvidia-tesla-t4:1", "n1-standard-64:nvidia-tesla-t4:4", "n1-highmem-64:nvidia-tesla-t4:2", "n1-highcpu-96:nvidia-tesla-t4:4", "n1-standard-96:nvidia-tesla-t4:1", "n1-highmem-64:nvidia-tesla-t4:4", "n1-standard-96:nvidia-tesla-t4:2", "n1-ultramem-40:nvidia-tesla-t4:1", "n1-standard-96:nvidia-tesla-t4:4", "n1-ultramem-40:nvidia-tesla-t4:2", "n1-highmem-96:nvidia-tesla-t4:1", "n1-highmem-96:nvidia-tesla-t4:2", "n1-ultramem-40:nvidia-tesla-t4:4", "n1-highmem-96:nvidia-tesla-t4:4", "n1-megamem-96:nvidia-tesla-t4:1", "n1-megamem-96:nvidia-tesla-t4:2", "n1-megamem-96:nvidia-tesla-t4:4", "n1-ultramem-80:nvidia-tesla-t4:1", "n1-ultramem-80:nvidia-tesla-t4:2", "n1-ultramem-80:nvidia-tesla-t4:4", "n1-ultramem-160:nvidia-tesla-t4:1", "n1-ultramem-160:nvidia-tesla-t4:2", "n1-ultramem-160:nvidia-tesla-t4:4",
}
)

type StartStore interface {
Expand All @@ -49,15 +45,6 @@ type StartStore interface {
GetFileAsString(path string) (string, error)
}

func validateInstanceType(instanceType string) bool {
for _, v := range instanceTypes {
if instanceType == v {
return true
}
}
return false
}

func NewCmdStart(t *terminal.Terminal, startStore StartStore, noLoginStartStore StartStore) *cobra.Command {
var org string
var name string
Expand All @@ -84,9 +71,9 @@ func NewCmdStart(t *terminal.Terminal, startStore StartStore, noLoginStartStore
}

if gpu != "" {
isValid := validateInstanceType(gpu)
isValid := instancetypes.ValidateInstanceType(gpu)
if !isValid {
err := fmt.Errorf("invalid GPU instance type: %s", gpu)
err := fmt.Errorf("invalid GPU instance type: %s, see https://brev.dev/docs/reference/gpu for a list of valid GPU instance types", gpu)
return breverrors.WrapAndTrace(err)
}
}
Expand Down
16 changes: 16 additions & 0 deletions pkg/instancetypes/instancetypes.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package instancetypes

var InstanceTypes = []string{
"p4d.24xlarge", "p3.2xlarge", "p3.8xlarge", "p3.16xlarge", "p3dn.24xlarge", "p2.xlarge", "p2.8xlarge", "p2.16xlarge", "g5.xlarge", "g5.2xlarge", "g5.4xlarge", "g5.8xlarge", "g5.16xlarge", "g5.12xlarge", "g5.24xlarge", "g5.48xlarge", "g5g.xlarge", "g5g.2xlarge", "g5g.4xlarge", "g5g.8xlarge", "g5g.16xlarge", "g5g.metal", "g4dn.xlarge", "g4dn.2xlarge", "g4dn.4xlarge", "g4dn.8xlarge", "g4dn.16xlarge", "g4dn.12xlarge", "g4dn.metal", "g4ad.xlarge", "g4ad.2xlarge", "g4ad.4xlarge", "g4ad.8xlarge", "g4ad.16xlarge", "g3s.xlarge", "g3.4xlarge", "g3.8xlarge", "g3.16xlarge",
"g5g.xlarge", "g5g.2xlarge", "g5g.4xlarge", "g5g.8xlarge", "g5g.16xlarge", "g5g.metal", "g4dn.xlarge", "g4dn.2xlarge", "g4dn.4xlarge", "g4dn.8xlarge", "g4dn.16xlarge", "g4dn.12xlarge", "g4dn.metal",
"n1-standard-1:nvidia-tesla-t4:1", "n1-highcpu-2:nvidia-tesla-t4:1", "n1-standard-2:nvidia-tesla-t4:1", "n1-highmem-2:nvidia-tesla-t4:1", "n1-highcpu-4:nvidia-tesla-t4:1", "n1-standard-4:nvidia-tesla-t4:1", "n1-highmem-4:nvidia-tesla-t4:1", "n1-highcpu-8:nvidia-tesla-t4:1", "n1-standard-1:nvidia-tesla-t4:2", "n1-highcpu-2:nvidia-tesla-t4:2", "n1-standard-8:nvidia-tesla-t4:1", "n1-standard-2:nvidia-tesla-t4:2", "n1-highmem-2:nvidia-tesla-t4:2", "n1-highcpu-4:nvidia-tesla-t4:2", "n1-highmem-8:nvidia-tesla-t4:1", "n1-standard-4:nvidia-tesla-t4:2", "n1-highmem-4:nvidia-tesla-t4:2", "n1-highcpu-16:nvidia-tesla-t4:1", "n1-highcpu-8:nvidia-tesla-t4:2", "n1-standard-8:nvidia-tesla-t4:2", "n1-standard-16:nvidia-tesla-t4:1", "n1-highmem-8:nvidia-tesla-t4:2", "n1-highcpu-16:nvidia-tesla-t4:2", "n1-standard-1:nvidia-tesla-t4:4", "n1-highcpu-2:nvidia-tesla-t4:4", "n1-highmem-16:nvidia-tesla-t4:1", "n1-standard-2:nvidia-tesla-t4:4", "n1-highmem-2:nvidia-tesla-t4:4", "n1-highcpu-4:nvidia-tesla-t4:4", "n1-standard-16:nvidia-tesla-t4:2", "n1-standard-4:nvidia-tesla-t4:4", "n1-highmem-4:nvidia-tesla-t4:4", "n1-highcpu-32:nvidia-tesla-t4:1", "n1-highcpu-8:nvidia-tesla-t4:4", "n1-highmem-16:nvidia-tesla-t4:2", "n1-standard-8:nvidia-tesla-t4:4", "n1-highmem-8:nvidia-tesla-t4:4", "n1-highcpu-32:nvidia-tesla-t4:2", "n1-highcpu-16:nvidia-tesla-t4:4", "n1-standard-32:nvidia-tesla-t4:1", "n1-standard-16:nvidia-tesla-t4:4", "n1-standard-32:nvidia-tesla-t4:2", "n1-highmem-16:nvidia-tesla-t4:4", "n1-highmem-32:nvidia-tesla-t4:1", "n1-highcpu-32:nvidia-tesla-t4:4", "n1-highmem-32:nvidia-tesla-t4:2", "n1-highcpu-64:nvidia-tesla-t4:1", "n1-standard-32:nvidia-tesla-t4:4", "n1-highcpu-64:nvidia-tesla-t4:2", "n1-highmem-32:nvidia-tesla-t4:4", "n1-standard-64:nvidia-tesla-t4:1", "n1-highcpu-64:nvidia-tesla-t4:4", "n1-standard-64:nvidia-tesla-t4:2", "n1-highcpu-96:nvidia-tesla-t4:1", "n1-highcpu-96:nvidia-tesla-t4:2", "n1-highmem-64:nvidia-tesla-t4:1", "n1-standard-64:nvidia-tesla-t4:4", "n1-highmem-64:nvidia-tesla-t4:2", "n1-highcpu-96:nvidia-tesla-t4:4", "n1-standard-96:nvidia-tesla-t4:1", "n1-highmem-64:nvidia-tesla-t4:4", "n1-standard-96:nvidia-tesla-t4:2", "n1-ultramem-40:nvidia-tesla-t4:1", "n1-standard-96:nvidia-tesla-t4:4", "n1-ultramem-40:nvidia-tesla-t4:2", "n1-highmem-96:nvidia-tesla-t4:1", "n1-highmem-96:nvidia-tesla-t4:2", "n1-ultramem-40:nvidia-tesla-t4:4", "n1-highmem-96:nvidia-tesla-t4:4", "n1-megamem-96:nvidia-tesla-t4:1", "n1-megamem-96:nvidia-tesla-t4:2", "n1-megamem-96:nvidia-tesla-t4:4", "n1-ultramem-80:nvidia-tesla-t4:1", "n1-ultramem-80:nvidia-tesla-t4:2", "n1-ultramem-80:nvidia-tesla-t4:4", "n1-ultramem-160:nvidia-tesla-t4:1", "n1-ultramem-160:nvidia-tesla-t4:2", "n1-ultramem-160:nvidia-tesla-t4:4",
}

func ValidateInstanceType(instanceType string) bool {
for _, v := range InstanceTypes {
if instanceType == v {
return true
}
}
return false
}
Loading