diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 2363c896..c53f49af 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -19,6 +19,9 @@ type LoginAuth struct { Auth } +// assert that LoginAuth implements store.Auth +// var _ store.Auth = (*LoginAuth)(nil) + func NewLoginAuth(authStore AuthStore, oauth OAuth) *LoginAuth { return &LoginAuth{ Auth: *NewAuth(authStore, oauth), @@ -257,6 +260,7 @@ type LoginTokens struct { func (t Auth) getSavedTokensOrNil() (*entity.AuthTokens, error) { tokens, err := t.authStore.GetAuthTokens() + fmt.Fprintf(os.Stderr, "AuthTokens: %+v\n", tokens) if err != nil { switch err.(type) { //nolint:gocritic // like the ability to extend case *breverrors.CredentialsFileNotFound: diff --git a/pkg/auth/auth0.go b/pkg/auth/auth0.go index 21b9a9b6..ca8a7fcd 100644 --- a/pkg/auth/auth0.go +++ b/pkg/auth/auth0.go @@ -43,16 +43,16 @@ var requiredScopes = []string{ "create:organizations", "delete:organizations", "read:organizations", "update:organizations", } -type Authenticator struct { +type Auth0Authenticator struct { Audience string ClientID string DeviceCodeEndpoint string OauthTokenEndpoint string } -var _ OAuth = Authenticator{} +var _ OAuth = Auth0Authenticator{} -type Result struct { +type Auth0Result struct { Tenant string Domain string RefreshToken string @@ -61,7 +61,7 @@ type Result struct { ExpiresIn int64 } -type State struct { +type Auth0State struct { DeviceCode string `json:"device_code"` UserCode string `json:"user_code"` VerificationURI string `json:"verification_uri_complete"` @@ -73,11 +73,11 @@ type State struct { // RequiredScopesMin returns minimum scopes used for login in integration tests. -func (s *State) IntervalDuration() time.Duration { +func (s *Auth0State) IntervalDuration() time.Duration { return time.Duration(s.Interval+waitThresholdInSeconds) * time.Second } -func (a Authenticator) DoDeviceAuthFlow(onStateRetrieved func(url string, code string)) (*LoginTokens, error) { +func (a Auth0Authenticator) DoDeviceAuthFlow(onStateRetrieved func(url string, code string)) (*LoginTokens, error) { ctx := context.Background() state, err := a.Start(ctx) @@ -104,10 +104,10 @@ func (a Authenticator) DoDeviceAuthFlow(onStateRetrieved func(url string, code s // Start kicks-off the device authentication flow // by requesting a device code from Auth0, // The returned state contains the URI for the next step of the flow. -func (a *Authenticator) Start(ctx context.Context) (State, error) { +func (a *Auth0Authenticator) Start(ctx context.Context) (Auth0State, error) { s, err := a.getDeviceCode(ctx) if err != nil { - return State{}, breverrors.WrapAndTrace(breverrors.Errorf("cannot get device code %w", err)) + return Auth0State{}, breverrors.WrapAndTrace(breverrors.Errorf("cannot get device code %w", err)) } return s, nil } @@ -129,12 +129,12 @@ func postFormWithContext(ctx context.Context, url string, data url.Values) (*htt } // Wait waits until the user is logged in on the browser. -func (a *Authenticator) Wait(ctx context.Context, state State) (Result, error) { +func (a *Auth0Authenticator) Wait(ctx context.Context, state Auth0State) (Auth0Result, error) { t := time.NewTicker(state.IntervalDuration()) for { select { case <-ctx.Done(): - return Result{}, breverrors.WrapAndTrace(ctx.Err()) + return Auth0Result{}, breverrors.WrapAndTrace(ctx.Err()) case <-t.C: data := url.Values{ "client_id": {a.ClientID}, @@ -143,11 +143,11 @@ func (a *Authenticator) Wait(ctx context.Context, state State) (Result, error) { } r, err := postFormWithContext(ctx, a.OauthTokenEndpoint, data) if err != nil { - return Result{}, breverrors.WrapAndTrace(breverrors.Errorf("%w %w", err, breverrors.NetworkErrorMessage)) + return Auth0Result{}, breverrors.WrapAndTrace(breverrors.Errorf("%w %w", err, breverrors.NetworkErrorMessage)) } err = ErrorIfBadHTTP(r, http.StatusForbidden) if err != nil { - return Result{}, breverrors.WrapAndTrace(err) + return Auth0Result{}, breverrors.WrapAndTrace(err) } var res struct { @@ -163,25 +163,25 @@ func (a *Authenticator) Wait(ctx context.Context, state State) (Result, error) { err = json.NewDecoder(r.Body).Decode(&res) if err != nil { - return Result{}, breverrors.Wrap(err, "cannot decode response") + return Auth0Result{}, breverrors.Wrap(err, "cannot decode response") } if res.Error != nil { if *res.Error == "authorization_pending" { continue } - return Result{}, breverrors.WrapAndTrace(errors.New(res.ErrorDescription)) + return Auth0Result{}, breverrors.WrapAndTrace(errors.New(res.ErrorDescription)) } ten, domain, err := parseTenant(res.AccessToken) if err != nil { - return Result{}, breverrors.Wrap(err, "cannot parse tenant from the given access token") + return Auth0Result{}, breverrors.Wrap(err, "cannot parse tenant from the given access token") } if err = r.Body.Close(); err != nil { - return Result{}, breverrors.WrapAndTrace(err) + return Auth0Result{}, breverrors.WrapAndTrace(err) } - return Result{ + return Auth0Result{ RefreshToken: res.RefreshToken, AccessToken: res.AccessToken, ExpiresIn: res.ExpiresIn, @@ -193,7 +193,7 @@ func (a *Authenticator) Wait(ctx context.Context, state State) (Result, error) { } } -func (a *Authenticator) getDeviceCode(ctx context.Context) (State, error) { +func (a *Auth0Authenticator) getDeviceCode(ctx context.Context) (Auth0State, error) { data := url.Values{ "client_id": {a.ClientID}, "scope": {strings.Join(requiredScopes, " ")}, @@ -201,23 +201,23 @@ func (a *Authenticator) getDeviceCode(ctx context.Context) (State, error) { } r, err := postFormWithContext(ctx, a.DeviceCodeEndpoint, data) if err != nil { - return State{}, breverrors.Wrap(err, breverrors.NetworkErrorMessage) + return Auth0State{}, breverrors.Wrap(err, breverrors.NetworkErrorMessage) } err = ErrorIfBadHTTP(r) if err != nil { - return State{}, breverrors.WrapAndTrace(err) + return Auth0State{}, breverrors.WrapAndTrace(err) } - var res State + var res Auth0State err = json.NewDecoder(r.Body).Decode(&res) if err != nil { - return State{}, breverrors.Wrap(err, "cannot decode response") + return Auth0State{}, breverrors.Wrap(err, "cannot decode response") } // TODO 9 if status code > 399 handle errors // {"error":"unauthorized_client","error_description":"Grant type 'urn:ietf:params:oauth:grant-type:device_code' not allowed for the client.","error_uri":"https://auth0.com/docs/clients/client-grant-types"} if err = r.Body.Close(); err != nil { - return State{}, breverrors.WrapAndTrace(err) + return Auth0State{}, breverrors.WrapAndTrace(err) } return res, nil } @@ -255,7 +255,7 @@ type AuthError struct { } // https://auth0.com/docs/get-started/authentication-and-authorization-flow/call-your-api-using-the-authorization-code-flow#example-post-to-token-url -func (a Authenticator) GetNewAuthTokensWithRefresh(refreshToken string) (*entity.AuthTokens, error) { +func (a Auth0Authenticator) GetNewAuthTokensWithRefresh(refreshToken string) (*entity.AuthTokens, error) { payload := url.Values{ "grant_type": {"refresh_token"}, "client_id": {a.ClientID}, diff --git a/pkg/cmd/cmd.go b/pkg/cmd/cmd.go index ddecd49d..498b9e24 100644 --- a/pkg/cmd/cmd.go +++ b/pkg/cmd/cmd.go @@ -73,12 +73,16 @@ func NewBrevCommand() *cobra.Command { //nolint:funlen,gocognit,gocyclo // defin conf := config.NewConstants() fs := files.AppFs - authenticator := auth.Authenticator{ + + // TODO: this is inappropriately tightly bound to auth0, we should bifurcate our path near here. + // auth.Multi.. auth.Dynamic.. ? + var authenticator auth.OAuth = auth.Auth0Authenticator{ Audience: "https://brevdev.us.auth0.com/api/v2/", ClientID: "JaqJRLEsdat5w7Tb0WqmTxzIeqwqepmk", DeviceCodeEndpoint: "https://brevdev.us.auth0.com/oauth/device/code", OauthTokenEndpoint: "https://brevdev.us.auth0.com/oauth/token", } + // super annoying. this is needed to make the import stay _ = color.New(color.FgYellow, color.Bold).SprintFunc() diff --git a/pkg/cmd/refresh/refresh.go b/pkg/cmd/refresh/refresh.go index 59c078d1..2a056bfb 100644 --- a/pkg/cmd/refresh/refresh.go +++ b/pkg/cmd/refresh/refresh.go @@ -57,6 +57,24 @@ func NewCmdRefresh(t *terminal.Terminal, store RefreshStore) *cobra.Command { return cmd } +func RunRefreshBetter(store RefreshStore) error { + if err := GetCloudflare(store).DownloadCloudflaredBinaryIfItDNE(); err != nil { + return breverrors.WrapAndTrace(err) + } + + cu, err := GetConfigUpdater(store) + if err != nil { + return breverrors.WrapAndTrace(err) + } + + err = cu.Run() + if err != nil { + return breverrors.WrapAndTrace(err) + } + + return nil +} + func RunRefresh(store RefreshStore) error { cl := GetCloudflare(store) err := cl.DownloadCloudflaredBinaryIfItDNE() diff --git a/pkg/entity/entity.go b/pkg/entity/entity.go index daa871d5..1214ae44 100644 --- a/pkg/entity/entity.go +++ b/pkg/entity/entity.go @@ -11,6 +11,16 @@ import ( "github.com/brevdev/brev-cli/pkg/collections" ) +// CredentialProvider describes which authentication system is resposnible for auth tokens. +type CredentialProvider string + +const ( + CredientialProviderUnspecified CredentialProvider = "" + CredentialProviderAuth0 CredentialProvider = "auth0" + CredentialProviderKAS CredentialProvider = "kas" + // CredentialProviderStarfleet CredentialProvider = "starfleet" +) + const WorkspaceGroupDevPlane = "devplane-brev-1" var LegacyWorkspaceGroups = map[string]bool{ @@ -21,6 +31,8 @@ var LegacyWorkspaceGroups = map[string]bool{ type AuthTokens struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` + + CredentialProvider CredentialProvider `json:"credential_provider"` } type IDEConfig struct { diff --git a/pkg/store/authtoken.go b/pkg/store/authtoken.go index 2b7e1ede..c7a19cbb 100644 --- a/pkg/store/authtoken.go +++ b/pkg/store/authtoken.go @@ -34,6 +34,20 @@ func (f FileStore) SaveAuthTokens(token entity.AuthTokens) error { return nil } +/* tmc rant : this could be way cleaner +// authtokens/authtokens.go +func ReadAuthTokensFromDisk() (*entity.AuthTokens, error) { + home, _ := os.UserHomeDir() + f, err := os.Open(filepath.Join(home, brevDirectory, brevCredentialsFile)) + if err != nil { + return nil, err + } + var tok entity.AuthTokens + return &tok, json.NewDecoder(f).Decode(&tok) +} +tokens, err := authtokens.ReadAuthTokensFromDisk() +*/ + func (f FileStore) GetAuthTokens() (*entity.AuthTokens, error) { serviceToken, err := f.GetCurrentWorkspaceServiceToken() if err == nil && serviceToken != "" { @@ -63,6 +77,15 @@ func (f FileStore) GetAuthTokens() (*entity.AuthTokens, error) { return &token, nil } +// func (f FileStore) getBrevCredentialsFile() (*string, error) { +// home, err := f.UserHomeDir() +// if err != nil { +// return nil, breverrors.WrapAndTrace(err) +// } +// brevCredentialsFile := path.Join(home, brevDirectory, brevCredentialsFile) +// return &brevCredentialsFile, nil +// } + func (f FileStore) GetCurrentWorkspaceServiceToken() (string, error) { saTokenFilePath := getServiceTokenFilePath() // safely check if file exists