Skip to content

Commit

Permalink
[PM-11764] Implement account switching and sdk initialization (#11472)
Browse files Browse the repository at this point in the history
* feat: update sdk service abstraction with documentation and new `userClient$` function

* feat: add uninitialized user client with cache

* feat: initialize user crypto

* feat: initialize org keys

* fix: org crypto not initializing properly

* feat: avoid creating clients unnecessarily

* chore: remove dev print/subscription

* fix: clean up cache

* chore: update sdk version

* feat: implement clean-up logic (#11504)

* chore: bump sdk version to fix build issues

* chore: bump sdk version to fix build issues

* fix: missing constructor parameters

* refactor: simplify free() and delete() calls

* refactor: use a named function for client creation

* fix: client never freeing after refactor

* fix: broken impl and race condition in tests
  • Loading branch information
coroiu authored Oct 18, 2024
1 parent cdd5bd4 commit c787ecd
Show file tree
Hide file tree
Showing 12 changed files with 355 additions and 21 deletions.
3 changes: 3 additions & 0 deletions apps/browser/src/background/main.background.ts
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,9 @@ export default class MainBackground {
sdkClientFactory,
this.environmentService,
this.platformUtilsService,
this.accountService,
this.kdfConfigService,
this.cryptoService,
this.apiService,
);

Expand Down
3 changes: 3 additions & 0 deletions apps/cli/src/service-container/service-container.ts
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,9 @@ export class ServiceContainer {
sdkClientFactory,
this.environmentService,
this.platformUtilsService,
this.accountService,
this.kdfConfigService,
this.cryptoService,
this.apiService,
customUserAgent,
);
Expand Down
3 changes: 3 additions & 0 deletions libs/angular/src/services/jslib-services.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1334,6 +1334,9 @@ const safeProviders: SafeProvider[] = [
SdkClientFactory,
EnvironmentService,
PlatformUtilsServiceAbstraction,
AccountServiceAbstraction,
KdfConfigServiceAbstraction,
CryptoServiceAbstraction,
ApiServiceAbstraction,
],
}),
Expand Down
7 changes: 5 additions & 2 deletions libs/common/src/auth/abstractions/kdf-config.service.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import { Observable } from "rxjs";

import { UserId } from "../../types/guid";
import { KdfConfig } from "../models/domain/kdf-config";

export abstract class KdfConfigService {
setKdfConfig: (userId: UserId, KdfConfig: KdfConfig) => Promise<void>;
getKdfConfig: () => Promise<KdfConfig>;
abstract setKdfConfig(userId: UserId, KdfConfig: KdfConfig): Promise<void>;
abstract getKdfConfig(): Promise<KdfConfig>;
abstract getKdfConfig$(userId: UserId): Observable<KdfConfig>;
}
6 changes: 5 additions & 1 deletion libs/common/src/auth/services/kdf-config.service.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { firstValueFrom } from "rxjs";
import { firstValueFrom, Observable } from "rxjs";

import { KdfType } from "../../platform/enums/kdf-type.enum";
import { KDF_CONFIG_DISK, StateProvider, UserKeyDefinition } from "../../platform/state";
Expand Down Expand Up @@ -38,4 +38,8 @@ export class KdfConfigService implements KdfConfigServiceAbstraction {
}
return state;
}

getKdfConfig$(userId: UserId): Observable<KdfConfig> {
return this.stateProvider.getUser(userId, KDF_CONFIG).state$;
}
}
26 changes: 25 additions & 1 deletion libs/common/src/platform/abstractions/crypto.service.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { Observable } from "rxjs";

import { EncryptedOrganizationKeyData } from "../../admin-console/models/data/encrypted-organization-key.data";
import { ProfileOrganizationResponse } from "../../admin-console/models/response/profile-organization.response";
import { ProfileProviderOrganizationResponse } from "../../admin-console/models/response/profile-provider-organization.response";
import { ProfileProviderResponse } from "../../admin-console/models/response/profile-provider.response";
Expand All @@ -15,7 +16,7 @@ import {
UserPublicKey,
} from "../../types/key";
import { KeySuffixOptions, HashPurpose } from "../enums";
import { EncString } from "../models/domain/enc-string";
import { EncryptedString, EncString } from "../models/domain/enc-string";
import { SymmetricCryptoKey } from "../models/domain/symmetric-crypto-key";

export class UserPrivateKeyDecryptionFailedError extends Error {
Expand Down Expand Up @@ -288,6 +289,17 @@ export abstract class CryptoService {
*/
abstract userPrivateKey$(userId: UserId): Observable<UserPrivateKey>;

/**
* Gets an observable stream of the given users encrypted private key, will emit null if the user
* doesn't have an encrypted private key at all.
*
* @param userId The user id of the user to get the data for.
*
* @deprecated Temporary function to allow the SDK to be initialized after the login process, it
* will be removed when auth has been migrated to the SDK.
*/
abstract userEncryptedPrivateKey$(userId: UserId): Observable<EncryptedString>;

/**
* Gets an observable stream of the given users decrypted private key with legacy support,
* will emit null if the user doesn't have a UserKey to decrypt the encrypted private key
Expand Down Expand Up @@ -381,6 +393,18 @@ export abstract class CryptoService {
*/
abstract orgKeys$(userId: UserId): Observable<Record<OrganizationId, OrgKey> | null>;

/**
* Gets an observable stream of the given users encrypted organisation keys.
*
* @param userId The user id of the user to get the data for.
*
* @deprecated Temporary function to allow the SDK to be initialized after the login process, it
* will be removed when auth has been migrated to the SDK.
*/
abstract encryptedOrgKeys$(
userId: UserId,
): Observable<Record<OrganizationId, EncryptedOrganizationKeyData>>;

/**
* Gets an observable stream of the users public key. If the user is does not have
* a {@link UserKey} or {@link UserPrivateKey} that is decryptable, this will emit null.
Expand Down
20 changes: 19 additions & 1 deletion libs/common/src/platform/abstractions/sdk/sdk.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,27 @@ import { Observable } from "rxjs";

import { BitwardenClient } from "@bitwarden/sdk-internal";

import { UserId } from "../../../types/guid";

export abstract class SdkService {
client$: Observable<BitwardenClient>;
/**
* Check if the SDK is supported in the current environment.
*/
supported$: Observable<boolean>;

/**
* Retrieve a client initialized without a user.
* This client can only be used for operations that don't require a user context.
*/
client$: Observable<BitwardenClient | undefined>;

/**
* Retrieve a client initialized for a specific user.
* This client can be used for operations that require a user context, such as retrieving ciphers
* and operations involving crypto. It can also be used for operations that don't require a user context.
* @param userId
*/
abstract userClient$(userId: UserId): Observable<BitwardenClient>;

abstract failedToInitialize(): Promise<void>;
}
10 changes: 10 additions & 0 deletions libs/common/src/platform/services/crypto.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,10 @@ export class CryptoService implements CryptoServiceAbstraction {
return this.userPrivateKeyHelper$(userId, false).pipe(map((keys) => keys?.userPrivateKey));
}

userEncryptedPrivateKey$(userId: UserId): Observable<EncryptedString> {
return this.stateProvider.getUser(userId, USER_ENCRYPTED_PRIVATE_KEY).state$;
}

userPrivateKeyWithLegacySupport$(userId: UserId): Observable<UserPrivateKey> {
return this.userPrivateKeyHelper$(userId, true).pipe(map((keys) => keys?.userPrivateKey));
}
Expand Down Expand Up @@ -929,6 +933,12 @@ export class CryptoService implements CryptoServiceAbstraction {
return this.cipherDecryptionKeys$(userId, true).pipe(map((keys) => keys?.orgKeys));
}

encryptedOrgKeys$(
userId: UserId,
): Observable<Record<OrganizationId, EncryptedOrganizationKeyData>> {
return this.stateProvider.getUser(userId, USER_ENCRYPTED_ORGANIZATION_KEYS).state$;
}

cipherDecryptionKeys$(
userId: UserId,
legacySupport: boolean = false,
Expand Down
132 changes: 132 additions & 0 deletions libs/common/src/platform/services/sdk/default-sdk.service.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import { mock, MockProxy } from "jest-mock-extended";
import { BehaviorSubject, firstValueFrom, of } from "rxjs";

import { BitwardenClient } from "@bitwarden/sdk-internal";

import { ApiService } from "../../../abstractions/api.service";
import { AccountInfo, AccountService } from "../../../auth/abstractions/account.service";
import { KdfConfigService } from "../../../auth/abstractions/kdf-config.service";
import { PBKDF2KdfConfig } from "../../../auth/models/domain/kdf-config";
import { UserId } from "../../../types/guid";
import { UserKey } from "../../../types/key";
import { CryptoService } from "../../abstractions/crypto.service";
import { Environment, EnvironmentService } from "../../abstractions/environment.service";
import { PlatformUtilsService } from "../../abstractions/platform-utils.service";
import { SdkClientFactory } from "../../abstractions/sdk/sdk-client-factory";
import { EncryptedString } from "../../models/domain/enc-string";
import { SymmetricCryptoKey } from "../../models/domain/symmetric-crypto-key";

import { DefaultSdkService } from "./default-sdk.service";

describe("DefaultSdkService", () => {
describe("userClient$", () => {
let sdkClientFactory!: MockProxy<SdkClientFactory>;
let environmentService!: MockProxy<EnvironmentService>;
let platformUtilsService!: MockProxy<PlatformUtilsService>;
let accountService!: MockProxy<AccountService>;
let kdfConfigService!: MockProxy<KdfConfigService>;
let cryptoService!: MockProxy<CryptoService>;
let apiService!: MockProxy<ApiService>;
let service!: DefaultSdkService;

let mockClient!: MockProxy<BitwardenClient>;

beforeEach(() => {
sdkClientFactory = mock<SdkClientFactory>();
environmentService = mock<EnvironmentService>();
platformUtilsService = mock<PlatformUtilsService>();
accountService = mock<AccountService>();
kdfConfigService = mock<KdfConfigService>();
cryptoService = mock<CryptoService>();
apiService = mock<ApiService>();

// Can't use `of(mock<Environment>())` for some reason
environmentService.environment$ = new BehaviorSubject(mock<Environment>());

service = new DefaultSdkService(
sdkClientFactory,
environmentService,
platformUtilsService,
accountService,
kdfConfigService,
cryptoService,
apiService,
);

mockClient = mock<BitwardenClient>();
mockClient.crypto.mockReturnValue(mock());
sdkClientFactory.createSdkClient.mockResolvedValue(mockClient);
});

describe("given the user is logged in", () => {
const userId = "user-id" as UserId;

beforeEach(() => {
accountService.accounts$ = of({
[userId]: { email: "email", emailVerified: true, name: "name" } as AccountInfo,
});
kdfConfigService.getKdfConfig$
.calledWith(userId)
.mockReturnValue(of(new PBKDF2KdfConfig()));
cryptoService.userKey$
.calledWith(userId)
.mockReturnValue(of(new SymmetricCryptoKey(new Uint8Array(64)) as UserKey));
cryptoService.userEncryptedPrivateKey$
.calledWith(userId)
.mockReturnValue(of("private-key" as EncryptedString));
cryptoService.encryptedOrgKeys$.calledWith(userId).mockReturnValue(of({}));
});

it("creates an SDK client when called the first time", async () => {
const result = await firstValueFrom(service.userClient$(userId));

expect(result).toBe(mockClient);
expect(sdkClientFactory.createSdkClient).toHaveBeenCalled();
});

it("does not create an SDK client when called the second time with same userId", async () => {
const subject_1 = new BehaviorSubject(undefined);
const subject_2 = new BehaviorSubject(undefined);

// Use subjects to ensure the subscription is kept alive
service.userClient$(userId).subscribe(subject_1);
service.userClient$(userId).subscribe(subject_2);

// Wait for the next tick to ensure all async operations are done
await new Promise(process.nextTick);

expect(subject_1.value).toBe(mockClient);
expect(subject_2.value).toBe(mockClient);
expect(sdkClientFactory.createSdkClient).toHaveBeenCalledTimes(1);
});

it("destroys the SDK client when all subscriptions are closed", async () => {
const subject_1 = new BehaviorSubject(undefined);
const subject_2 = new BehaviorSubject(undefined);
const subscription_1 = service.userClient$(userId).subscribe(subject_1);
const subscription_2 = service.userClient$(userId).subscribe(subject_2);
await new Promise(process.nextTick);

subscription_1.unsubscribe();
subscription_2.unsubscribe();

expect(mockClient.free).toHaveBeenCalledTimes(1);
});

it("destroys the SDK client when the userKey is unset (i.e. lock or logout)", async () => {
const userKey$ = new BehaviorSubject(new SymmetricCryptoKey(new Uint8Array(64)) as UserKey);
cryptoService.userKey$.calledWith(userId).mockReturnValue(userKey$);

const subject = new BehaviorSubject(undefined);
service.userClient$(userId).subscribe(subject);
await new Promise(process.nextTick);

userKey$.next(undefined);
await new Promise(process.nextTick);

expect(mockClient.free).toHaveBeenCalledTimes(1);
expect(subject.value).toBe(undefined);
});
});
});
});
Loading

0 comments on commit c787ecd

Please sign in to comment.