Skip to content

Commit

Permalink
auth: refactor and extend authentication structures
Browse files Browse the repository at this point in the history
Refactor Auth0 authentication, add new structs and methods for improved
auth handling. Update related command and entity files.
  • Loading branch information
tmc authored and theFong committed Nov 13, 2024
1 parent df9874c commit 02a9931
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 25 deletions.
4 changes: 4 additions & 0 deletions pkg/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ type LoginAuth struct {
Auth
}

// assert that LoginAuth implements store.Auth
// var _ store.Auth = (*LoginAuth)(nil)

func NewLoginAuth(authStore AuthStore, oauth OAuth) *LoginAuth {
return &LoginAuth{
Auth: *NewAuth(authStore, oauth),
Expand Down Expand Up @@ -257,6 +260,7 @@ type LoginTokens struct {

func (t Auth) getSavedTokensOrNil() (*entity.AuthTokens, error) {
tokens, err := t.authStore.GetAuthTokens()
fmt.Fprintf(os.Stderr, "AuthTokens: %+v\n", tokens)
if err != nil {
switch err.(type) { //nolint:gocritic // like the ability to extend
case *breverrors.CredentialsFileNotFound:
Expand Down
48 changes: 24 additions & 24 deletions pkg/auth/auth0.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,16 @@ var requiredScopes = []string{
"create:organizations", "delete:organizations", "read:organizations", "update:organizations",
}

type Authenticator struct {
type Auth0Authenticator struct {
Audience string
ClientID string
DeviceCodeEndpoint string
OauthTokenEndpoint string
}

var _ OAuth = Authenticator{}
var _ OAuth = Auth0Authenticator{}

type Result struct {
type Auth0Result struct {
Tenant string
Domain string
RefreshToken string
Expand All @@ -61,7 +61,7 @@ type Result struct {
ExpiresIn int64
}

type State struct {
type Auth0State struct {
DeviceCode string `json:"device_code"`
UserCode string `json:"user_code"`
VerificationURI string `json:"verification_uri_complete"`
Expand All @@ -73,11 +73,11 @@ type State struct {

// RequiredScopesMin returns minimum scopes used for login in integration tests.

func (s *State) IntervalDuration() time.Duration {
func (s *Auth0State) IntervalDuration() time.Duration {
return time.Duration(s.Interval+waitThresholdInSeconds) * time.Second
}

func (a Authenticator) DoDeviceAuthFlow(onStateRetrieved func(url string, code string)) (*LoginTokens, error) {
func (a Auth0Authenticator) DoDeviceAuthFlow(onStateRetrieved func(url string, code string)) (*LoginTokens, error) {
ctx := context.Background()

state, err := a.Start(ctx)
Expand All @@ -104,10 +104,10 @@ func (a Authenticator) DoDeviceAuthFlow(onStateRetrieved func(url string, code s
// Start kicks-off the device authentication flow
// by requesting a device code from Auth0,
// The returned state contains the URI for the next step of the flow.
func (a *Authenticator) Start(ctx context.Context) (State, error) {
func (a *Auth0Authenticator) Start(ctx context.Context) (Auth0State, error) {
s, err := a.getDeviceCode(ctx)
if err != nil {
return State{}, breverrors.WrapAndTrace(breverrors.Errorf("cannot get device code %w", err))
return Auth0State{}, breverrors.WrapAndTrace(breverrors.Errorf("cannot get device code %w", err))
}
return s, nil
}
Expand All @@ -129,12 +129,12 @@ func postFormWithContext(ctx context.Context, url string, data url.Values) (*htt
}

// Wait waits until the user is logged in on the browser.
func (a *Authenticator) Wait(ctx context.Context, state State) (Result, error) {
func (a *Auth0Authenticator) Wait(ctx context.Context, state Auth0State) (Auth0Result, error) {
t := time.NewTicker(state.IntervalDuration())
for {
select {
case <-ctx.Done():
return Result{}, breverrors.WrapAndTrace(ctx.Err())
return Auth0Result{}, breverrors.WrapAndTrace(ctx.Err())
case <-t.C:
data := url.Values{
"client_id": {a.ClientID},
Expand All @@ -143,11 +143,11 @@ func (a *Authenticator) Wait(ctx context.Context, state State) (Result, error) {
}
r, err := postFormWithContext(ctx, a.OauthTokenEndpoint, data)
if err != nil {
return Result{}, breverrors.WrapAndTrace(breverrors.Errorf("%w %w", err, breverrors.NetworkErrorMessage))
return Auth0Result{}, breverrors.WrapAndTrace(breverrors.Errorf("%w %w", err, breverrors.NetworkErrorMessage))
}
err = ErrorIfBadHTTP(r, http.StatusForbidden)
if err != nil {
return Result{}, breverrors.WrapAndTrace(err)
return Auth0Result{}, breverrors.WrapAndTrace(err)
}

var res struct {
Expand All @@ -163,25 +163,25 @@ func (a *Authenticator) Wait(ctx context.Context, state State) (Result, error) {

err = json.NewDecoder(r.Body).Decode(&res)
if err != nil {
return Result{}, breverrors.Wrap(err, "cannot decode response")
return Auth0Result{}, breverrors.Wrap(err, "cannot decode response")
}

if res.Error != nil {
if *res.Error == "authorization_pending" {
continue
}
return Result{}, breverrors.WrapAndTrace(errors.New(res.ErrorDescription))
return Auth0Result{}, breverrors.WrapAndTrace(errors.New(res.ErrorDescription))
}

ten, domain, err := parseTenant(res.AccessToken)
if err != nil {
return Result{}, breverrors.Wrap(err, "cannot parse tenant from the given access token")
return Auth0Result{}, breverrors.Wrap(err, "cannot parse tenant from the given access token")
}

if err = r.Body.Close(); err != nil {
return Result{}, breverrors.WrapAndTrace(err)
return Auth0Result{}, breverrors.WrapAndTrace(err)
}
return Result{
return Auth0Result{
RefreshToken: res.RefreshToken,
AccessToken: res.AccessToken,
ExpiresIn: res.ExpiresIn,
Expand All @@ -193,31 +193,31 @@ func (a *Authenticator) Wait(ctx context.Context, state State) (Result, error) {
}
}

func (a *Authenticator) getDeviceCode(ctx context.Context) (State, error) {
func (a *Auth0Authenticator) getDeviceCode(ctx context.Context) (Auth0State, error) {
data := url.Values{
"client_id": {a.ClientID},
"scope": {strings.Join(requiredScopes, " ")},
"audience": {a.Audience},
}
r, err := postFormWithContext(ctx, a.DeviceCodeEndpoint, data)
if err != nil {
return State{}, breverrors.Wrap(err, breverrors.NetworkErrorMessage)
return Auth0State{}, breverrors.Wrap(err, breverrors.NetworkErrorMessage)
}
err = ErrorIfBadHTTP(r)
if err != nil {
return State{}, breverrors.WrapAndTrace(err)
return Auth0State{}, breverrors.WrapAndTrace(err)
}

var res State
var res Auth0State
err = json.NewDecoder(r.Body).Decode(&res)
if err != nil {
return State{}, breverrors.Wrap(err, "cannot decode response")
return Auth0State{}, breverrors.Wrap(err, "cannot decode response")
}
// TODO 9 if status code > 399 handle errors
// {"error":"unauthorized_client","error_description":"Grant type 'urn:ietf:params:oauth:grant-type:device_code' not allowed for the client.","error_uri":"https://auth0.com/docs/clients/client-grant-types"}

if err = r.Body.Close(); err != nil {
return State{}, breverrors.WrapAndTrace(err)
return Auth0State{}, breverrors.WrapAndTrace(err)
}
return res, nil
}
Expand Down Expand Up @@ -255,7 +255,7 @@ type AuthError struct {
}

// https://auth0.com/docs/get-started/authentication-and-authorization-flow/call-your-api-using-the-authorization-code-flow#example-post-to-token-url
func (a Authenticator) GetNewAuthTokensWithRefresh(refreshToken string) (*entity.AuthTokens, error) {
func (a Auth0Authenticator) GetNewAuthTokensWithRefresh(refreshToken string) (*entity.AuthTokens, error) {
payload := url.Values{
"grant_type": {"refresh_token"},
"client_id": {a.ClientID},
Expand Down
6 changes: 5 additions & 1 deletion pkg/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,16 @@ func NewBrevCommand() *cobra.Command { //nolint:funlen,gocognit,gocyclo // defin

conf := config.NewConstants()
fs := files.AppFs
authenticator := auth.Authenticator{

// TODO: this is inappropriately tightly bound to auth0, we should bifurcate our path near here.
// auth.Multi.. auth.Dynamic.. ?
var authenticator auth.OAuth = auth.Auth0Authenticator{
Audience: "https://brevdev.us.auth0.com/api/v2/",
ClientID: "JaqJRLEsdat5w7Tb0WqmTxzIeqwqepmk",
DeviceCodeEndpoint: "https://brevdev.us.auth0.com/oauth/device/code",
OauthTokenEndpoint: "https://brevdev.us.auth0.com/oauth/token",
}

// super annoying. this is needed to make the import stay
_ = color.New(color.FgYellow, color.Bold).SprintFunc()

Expand Down
18 changes: 18 additions & 0 deletions pkg/cmd/refresh/refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,24 @@ func NewCmdRefresh(t *terminal.Terminal, store RefreshStore) *cobra.Command {
return cmd
}

func RunRefreshBetter(store RefreshStore) error {
if err := GetCloudflare(store).DownloadCloudflaredBinaryIfItDNE(); err != nil {
return breverrors.WrapAndTrace(err)
}

cu, err := GetConfigUpdater(store)
if err != nil {
return breverrors.WrapAndTrace(err)
}

err = cu.Run()
if err != nil {
return breverrors.WrapAndTrace(err)
}

return nil
}

func RunRefresh(store RefreshStore) error {
cl := GetCloudflare(store)
err := cl.DownloadCloudflaredBinaryIfItDNE()
Expand Down
12 changes: 12 additions & 0 deletions pkg/entity/entity.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@ import (
"github.com/brevdev/brev-cli/pkg/collections"
)

// CredentialProvider describes which authentication system is resposnible for auth tokens.
type CredentialProvider string

const (
CredientialProviderUnspecified CredentialProvider = ""
CredentialProviderAuth0 CredentialProvider = "auth0"
CredentialProviderKAS CredentialProvider = "kas"
// CredentialProviderStarfleet CredentialProvider = "starfleet"
)

const WorkspaceGroupDevPlane = "devplane-brev-1"

var LegacyWorkspaceGroups = map[string]bool{
Expand All @@ -21,6 +31,8 @@ var LegacyWorkspaceGroups = map[string]bool{
type AuthTokens struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`

CredentialProvider CredentialProvider `json:"credential_provider"`
}

type IDEConfig struct {
Expand Down
23 changes: 23 additions & 0 deletions pkg/store/authtoken.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,20 @@ func (f FileStore) SaveAuthTokens(token entity.AuthTokens) error {
return nil
}

/* tmc rant : this could be way cleaner
// authtokens/authtokens.go
func ReadAuthTokensFromDisk() (*entity.AuthTokens, error) {
home, _ := os.UserHomeDir()
f, err := os.Open(filepath.Join(home, brevDirectory, brevCredentialsFile))
if err != nil {
return nil, err
}
var tok entity.AuthTokens
return &tok, json.NewDecoder(f).Decode(&tok)
}
tokens, err := authtokens.ReadAuthTokensFromDisk()
*/

func (f FileStore) GetAuthTokens() (*entity.AuthTokens, error) {
serviceToken, err := f.GetCurrentWorkspaceServiceToken()
if err == nil && serviceToken != "" {
Expand Down Expand Up @@ -63,6 +77,15 @@ func (f FileStore) GetAuthTokens() (*entity.AuthTokens, error) {
return &token, nil
}

// func (f FileStore) getBrevCredentialsFile() (*string, error) {
// home, err := f.UserHomeDir()
// if err != nil {
// return nil, breverrors.WrapAndTrace(err)
// }
// brevCredentialsFile := path.Join(home, brevDirectory, brevCredentialsFile)
// return &brevCredentialsFile, nil
// }

func (f FileStore) GetCurrentWorkspaceServiceToken() (string, error) {
saTokenFilePath := getServiceTokenFilePath()
// safely check if file exists
Expand Down

0 comments on commit 02a9931

Please sign in to comment.