Skip to content

Commit

Permalink
kas as authenticator
Browse files Browse the repository at this point in the history
  • Loading branch information
theFong committed Nov 14, 2024
1 parent 846d9dd commit 8666fae
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 75 deletions.
19 changes: 0 additions & 19 deletions pkg/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,6 @@ import (
"github.com/pkg/browser"
)

type AuthChecker interface {
GetCredentialProvider() entity.CredentialProvider
}

func GetAuthenticator(ac AuthChecker) OAuth {
switch ac.GetCredentialProvider() {
case entity.CredentialProviderKAS:
return KasAuthenticator{}
case entity.CredentialProviderAuth0:
return Auth0Authenticator{
Audience: "https://brevdev.us.auth0.com/api/v2/",
ClientID: "JaqJRLEsdat5w7Tb0WqmTxzIeqwqepmk",
DeviceCodeEndpoint: "https://brevdev.us.auth0.com/oauth/device/code",
OauthTokenEndpoint: "https://brevdev.us.auth0.com/oauth/token",
}
}
return nil
}

type LoginAuth struct {
Auth
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/auth/auth0.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,9 @@ func (a Auth0Authenticator) DoDeviceAuthFlow(onStateRetrieved func(url string, c

return &LoginTokens{
AuthTokens: entity.AuthTokens{
AccessToken: res.AccessToken,
RefreshToken: res.RefreshToken,
AccessToken: res.AccessToken,
RefreshToken: res.RefreshToken,
CredentialProvider: entity.CredentialProviderAuth0,
},
IDToken: res.IDToken,
}, nil
Expand Down
75 changes: 25 additions & 50 deletions pkg/auth/kas.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"fmt"
"io"
"net/http"
"time"
"strings"

"github.com/brevdev/brev-cli/pkg/entity"
breverrors "github.com/brevdev/brev-cli/pkg/errors"
Expand All @@ -16,13 +16,25 @@ import (
var _ OAuth = KasAuthenticator{}

type KasAuthenticator struct {
BaseURL string
IDTokenExpiryMinutes float64
SessionIDExpiryHours float64
Email string
BaseURL string
}

func (a KasAuthenticator) GetNewAuthTokensWithRefresh(refreshToken string) (*entity.AuthTokens, error) {
return nil, nil
splitRefreshToken := strings.Split(refreshToken, ":")
if len(splitRefreshToken) != 2 {
return nil, fmt.Errorf("invalid refresh token")
}
sessionKey, deviceID := splitRefreshToken[0], splitRefreshToken[1]
token, err := a.retrieveIDToken(sessionKey, deviceID)
if err != nil {
return nil, breverrors.WrapAndTrace(err)
}
return &entity.AuthTokens{
AccessToken: token,
RefreshToken: refreshToken,
CredentialProvider: entity.CredentialProviderKAS,
}, nil
}

type LoginCallResponse struct {
Expand Down Expand Up @@ -70,16 +82,21 @@ func (a KasAuthenticator) MakeLoginCall(id, email string) (LoginCallResponse, er
return response, nil
}

func (a KasAuthenticator) DoDeviceAuthFlow(onStateRetrieved func(url string, code string)) (*LoginTokens, error) {
func (a KasAuthenticator) DoDeviceAuthFlow(userLoginFlow func(url string, code string)) (*LoginTokens, error) {
id := uuid.New()
email := "[email protected]" // TODO: ask user for email
email := a.Email

if a.Email == "" {
fmt.Print("Enter your email: ")
fmt.Scanln(&email)

Check failure on line 91 in pkg/auth/kas.go

View workflow job for this annotation

GitHub Actions / ci (ubuntu-20.04)

G104: Errors unhandled. (gosec)
}

loginResp, err := a.MakeLoginCall(id.String(), email)
if err != nil {
return nil, breverrors.WrapAndTrace(err)
}

onStateRetrieved(loginResp.LoginUrl, id.String())
userLoginFlow(loginResp.LoginUrl, id.String())

idToken, err := a.retrieveIDToken(loginResp.SessionKey, id.String())
if err != nil {
Expand Down Expand Up @@ -130,45 +147,3 @@ func (a KasAuthenticator) retrieveIDToken(sessionKey, deviceID string) (string,

return tokenResponse.IDToken, nil
}

// getIDToken reads the sessionConfig file and retrieves the idToken.
// If the idToken is expired (15 minutes after creation), a new token is fetched and set in the sessionConfig file.
func (a KasAuthenticator) getIDToken() (string, error) {
session, err := readSessionConfig()
if err != nil {
fmt.Println(err)
return "", err
}

idToken, ok := session["idToken"]
if !ok || idToken == "" {
return "", fmt.Errorf("idToken not found in session data. Please ensure you are logged in")
}

idTokenCreatedAtStr, ok := session["idTokenCreatedAt"]
if !ok || idTokenCreatedAtStr == "" {
return "", fmt.Errorf("idTokenCreatedAt not found in session data")
}

idTokenCreatedAt, err := time.Parse(time.RFC3339, idTokenCreatedAtStr)
if err != nil {
return "", fmt.Errorf("error parsing idTokenCreatedAt timestamp: %v", err)
}

if time.Since(idTokenCreatedAt).Minutes() >= a.IDTokenExpiryMinutes {
fmt.Println("ID token is expired. Creating a new ID token using the sessionKey.")

newIDToken, newIDTokenCreatedAt, err := a.retrieveIDToken()
if err != nil {
fmt.Println("error retrieving new ID token", err)
return "", fmt.Errorf("error retrieving new ID token: %v", err)
}

session["idToken"] = newIDToken
session["idTokenCreatedAt"] = newIDTokenCreatedAt

return newIDToken, updateSessionConfig(session)
}

return idToken, nil
}
32 changes: 29 additions & 3 deletions pkg/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import (
"github.com/brevdev/brev-cli/pkg/cmd/workspacegroups"
"github.com/brevdev/brev-cli/pkg/cmd/writeconnectionevent"
"github.com/brevdev/brev-cli/pkg/config"
"github.com/brevdev/brev-cli/pkg/entity"
"github.com/brevdev/brev-cli/pkg/featureflag"
"github.com/brevdev/brev-cli/pkg/files"
"github.com/brevdev/brev-cli/pkg/remoteversion"
Expand All @@ -58,11 +59,17 @@ import (
breverrors "github.com/brevdev/brev-cli/pkg/errors"
)

var user string
var (
user string
email string
authProvider string
)

func NewDefaultBrevCommand() *cobra.Command {
cmd := NewBrevCommand()
cmd.PersistentFlags().StringVar(&user, "user", "", "non root user to use for per user configuration of commands run as root")
cmd.PersistentFlags().StringVar(&email, "email", "", "email address to use for authentication")
cmd.PersistentFlags().StringVar(&authProvider, "auth", "", "authentication provider to use (auth0 or kas, defaults to auth0)")
return cmd
}

Expand All @@ -78,7 +85,26 @@ func NewBrevCommand() *cobra.Command { //nolint:funlen,gocognit,gocyclo // defin
NewBasicStore().
WithFileSystem(fs)

authenticator := auth.GetAuthenticator(nil) // todo
tokens, err := fsStore.GetAuthTokens()
if err != nil {
fmt.Printf("%v\n", err)
}

var authenticator auth.OAuth
switch tokens.GetCredentialProvider() {
case entity.CredentialProviderKAS:
authenticator = auth.KasAuthenticator{
BaseURL: "https://api.ngc.nvidia.com",
Email: email,
}
default:
authenticator = auth.Auth0Authenticator{
Audience: "https://brevdev.us.auth0.com/api/v2/",
ClientID: "JaqJRLEsdat5w7Tb0WqmTxzIeqwqepmk",
DeviceCodeEndpoint: "https://brevdev.us.auth0.com/oauth/device/code",
OauthTokenEndpoint: "https://brevdev.us.auth0.com/oauth/token",
}
}

// super annoying. this is needed to make the import stay
_ = color.New(color.FgYellow, color.Bold).SprintFunc()
Expand All @@ -91,7 +117,7 @@ func NewBrevCommand() *cobra.Command { //nolint:funlen,gocognit,gocyclo // defin
).
WithAuth(loginAuth, store.WithDebug(conf.GetDebugHTTP()))

err := loginCmdStore.SetForbiddenStatusRetryHandler(func() error {
err = loginCmdStore.SetForbiddenStatusRetryHandler(func() error {
_, err1 := loginAuth.GetAccessToken()
if err1 != nil {
return breverrors.WrapAndTrace(err1)
Expand Down
7 changes: 7 additions & 0 deletions pkg/entity/entity.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ type AuthTokens struct {
CredentialProvider CredentialProvider `json:"credential_provider"`
}

func (a AuthTokens) GetCredentialProvider() CredentialProvider {
if a.CredentialProvider == CredientialProviderUnspecified {
return CredentialProviderAuth0
}
return a.CredentialProvider
}

type IDEConfig struct {
DefaultWorkingDir string `json:"defaultWorkingDir"`
VSCode VSCodeConfig `json:"vscode"`
Expand Down
2 changes: 1 addition & 1 deletion pkg/store/cloudflared_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
func makeCloudflare() Cloudflared {
conf := config.NewConstants()
fs := files.AppFs
authenticator := auth.Authenticator{
authenticator := auth.Auth0Authenticator{
Audience: "https://brevdev.us.auth0.com/api/v2/",
ClientID: "JaqJRLEsdat5w7Tb0WqmTxzIeqwqepmk",
DeviceCodeEndpoint: "https://brevdev.us.auth0.com/oauth/device/code",
Expand Down

0 comments on commit 8666fae

Please sign in to comment.