Skip to content

Commit

Permalink
Okta Dep Update (#28121)
Browse files Browse the repository at this point in the history
Update okta to use v5 sdk instead of v2
---------

Co-authored-by: Theron Voran <[email protected]>
  • Loading branch information
kpcraig and tvoran authored Nov 26, 2024
1 parent a2c467c commit 71c2121
Show file tree
Hide file tree
Showing 7 changed files with 222 additions and 70 deletions.
16 changes: 7 additions & 9 deletions builtin/credential/okta/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/cidrutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/okta/okta-sdk-golang/v2/okta"
"github.com/okta/okta-sdk-golang/v5/okta"
"github.com/patrickmn/go-cache"
)

Expand Down Expand Up @@ -118,6 +118,7 @@ func (b *backend) Login(ctx context.Context, req *logical.Request, username, pas
StateToken string `json:"stateToken"`
}

// The okta-sdk-golang API says to construct your own requests for auth, and the Request Executor is gone, so
authReq, err := shim.NewRequest("POST", "authn", map[string]interface{}{
"username": username,
"password": password,
Expand All @@ -129,9 +130,6 @@ func (b *backend) Login(ctx context.Context, req *logical.Request, username, pas
var result authResult
rsp, err := shim.Do(authReq, &result)
if err != nil {
if oe, ok := err.(*okta.Error); ok {
return nil, logical.ErrorResponse("Okta auth failed: %v (code=%v)", err, oe.ErrorCode), nil, nil
}
return nil, logical.ErrorResponse(fmt.Sprintf("Okta auth failed: %v", err)), nil, nil
}
if rsp == nil {
Expand Down Expand Up @@ -370,23 +368,23 @@ func (b *backend) Login(ctx context.Context, req *logical.Request, username, pas
return policies, oktaResponse, allGroups, nil
}

func (b *backend) getOktaGroups(ctx context.Context, client *okta.Client, user *okta.User) ([]string, error) {
groups, resp, err := client.User.ListUserGroups(ctx, user.Id)
func (b *backend) getOktaGroups(ctx context.Context, client *okta.APIClient, user *okta.User) ([]string, error) {
groups, resp, err := client.UserAPI.ListUserGroups(ctx, user.GetId()).Execute()
if err != nil {
return nil, err
}
oktaGroups := make([]string, 0, len(groups))
for _, group := range groups {
oktaGroups = append(oktaGroups, group.Profile.Name)
oktaGroups = append(oktaGroups, group.Profile.GetName())
}
for resp.HasNextPage() {
var nextGroups []*okta.Group
resp, err = resp.Next(ctx, &nextGroups)
resp, err = resp.Next(&nextGroups)
if err != nil {
return nil, err
}
for _, group := range nextGroups {
oktaGroups = append(oktaGroups, group.Profile.Name)
oktaGroups = append(oktaGroups, group.Profile.GetName())
}
}
if b.Logger().IsDebug() {
Expand Down
34 changes: 16 additions & 18 deletions builtin/credential/okta/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ import (
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/helper/policyutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/okta/okta-sdk-golang/v2/okta"
"github.com/okta/okta-sdk-golang/v2/okta/query"
"github.com/okta/okta-sdk-golang/v5/okta"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -115,15 +114,15 @@ func TestBackend_Config(t *testing.T) {

func createOktaGroups(t *testing.T, username string, token string, org string) []string {
orgURL := "https://" + org + "." + previewBaseURL
ctx, client, err := okta.NewClient(context.Background(), okta.WithOrgUrl(orgURL), okta.WithToken(token))
cfg, err := okta.NewConfiguration(okta.WithOrgUrl(orgURL), okta.WithToken(token))
require.Nil(t, err)
client := okta.NewAPIClient(cfg)
ctx := context.Background()

users, _, err := client.User.ListUsers(ctx, &query.Params{
Q: username,
})
users, _, err := client.UserAPI.ListUsers(ctx).Q(username).Execute()
require.Nil(t, err)
require.Len(t, users, 1)
userID := users[0].Id
userID := users[0].GetId()
var groupIDs []string

// Verify that login's call to list the groups of the user logging in will page
Expand All @@ -133,38 +132,37 @@ func createOktaGroups(t *testing.T, username string, token string, org string) [
// only 200 results are returned for most orgs."
for i := 0; i < 201; i++ {
name := fmt.Sprintf("TestGroup%d", i)
groups, _, err := client.Group.ListGroups(ctx, &query.Params{
Q: name,
})
groups, _, err := client.GroupAPI.ListGroups(ctx).Q(name).Execute()
require.Nil(t, err)

var groupID string
if len(groups) == 0 {
group, _, err := client.Group.CreateGroup(ctx, okta.Group{
group, _, err := client.GroupAPI.CreateGroup(ctx).Group(okta.Group{
Profile: &okta.GroupProfile{
Name: fmt.Sprintf("TestGroup%d", i),
Name: okta.PtrString(fmt.Sprintf("TestGroup%d", i)),
},
})
}).Execute()
require.Nil(t, err)
groupID = group.Id
groupID = group.GetId()
} else {
groupID = groups[0].Id
groupID = groups[0].GetId()
}
groupIDs = append(groupIDs, groupID)

_, err = client.Group.AddUserToGroup(ctx, groupID, userID)
_, err = client.GroupAPI.AssignUserToGroup(ctx, groupID, userID).Execute()
require.Nil(t, err)
}
return groupIDs
}

func deleteOktaGroups(t *testing.T, token string, org string, groupIDs []string) {
orgURL := "https://" + org + "." + previewBaseURL
ctx, client, err := okta.NewClient(context.Background(), okta.WithOrgUrl(orgURL), okta.WithToken(token))
cfg, err := okta.NewConfiguration(okta.WithOrgUrl(orgURL), okta.WithToken(token))
require.Nil(t, err)
client := okta.NewAPIClient(cfg)

for _, groupID := range groupIDs {
_, err := client.Group.DeleteGroup(ctx, groupID)
_, err := client.GroupAPI.DeleteGroup(context.Background(), groupID).Execute()
require.Nil(t, err)
}
}
Expand Down
145 changes: 135 additions & 10 deletions builtin/credential/okta/path_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,24 @@
package okta

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"

gocache "github.com/patrickmn/go-cache"

oktaold "github.com/chrismalek/oktasdk-go/okta"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/tokenutil"
"github.com/hashicorp/vault/sdk/logical"
oktanew "github.com/okta/okta-sdk-golang/v2/okta"
oktanew "github.com/okta/okta-sdk-golang/v5/okta"
)

const (
Expand Down Expand Up @@ -290,36 +295,133 @@ func (b *backend) pathConfigExistenceCheck(ctx context.Context, req *logical.Req
}

type oktaShim interface {
Client() (*oktanew.Client, context.Context)
Client() (*oktanew.APIClient, context.Context)
NewRequest(method string, url string, body interface{}) (*http.Request, error)
Do(req *http.Request, v interface{}) (interface{}, error)
}

type oktaShimNew struct {
client *oktanew.Client
cfg *oktanew.Configuration
client *oktanew.APIClient
ctx context.Context
cache *gocache.Cache // cache used to hold authorization values created for NewRequests
}

func (new *oktaShimNew) Client() (*oktanew.Client, context.Context) {
func (new *oktaShimNew) Client() (*oktanew.APIClient, context.Context) {
return new.client, new.ctx
}

func (new *oktaShimNew) NewRequest(method string, url string, body interface{}) (*http.Request, error) {
if !strings.HasPrefix(url, "/") {
url = "/api/v1/" + url
}
return new.client.GetRequestExecutor().NewRequest(method, url, body)

// reimplementation of RequestExecutor.NewRequest() in v2 of okta-golang-sdk
var buff io.ReadWriter
if body != nil {
switch v := body.(type) {
case []byte:
buff = bytes.NewBuffer(v)
case *bytes.Buffer:
buff = v
default:
buff = &bytes.Buffer{}
// need to create an encoder specifically to disable html escaping
encoder := json.NewEncoder(buff)
encoder.SetEscapeHTML(false)
err := encoder.Encode(body)
if err != nil {
return nil, err
}
}
}

url = new.cfg.Okta.Client.OrgUrl + url
req, err := http.NewRequest(method, url, buff)
if err != nil {
return nil, err
}

// construct an authorization header for the request using our okta config
var auth oktanew.Authorization
// I think the only usage of the shim is in credential/okta/backend.go, and in that case, the
// AuthorizationMode is only ever SSWS (since OktaClient() below never overrides the default authorization
// mode. This function will faithfully replicate the old RequestExecutor code, though.
switch new.cfg.Okta.Client.AuthorizationMode {
case "SSWS":
auth = oktanew.NewSSWSAuth(new.cfg.Okta.Client.Token, req)
case "Bearer":
auth = oktanew.NewBearerAuth(new.cfg.Okta.Client.Token, req)
case "PrivateKey":
auth = oktanew.NewPrivateKeyAuth(oktanew.PrivateKeyAuthConfig{
TokenCache: new.cache,
HttpClient: new.cfg.HTTPClient,
PrivateKeySigner: new.cfg.PrivateKeySigner,
PrivateKey: new.cfg.Okta.Client.PrivateKey,
PrivateKeyId: new.cfg.Okta.Client.PrivateKeyId,
ClientId: new.cfg.Okta.Client.ClientId,
OrgURL: new.cfg.Okta.Client.OrgUrl,
Scopes: new.cfg.Okta.Client.Scopes,
MaxRetries: new.cfg.Okta.Client.RateLimit.MaxRetries,
MaxBackoff: new.cfg.Okta.Client.RateLimit.MaxBackoff,
Req: req,
})
case "JWT":
auth = oktanew.NewJWTAuth(oktanew.JWTAuthConfig{
TokenCache: new.cache,
HttpClient: new.cfg.HTTPClient,
OrgURL: new.cfg.Okta.Client.OrgUrl,
Scopes: new.cfg.Okta.Client.Scopes,
ClientAssertion: new.cfg.Okta.Client.ClientAssertion,
MaxRetries: new.cfg.Okta.Client.RateLimit.MaxRetries,
MaxBackoff: new.cfg.Okta.Client.RateLimit.MaxBackoff,
Req: req,
})
default:
return nil, fmt.Errorf("unknown authorization mode %v", new.cfg.Okta.Client.AuthorizationMode)
}

// Authorize adds a header based on the contents of the Authorization struct
err = auth.Authorize("POST", url)
if err != nil {
return nil, err
}

req.Header.Add("Accept", "application/json")

if body != nil {
req.Header.Set("Content-Type", "application/json")
}

return req, nil
}

func (new *oktaShimNew) Do(req *http.Request, v interface{}) (interface{}, error) {
return new.client.GetRequestExecutor().Do(new.ctx, req, v)
resp, err := new.cfg.HTTPClient.Do(req)
if err != nil {
return nil, err
}

if resp.Body == nil {
return nil, nil
}
defer resp.Body.Close()

bt, err := io.ReadAll(resp.Body)
err = json.Unmarshal(bt, v)
if err != nil {
return nil, err
}

// as far as i can tell, we only use the first return to check if it is nil, and assume that means an error happened.
return resp, nil
}

type oktaShimOld struct {
client *oktaold.Client
}

func (new *oktaShimOld) Client() (*oktanew.Client, context.Context) {
func (new *oktaShimOld) Client() (*oktanew.APIClient, context.Context) {
return nil, nil
}

Expand All @@ -331,7 +433,25 @@ func (new *oktaShimOld) Do(req *http.Request, v interface{}) (interface{}, error
return new.client.Do(req, v)
}

// OktaClient creates a basic okta client connection
func (c *ConfigEntry) OktaConfiguration(ctx context.Context) (*oktanew.Configuration, error) {
baseURL := defaultBaseURL
if c.Production != nil {
if !*c.Production {
baseURL = previewBaseURL
}
}
if c.BaseURL != "" {
baseURL = c.BaseURL
}

cfg, err := oktanew.NewConfiguration(oktanew.WithOrgUrl("https://"+c.Org+"."+baseURL), oktanew.WithToken(c.Token))
if err != nil {
return nil, err
}
return cfg, nil
}

// OktaClient returns an OktaShim, based on the presence of a token in the ConfigEntry.
func (c *ConfigEntry) OktaClient(ctx context.Context) (oktaShim, error) {
baseURL := defaultBaseURL
if c.Production != nil {
Expand All @@ -344,13 +464,18 @@ func (c *ConfigEntry) OktaClient(ctx context.Context) (oktaShim, error) {
}

if c.Token != "" {
ctx, client, err := oktanew.NewClient(ctx,
cfg, err := oktanew.NewConfiguration(
oktanew.WithOrgUrl("https://"+c.Org+"."+baseURL),
oktanew.WithToken(c.Token))
if err != nil {
return nil, err
}
return &oktaShimNew{client, ctx}, nil
return &oktaShimNew{
cfg: cfg,
client: oktanew.NewAPIClient(cfg),
ctx: ctx,
cache: gocache.New(gocache.DefaultExpiration, 1*time.Second),
}, nil
}
client, err := oktaold.NewClientWithDomain(cleanhttp.DefaultClient(), c.Org, baseURL, "")
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions changelog/28121.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:improvement
auth/okta: update to okta sdk v5
```
9 changes: 8 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ require (
github.com/mitchellh/reflectwalk v1.0.2
github.com/ncw/swift v1.0.47
github.com/oklog/run v1.1.0
github.com/okta/okta-sdk-golang/v2 v2.20.0
github.com/okta/okta-sdk-golang/v5 v5.0.2
github.com/oracle/oci-go-sdk v24.3.0+incompatible
github.com/ory/dockertest v3.3.5+incompatible
github.com/ory/dockertest/v3 v3.10.0
Expand Down Expand Up @@ -230,11 +230,18 @@ require (
require (
cel.dev/expr v0.15.0 // indirect
cloud.google.com/go/longrunning v0.6.0 // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect
github.com/fsnotify/fsnotify v1.6.0 // indirect
github.com/fxamacker/cbor/v2 v2.7.0 // indirect
github.com/go-viper/mapstructure/v2 v2.1.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 // indirect
github.com/hashicorp/go-secure-stdlib/httputil v0.1.0 // indirect
github.com/lestrrat-go/backoff/v2 v2.0.8 // indirect
github.com/lestrrat-go/blackmagic v1.0.2 // indirect
github.com/lestrrat-go/httpcc v1.0.1 // indirect
github.com/lestrrat-go/iter v1.0.2 // indirect
github.com/lestrrat-go/jwx v1.2.29 // indirect
github.com/lestrrat-go/option v1.0.1 // indirect
github.com/mitchellh/go-testing-interface v1.14.1 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/sys/userns v0.1.0 // indirect
Expand Down
Loading

0 comments on commit 71c2121

Please sign in to comment.