From 63f5524e85807221e6d702e0fa2d150d8cc5ddc4 Mon Sep 17 00:00:00 2001 From: Lawrence Forooghian Date: Tue, 1 Aug 2023 14:58:32 -0300 Subject: [PATCH] Make crypto functionality tree-shakable We move the crypto functionality into a tree-shakable Crypto module. Then, we split the decodeMessage* functions introduced in 601b46b in two; we introduce new decodeEncryptedMessage* variants which import the Crypto module and are hence able to decrypt encrypted messages, and we change the existing functions to not import the Crypto module and to fail if they are given cipher options. Resolves #1396. --- scripts/moduleReport.js | 42 +++++++- src/common/lib/client/baseclient.ts | 6 +- src/common/lib/client/channel.ts | 7 +- src/common/lib/client/defaultrealtime.ts | 2 +- src/common/lib/client/defaultrest.ts | 2 +- src/common/lib/client/modulesmap.ts | 2 + src/common/lib/types/message.ts | 2 +- src/common/lib/util/utils.ts | 5 + src/platform/web/modules.ts | 4 - src/platform/web/modules/crypto.ts | 10 +- src/platform/web/modules/message.ts | 14 ++- test/browser/modules.test.js | 126 +++++++++++++++++++++-- 12 files changed, 191 insertions(+), 31 deletions(-) diff --git a/scripts/moduleReport.js b/scripts/moduleReport.js index 54c50af390..1937afc34d 100644 --- a/scripts/moduleReport.js +++ b/scripts/moduleReport.js @@ -1,10 +1,18 @@ const esbuild = require('esbuild'); // List of all modules accepted in ModulesMap -const moduleNames = ['Rest']; +const moduleNames = ['Rest', 'Crypto']; -// List of all free-standing functions exported by the library -const functionNames = ['generateRandomKey', 'getDefaultCryptoParams', 'decodeMessage', 'decodeMessages']; +// List of all free-standing functions exported by the library along with the +// ModulesMap entries that we expect them to transitively import +const functions = [ + { name: 'generateRandomKey', transitiveImports: ['Crypto'] }, + { name: 'getDefaultCryptoParams', transitiveImports: ['Crypto'] }, + { name: 'decodeMessage', transitiveImports: [] }, + { name: 'decodeEncryptedMessage', transitiveImports: ['Crypto'] }, + { name: 'decodeMessages', transitiveImports: [] }, + { name: 'decodeEncryptedMessages', transitiveImports: ['Crypto'] }, +]; function formatBytes(bytes) { const kibibytes = bytes / 1024; @@ -39,7 +47,7 @@ const errors = []; console.log(`${baseClient}: ${formatBytes(baseClientSize)}`); // Then display the size of each export together with the base client - [...moduleNames, ...functionNames].forEach((exportName) => { + [...moduleNames, ...Object.values(functions).map((functionData) => functionData.name)].forEach((exportName) => { const size = getImportSize([baseClient, exportName]); console.log(`${baseClient} + ${exportName}: ${formatBytes(size)}`); @@ -51,6 +59,32 @@ const errors = []; }); }); +for (const functionData of functions) { + const { name: functionName, transitiveImports } = functionData; + + // First display the size of the function + const standaloneSize = getImportSize([functionName]); + console.log(`${functionName}: ${formatBytes(standaloneSize)}`); + + // Then display the size of the function together with the modules we expect + // it to transitively import + if (transitiveImports.length > 0) { + const withTransitiveImportsSize = getImportSize([functionName, ...transitiveImports]); + console.log(`${functionName} + ${transitiveImports.join(' + ')}: ${formatBytes(withTransitiveImportsSize)}`); + + if (withTransitiveImportsSize > standaloneSize) { + // Emit an error if the bundle size is increased by adding the modules + // that we expect this function to have transitively imported anyway. + // This seemed like a useful sense check, but it might need tweaking in + // the future if we make future optimisations that mean that the + // standalone functions don’t necessarily import the whole module. + errors.push( + new Error(`Adding ${transitiveImports.join(' + ')} to ${functionName} unexpectedly increases the bundle size.`) + ); + } + } +} + if (errors.length > 0) { for (const error of errors) { console.log(error.message); diff --git a/src/common/lib/client/baseclient.ts b/src/common/lib/client/baseclient.ts index 1fe6aeb7b5..da5ee730e3 100644 --- a/src/common/lib/client/baseclient.ts +++ b/src/common/lib/client/baseclient.ts @@ -13,6 +13,8 @@ import Platform from '../../platform'; import PresenceMessage from '../types/presencemessage'; import { ModulesMap } from './modulesmap'; import { Rest } from './rest'; +import { IUntypedCryptoStatic } from 'common/types/ICryptoStatic'; +import { throwMissingModuleError } from '../util/utils'; type BatchResult = API.Types.BatchResult; type BatchPublishSpec = API.Types.BatchPublishSpec; @@ -37,6 +39,7 @@ class BaseClient { auth: Auth; private readonly _rest: Rest | null; + readonly _Crypto: IUntypedCryptoStatic | null; constructor(options: ClientOptions | string, modules: ModulesMap) { if (!options) { @@ -87,11 +90,12 @@ class BaseClient { this.auth = new Auth(this, normalOptions); this._rest = modules.Rest ? new modules.Rest(this) : null; + this._Crypto = modules.Crypto ?? null; } private get rest(): Rest { if (!this._rest) { - throw new ErrorInfo('Rest module not provided', 400, 40000); + throwMissingModuleError('Rest'); } return this._rest; } diff --git a/src/common/lib/client/channel.ts b/src/common/lib/client/channel.ts index 67d824a6fc..370c32993e 100644 --- a/src/common/lib/client/channel.ts +++ b/src/common/lib/client/channel.ts @@ -10,7 +10,6 @@ import { ChannelOptions } from '../../types/channel'; import { PaginatedResultCallback, StandardCallback } from '../../types/utils'; import BaseClient from './baseclient'; import * as API from '../../../../ably'; -import Platform from 'common/platform'; import Defaults from '../util/defaults'; import { IUntypedCryptoStatic } from 'common/types/ICryptoStatic'; @@ -34,7 +33,7 @@ function allEmptyIds(messages: Array) { function normaliseChannelOptions(Crypto: IUntypedCryptoStatic | null, options?: ChannelOptions) { const channelOptions = options || {}; if (channelOptions.cipher) { - if (!Crypto) throw new Error('Encryption not enabled; use ably.encryption.js instead'); + if (!Crypto) Utils.throwMissingModuleError('Crypto'); const cipher = Crypto.getCipher(channelOptions.cipher); channelOptions.cipher = cipher.cipherParams; channelOptions.channelCipher = cipher.cipher; @@ -61,11 +60,11 @@ class Channel extends EventEmitter { this.name = name; this.basePath = '/channels/' + encodeURIComponent(name); this.presence = new Presence(this); - this.channelOptions = normaliseChannelOptions(Platform.Crypto, channelOptions); + this.channelOptions = normaliseChannelOptions(client._Crypto ?? null, channelOptions); } setOptions(options?: ChannelOptions): void { - this.channelOptions = normaliseChannelOptions(Platform.Crypto, options); + this.channelOptions = normaliseChannelOptions(this.client._Crypto ?? null, options); } history( diff --git a/src/common/lib/client/defaultrealtime.ts b/src/common/lib/client/defaultrealtime.ts index 1cfbab3ab7..e8b66f878e 100644 --- a/src/common/lib/client/defaultrealtime.ts +++ b/src/common/lib/client/defaultrealtime.ts @@ -12,7 +12,7 @@ import { DefaultMessage } from '../types/defaultmessage'; */ export class DefaultRealtime extends BaseRealtime { constructor(options: ClientOptions) { - super(options, allCommonModules); + super(options, { ...allCommonModules, Crypto: DefaultRealtime.Crypto ?? undefined }); } static Utils = Utils; diff --git a/src/common/lib/client/defaultrest.ts b/src/common/lib/client/defaultrest.ts index 324f6d8756..1b7df607a6 100644 --- a/src/common/lib/client/defaultrest.ts +++ b/src/common/lib/client/defaultrest.ts @@ -9,7 +9,7 @@ import { DefaultMessage } from '../types/defaultmessage'; */ export class DefaultRest extends BaseRest { constructor(options: ClientOptions | string) { - super(options, allCommonModules); + super(options, { ...allCommonModules, Crypto: DefaultRest.Crypto ?? undefined }); } private static _Crypto: typeof Platform.Crypto = null; diff --git a/src/common/lib/client/modulesmap.ts b/src/common/lib/client/modulesmap.ts index c56c2d98e7..4133bcc3a3 100644 --- a/src/common/lib/client/modulesmap.ts +++ b/src/common/lib/client/modulesmap.ts @@ -1,7 +1,9 @@ import { Rest } from './rest'; +import { IUntypedCryptoStatic } from '../../types/ICryptoStatic'; export interface ModulesMap { Rest?: typeof Rest; + Crypto?: IUntypedCryptoStatic; } export const allCommonModules: ModulesMap = { Rest }; diff --git a/src/common/lib/types/message.ts b/src/common/lib/types/message.ts index 35697038b9..692c5e069f 100644 --- a/src/common/lib/types/message.ts +++ b/src/common/lib/types/message.ts @@ -48,7 +48,7 @@ function normalizeCipherOptions( options: API.Types.ChannelOptions | null ): ChannelOptions { if (options && options.cipher) { - if (!Crypto) throw new Error('Encryption not enabled; use ably.encryption.js instead'); + if (!Crypto) Utils.throwMissingModuleError('Crypto'); const cipher = Crypto.getCipher(options.cipher); return { cipher: cipher.cipherParams, diff --git a/src/common/lib/util/utils.ts b/src/common/lib/util/utils.ts index 03d546b0a8..788f54d502 100644 --- a/src/common/lib/util/utils.ts +++ b/src/common/lib/util/utils.ts @@ -1,5 +1,6 @@ import Platform from 'common/platform'; import ErrorInfo, { PartialErrorInfo } from 'common/lib/types/errorinfo'; +import { ModulesMap } from '../client/modulesmap'; function randomPosn(arrOrStr: Array | string) { return Math.floor(Math.random() * arrOrStr.length); @@ -551,3 +552,7 @@ export function arrEquals(a: any[], b: any[]) { }) ); } + +export function throwMissingModuleError(moduleName: keyof ModulesMap): never { + throw new ErrorInfo(`${moduleName} module not provided`, 400, 40000); +} diff --git a/src/platform/web/modules.ts b/src/platform/web/modules.ts index ed8456d789..ed91109bfe 100644 --- a/src/platform/web/modules.ts +++ b/src/platform/web/modules.ts @@ -7,7 +7,6 @@ import ErrorInfo from '../../common/lib/types/errorinfo'; // Platform Specific import BufferUtils from './lib/util/bufferutils'; // @ts-ignore -import { createCryptoClass } from './lib/util/crypto'; import Http from './lib/util/http'; import Config from './config'; // @ts-ignore @@ -17,9 +16,6 @@ import { getDefaults } from '../../common/lib/util/defaults'; import WebStorage from './lib/util/webstorage'; import PlatformDefaults from './lib/util/defaults'; -const Crypto = createCryptoClass(Config, BufferUtils); - -Platform.Crypto = Crypto; Platform.BufferUtils = BufferUtils; Platform.Http = Http; Platform.Config = Config; diff --git a/src/platform/web/modules/crypto.ts b/src/platform/web/modules/crypto.ts index 7c9935b05a..2c65d23eb1 100644 --- a/src/platform/web/modules/crypto.ts +++ b/src/platform/web/modules/crypto.ts @@ -1,10 +1,14 @@ +import BufferUtils from '../lib/util/bufferutils'; +import { createCryptoClass } from '../lib/util/crypto'; +import Config from '../config'; import * as API from '../../../../ably'; -import Platform from 'common/platform'; + +export const Crypto = /* @__PURE__@ */ createCryptoClass(Config, BufferUtils); export const generateRandomKey: API.Types.Crypto['generateRandomKey'] = (keyLength) => { - return Platform.Crypto!.generateRandomKey(keyLength); + return Crypto.generateRandomKey(keyLength); }; export const getDefaultCryptoParams: API.Types.Crypto['getDefaultParams'] = (params) => { - return Platform.Crypto!.getDefaultParams(params); + return Crypto.getDefaultParams(params); }; diff --git a/src/platform/web/modules/message.ts b/src/platform/web/modules/message.ts index b908dce825..de8b2ab4f9 100644 --- a/src/platform/web/modules/message.ts +++ b/src/platform/web/modules/message.ts @@ -1,13 +1,21 @@ import * as API from '../../../../ably'; -import Platform from 'common/platform'; +import { Crypto } from './crypto'; import { fromEncoded, fromEncodedArray } from '../../../common/lib/types/message'; // The type assertions for the decode* functions below are due to https://github.com/ably/ably-js/issues/1421 export const decodeMessage = ((obj, options) => { - return fromEncoded(Platform.Crypto, obj, options); + return fromEncoded(null, obj, options); +}) as API.Types.MessageStatic['fromEncoded']; + +export const decodeEncryptedMessage = ((obj, options) => { + return fromEncoded(Crypto, obj, options); }) as API.Types.MessageStatic['fromEncoded']; export const decodeMessages = ((obj, options) => { - return fromEncodedArray(Platform.Crypto, obj, options); + return fromEncodedArray(null, obj, options); +}) as API.Types.MessageStatic['fromEncodedArray']; + +export const decodeEncryptedMessages = ((obj, options) => { + return fromEncodedArray(Crypto, obj, options); }) as API.Types.MessageStatic['fromEncodedArray']; diff --git a/test/browser/modules.test.js b/test/browser/modules.test.js index 8d13564e74..72aabe2c8e 100644 --- a/test/browser/modules.test.js +++ b/test/browser/modules.test.js @@ -5,7 +5,10 @@ import { generateRandomKey, getDefaultCryptoParams, decodeMessage, + decodeEncryptedMessage, decodeMessages, + decodeEncryptedMessages, + Crypto, } from '../../build/modules/index.js'; describe('browser/modules', function () { @@ -80,14 +83,40 @@ describe('browser/modules', function () { }); describe('Message standalone functions', () => { + async function testDecodesMessageData(functionUnderTest) { + const testData = await loadTestData(testResourcesPath + 'crypto-data-128.json'); + + const item = testData.items[1]; + const decoded = await functionUnderTest(item.encoded); + + expect(decoded.data).to.be.an('ArrayBuffer'); + } + describe('decodeMessage', () => { it('decodes a message’s data', async () => { + testDecodesMessageData(decodeMessage); + }); + + it('throws an error when given channel options with a cipher', async () => { const testData = await loadTestData(testResourcesPath + 'crypto-data-128.json'); + const key = BufferUtils.base64Decode(testData.key); + const iv = BufferUtils.base64Decode(testData.iv); - const item = testData.items[1]; - const decoded = await decodeMessage(item.encoded); + let thrownError = null; + try { + await decodeMessage(testData.items[0].encrypted, { cipher: { key, iv } }); + } catch (error) { + thrownError = error; + } + + expect(thrownError).not.to.be.null; + expect(thrownError.message).to.equal('Crypto module not provided'); + }); + }); - expect(decoded.data).to.be.an('ArrayBuffer'); + describe('decodeEncryptedMessage', async () => { + it('decodes a message’s data', async () => { + testDecodesMessageData(decodeEncryptedMessage); }); it('decrypts a message', async () => { @@ -99,7 +128,7 @@ describe('browser/modules', function () { for (const item of testData.items) { const [decodedFromEncoded, decodedFromEncrypted] = await Promise.all([ decodeMessage(item.encoded), - decodeMessage(item.encrypted, { cipher: { key, iv } }), + decodeEncryptedMessage(item.encrypted, { cipher: { key, iv } }), ]); testMessageEquality(decodedFromEncoded, decodedFromEncrypted); @@ -107,15 +136,44 @@ describe('browser/modules', function () { }); }); + async function testDecodesMessagesData(functionUnderTest) { + const testData = await loadTestData(testResourcesPath + 'crypto-data-128.json'); + + const items = [testData.items[1], testData.items[3]]; + const decoded = await functionUnderTest(items.map((item) => item.encoded)); + + expect(decoded[0].data).to.be.an('ArrayBuffer'); + expect(decoded[1].data).to.be.an('array'); + } + describe('decodeMessages', () => { it('decodes messages’ data', async () => { + testDecodesMessagesData(decodeMessages); + }); + + it('throws an error when given channel options with a cipher', async () => { const testData = await loadTestData(testResourcesPath + 'crypto-data-128.json'); + const key = BufferUtils.base64Decode(testData.key); + const iv = BufferUtils.base64Decode(testData.iv); + + let thrownError = null; + try { + await decodeMessages( + testData.items.map((item) => item.encrypted), + { cipher: { key, iv } } + ); + } catch (error) { + thrownError = error; + } - const items = [testData.items[1], testData.items[3]]; - const decoded = await decodeMessages(items.map((item) => item.encoded)); + expect(thrownError).not.to.be.null; + expect(thrownError.message).to.equal('Crypto module not provided'); + }); + }); - expect(decoded[0].data).to.be.an('ArrayBuffer'); - expect(decoded[1].data).to.be.an('array'); + describe('decodeEncryptedMessages', () => { + it('decodes messages’ data', async () => { + testDecodesMessagesData(decodeEncryptedMessages); }); it('decrypts messages', async () => { @@ -126,7 +184,7 @@ describe('browser/modules', function () { const [decodedFromEncoded, decodedFromEncrypted] = await Promise.all([ decodeMessages(testData.items.map((item) => item.encoded)), - decodeMessages( + decodeEncryptedMessages( testData.items.map((item) => item.encrypted), { cipher: { key, iv } } ), @@ -138,4 +196,54 @@ describe('browser/modules', function () { }); }); }); + + describe('Crypto', () => { + describe('without Crypto', () => { + for (const clientClass of [BaseRest, BaseRealtime]) { + describe(clientClass.name, () => { + it('throws an error when given channel options with a cipher', async () => { + const client = new clientClass(ablyClientOptions(), {}); + const key = await generateRandomKey(); + expect(() => client.channels.get('channel', { cipher: { key } })).to.throw('Crypto module not provided'); + }); + }); + } + }); + + describe('with Crypto', () => { + for (const clientClass of [BaseRest, BaseRealtime]) { + describe(clientClass.name, () => { + it('is able to publish encrypted messages', async () => { + const clientOptions = ablyClientOptions(); + + const key = await generateRandomKey(); + + // Publish the message on a channel configured to use encryption, and receive it on one not configured to use encryption + + const rxClient = new BaseRealtime(clientOptions, {}); + const rxChannel = rxClient.channels.get('channel'); + await rxChannel.attach(); + + const rxMessagePromise = new Promise((resolve, _) => rxChannel.subscribe((message) => resolve(message))); + + const encryptionChannelOptions = { cipher: { key } }; + + const txMessage = { name: 'message', data: 'data' }; + const txClient = new clientClass(clientOptions, { Crypto }); + const txChannel = txClient.channels.get('channel', encryptionChannelOptions); + await txChannel.publish(txMessage); + + const rxMessage = await rxMessagePromise; + + // Verify that the message was published with encryption + expect(rxMessage.encoding).to.equal('utf-8/cipher+aes-256-cbc'); + + // Verify that the message was correctly encrypted + const rxMessageDecrypted = await decodeEncryptedMessage(rxMessage, encryptionChannelOptions); + testMessageEquality(rxMessageDecrypted, txMessage); + }); + }); + } + }); + }); });