Skip to content

Commit

Permalink
Make crypto functionality tree-shakable
Browse files Browse the repository at this point in the history
We move the crypto functionality into a tree-shakable Crypto module.

Then, we split the decodeMessage* functions introduced in 601b46b in
two; we introduce new decodeEncryptedMessage* variants which import the
Crypto module and are hence able to decrypt encrypted messages, and we
change the existing functions to not import the Crypto module and to
fail if they are given cipher options.

Resolves #1396.
  • Loading branch information
lawrence-forooghian committed Nov 6, 2023
1 parent 601b46b commit 63f5524
Show file tree
Hide file tree
Showing 12 changed files with 191 additions and 31 deletions.
42 changes: 38 additions & 4 deletions scripts/moduleReport.js
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
const esbuild = require('esbuild');

// List of all modules accepted in ModulesMap
const moduleNames = ['Rest'];
const moduleNames = ['Rest', 'Crypto'];

// List of all free-standing functions exported by the library
const functionNames = ['generateRandomKey', 'getDefaultCryptoParams', 'decodeMessage', 'decodeMessages'];
// List of all free-standing functions exported by the library along with the
// ModulesMap entries that we expect them to transitively import
const functions = [
{ name: 'generateRandomKey', transitiveImports: ['Crypto'] },
{ name: 'getDefaultCryptoParams', transitiveImports: ['Crypto'] },
{ name: 'decodeMessage', transitiveImports: [] },
{ name: 'decodeEncryptedMessage', transitiveImports: ['Crypto'] },
{ name: 'decodeMessages', transitiveImports: [] },
{ name: 'decodeEncryptedMessages', transitiveImports: ['Crypto'] },
];

function formatBytes(bytes) {
const kibibytes = bytes / 1024;
Expand Down Expand Up @@ -39,7 +47,7 @@ const errors = [];
console.log(`${baseClient}: ${formatBytes(baseClientSize)}`);

// Then display the size of each export together with the base client
[...moduleNames, ...functionNames].forEach((exportName) => {
[...moduleNames, ...Object.values(functions).map((functionData) => functionData.name)].forEach((exportName) => {
const size = getImportSize([baseClient, exportName]);
console.log(`${baseClient} + ${exportName}: ${formatBytes(size)}`);

Expand All @@ -51,6 +59,32 @@ const errors = [];
});
});

for (const functionData of functions) {
const { name: functionName, transitiveImports } = functionData;

// First display the size of the function
const standaloneSize = getImportSize([functionName]);
console.log(`${functionName}: ${formatBytes(standaloneSize)}`);

// Then display the size of the function together with the modules we expect
// it to transitively import
if (transitiveImports.length > 0) {
const withTransitiveImportsSize = getImportSize([functionName, ...transitiveImports]);
console.log(`${functionName} + ${transitiveImports.join(' + ')}: ${formatBytes(withTransitiveImportsSize)}`);

if (withTransitiveImportsSize > standaloneSize) {
// Emit an error if the bundle size is increased by adding the modules
// that we expect this function to have transitively imported anyway.
// This seemed like a useful sense check, but it might need tweaking in
// the future if we make future optimisations that mean that the
// standalone functions don’t necessarily import the whole module.
errors.push(
new Error(`Adding ${transitiveImports.join(' + ')} to ${functionName} unexpectedly increases the bundle size.`)
);
}
}
}

if (errors.length > 0) {
for (const error of errors) {
console.log(error.message);
Expand Down
6 changes: 5 additions & 1 deletion src/common/lib/client/baseclient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import Platform from '../../platform';
import PresenceMessage from '../types/presencemessage';
import { ModulesMap } from './modulesmap';
import { Rest } from './rest';
import { IUntypedCryptoStatic } from 'common/types/ICryptoStatic';
import { throwMissingModuleError } from '../util/utils';

type BatchResult<T> = API.Types.BatchResult<T>;
type BatchPublishSpec = API.Types.BatchPublishSpec;
Expand All @@ -37,6 +39,7 @@ class BaseClient {
auth: Auth;

private readonly _rest: Rest | null;
readonly _Crypto: IUntypedCryptoStatic | null;

constructor(options: ClientOptions | string, modules: ModulesMap) {
if (!options) {
Expand Down Expand Up @@ -87,11 +90,12 @@ class BaseClient {
this.auth = new Auth(this, normalOptions);

this._rest = modules.Rest ? new modules.Rest(this) : null;
this._Crypto = modules.Crypto ?? null;
}

private get rest(): Rest {
if (!this._rest) {
throw new ErrorInfo('Rest module not provided', 400, 40000);
throwMissingModuleError('Rest');
}
return this._rest;
}
Expand Down
7 changes: 3 additions & 4 deletions src/common/lib/client/channel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import { ChannelOptions } from '../../types/channel';
import { PaginatedResultCallback, StandardCallback } from '../../types/utils';
import BaseClient from './baseclient';
import * as API from '../../../../ably';
import Platform from 'common/platform';
import Defaults from '../util/defaults';
import { IUntypedCryptoStatic } from 'common/types/ICryptoStatic';

Expand All @@ -34,7 +33,7 @@ function allEmptyIds(messages: Array<Message>) {
function normaliseChannelOptions(Crypto: IUntypedCryptoStatic | null, options?: ChannelOptions) {
const channelOptions = options || {};
if (channelOptions.cipher) {
if (!Crypto) throw new Error('Encryption not enabled; use ably.encryption.js instead');
if (!Crypto) Utils.throwMissingModuleError('Crypto');
const cipher = Crypto.getCipher(channelOptions.cipher);
channelOptions.cipher = cipher.cipherParams;
channelOptions.channelCipher = cipher.cipher;
Expand All @@ -61,11 +60,11 @@ class Channel extends EventEmitter {
this.name = name;
this.basePath = '/channels/' + encodeURIComponent(name);
this.presence = new Presence(this);
this.channelOptions = normaliseChannelOptions(Platform.Crypto, channelOptions);
this.channelOptions = normaliseChannelOptions(client._Crypto ?? null, channelOptions);
}

setOptions(options?: ChannelOptions): void {
this.channelOptions = normaliseChannelOptions(Platform.Crypto, options);
this.channelOptions = normaliseChannelOptions(this.client._Crypto ?? null, options);
}

history(
Expand Down
2 changes: 1 addition & 1 deletion src/common/lib/client/defaultrealtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import { DefaultMessage } from '../types/defaultmessage';
*/
export class DefaultRealtime extends BaseRealtime {
constructor(options: ClientOptions) {
super(options, allCommonModules);
super(options, { ...allCommonModules, Crypto: DefaultRealtime.Crypto ?? undefined });
}

static Utils = Utils;
Expand Down
2 changes: 1 addition & 1 deletion src/common/lib/client/defaultrest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { DefaultMessage } from '../types/defaultmessage';
*/
export class DefaultRest extends BaseRest {
constructor(options: ClientOptions | string) {
super(options, allCommonModules);
super(options, { ...allCommonModules, Crypto: DefaultRest.Crypto ?? undefined });
}

private static _Crypto: typeof Platform.Crypto = null;
Expand Down
2 changes: 2 additions & 0 deletions src/common/lib/client/modulesmap.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import { Rest } from './rest';
import { IUntypedCryptoStatic } from '../../types/ICryptoStatic';

export interface ModulesMap {
Rest?: typeof Rest;
Crypto?: IUntypedCryptoStatic;
}

export const allCommonModules: ModulesMap = { Rest };
2 changes: 1 addition & 1 deletion src/common/lib/types/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ function normalizeCipherOptions(
options: API.Types.ChannelOptions | null
): ChannelOptions {
if (options && options.cipher) {
if (!Crypto) throw new Error('Encryption not enabled; use ably.encryption.js instead');
if (!Crypto) Utils.throwMissingModuleError('Crypto');
const cipher = Crypto.getCipher(options.cipher);
return {
cipher: cipher.cipherParams,
Expand Down
5 changes: 5 additions & 0 deletions src/common/lib/util/utils.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import Platform from 'common/platform';
import ErrorInfo, { PartialErrorInfo } from 'common/lib/types/errorinfo';
import { ModulesMap } from '../client/modulesmap';

function randomPosn(arrOrStr: Array<unknown> | string) {
return Math.floor(Math.random() * arrOrStr.length);
Expand Down Expand Up @@ -551,3 +552,7 @@ export function arrEquals(a: any[], b: any[]) {
})
);
}

export function throwMissingModuleError(moduleName: keyof ModulesMap): never {
throw new ErrorInfo(`${moduleName} module not provided`, 400, 40000);
}
4 changes: 0 additions & 4 deletions src/platform/web/modules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import ErrorInfo from '../../common/lib/types/errorinfo';
// Platform Specific
import BufferUtils from './lib/util/bufferutils';
// @ts-ignore
import { createCryptoClass } from './lib/util/crypto';
import Http from './lib/util/http';
import Config from './config';
// @ts-ignore
Expand All @@ -17,9 +16,6 @@ import { getDefaults } from '../../common/lib/util/defaults';
import WebStorage from './lib/util/webstorage';
import PlatformDefaults from './lib/util/defaults';

const Crypto = createCryptoClass(Config, BufferUtils);

Platform.Crypto = Crypto;
Platform.BufferUtils = BufferUtils;
Platform.Http = Http;
Platform.Config = Config;
Expand Down
10 changes: 7 additions & 3 deletions src/platform/web/modules/crypto.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import BufferUtils from '../lib/util/bufferutils';
import { createCryptoClass } from '../lib/util/crypto';
import Config from '../config';
import * as API from '../../../../ably';
import Platform from 'common/platform';

export const Crypto = /* @__PURE__@ */ createCryptoClass(Config, BufferUtils);

export const generateRandomKey: API.Types.Crypto['generateRandomKey'] = (keyLength) => {
return Platform.Crypto!.generateRandomKey(keyLength);
return Crypto.generateRandomKey(keyLength);
};

export const getDefaultCryptoParams: API.Types.Crypto['getDefaultParams'] = (params) => {
return Platform.Crypto!.getDefaultParams(params);
return Crypto.getDefaultParams(params);
};
14 changes: 11 additions & 3 deletions src/platform/web/modules/message.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
import * as API from '../../../../ably';
import Platform from 'common/platform';
import { Crypto } from './crypto';
import { fromEncoded, fromEncodedArray } from '../../../common/lib/types/message';

// The type assertions for the decode* functions below are due to https://github.com/ably/ably-js/issues/1421

export const decodeMessage = ((obj, options) => {
return fromEncoded(Platform.Crypto, obj, options);
return fromEncoded(null, obj, options);
}) as API.Types.MessageStatic['fromEncoded'];

export const decodeEncryptedMessage = ((obj, options) => {
return fromEncoded(Crypto, obj, options);
}) as API.Types.MessageStatic['fromEncoded'];

export const decodeMessages = ((obj, options) => {
return fromEncodedArray(Platform.Crypto, obj, options);
return fromEncodedArray(null, obj, options);
}) as API.Types.MessageStatic['fromEncodedArray'];

export const decodeEncryptedMessages = ((obj, options) => {
return fromEncodedArray(Crypto, obj, options);
}) as API.Types.MessageStatic['fromEncodedArray'];
Loading

0 comments on commit 63f5524

Please sign in to comment.