Skip to content

Commit

Permalink
feat: Implement function to initialise correct adapter from env vars (#…
Browse files Browse the repository at this point in the history
…26)

- [x] GCP
- [x] AWS

Fixes #4.
  • Loading branch information
gnarea authored Apr 3, 2023
1 parent 2e95bd9 commit 234058f
Show file tree
Hide file tree
Showing 8 changed files with 253 additions and 16 deletions.
14 changes: 14 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"@aws-sdk/client-kms": "^3.303.0",
"@google-cloud/kms": "^3.5.0",
"@peculiar/webcrypto": "^1.4.3",
"env-var": "^7.3.0",
"fast-crc32c": "^2.0.0",
"uuid4": "^2.0.3",
"webcrypto-core": "^1.7.6"
Expand Down
4 changes: 4 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
// Do NOT import specific adapters here, because some SDKs do some heavy lifting on import (e.g.,
// call APIs).

export { initKmsProviderFromEnv } from './lib/init';
2 changes: 1 addition & 1 deletion src/lib/aws/AwsKmsRsaPssProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ const SUPPORTED_MODULUS_LENGTHS: readonly number[] = [2048, 3072, 4096];
const REQUEST_OPTIONS = { requestTimeout: 3_000 };

export class AwsKmsRsaPssProvider extends KmsRsaPssProvider {
constructor(protected readonly client: KMSClient) {
constructor(public readonly client: KMSClient) {
super();

// See: https://docs.aws.amazon.com/kms/latest/developerguide/asymmetric-key-specs.html
Expand Down
29 changes: 14 additions & 15 deletions src/lib/gcp/GcpKmsRsaPssProvider.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { KeyManagementServiceClient } from '@google-cloud/kms';
import type { KeyManagementServiceClient } from '@google-cloud/kms';
import { calculate as calculateCRC32C } from 'fast-crc32c';
import { CryptoKey } from 'webcrypto-core';
import uuid4 from 'uuid4';
Expand Down Expand Up @@ -33,7 +33,7 @@ const DEFAULT_DESTROY_SCHEDULED_DURATION_SECONDS = 86_400; // One day; the minim
const REQUEST_OPTIONS = { timeout: 3_000, maxRetries: 10 };

export class GcpKmsRsaPssProvider extends KmsRsaPssProvider {
constructor(public kmsClient: KeyManagementServiceClient, protected kmsConfig: GcpKmsConfig) {
constructor(public client: KeyManagementServiceClient, public config: GcpKmsConfig) {
super();

// See: https://cloud.google.com/kms/docs/algorithms#rsa_signing_algorithms
Expand All @@ -50,10 +50,10 @@ export class GcpKmsRsaPssProvider extends KmsRsaPssProvider {
const cryptoKeyId = uuid4();
await this.createCryptoKey(algorithm, projectId, cryptoKeyId);

const kmsKeyVersionPath = this.kmsClient.cryptoKeyVersionPath(
const kmsKeyVersionPath = this.client.cryptoKeyVersionPath(
projectId,
this.kmsConfig.location,
this.kmsConfig.keyRing,
this.config.location,
this.config.keyRing,
cryptoKeyId,
'1',
);
Expand Down Expand Up @@ -90,7 +90,7 @@ export class GcpKmsRsaPssProvider extends KmsRsaPssProvider {

let keySerialised: ArrayBuffer;
if (format === 'spki') {
keySerialised = await retrieveKMSPublicKey(key.kmsKeyVersionPath, this.kmsClient);
keySerialised = await retrieveKMSPublicKey(key.kmsKeyVersionPath, this.client);
} else if (format === 'raw') {
const pathEncoded = Buffer.from(key.kmsKeyVersionPath);
keySerialised = bufferToArrayBuffer(pathEncoded);
Expand Down Expand Up @@ -122,7 +122,7 @@ export class GcpKmsRsaPssProvider extends KmsRsaPssProvider {

private async getGCPProjectId(): Promise<string> {
// GCP client library already caches the project id.
return this.kmsClient.getProjectId();
return this.client.getProjectId();
}

private async createCryptoKey(
Expand All @@ -131,31 +131,30 @@ export class GcpKmsRsaPssProvider extends KmsRsaPssProvider {
cryptoKeyId: string,
): Promise<void> {
const kmsAlgorithm = getKmsAlgorithm(algorithm);
const keyRingName = this.kmsClient.keyRingPath(
const keyRingName = this.client.keyRingPath(
projectId,
this.kmsConfig.location,
this.kmsConfig.keyRing,
this.config.location,
this.config.keyRing,
);
const destroyScheduledDuration = {
seconds:
this.kmsConfig.destroyScheduledDurationSeconds ??
DEFAULT_DESTROY_SCHEDULED_DURATION_SECONDS,
this.config.destroyScheduledDurationSeconds ?? DEFAULT_DESTROY_SCHEDULED_DURATION_SECONDS,
};
const creationOptions = {
cryptoKey: {
destroyScheduledDuration,
purpose: 'ASYMMETRIC_SIGN',
versionTemplate: {
algorithm: kmsAlgorithm as any,
protectionLevel: this.kmsConfig.protectionLevel,
protectionLevel: this.config.protectionLevel,
},
},
cryptoKeyId,
parent: keyRingName,
skipInitialVersionCreation: false,
} as const;
await wrapGCPCallError(
this.kmsClient.createCryptoKey(creationOptions, REQUEST_OPTIONS),
this.client.createCryptoKey(creationOptions, REQUEST_OPTIONS),
'Failed to create key',
);
}
Expand All @@ -171,7 +170,7 @@ export class GcpKmsRsaPssProvider extends KmsRsaPssProvider {
private async kmsSign(plaintext: Buffer, key: GcpKmsRsaPssPrivateKey): Promise<ArrayBuffer> {
const plaintextChecksum = calculateCRC32C(plaintext);
const [response] = await wrapGCPCallError(
this.kmsClient.asymmetricSign(
this.client.asymmetricSign(
{ data: plaintext, dataCrc32c: { value: plaintextChecksum }, name: key.kmsKeyVersionPath },
REQUEST_OPTIONS,
),
Expand Down
144 changes: 144 additions & 0 deletions src/lib/init.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/* tslint:disable:max-classes-per-file */
import { EnvVarError } from 'env-var';

import { configureMockEnvVars } from '../testUtils/envVars';
import { initKmsProviderFromEnv } from './init';
import { GcpKmsConfig } from './gcp/GcpKmsConfig';
import { KmsError } from './KmsError';

class MockGcpSdkClient {}

class MockAwsSdkClient {
constructor(public readonly config: any) {}
}

let gcpSdkImported = false;
jest.mock('@google-cloud/kms', () => {
gcpSdkImported = true;
return {
KeyManagementServiceClient: MockGcpSdkClient,
};
});
let awsSdkImported = false;
jest.mock('@aws-sdk/client-kms', () => {
awsSdkImported = true;
return { ...jest.requireActual('@aws-sdk/client-kms'), KMSClient: MockAwsSdkClient };
});
beforeEach(() => {
gcpSdkImported = false;
awsSdkImported = false;
});

describe('initKmsProviderFromEnv', () => {
const mockEnvVars = configureMockEnvVars();

const GCP_REQUIRED_ENV_VARS = {
GCP_KMS_LOCATION: 'westeros-3',
GCP_KMS_KEYRING: 'my-precious',
GCP_KMS_PROTECTION_LEVEL: 'HSM',
} as const;

test('Unknown adapter should be refused', async () => {
const invalidAdapter = 'potato';
await expect(() => initKmsProviderFromEnv(invalidAdapter as any)).rejects.toThrowWithMessage(
KmsError,
`Invalid adapter (${invalidAdapter})`,
);
});

test('Adapters should be imported lazily', async () => {
expect(gcpSdkImported).toBeFalse();
expect(awsSdkImported).toBeFalse();

mockEnvVars(GCP_REQUIRED_ENV_VARS);
await initKmsProviderFromEnv('GCP');
expect(gcpSdkImported).toBeTrue();
expect(awsSdkImported).toBeFalse();

await initKmsProviderFromEnv('AWS');
expect(awsSdkImported).toBeTrue();
});

describe('GPC', () => {
beforeEach(() => {
mockEnvVars(GCP_REQUIRED_ENV_VARS);
});

test.each(Object.getOwnPropertyNames(GCP_REQUIRED_ENV_VARS))(
'Environment variable %s should be present',
async (envVar) => {
mockEnvVars({ ...GCP_REQUIRED_ENV_VARS, [envVar]: undefined });

await expect(initKmsProviderFromEnv('GCP')).rejects.toThrowWithMessage(
EnvVarError,
new RegExp(envVar),
);
},
);

test('Provider should be returned if env vars are present', async () => {
const provider = await initKmsProviderFromEnv('GCP');

const { GcpKmsRsaPssProvider } = await import('./gcp/GcpKmsRsaPssProvider');
expect(provider).toBeInstanceOf(GcpKmsRsaPssProvider);
expect(provider).toHaveProperty('client', expect.any(MockGcpSdkClient));
expect(provider).toHaveProperty<GcpKmsConfig>('config', {
keyRing: GCP_REQUIRED_ENV_VARS.GCP_KMS_KEYRING,
location: GCP_REQUIRED_ENV_VARS.GCP_KMS_LOCATION,
protectionLevel: GCP_REQUIRED_ENV_VARS.GCP_KMS_PROTECTION_LEVEL,
});
});

test('GCP_KMS_DESTROY_SCHEDULED_DURATION_SECONDS should be honoured if set', async () => {
const seconds = 123;
mockEnvVars({
...GCP_REQUIRED_ENV_VARS,
GCP_KMS_DESTROY_SCHEDULED_DURATION_SECONDS: seconds.toString(),
});

const provider = await initKmsProviderFromEnv('GCP');

expect(provider).toHaveProperty('config.destroyScheduledDurationSeconds', seconds);
});

test('Invalid GCP_KMS_PROTECTION_LEVEL should be refused', async () => {
mockEnvVars({ ...GCP_REQUIRED_ENV_VARS, GCP_KMS_PROTECTION_LEVEL: 'potato' });

await expect(initKmsProviderFromEnv('GCP')).rejects.toThrowWithMessage(
EnvVarError,
/GCP_KMS_PROTECTION_LEVEL/,
);
});
});

describe('AWS', () => {
test('AWS KMS provider should be output', async () => {
const provider = await initKmsProviderFromEnv('AWS');

const { AwsKmsRsaPssProvider } = await import('./aws/AwsKmsRsaPssProvider');
expect(provider).toBeInstanceOf(AwsKmsRsaPssProvider);
expect(provider).toHaveProperty('client.config', {
endpoint: undefined,
region: undefined,
});
});

test('AWS_KMS_ENDPOINT should be honoured if present', async () => {
const endpoint = 'https://kms.example.com';
mockEnvVars({ AWS_KMS_ENDPOINT: endpoint });

const provider = await initKmsProviderFromEnv('AWS');

expect(provider).toHaveProperty('client.config.endpoint', endpoint);
});

test('AWS_KMS_REGION should be honoured if present', async () => {
const region = 'westeros-3';
mockEnvVars({ AWS_KMS_REGION: region });

const provider = await initKmsProviderFromEnv('AWS');

expect(provider).toHaveProperty('client.config.region', region);
});
});
});
47 changes: 47 additions & 0 deletions src/lib/init.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import { get as getEnvVar } from 'env-var';

import { KmsError } from './KmsError';

import { GcpKmsConfig } from './gcp/GcpKmsConfig';
import { KmsRsaPssProvider } from './KmsRsaPssProvider';
import { GcpKmsRsaPssProvider } from './gcp/GcpKmsRsaPssProvider';

const INITIALISERS: { readonly [key: string]: () => Promise<KmsRsaPssProvider> } = {
AWS: initAwsProvider,
GCP: initGcpProvider,
};

export async function initKmsProviderFromEnv(adapter: string): Promise<KmsRsaPssProvider> {
const init = INITIALISERS[adapter];
if (!init) {
throw new KmsError(`Invalid adapter (${adapter})`);
}
return init();
}

export async function initAwsProvider(): Promise<KmsRsaPssProvider> {
// Avoid import-time side effects (e.g., expensive API calls)
const { AwsKmsRsaPssProvider } = await import('./aws/AwsKmsRsaPssProvider');
const { KMSClient } = await import('@aws-sdk/client-kms');
return new AwsKmsRsaPssProvider(
new KMSClient({
endpoint: getEnvVar('AWS_KMS_ENDPOINT').asString(),
region: getEnvVar('AWS_KMS_REGION').asString(),
}),
);
}

export async function initGcpProvider(): Promise<KmsRsaPssProvider> {
const kmsConfig: GcpKmsConfig = {
location: getEnvVar('GCP_KMS_LOCATION').required().asString(),
keyRing: getEnvVar('GCP_KMS_KEYRING').required().asString(),
protectionLevel: getEnvVar('GCP_KMS_PROTECTION_LEVEL').required().asEnum(['SOFTWARE', 'HSM']),
destroyScheduledDurationSeconds: getEnvVar(
'GCP_KMS_DESTROY_SCHEDULED_DURATION_SECONDS',
).asIntPositive(),
};

// Avoid import-time side effects (e.g., expensive API calls)
const { KeyManagementServiceClient } = await import('@google-cloud/kms');
return new GcpKmsRsaPssProvider(new KeyManagementServiceClient(), kmsConfig);
}
28 changes: 28 additions & 0 deletions src/testUtils/envVars.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import envVar from 'env-var';

interface EnvVarSet {
readonly [key: string]: string | undefined;
}

export function configureMockEnvVars(envVars: EnvVarSet = {}): (envVars: EnvVarSet) => void {
const mockEnvVarGet = jest.spyOn(envVar, 'get');

function setEnvVars(newEnvVars: EnvVarSet): void {
mockEnvVarGet.mockReset();
mockEnvVarGet.mockImplementation((...args: readonly any[]) => {
const originalEnvVar = jest.requireActual('env-var');
const env = originalEnvVar.from(newEnvVars);

return env.get(...args);
});
}

beforeAll(() => setEnvVars(envVars));
beforeEach(() => setEnvVars(envVars));

afterAll(() => {
mockEnvVarGet.mockRestore();
});

return (newEnvVars: EnvVarSet) => setEnvVars(newEnvVars);
}

0 comments on commit 234058f

Please sign in to comment.