Skip to content

Commit

Permalink
Add Pagination for IdP Users Fetch (#1210)
Browse files Browse the repository at this point in the history
* Retrieve all workspace users via pagination, excluding custom user attributes

* Retrieve all authentik users via pagination

* Retrieve all Azure AD users via pagination

* Simplify user data appending operation

Reduced unnecessary iteration and used an efficient way to append all users to 'indexedUsers'

* Fix ineffectual assignment to reqURL

* Retrieve all Okta users via pagination

* Add missing GetAccount metrics

* Refactor

* minimize memory allocation

Refactored the memory allocation for the 'users' slice in the Okta IDP code. Previously, the slice was only initialized but not given a size. Now the size of userList is utilized to optimize memory allocation, reducing potential slice resizing and memory re-allocation costs while appending users.

* Add logging for entries received from IdP management

Added informative and debug logging statements in account.go file. Logging has been added to identify the number of entries received from Identity Provider (IdP) management. This will aid in tracking and debugging any potential data ingestion issues.
  • Loading branch information
bcmmbaga authored Oct 11, 2023
1 parent 3c485dc commit 4ad14cb
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 107 deletions.
2 changes: 2 additions & 0 deletions management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,7 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
if err != nil {
return err
}
log.Infof("%d entries received from IdP management", len(userData))

// If the Identity Provider does not support writing AppMetadata,
// in cases like this, we expect it to return all users in an "unset" field.
Expand Down Expand Up @@ -1045,6 +1046,7 @@ func (am *DefaultAccountManager) loadAccount(_ context.Context, accountID interf
if err != nil {
return nil, err
}
log.Debugf("%d entries received from IdP management", len(userData))

dataMap := make(map[string]*idp.UserData, len(userData))
for _, datum := range userData {
Expand Down
78 changes: 42 additions & 36 deletions management/server/idp/authentik.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,34 +251,18 @@ func (am *AuthentikManager) GetUserDataByID(userID string, appMetadata AppMetada

// GetAccount returns all the users for a given profile.
func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) {
ctx, err := am.authenticationContext()
if err != nil {
return nil, err
}

userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Execute()
users, err := am.getAllUsers()
if err != nil {
return nil, err
}
defer resp.Body.Close()

if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetAccount()
}

if resp.StatusCode != http.StatusOK {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get account %s users, statusCode %d", accountID, resp.StatusCode)
}

users := make([]*UserData, 0)
for _, user := range userList.Results {
userData := parseAuthentikUser(user)
userData.AppMetadata.WTAccountID = accountID

users = append(users, userData)
for index, user := range users {
user.AppMetadata.WTAccountID = accountID
users[index] = user
}

return users, nil
Expand All @@ -287,35 +271,57 @@ func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) {
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (am *AuthentikManager) GetAllAccounts() (map[string][]*UserData, error) {
ctx, err := am.authenticationContext()
users, err := am.getAllUsers()
if err != nil {
return nil, err
}

userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Execute()
if err != nil {
return nil, err
}
defer resp.Body.Close()
indexedUsers := make(map[string][]*UserData)
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)

if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetAllAccounts()
}

if resp.StatusCode != http.StatusOK {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
return indexedUsers, nil
}

// getAllUsers returns all users in a Authentik account.
func (am *AuthentikManager) getAllUsers() ([]*UserData, error) {
users := make([]*UserData, 0)

page := int32(1)
for {
ctx, err := am.authenticationContext()
if err != nil {
return nil, err
}

userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Page(page).Execute()
if err != nil {
return nil, err
}
_ = resp.Body.Close()

if resp.StatusCode != http.StatusOK {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
}

for _, user := range userList.Results {
users = append(users, parseAuthentikUser(user))
}

page = int32(userList.GetPagination().Next)
if userList.GetPagination().Next == 0 {
break
}
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
}

indexedUsers := make(map[string][]*UserData)
for _, user := range userList.Results {
userData := parseAuthentikUser(user)
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
}

return indexedUsers, nil
return users, nil
}

// CreateUser creates a new user in authentik Idp and sends an invitation.
Expand Down
82 changes: 49 additions & 33 deletions management/server/idp/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,7 @@ func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) {

// GetAccount returns all the users for a given profile.
func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
q := url.Values{}
q.Add("$select", profileFields)

body, err := am.get("users", q)
users, err := am.getAllUsers()
if err != nil {
return nil, err
}
Expand All @@ -278,18 +275,9 @@ func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
am.appMetrics.IDPMetrics().CountGetAccount()
}

var profiles struct{ Value []azureProfile }
err = am.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}

users := make([]*UserData, 0)
for _, profile := range profiles.Value {
userData := profile.userData()
userData.AppMetadata.WTAccountID = accountID

users = append(users, userData)
for index, user := range users {
user.AppMetadata.WTAccountID = accountID
users[index] = user
}

return users, nil
Expand All @@ -298,30 +286,18 @@ func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) {
q := url.Values{}
q.Add("$select", profileFields)

body, err := am.get("users", q)
users, err := am.getAllUsers()
if err != nil {
return nil, err
}

indexedUsers := make(map[string][]*UserData)
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)

if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetAllAccounts()
}

var profiles struct{ Value []azureProfile }
err = am.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}

indexedUsers := make(map[string][]*UserData)
for _, profile := range profiles.Value {
userData := profile.userData()
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
}

return indexedUsers, nil
}

Expand Down Expand Up @@ -373,14 +349,54 @@ func (am *AzureManager) DeleteUser(userID string) error {
return nil
}

// getAllUsers returns all users in an Azure AD account.
func (am *AzureManager) getAllUsers() ([]*UserData, error) {
users := make([]*UserData, 0)

q := url.Values{}
q.Add("$select", profileFields)
q.Add("$top", "500")

for nextLink := "users"; nextLink != ""; {
body, err := am.get(nextLink, q)
if err != nil {
return nil, err
}

var profiles struct {
Value []azureProfile
NextLink string `json:"@odata.nextLink"`
}
err = am.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}

for _, profile := range profiles.Value {
users = append(users, profile.userData())
}

nextLink = profiles.NextLink
}

return users, nil
}

// get perform Get requests.
func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) {
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return nil, err
}

reqURL := fmt.Sprintf("%s/%s?%s", am.GraphAPIEndpoint, resource, q.Encode())
var reqURL string
if strings.HasPrefix(resource, "https") {
// Already an absolute URL for paging
reqURL = resource
} else {
reqURL = fmt.Sprintf("%s/%s?%s", am.GraphAPIEndpoint, resource, q.Encode())
}

req, err := http.NewRequest(http.MethodGet, reqURL, nil)
if err != nil {
return nil, err
Expand Down
56 changes: 41 additions & 15 deletions management/server/idp/google_workspace.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(_ string, _ AppMetadata)

// GetUserDataByID requests user data from Google Workspace via ID.
func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
user, err := gm.usersService.Get(userID).Projection("full").Do()
user, err := gm.usersService.Get(userID).Do()
if err != nil {
return nil, err
}
Expand All @@ -113,41 +113,67 @@ func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata App

// GetAccount returns all the users for a given profile.
func (gm *GoogleWorkspaceManager) GetAccount(accountID string) ([]*UserData, error) {
usersList, err := gm.usersService.List().Customer(gm.CustomerID).Projection("full").Do()
users, err := gm.getAllUsers()
if err != nil {
return nil, err
}

usersData := make([]*UserData, 0)
for _, user := range usersList.Users {
userData := parseGoogleWorkspaceUser(user)
userData.AppMetadata.WTAccountID = accountID
if gm.appMetrics != nil {
gm.appMetrics.IDPMetrics().CountGetAccount()
}

usersData = append(usersData, userData)
for index, user := range users {
user.AppMetadata.WTAccountID = accountID
users[index] = user
}

return usersData, nil
return users, nil
}

// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (gm *GoogleWorkspaceManager) GetAllAccounts() (map[string][]*UserData, error) {
usersList, err := gm.usersService.List().Customer(gm.CustomerID).Projection("full").Do()
users, err := gm.getAllUsers()
if err != nil {
return nil, err
}

indexedUsers := make(map[string][]*UserData)
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)

if gm.appMetrics != nil {
gm.appMetrics.IDPMetrics().CountGetAllAccounts()
}

indexedUsers := make(map[string][]*UserData)
for _, user := range usersList.Users {
userData := parseGoogleWorkspaceUser(user)
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
return indexedUsers, nil
}

// getAllUsers returns all users in a Google Workspace account filtered by customer ID.
func (gm *GoogleWorkspaceManager) getAllUsers() ([]*UserData, error) {
users := make([]*UserData, 0)
pageToken := ""
for {
call := gm.usersService.List().Customer(gm.CustomerID).MaxResults(500)
if pageToken != "" {
call.PageToken(pageToken)
}

resp, err := call.Do()
if err != nil {
return nil, err
}

for _, user := range resp.Users {
users = append(users, parseGoogleWorkspaceUser(user))
}

pageToken = resp.NextPageToken
if pageToken == "" {
break
}
}

return indexedUsers, nil
return users, nil
}

// CreateUser creates a new user in Google Workspace and sends an invitation.
Expand All @@ -158,7 +184,7 @@ func (gm *GoogleWorkspaceManager) CreateUser(_, _, _, _ string) (*UserData, erro
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (gm *GoogleWorkspaceManager) GetUserByEmail(email string) ([]*UserData, error) {
user, err := gm.usersService.Get(email).Projection("full").Do()
user, err := gm.usersService.Get(email).Do()
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 4ad14cb

Please sign in to comment.