diff --git a/pkg/cmd/ollama/ollama.go b/pkg/cmd/ollama/ollama.go index 0b891d47..0ebfc99a 100644 --- a/pkg/cmd/ollama/ollama.go +++ b/pkg/cmd/ollama/ollama.go @@ -16,6 +16,7 @@ import ( "github.com/brevdev/brev-cli/pkg/collections" "github.com/brevdev/brev-cli/pkg/config" "github.com/brevdev/brev-cli/pkg/entity" + "github.com/brevdev/brev-cli/pkg/instancetypes" "github.com/brevdev/brev-cli/pkg/store" "github.com/brevdev/brev-cli/pkg/terminal" "github.com/spf13/cobra" @@ -69,6 +70,7 @@ func validateModelType(input string) (bool, error) { func NewCmdOllama(t *terminal.Terminal, ollamaStore OllamaStore) *cobra.Command { var model string + var gpu string cmd := &cobra.Command{ Use: "ollama", @@ -91,10 +93,20 @@ func NewCmdOllama(t *terminal.Terminal, ollamaStore OllamaStore) *cobra.Command if !isValid { return fmt.Errorf("invalid model type: %s", model) } + if gpu != "" { + isValid := instancetypes.ValidateInstanceType(gpu) + if !isValid { + err := fmt.Errorf("invalid GPU instance type: %s, see https://brev.dev/docs/reference/gpu for a list of valid GPU instance types", gpu) + return breverrors.WrapAndTrace(err) + } + } // Start the network call in a goroutine res := collections.Async(func() (any, error) { - err := runOllamaWorkspace(t, model, ollamaStore) + err := runOllamaWorkspace(t, RunOptions{ + Model: model, + GPUType: gpu, + }, ollamaStore) if err != nil { return nil, breverrors.WrapAndTrace(err) } @@ -114,10 +126,16 @@ func NewCmdOllama(t *terminal.Terminal, ollamaStore OllamaStore) *cobra.Command }, } cmd.Flags().StringVarP(&model, "model", "m", "", "AI/ML model type (e.g., llama2, llama3, mistral7b)") + cmd.Flags().StringVarP(&gpu, "gpu", "g", "g5.xlarge", "GPU instance type. See https://brev.dev/docs/reference/gpu for details") return cmd } -func runOllamaWorkspace(t *terminal.Terminal, model string, ollamaStore OllamaStore) error { //nolint:funlen, gocyclo // todo +type RunOptions struct { + Model string + GPUType string +} + +func runOllamaWorkspace(t *terminal.Terminal, opts RunOptions, ollamaStore OllamaStore) error { //nolint:funlen, gocyclo // todo _, err := ollamaStore.GetCurrentUser() if err != nil { return breverrors.WrapAndTrace(err) @@ -128,14 +146,13 @@ func runOllamaWorkspace(t *terminal.Terminal, model string, ollamaStore OllamaSt return breverrors.WrapAndTrace(err) } - // Placeholder for instance type, to be updated later - instanceType := "g4dn.xlarge" + instanceType := opts.GPUType clusterID := config.GlobalConfig.GetDefaultClusterID() uuid := uuid.New().String() instanceName := fmt.Sprintf("ollama-%s", uuid) cwOptions := store.NewCreateWorkspacesOptions(clusterID, instanceName).WithInstanceType(instanceType) - hello.TypeItToMeUnskippable27(fmt.Sprintf("Creating Ollama server %s with model %s in org %s\n", t.Green(cwOptions.Name), t.Green(model), t.Green(org.ID))) + hello.TypeItToMeUnskippable27(fmt.Sprintf("Creating Ollama server %s with model %s in org %s\n", t.Green(cwOptions.Name), t.Green(opts.Model), t.Green(org.ID))) s := t.NewSpinner() @@ -223,17 +240,17 @@ func runOllamaWorkspace(t *terminal.Terminal, model string, ollamaStore OllamaSt s.Suffix = "Pulling the %s model, just a bit more! 🏄" // shell in and run ollama pull: - if err := runSSHExec(instanceName, []string{"ollama", "pull", model}, false); err != nil { + if err := runSSHExec(instanceName, []string{"ollama", "pull", opts.Model}, false); err != nil { return breverrors.WrapAndTrace(err) } - if err := runSSHExec(instanceName, []string{"ollama", "run", model, "hello world"}, true); err != nil { + if err := runSSHExec(instanceName, []string{"ollama", "run", opts.Model, "hello world"}, true); err != nil { return breverrors.WrapAndTrace(err) } s.Stop() fmt.Print("\n") t.Vprint(t.Green("Ollama is ready to go!\n")) - displayOllamaConnectBreadCrumb(t, link, model) + displayOllamaConnectBreadCrumb(t, link, opts.Model) return nil } diff --git a/pkg/cmd/start/start.go b/pkg/cmd/start/start.go index 257abe98..2ce046ef 100644 --- a/pkg/cmd/start/start.go +++ b/pkg/cmd/start/start.go @@ -13,6 +13,7 @@ import ( "github.com/brevdev/brev-cli/pkg/config" "github.com/brevdev/brev-cli/pkg/entity" "github.com/brevdev/brev-cli/pkg/featureflag" + "github.com/brevdev/brev-cli/pkg/instancetypes" "github.com/brevdev/brev-cli/pkg/mergeshells" "github.com/brevdev/brev-cli/pkg/store" "github.com/brevdev/brev-cli/pkg/terminal" @@ -29,11 +30,6 @@ var ( brev start brev start --org myFancyOrg ` - instanceTypes = []string{ - "p4d.24xlarge", "p3.2xlarge", "p3.8xlarge", "p3.16xlarge", "p3dn.24xlarge", "p2.xlarge", "p2.8xlarge", "p2.16xlarge", "g5.xlarge", "g5.2xlarge", "g5.4xlarge", "g5.8xlarge", "g5.16xlarge", "g5.12xlarge", "g5.24xlarge", "g5.48xlarge", "g5g.xlarge", "g5g.2xlarge", "g5g.4xlarge", "g5g.8xlarge", "g5g.16xlarge", "g5g.metal", "g4dn.xlarge", "g4dn.2xlarge", "g4dn.4xlarge", "g4dn.8xlarge", "g4dn.16xlarge", "g4dn.12xlarge", "g4dn.metal", "g4ad.xlarge", "g4ad.2xlarge", "g4ad.4xlarge", "g4ad.8xlarge", "g4ad.16xlarge", "g3s.xlarge", "g3.4xlarge", "g3.8xlarge", "g3.16xlarge", - "g5g.xlarge", "g5g.2xlarge", "g5g.4xlarge", "g5g.8xlarge", "g5g.16xlarge", "g5g.metal", "g4dn.xlarge", "g4dn.2xlarge", "g4dn.4xlarge", "g4dn.8xlarge", "g4dn.16xlarge", "g4dn.12xlarge", "g4dn.metal", - "n1-standard-1:nvidia-tesla-t4:1", "n1-highcpu-2:nvidia-tesla-t4:1", "n1-standard-2:nvidia-tesla-t4:1", "n1-highmem-2:nvidia-tesla-t4:1", "n1-highcpu-4:nvidia-tesla-t4:1", "n1-standard-4:nvidia-tesla-t4:1", "n1-highmem-4:nvidia-tesla-t4:1", "n1-highcpu-8:nvidia-tesla-t4:1", "n1-standard-1:nvidia-tesla-t4:2", "n1-highcpu-2:nvidia-tesla-t4:2", "n1-standard-8:nvidia-tesla-t4:1", "n1-standard-2:nvidia-tesla-t4:2", "n1-highmem-2:nvidia-tesla-t4:2", "n1-highcpu-4:nvidia-tesla-t4:2", "n1-highmem-8:nvidia-tesla-t4:1", "n1-standard-4:nvidia-tesla-t4:2", "n1-highmem-4:nvidia-tesla-t4:2", "n1-highcpu-16:nvidia-tesla-t4:1", "n1-highcpu-8:nvidia-tesla-t4:2", "n1-standard-8:nvidia-tesla-t4:2", "n1-standard-16:nvidia-tesla-t4:1", "n1-highmem-8:nvidia-tesla-t4:2", "n1-highcpu-16:nvidia-tesla-t4:2", "n1-standard-1:nvidia-tesla-t4:4", "n1-highcpu-2:nvidia-tesla-t4:4", "n1-highmem-16:nvidia-tesla-t4:1", "n1-standard-2:nvidia-tesla-t4:4", "n1-highmem-2:nvidia-tesla-t4:4", "n1-highcpu-4:nvidia-tesla-t4:4", "n1-standard-16:nvidia-tesla-t4:2", "n1-standard-4:nvidia-tesla-t4:4", "n1-highmem-4:nvidia-tesla-t4:4", "n1-highcpu-32:nvidia-tesla-t4:1", "n1-highcpu-8:nvidia-tesla-t4:4", "n1-highmem-16:nvidia-tesla-t4:2", "n1-standard-8:nvidia-tesla-t4:4", "n1-highmem-8:nvidia-tesla-t4:4", "n1-highcpu-32:nvidia-tesla-t4:2", "n1-highcpu-16:nvidia-tesla-t4:4", "n1-standard-32:nvidia-tesla-t4:1", "n1-standard-16:nvidia-tesla-t4:4", "n1-standard-32:nvidia-tesla-t4:2", "n1-highmem-16:nvidia-tesla-t4:4", "n1-highmem-32:nvidia-tesla-t4:1", "n1-highcpu-32:nvidia-tesla-t4:4", "n1-highmem-32:nvidia-tesla-t4:2", "n1-highcpu-64:nvidia-tesla-t4:1", "n1-standard-32:nvidia-tesla-t4:4", "n1-highcpu-64:nvidia-tesla-t4:2", "n1-highmem-32:nvidia-tesla-t4:4", "n1-standard-64:nvidia-tesla-t4:1", "n1-highcpu-64:nvidia-tesla-t4:4", "n1-standard-64:nvidia-tesla-t4:2", "n1-highcpu-96:nvidia-tesla-t4:1", "n1-highcpu-96:nvidia-tesla-t4:2", "n1-highmem-64:nvidia-tesla-t4:1", "n1-standard-64:nvidia-tesla-t4:4", "n1-highmem-64:nvidia-tesla-t4:2", "n1-highcpu-96:nvidia-tesla-t4:4", "n1-standard-96:nvidia-tesla-t4:1", "n1-highmem-64:nvidia-tesla-t4:4", "n1-standard-96:nvidia-tesla-t4:2", "n1-ultramem-40:nvidia-tesla-t4:1", "n1-standard-96:nvidia-tesla-t4:4", "n1-ultramem-40:nvidia-tesla-t4:2", "n1-highmem-96:nvidia-tesla-t4:1", "n1-highmem-96:nvidia-tesla-t4:2", "n1-ultramem-40:nvidia-tesla-t4:4", "n1-highmem-96:nvidia-tesla-t4:4", "n1-megamem-96:nvidia-tesla-t4:1", "n1-megamem-96:nvidia-tesla-t4:2", "n1-megamem-96:nvidia-tesla-t4:4", "n1-ultramem-80:nvidia-tesla-t4:1", "n1-ultramem-80:nvidia-tesla-t4:2", "n1-ultramem-80:nvidia-tesla-t4:4", "n1-ultramem-160:nvidia-tesla-t4:1", "n1-ultramem-160:nvidia-tesla-t4:2", "n1-ultramem-160:nvidia-tesla-t4:4", - } ) type StartStore interface { @@ -49,15 +45,6 @@ type StartStore interface { GetFileAsString(path string) (string, error) } -func validateInstanceType(instanceType string) bool { - for _, v := range instanceTypes { - if instanceType == v { - return true - } - } - return false -} - func NewCmdStart(t *terminal.Terminal, startStore StartStore, noLoginStartStore StartStore) *cobra.Command { var org string var name string @@ -84,9 +71,9 @@ func NewCmdStart(t *terminal.Terminal, startStore StartStore, noLoginStartStore } if gpu != "" { - isValid := validateInstanceType(gpu) + isValid := instancetypes.ValidateInstanceType(gpu) if !isValid { - err := fmt.Errorf("invalid GPU instance type: %s", gpu) + err := fmt.Errorf("invalid GPU instance type: %s, see https://brev.dev/docs/reference/gpu for a list of valid GPU instance types", gpu) return breverrors.WrapAndTrace(err) } } diff --git a/pkg/instancetypes/instancetypes.go b/pkg/instancetypes/instancetypes.go new file mode 100644 index 00000000..9e93b906 --- /dev/null +++ b/pkg/instancetypes/instancetypes.go @@ -0,0 +1,16 @@ +package instancetypes + +var InstanceTypes = []string{ + "p4d.24xlarge", "p3.2xlarge", "p3.8xlarge", "p3.16xlarge", "p3dn.24xlarge", "p2.xlarge", "p2.8xlarge", "p2.16xlarge", "g5.xlarge", "g5.2xlarge", "g5.4xlarge", "g5.8xlarge", "g5.16xlarge", "g5.12xlarge", "g5.24xlarge", "g5.48xlarge", "g5g.xlarge", "g5g.2xlarge", "g5g.4xlarge", "g5g.8xlarge", "g5g.16xlarge", "g5g.metal", "g4dn.xlarge", "g4dn.2xlarge", "g4dn.4xlarge", "g4dn.8xlarge", "g4dn.16xlarge", "g4dn.12xlarge", "g4dn.metal", "g4ad.xlarge", "g4ad.2xlarge", "g4ad.4xlarge", "g4ad.8xlarge", "g4ad.16xlarge", "g3s.xlarge", "g3.4xlarge", "g3.8xlarge", "g3.16xlarge", + "g5g.xlarge", "g5g.2xlarge", "g5g.4xlarge", "g5g.8xlarge", "g5g.16xlarge", "g5g.metal", "g4dn.xlarge", "g4dn.2xlarge", "g4dn.4xlarge", "g4dn.8xlarge", "g4dn.16xlarge", "g4dn.12xlarge", "g4dn.metal", + "n1-standard-1:nvidia-tesla-t4:1", "n1-highcpu-2:nvidia-tesla-t4:1", "n1-standard-2:nvidia-tesla-t4:1", "n1-highmem-2:nvidia-tesla-t4:1", "n1-highcpu-4:nvidia-tesla-t4:1", "n1-standard-4:nvidia-tesla-t4:1", "n1-highmem-4:nvidia-tesla-t4:1", "n1-highcpu-8:nvidia-tesla-t4:1", "n1-standard-1:nvidia-tesla-t4:2", "n1-highcpu-2:nvidia-tesla-t4:2", "n1-standard-8:nvidia-tesla-t4:1", "n1-standard-2:nvidia-tesla-t4:2", "n1-highmem-2:nvidia-tesla-t4:2", "n1-highcpu-4:nvidia-tesla-t4:2", "n1-highmem-8:nvidia-tesla-t4:1", "n1-standard-4:nvidia-tesla-t4:2", "n1-highmem-4:nvidia-tesla-t4:2", "n1-highcpu-16:nvidia-tesla-t4:1", "n1-highcpu-8:nvidia-tesla-t4:2", "n1-standard-8:nvidia-tesla-t4:2", "n1-standard-16:nvidia-tesla-t4:1", "n1-highmem-8:nvidia-tesla-t4:2", "n1-highcpu-16:nvidia-tesla-t4:2", "n1-standard-1:nvidia-tesla-t4:4", "n1-highcpu-2:nvidia-tesla-t4:4", "n1-highmem-16:nvidia-tesla-t4:1", "n1-standard-2:nvidia-tesla-t4:4", "n1-highmem-2:nvidia-tesla-t4:4", "n1-highcpu-4:nvidia-tesla-t4:4", "n1-standard-16:nvidia-tesla-t4:2", "n1-standard-4:nvidia-tesla-t4:4", "n1-highmem-4:nvidia-tesla-t4:4", "n1-highcpu-32:nvidia-tesla-t4:1", "n1-highcpu-8:nvidia-tesla-t4:4", "n1-highmem-16:nvidia-tesla-t4:2", "n1-standard-8:nvidia-tesla-t4:4", "n1-highmem-8:nvidia-tesla-t4:4", "n1-highcpu-32:nvidia-tesla-t4:2", "n1-highcpu-16:nvidia-tesla-t4:4", "n1-standard-32:nvidia-tesla-t4:1", "n1-standard-16:nvidia-tesla-t4:4", "n1-standard-32:nvidia-tesla-t4:2", "n1-highmem-16:nvidia-tesla-t4:4", "n1-highmem-32:nvidia-tesla-t4:1", "n1-highcpu-32:nvidia-tesla-t4:4", "n1-highmem-32:nvidia-tesla-t4:2", "n1-highcpu-64:nvidia-tesla-t4:1", "n1-standard-32:nvidia-tesla-t4:4", "n1-highcpu-64:nvidia-tesla-t4:2", "n1-highmem-32:nvidia-tesla-t4:4", "n1-standard-64:nvidia-tesla-t4:1", "n1-highcpu-64:nvidia-tesla-t4:4", "n1-standard-64:nvidia-tesla-t4:2", "n1-highcpu-96:nvidia-tesla-t4:1", "n1-highcpu-96:nvidia-tesla-t4:2", "n1-highmem-64:nvidia-tesla-t4:1", "n1-standard-64:nvidia-tesla-t4:4", "n1-highmem-64:nvidia-tesla-t4:2", "n1-highcpu-96:nvidia-tesla-t4:4", "n1-standard-96:nvidia-tesla-t4:1", "n1-highmem-64:nvidia-tesla-t4:4", "n1-standard-96:nvidia-tesla-t4:2", "n1-ultramem-40:nvidia-tesla-t4:1", "n1-standard-96:nvidia-tesla-t4:4", "n1-ultramem-40:nvidia-tesla-t4:2", "n1-highmem-96:nvidia-tesla-t4:1", "n1-highmem-96:nvidia-tesla-t4:2", "n1-ultramem-40:nvidia-tesla-t4:4", "n1-highmem-96:nvidia-tesla-t4:4", "n1-megamem-96:nvidia-tesla-t4:1", "n1-megamem-96:nvidia-tesla-t4:2", "n1-megamem-96:nvidia-tesla-t4:4", "n1-ultramem-80:nvidia-tesla-t4:1", "n1-ultramem-80:nvidia-tesla-t4:2", "n1-ultramem-80:nvidia-tesla-t4:4", "n1-ultramem-160:nvidia-tesla-t4:1", "n1-ultramem-160:nvidia-tesla-t4:2", "n1-ultramem-160:nvidia-tesla-t4:4", +} + +func ValidateInstanceType(instanceType string) bool { + for _, v := range InstanceTypes { + if instanceType == v { + return true + } + } + return false +}