From 7ec448b4e70b1544feb3bd64c43b1cdd2f0158b7 Mon Sep 17 00:00:00 2001 From: Alec Fong Date: Thu, 14 Nov 2024 20:58:09 +0000 Subject: [PATCH] login works with kas --- .vscode/launch.json | 16 +++++++- pkg/auth/auth.go | 15 +++++--- pkg/auth/kas.go | 78 +++++++++++++++++++++++++++++--------- pkg/cmd/cmd.go | 24 ++++++++---- pkg/cmd/hello/steps.go | 7 ++-- pkg/cmd/invite/invite.go | 4 +- pkg/cmd/ls/ls.go | 3 +- pkg/cmd/org/org.go | 3 +- pkg/cmd/profile/profile.go | 3 +- pkg/config/config.go | 2 + 10 files changed, 114 insertions(+), 41 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 8a46d31f..80a4b962 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -5,13 +5,25 @@ // prod_lite.env "version": "0.2.0", "configurations": [ + { + "name": "login-kas", + "type": "go", + "request": "launch", + "mode": "debug", + "program": "${workspaceFolder}/main.go", + // "envFile": "${workspaceFolder}/local.env", + "args": [ + "login", + "--auth", + "kas" + ], + }, { "name": "login", "type": "go", "request": "launch", "mode": "debug", "program": "${workspaceFolder}/main.go", - "envFile": "${workspaceFolder}/prod_lite.env", // "envFile": "${workspaceFolder}/local.env", "args": [ "login", @@ -301,4 +313,4 @@ ], } ] -} +} \ No newline at end of file diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index c53f49af..c5843e6d 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -7,6 +7,7 @@ import ( "os" "strings" + "github.com/brevdev/brev-cli/pkg/config" "github.com/brevdev/brev-cli/pkg/entity" breverrors "github.com/brevdev/brev-cli/pkg/errors" "github.com/brevdev/brev-cli/pkg/terminal" @@ -189,14 +190,16 @@ func (t Auth) LoginWithToken(token string) error { func defaultAuthFunc(url, code string) { codeType := color.New(color.FgWhite, color.Bold).SprintFunc() + if code != "" { + fmt.Println("Your Device Confirmation Code is 👉", codeType(code), "👈") + fmt.Print("\n") + } + urlType := color.New(color.FgCyan, color.Bold).SprintFunc() + fmt.Println("Browser link: " + urlType(url) + "\n") + fmt.Println("Alternatively, get CLI Command (\"Login via CLI\"): ", urlType(fmt.Sprintf("%s/profile?login=cli", config.ConsoleBaseURL))) fmt.Print("\n") - fmt.Println("Your Device Confirmation Code is 👉", codeType(code), "👈") caretType := color.New(color.FgGreen, color.Bold).SprintFunc() enterType := color.New(color.FgGreen, color.Bold).SprintFunc() - urlType := color.New(color.FgCyan, color.Bold).SprintFunc() - fmt.Println("\n" + "Browser link: " + urlType(url) + "\n") - fmt.Println("Alternatively, get CLI Command (\"Login via CLI\"): ", urlType("https://console.brev.dev/profile?login=cli")) - fmt.Print("\n") _ = terminal.PromptGetInput(terminal.PromptContent{ Label: " " + caretType("▸") + " Press " + enterType("Enter") + " to login via browser", ErrorMsg: "error", @@ -215,7 +218,7 @@ func defaultAuthFunc(url, code string) { func skipBrowserAuthFunc(url, _ string) { urlType := color.New(color.FgCyan, color.Bold).SprintFunc() fmt.Println("Please copy", urlType(url), "and paste it in your browser.") - fmt.Println("Alternatively, get CLI Command (\"Login via CLI\"): ", urlType("https://console.brev.dev/profile?login=cli")) + fmt.Println("Alternatively, get CLI Command (\"Login via CLI\"): ", urlType(fmt.Sprintf("%s/profile?login=cli", config.ConsoleBaseURL))) fmt.Println("Waiting for login to complete in browser... Ctrl+C to use CLI command instead.") } diff --git a/pkg/auth/kas.go b/pkg/auth/kas.go index 1aea9831..e6da7900 100644 --- a/pkg/auth/kas.go +++ b/pkg/auth/kas.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "strings" + "time" "github.com/brevdev/brev-cli/pkg/entity" breverrors "github.com/brevdev/brev-cli/pkg/errors" @@ -16,8 +17,17 @@ import ( var _ OAuth = KasAuthenticator{} type KasAuthenticator struct { - Email string - BaseURL string + Email string + BaseURL string + PollTimeout time.Duration +} + +func NewKasAuthenticator(email, baseURL string) KasAuthenticator { + return KasAuthenticator{ + Email: email, + BaseURL: baseURL, + PollTimeout: 5 * time.Minute, + } } func (a KasAuthenticator) GetNewAuthTokensWithRefresh(refreshToken string) (*entity.AuthTokens, error) { @@ -38,7 +48,7 @@ func (a KasAuthenticator) GetNewAuthTokensWithRefresh(refreshToken string) (*ent } type LoginCallResponse struct { - LoginUrl string `json:"loginUrl"` + LoginURL string `json:"loginUrl"` SessionKey string `json:"sessionKey"` } @@ -75,6 +85,10 @@ func (a KasAuthenticator) MakeLoginCall(id, email string) (LoginCallResponse, er return LoginCallResponse{}, breverrors.WrapAndTrace(err) } + if resp.StatusCode >= 400 { + return LoginCallResponse{}, fmt.Errorf("error making login call, status code: %d, body: %s", resp.StatusCode, string(body)) + } + var response LoginCallResponse if err := json.Unmarshal(body, &response); err != nil { return LoginCallResponse{}, breverrors.WrapAndTrace(err) @@ -83,36 +97,60 @@ func (a KasAuthenticator) MakeLoginCall(id, email string) (LoginCallResponse, er } func (a KasAuthenticator) DoDeviceAuthFlow(userLoginFlow func(url string, code string)) (*LoginTokens, error) { - id := uuid.New() + id := uuid.New().String() email := a.Email if a.Email == "" { fmt.Print("Enter your email: ") - fmt.Scanln(&email) + _, err := fmt.Scanln(&email) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } } - loginResp, err := a.MakeLoginCall(id.String(), email) + loginResp, err := a.MakeLoginCall(id, email) if err != nil { return nil, breverrors.WrapAndTrace(err) } - userLoginFlow(loginResp.LoginUrl, id.String()) + userLoginFlow(loginResp.LoginURL, "") - idToken, err := a.retrieveIDToken(loginResp.SessionKey, id.String()) - if err != nil { - return nil, breverrors.WrapAndTrace(err) + return a.pollForTokens(loginResp.SessionKey, id) +} + +func (a KasAuthenticator) pollForTokens(sessionKey, id string) (*LoginTokens, error) { + // Try to retrieve tokens for up to 5 minutes + timeout := time.After(a.PollTimeout) + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + for { + select { + case <-timeout: + return nil, breverrors.WrapAndTrace(fmt.Errorf("timed out waiting for login")) + case <-ticker.C: + idToken, err := a.retrieveIDToken(sessionKey, id) + if err == nil { + fmt.Println(idToken) + return &LoginTokens{ + AuthTokens: entity.AuthTokens{ + AccessToken: idToken, + RefreshToken: fmt.Sprintf("%s:%s", sessionKey, id), + }, + IDToken: idToken, + }, nil + } + // Continue polling on error + } } - return &LoginTokens{ - AuthTokens: entity.AuthTokens{ - AccessToken: idToken, - RefreshToken: fmt.Sprintf("%s:%s", loginResp.SessionKey, id.String()), - }, - IDToken: idToken, - }, nil } type RetrieveIDTokenResponse struct { - IDToken string `json:"token"` + IDToken string `json:"token"` + RequestStatus struct { + StatusCode string `json:"statusCode"` + StatusDescription string `json:"statusDescription"` + RequestID string `json:"requestId"` + } `json:"requestStatus"` } // retrieveIDToken retrieves the ID token from BASE_API_URL + "/token". @@ -140,6 +178,10 @@ func (a KasAuthenticator) retrieveIDToken(sessionKey, deviceID string) (string, return "", fmt.Errorf("error reading token response: %v", err) } + if tokenResp.StatusCode >= 400 { + return "", fmt.Errorf("error retrieving token, status code: %d, body: %s", tokenResp.StatusCode, string(tokenBody)) + } + var tokenResponse RetrieveIDTokenResponse if err := json.Unmarshal(tokenBody, &tokenResponse); err != nil { return "", fmt.Errorf("error parsing token JSON response: %v", err) diff --git a/pkg/cmd/cmd.go b/pkg/cmd/cmd.go index a2737207..ca671bdd 100644 --- a/pkg/cmd/cmd.go +++ b/pkg/cmd/cmd.go @@ -2,6 +2,7 @@ package cmd import ( + "flag" "fmt" "github.com/brevdev/brev-cli/pkg/auth" @@ -65,11 +66,17 @@ var ( authProvider string ) +func init() { + flag.StringVar(&email, "email", "", "email to use for authentication") + flag.StringVar(&authProvider, "auth", "", "authentication provider to use (auth0 or kas, default is auth0)") + flag.Parse() +} + func NewDefaultBrevCommand() *cobra.Command { cmd := NewBrevCommand() cmd.PersistentFlags().StringVar(&user, "user", "", "non root user to use for per user configuration of commands run as root") - cmd.PersistentFlags().StringVar(&email, "email", "", "email address to use for authentication") - cmd.PersistentFlags().StringVar(&authProvider, "auth", "", "authentication provider to use (auth0 or kas, defaults to auth0)") + cmd.PersistentFlags().StringVar(&email, "email", "", "email to use for authentication") + cmd.PersistentFlags().StringVar(&authProvider, "auth", "", "authentication provider to use (auth0 or kas, default is auth0)") return cmd } @@ -90,13 +97,16 @@ func NewBrevCommand() *cobra.Command { //nolint:funlen,gocognit,gocyclo // defin fmt.Printf("%v\n", err) } + authP := tokens.GetCredentialProvider() + if authProvider != "" { + authP = entity.CredentialProvider(authProvider) + } + var authenticator auth.OAuth - switch tokens.GetCredentialProvider() { + switch authP { case entity.CredentialProviderKAS: - authenticator = auth.KasAuthenticator{ - BaseURL: "https://api.ngc.nvidia.com", - Email: email, - } + config.ConsoleBaseURL = "https://brev.nvidia.com" + authenticator = auth.NewKasAuthenticator(email, "https://api.ngc.nvidia.com") default: authenticator = auth.Auth0Authenticator{ Audience: "https://brevdev.us.auth0.com/api/v2/", diff --git a/pkg/cmd/hello/steps.go b/pkg/cmd/hello/steps.go index c76a79cd..e8562300 100644 --- a/pkg/cmd/hello/steps.go +++ b/pkg/cmd/hello/steps.go @@ -4,6 +4,7 @@ import ( "fmt" "time" + "github.com/brevdev/brev-cli/pkg/config" "github.com/brevdev/brev-cli/pkg/entity" breverrors "github.com/brevdev/brev-cli/pkg/errors" "github.com/brevdev/brev-cli/pkg/terminal" @@ -40,7 +41,7 @@ func GetTextBasedONStatus(status string, t *terminal.Terminal) string { s += "\n\nRun " + t.Yellow("brev hello") + " to resume this walk through when your instance is ready\n" default: s += t.Red("Please create a running instance for this walk through. ") - s += "\n\tYou can do that here: " + t.Yellow("https://console.brev.dev/environments/new") + s += "\n\tYou can do that here: " + t.Yellow(fmt.Sprintf("%s/environments/new", config.ConsoleBaseURL)) s += "\n\nRun " + t.Yellow("brev hello") + " to resume this walk through when your instance is ready\n" } return s @@ -61,7 +62,7 @@ func GetDevEnvOrStall(t *terminal.Terminal, workspaces []entity.Workspace) *enti if noneFound { s := t.Red("Please create a running instance for this walk through. ") - s += "\n\tYou can do that here: " + t.Yellow("https://console.brev.dev/environments/new") + s += "\n\tYou can do that here: " + t.Yellow(fmt.Sprintf("%s/environments/new", config.ConsoleBaseURL)) s += "\n\nRun: " + t.Yellow("brev hello") + " to resume this walk through when your instance is ready\n" TypeItToMe(s) return nil @@ -102,7 +103,7 @@ func printBrevOpen(t *terminal.Terminal, firstWorkspace entity.Workspace) { func printCompletedOnboarding(t *terminal.Terminal) { s := "\n\nI think I'm done here. Now you know how to open an instance and start coding." - s += "\n\nUse the console " + t.Yellow("(https://console.brev.dev)") + " to create a new instance or share it with people" + s += "\n\nUse the console " + t.Yellow(fmt.Sprintf("(%s)"), config.ConsoleBaseURL) + " to create a new instance or share it with people" s += "\nand use this CLI to code the way you would normally 🤙" s += "\n\nCheck out the docs at " + t.Yellow("https://brev.dev") + " and let us know if we can help!\n" s += "\n\nIn case you missed it, my cell is " + t.Yellow("(415) 237-2247") + "\n\t-Nader\n" diff --git a/pkg/cmd/invite/invite.go b/pkg/cmd/invite/invite.go index 4dc773b8..32b14003 100644 --- a/pkg/cmd/invite/invite.go +++ b/pkg/cmd/invite/invite.go @@ -7,6 +7,7 @@ import ( "github.com/brevdev/brev-cli/pkg/cmd/cmderrors" "github.com/brevdev/brev-cli/pkg/cmd/completions" "github.com/brevdev/brev-cli/pkg/cmdcontext" + "github.com/brevdev/brev-cli/pkg/config" "github.com/brevdev/brev-cli/pkg/entity" breverrors "github.com/brevdev/brev-cli/pkg/errors" "github.com/brevdev/brev-cli/pkg/store" @@ -96,9 +97,8 @@ func RunInvite(t *terminal.Terminal, inviteStore InviteStore, orgflag string) er } t.Vprintf("Share this link to add someone to %s. It will expire in 7 days.", t.Green(org.Name)) - // t.Vprintf("\n\n\t%s", t.White("https://console.brev.dev/invite?token=%s\n\n", token)) t.Vprintf("\n\n %s", t.Green("▸")) - t.Vprintf(" %s", t.White("https://console.brev.dev/invite?token=%s\n\n", token)) + t.Vprintf(" %s", t.White("%sinvite?token=%s\n\n", config.ConsoleBaseURL, token)) return nil } diff --git a/pkg/cmd/ls/ls.go b/pkg/cmd/ls/ls.go index e4324c10..f10eb269 100644 --- a/pkg/cmd/ls/ls.go +++ b/pkg/cmd/ls/ls.go @@ -11,6 +11,7 @@ import ( "github.com/brevdev/brev-cli/pkg/cmd/hello" utilities "github.com/brevdev/brev-cli/pkg/cmd/util" "github.com/brevdev/brev-cli/pkg/cmdcontext" + "github.com/brevdev/brev-cli/pkg/config" "github.com/brevdev/brev-cli/pkg/entity" "github.com/brevdev/brev-cli/pkg/entity/virtualproject" breverrors "github.com/brevdev/brev-cli/pkg/errors" @@ -235,7 +236,7 @@ func (ls Ls) RunOrgs() error { return breverrors.WrapAndTrace(err) } if len(orgs) == 0 { - ls.terminal.Vprint(ls.terminal.Yellow("You don't have any orgs. Create one! https://console.brev.dev")) + ls.terminal.Vprint(ls.terminal.Yellow(fmt.Sprintf("You don't have any orgs. Create one! %s", config.ConsoleBaseURL))) return nil } diff --git a/pkg/cmd/org/org.go b/pkg/cmd/org/org.go index 82a47b4f..98a1bf12 100644 --- a/pkg/cmd/org/org.go +++ b/pkg/cmd/org/org.go @@ -8,6 +8,7 @@ import ( "github.com/brevdev/brev-cli/pkg/cmd/completions" "github.com/brevdev/brev-cli/pkg/cmd/invite" "github.com/brevdev/brev-cli/pkg/cmdcontext" + "github.com/brevdev/brev-cli/pkg/config" "github.com/brevdev/brev-cli/pkg/entity" breverrors "github.com/brevdev/brev-cli/pkg/errors" "github.com/brevdev/brev-cli/pkg/store" @@ -74,7 +75,7 @@ func RunOrgs(t *terminal.Terminal, store OrgCmdStore) error { return breverrors.WrapAndTrace(err) } if len(orgs) == 0 { - t.Vprint(t.Yellow("You don't have any orgs. Create one! https://console.brev.dev")) + t.Vprint(t.Yellow(fmt.Sprintf("You don't have any orgs. Create one! %s", config.ConsoleBaseURL))) return nil } diff --git a/pkg/cmd/profile/profile.go b/pkg/cmd/profile/profile.go index 030dc76d..2319c242 100644 --- a/pkg/cmd/profile/profile.go +++ b/pkg/cmd/profile/profile.go @@ -7,6 +7,7 @@ import ( "github.com/brevdev/brev-cli/pkg/cmd/completions" "github.com/brevdev/brev-cli/pkg/cmd/start" + "github.com/brevdev/brev-cli/pkg/config" "github.com/brevdev/brev-cli/pkg/entity" breverrors "github.com/brevdev/brev-cli/pkg/errors" "github.com/brevdev/brev-cli/pkg/terminal" @@ -54,7 +55,7 @@ func NewCmdProfile(t *terminal.Terminal, loginProfileStore ProfileStore, noLogin } func goToProfileInConsole() { - url := "https://console.brev.dev/profile" + url := fmt.Sprintf("%s/profile", config.ConsoleBaseURL) caretType := color.New(color.FgGreen, color.Bold).SprintFunc() enterType := color.New(color.FgGreen, color.Bold).SprintFunc() urlType := color.New(color.FgWhite, color.Bold).SprintFunc() diff --git a/pkg/config/config.go b/pkg/config/config.go index f5538733..6a2db58f 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -18,6 +18,8 @@ const ( ollamaAPIURL EnvVarName = "OLLAMA_API_URL" ) +var ConsoleBaseURL = "https://console.brev.dev" + type ConstantsConfig struct{} func NewConstants() *ConstantsConfig {