Skip to content

Commit

Permalink
Optimize Cache and IDP Management (netbirdio#1147)
Browse files Browse the repository at this point in the history
This pull request modifies the IdP and cache manager(s) to prevent the sending of app metadata
 to the upstream IDP on self-hosted instances. 
As a result, the IdP will now load all users from the IdP without filtering based on accountID.

We disable user invites as the administrator's own IDP system manages them.
  • Loading branch information
bcmmbaga authored Oct 3, 2023
1 parent 96eaa52 commit 09477ce
Show file tree
Hide file tree
Showing 12 changed files with 182 additions and 1,670 deletions.
21 changes: 21 additions & 0 deletions management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,27 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
return err
}

// If the Identity Provider does not support writing AppMetadata,
// in cases like this, we expect it to return all users in an "unset" field.
// We iterate over the users in the "unset" field, look up their AccountID in our store, and
// update their AppMetadata with the AccountID.
if unsetData, ok := userData[idp.UnsetAccountID]; ok {
for _, user := range unsetData {
accountID, err := am.Store.GetAccountByUser(user.ID)
if err == nil {
data := userData[accountID.Id]
if data == nil {
data = make([]*idp.UserData, 0, 1)
}

user.AppMetadata.WTAccountID = accountID.Id

userData[accountID.Id] = append(data, user)
}
}
}
delete(userData, idp.UnsetAccountID)

for accountID, users := range userData {
err = am.cacheManager.Set(am.ctx, accountID, users, cacheStore.WithExpiration(cacheEntryExpiration()))
if err != nil {
Expand Down
178 changes: 16 additions & 162 deletions management/server/idp/authentik.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,47 +210,7 @@ func (ac *AuthentikCredentials) Authenticate() (JWTToken, error) {
}

// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (am *AuthentikManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
ctx, err := am.authenticationContext()
if err != nil {
return err
}

userPk, err := strconv.ParseInt(userID, 10, 32)
if err != nil {
return err
}

var pendingInvite bool
if appMetadata.WTPendingInvite != nil {
pendingInvite = *appMetadata.WTPendingInvite
}

patchedUserReq := api.PatchedUserRequest{
Attributes: map[string]interface{}{
wtAccountID: appMetadata.WTAccountID,
wtPendingInvite: pendingInvite,
},
}
_, resp, err := am.apiClient.CoreApi.CoreUsersPartialUpdate(ctx, int32(userPk)).
PatchedUserRequest(patchedUserReq).
Execute()
if err != nil {
return err
}
defer resp.Body.Close()

if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
}

if resp.StatusCode != http.StatusOK {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to update user %s, statusCode %d", userID, resp.StatusCode)
}

func (am *AuthentikManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
return nil
}

Expand Down Expand Up @@ -283,7 +243,10 @@ func (am *AuthentikManager) GetUserDataByID(userID string, appMetadata AppMetada
return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode)
}

return parseAuthentikUser(*user)
userData := parseAuthentikUser(*user)
userData.AppMetadata = appMetadata

return userData, nil
}

// GetAccount returns all the users for a given profile.
Expand All @@ -293,8 +256,7 @@ func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) {
return nil, err
}

accountFilter := fmt.Sprintf("{%q:%q}", wtAccountID, accountID)
userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Attributes(accountFilter).Execute()
userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Execute()
if err != nil {
return nil, err
}
Expand All @@ -313,10 +275,9 @@ func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) {

users := make([]*UserData, 0)
for _, user := range userList.Results {
userData, err := parseAuthentikUser(user)
if err != nil {
return nil, err
}
userData := parseAuthentikUser(user)
userData.AppMetadata.WTAccountID = accountID

users = append(users, userData)
}

Expand Down Expand Up @@ -350,65 +311,16 @@ func (am *AuthentikManager) GetAllAccounts() (map[string][]*UserData, error) {

indexedUsers := make(map[string][]*UserData)
for _, user := range userList.Results {
userData, err := parseAuthentikUser(user)
if err != nil {
return nil, err
}

accountID := userData.AppMetadata.WTAccountID
if accountID != "" {
if _, ok := indexedUsers[accountID]; !ok {
indexedUsers[accountID] = make([]*UserData, 0)
}
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
}
userData := parseAuthentikUser(user)
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
}

return indexedUsers, nil
}

// CreateUser creates a new user in authentik Idp and sends an invitation.
func (am *AuthentikManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) {
ctx, err := am.authenticationContext()
if err != nil {
return nil, err
}

groupID, err := am.getUserGroupByName("netbird")
if err != nil {
return nil, err
}

defaultBoolValue := true
createUserRequest := api.UserRequest{
Email: &email,
Name: name,
IsActive: &defaultBoolValue,
Groups: []string{groupID},
Username: email,
Attributes: map[string]interface{}{
wtAccountID: accountID,
wtPendingInvite: &defaultBoolValue,
},
}
user, resp, err := am.apiClient.CoreApi.CoreUsersCreate(ctx).UserRequest(createUserRequest).Execute()
if err != nil {
return nil, err
}
defer resp.Body.Close()

if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountCreateUser()
}

if resp.StatusCode != http.StatusCreated {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode)
}

return parseAuthentikUser(*user)
func (am *AuthentikManager) CreateUser(_, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented")
}

// GetUserByEmail searches users with a given email.
Expand Down Expand Up @@ -438,11 +350,7 @@ func (am *AuthentikManager) GetUserByEmail(email string) ([]*UserData, error) {

users := make([]*UserData, 0)
for _, user := range userList.Results {
userData, err := parseAuthentikUser(user)
if err != nil {
return nil, err
}
users = append(users, userData)
users = append(users, parseAuthentikUser(user))
}

return users, nil
Expand Down Expand Up @@ -501,64 +409,10 @@ func (am *AuthentikManager) authenticationContext() (context.Context, error) {
return context.WithValue(context.Background(), api.ContextAPIKeys, value), nil
}

// getUserGroupByName retrieves the user group for assigning new users.
// If the group is not found, a new group with the specified name will be created.
func (am *AuthentikManager) getUserGroupByName(name string) (string, error) {
ctx, err := am.authenticationContext()
if err != nil {
return "", err
}

groupList, resp, err := am.apiClient.CoreApi.CoreGroupsList(ctx).Name(name).Execute()
if err != nil {
return "", err
}
defer resp.Body.Close()

if groupList != nil {
if len(groupList.Results) > 0 {
return groupList.Results[0].Pk, nil
}
}

createGroupRequest := api.GroupRequest{Name: name}
group, resp, err := am.apiClient.CoreApi.CoreGroupsCreate(ctx).GroupRequest(createGroupRequest).Execute()
if err != nil {
return "", err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusCreated {
return "", fmt.Errorf("unable to create user group, statusCode: %d", resp.StatusCode)
}

return group.Pk, nil
}

func parseAuthentikUser(user api.User) (*UserData, error) {
var attributes struct {
AccountID string `json:"wt_account_id"`
PendingInvite bool `json:"wt_pending_invite"`
}

helper := JsonParser{}
buf, err := helper.Marshal(user.Attributes)
if err != nil {
return nil, err
}

err = helper.Unmarshal(buf, &attributes)
if err != nil {
return nil, err
}

func parseAuthentikUser(user api.User) *UserData {
return &UserData{
Email: *user.Email,
Name: user.Name,
ID: strconv.FormatInt(int64(user.Pk), 10),
AppMetadata: AppMetadata{
WTAccountID: attributes.AccountID,
WTPendingInvite: &attributes.PendingInvite,
},
}, nil
}
}
Loading

0 comments on commit 09477ce

Please sign in to comment.