Skip to content

Commit

Permalink
Merge pull request #1419 from ably/1396-Crypto-tree-shaking
Browse files Browse the repository at this point in the history
[SDK-3735] Create tree-shakable `Crypto` module
  • Loading branch information
lawrence-forooghian authored Nov 7, 2023
2 parents 38b5928 + 63f5524 commit 8dbfcc7
Show file tree
Hide file tree
Showing 24 changed files with 455 additions and 101 deletions.
51 changes: 44 additions & 7 deletions scripts/moduleReport.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
const esbuild = require('esbuild');

// List of all modules accepted in ModulesMap
const moduleNames = ['Rest'];
const moduleNames = ['Rest', 'Crypto'];

// List of all free-standing functions exported by the library along with the
// ModulesMap entries that we expect them to transitively import
const functions = [
{ name: 'generateRandomKey', transitiveImports: ['Crypto'] },
{ name: 'getDefaultCryptoParams', transitiveImports: ['Crypto'] },
{ name: 'decodeMessage', transitiveImports: [] },
{ name: 'decodeEncryptedMessage', transitiveImports: ['Crypto'] },
{ name: 'decodeMessages', transitiveImports: [] },
{ name: 'decodeEncryptedMessages', transitiveImports: ['Crypto'] },
];

function formatBytes(bytes) {
const kibibytes = bytes / 1024;
Expand Down Expand Up @@ -35,19 +46,45 @@ const errors = [];
// First display the size of the base client
console.log(`${baseClient}: ${formatBytes(baseClientSize)}`);

// Then display the size of each module together with the base client
moduleNames.forEach((moduleName) => {
const size = getImportSize([baseClient, moduleName]);
console.log(`${baseClient} + ${moduleName}: ${formatBytes(size)}`);
// Then display the size of each export together with the base client
[...moduleNames, ...Object.values(functions).map((functionData) => functionData.name)].forEach((exportName) => {
const size = getImportSize([baseClient, exportName]);
console.log(`${baseClient} + ${exportName}: ${formatBytes(size)}`);

if (!(baseClientSize < size) && !(baseClient === 'BaseRest' && moduleName === 'Rest')) {
if (!(baseClientSize < size) && !(baseClient === 'BaseRest' && exportName === 'Rest')) {
// Emit an error if adding the module does not increase the bundle size
// (this means that the module is not being tree-shaken correctly).
errors.push(new Error(`Adding ${moduleName} to ${baseClient} does not increase the bundle size.`));
errors.push(new Error(`Adding ${exportName} to ${baseClient} does not increase the bundle size.`));
}
});
});

for (const functionData of functions) {
const { name: functionName, transitiveImports } = functionData;

// First display the size of the function
const standaloneSize = getImportSize([functionName]);
console.log(`${functionName}: ${formatBytes(standaloneSize)}`);

// Then display the size of the function together with the modules we expect
// it to transitively import
if (transitiveImports.length > 0) {
const withTransitiveImportsSize = getImportSize([functionName, ...transitiveImports]);
console.log(`${functionName} + ${transitiveImports.join(' + ')}: ${formatBytes(withTransitiveImportsSize)}`);

if (withTransitiveImportsSize > standaloneSize) {
// Emit an error if the bundle size is increased by adding the modules
// that we expect this function to have transitively imported anyway.
// This seemed like a useful sense check, but it might need tweaking in
// the future if we make future optimisations that mean that the
// standalone functions don’t necessarily import the whole module.
errors.push(
new Error(`Adding ${transitiveImports.join(' + ')} to ${functionName} unexpectedly increases the bundle size.`)
);
}
}
}

if (errors.length > 0) {
for (const error of errors) {
console.log(error.message);
Expand Down
9 changes: 5 additions & 4 deletions src/common/lib/client/baseclient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ import ClientOptions, { NormalisedClientOptions } from '../../types/ClientOption
import * as API from '../../../../ably';

import Platform from '../../platform';
import Message from '../types/message';
import PresenceMessage from '../types/presencemessage';
import { ModulesMap } from './modulesmap';
import { Rest } from './rest';
import { IUntypedCryptoStatic } from 'common/types/ICryptoStatic';
import { throwMissingModuleError } from '../util/utils';

type BatchResult<T> = API.Types.BatchResult<T>;
type BatchPublishSpec = API.Types.BatchPublishSpec;
Expand All @@ -38,6 +39,7 @@ class BaseClient {
auth: Auth;

private readonly _rest: Rest | null;
readonly _Crypto: IUntypedCryptoStatic | null;

constructor(options: ClientOptions | string, modules: ModulesMap) {
if (!options) {
Expand Down Expand Up @@ -88,11 +90,12 @@ class BaseClient {
this.auth = new Auth(this, normalOptions);

this._rest = modules.Rest ? new modules.Rest(this) : null;
this._Crypto = modules.Crypto ?? null;
}

private get rest(): Rest {
if (!this._rest) {
throw new ErrorInfo('Rest module not provided', 400, 40000);
throwMissingModuleError('Rest');
}
return this._rest;
}
Expand Down Expand Up @@ -147,8 +150,6 @@ class BaseClient {
}

static Platform = Platform;
static Crypto?: typeof Platform.Crypto;
static Message = Message;
static PresenceMessage = PresenceMessage;
}

Expand Down
12 changes: 6 additions & 6 deletions src/common/lib/client/channel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import { ChannelOptions } from '../../types/channel';
import { PaginatedResultCallback, StandardCallback } from '../../types/utils';
import BaseClient from './baseclient';
import * as API from '../../../../ably';
import Platform from 'common/platform';
import Defaults from '../util/defaults';
import { IUntypedCryptoStatic } from 'common/types/ICryptoStatic';

interface RestHistoryParams {
start?: number;
Expand All @@ -30,11 +30,11 @@ function allEmptyIds(messages: Array<Message>) {
});
}

function normaliseChannelOptions(options?: ChannelOptions) {
function normaliseChannelOptions(Crypto: IUntypedCryptoStatic | null, options?: ChannelOptions) {
const channelOptions = options || {};
if (channelOptions.cipher) {
if (!Platform.Crypto) throw new Error('Encryption not enabled; use ably.encryption.js instead');
const cipher = Platform.Crypto.getCipher(channelOptions.cipher);
if (!Crypto) Utils.throwMissingModuleError('Crypto');
const cipher = Crypto.getCipher(channelOptions.cipher);
channelOptions.cipher = cipher.cipherParams;
channelOptions.channelCipher = cipher.cipher;
} else if ('cipher' in channelOptions) {
Expand All @@ -60,11 +60,11 @@ class Channel extends EventEmitter {
this.name = name;
this.basePath = '/channels/' + encodeURIComponent(name);
this.presence = new Presence(this);
this.channelOptions = normaliseChannelOptions(channelOptions);
this.channelOptions = normaliseChannelOptions(client._Crypto ?? null, channelOptions);
}

setOptions(options?: ChannelOptions): void {
this.channelOptions = normaliseChannelOptions(options);
this.channelOptions = normaliseChannelOptions(this.client._Crypto ?? null, options);
}

history(
Expand Down
18 changes: 17 additions & 1 deletion src/common/lib/client/defaultrealtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,32 @@ import { allCommonModules } from './modulesmap';
import * as Utils from '../util/utils';
import ConnectionManager from '../transport/connectionmanager';
import ProtocolMessage from '../types/protocolmessage';
import Platform from 'common/platform';
import { DefaultMessage } from '../types/defaultmessage';

/**
`DefaultRealtime` is the class that the non tree-shakable version of the SDK exports as `Realtime`. It ensures that this version of the SDK includes all of the functionality which is optionally available in the tree-shakable version.
*/
export class DefaultRealtime extends BaseRealtime {
constructor(options: ClientOptions) {
super(options, allCommonModules);
super(options, { ...allCommonModules, Crypto: DefaultRealtime.Crypto ?? undefined });
}

static Utils = Utils;
static ConnectionManager = ConnectionManager;
static ProtocolMessage = ProtocolMessage;

private static _Crypto: typeof Platform.Crypto = null;
static get Crypto() {
if (this._Crypto === null) {
throw new Error('Encryption not enabled; use ably.encryption.js instead');
}

return this._Crypto;
}
static set Crypto(newValue: typeof Platform.Crypto) {
this._Crypto = newValue;
}

static Message = DefaultMessage;
}
18 changes: 17 additions & 1 deletion src/common/lib/client/defaultrest.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
import { BaseRest } from './baserest';
import ClientOptions from '../../types/ClientOptions';
import { allCommonModules } from './modulesmap';
import Platform from 'common/platform';
import { DefaultMessage } from '../types/defaultmessage';

/**
`DefaultRest` is the class that the non tree-shakable version of the SDK exports as `Rest`. It ensures that this version of the SDK includes all of the functionality which is optionally available in the tree-shakable version.
*/
export class DefaultRest extends BaseRest {
constructor(options: ClientOptions | string) {
super(options, allCommonModules);
super(options, { ...allCommonModules, Crypto: DefaultRest.Crypto ?? undefined });
}

private static _Crypto: typeof Platform.Crypto = null;
static get Crypto() {
if (this._Crypto === null) {
throw new Error('Encryption not enabled; use ably.encryption.js instead');
}

return this._Crypto;
}
static set Crypto(newValue: typeof Platform.Crypto) {
this._Crypto = newValue;
}

static Message = DefaultMessage;
}
2 changes: 2 additions & 0 deletions src/common/lib/client/modulesmap.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import { Rest } from './rest';
import { IUntypedCryptoStatic } from '../../types/ICryptoStatic';

export interface ModulesMap {
Rest?: typeof Rest;
Crypto?: IUntypedCryptoStatic;
}

export const allCommonModules: ModulesMap = { Rest };
16 changes: 16 additions & 0 deletions src/common/lib/types/defaultmessage.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import Message, { fromEncoded, fromEncodedArray } from './message';
import * as API from '../../../../ably';
import Platform from 'common/platform';

/**
`DefaultMessage` is the class returned by `DefaultRest` and `DefaultRealtime`’s `Message` static property. It introduces the static methods described in the `MessageStatic` interface of the public API of the non tree-shakable version of the library.
*/
export class DefaultMessage extends Message {
static async fromEncoded(encoded: unknown, inputOptions?: API.Types.ChannelOptions): Promise<Message> {
return fromEncoded(Platform.Crypto, encoded, inputOptions);
}

static async fromEncodedArray(encodedArray: Array<unknown>, options?: API.Types.ChannelOptions): Promise<Message[]> {
return fromEncodedArray(Platform.Crypto, encodedArray, options);
}
}
60 changes: 36 additions & 24 deletions src/common/lib/types/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import PresenceMessage from './presencemessage';
import * as Utils from '../util/utils';
import { Bufferlike as BrowserBufferlike } from '../../../platform/web/lib/util/bufferutils';
import * as API from '../../../../ably';
import { IUntypedCryptoStatic } from 'common/types/ICryptoStatic';

export type CipherOptions = {
channelCipher: {
Expand Down Expand Up @@ -42,10 +43,13 @@ function normaliseContext(context: CipherOptions | EncodingDecodingContext | Cha
return context as EncodingDecodingContext;
}

function normalizeCipherOptions(options: API.Types.ChannelOptions | null): ChannelOptions {
function normalizeCipherOptions(
Crypto: IUntypedCryptoStatic | null,
options: API.Types.ChannelOptions | null
): ChannelOptions {
if (options && options.cipher) {
if (!Platform.Crypto) throw new Error('Encryption not enabled; use ably.encryption.js instead');
const cipher = Platform.Crypto.getCipher(options.cipher);
if (!Crypto) Utils.throwMissingModuleError('Crypto');
const cipher = Crypto.getCipher(options.cipher);
return {
cipher: cipher.cipherParams,
channelCipher: cipher.cipher,
Expand All @@ -71,6 +75,35 @@ function getMessageSize(msg: Message) {
return size;
}

export async function fromEncoded(
Crypto: IUntypedCryptoStatic | null,
encoded: unknown,
inputOptions?: API.Types.ChannelOptions
): Promise<Message> {
const msg = Message.fromValues(encoded);
const options = normalizeCipherOptions(Crypto, inputOptions ?? null);
/* if decoding fails at any point, catch and return the message decoded to
* the fullest extent possible */
try {
await Message.decode(msg, options);
} catch (e) {
Logger.logAction(Logger.LOG_ERROR, 'Message.fromEncoded()', (e as Error).toString());
}
return msg;
}

export async function fromEncodedArray(
Crypto: IUntypedCryptoStatic | null,
encodedArray: Array<unknown>,
options?: API.Types.ChannelOptions
): Promise<Message[]> {
return Promise.all(
encodedArray.map(function (encoded) {
return fromEncoded(Crypto, encoded, options);
})
);
}

class Message {
name?: string;
id?: string;
Expand Down Expand Up @@ -330,27 +363,6 @@ class Message {
return result;
}

static async fromEncoded(encoded: unknown, inputOptions?: API.Types.ChannelOptions): Promise<Message> {
const msg = Message.fromValues(encoded);
const options = normalizeCipherOptions(inputOptions ?? null);
/* if decoding fails at any point, catch and return the message decoded to
* the fullest extent possible */
try {
await Message.decode(msg, options);
} catch (e) {
Logger.logAction(Logger.LOG_ERROR, 'Message.fromEncoded()', (e as Error).toString());
}
return msg;
}

static async fromEncodedArray(encodedArray: Array<unknown>, options?: API.Types.ChannelOptions): Promise<Message[]> {
return Promise.all(
encodedArray.map(function (encoded) {
return Message.fromEncoded(encoded, options);
})
);
}

/* This should be called on encode()d (and encrypt()d) Messages (as it
* assumes the data is a string or buffer) */
static getMessagesSize(messages: Message[]): number {
Expand Down
5 changes: 5 additions & 0 deletions src/common/lib/util/utils.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import Platform from 'common/platform';
import ErrorInfo, { PartialErrorInfo } from 'common/lib/types/errorinfo';
import { ModulesMap } from '../client/modulesmap';

function randomPosn(arrOrStr: Array<unknown> | string) {
return Math.floor(Math.random() * arrOrStr.length);
Expand Down Expand Up @@ -551,3 +552,7 @@ export function arrEquals(a: any[], b: any[]) {
})
);
}

export function throwMissingModuleError(moduleName: keyof ModulesMap): never {
throw new ErrorInfo(`${moduleName} module not provided`, 400, 40000);
}
10 changes: 5 additions & 5 deletions src/common/platform.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import IBufferUtils from './types/IBufferUtils';
import Transport from './lib/transport/transport';
import * as WebBufferUtils from '../platform/web/lib/util/bufferutils';
import * as NodeBufferUtils from '../platform/nodejs/lib/util/bufferutils';
import { IUntypedCryptoStatic } from '../common/types/ICryptoStatic';

type Bufferlike = WebBufferUtils.Bufferlike | NodeBufferUtils.Bufferlike;
type BufferUtilsOutput = WebBufferUtils.Output | NodeBufferUtils.Output;
Expand All @@ -23,12 +24,11 @@ export default class Platform {
*/
static BufferUtils: IBufferUtils<Bufferlike, BufferUtilsOutput, ToBufferOutput>;
/*
This should be a class whose static methods implement the ICryptoStatic
interface, but (for the same reasons as described in the BufferUtils
comment above) Platform doesn’t currently allow us to express the
generic parameters, hence keeping the type as `any`.
We’d like this to be ICryptoStatic with the correct generic arguments,
but Platform doesn’t currently allow that, as described in the BufferUtils
comment above.
*/
static Crypto: any;
static Crypto: IUntypedCryptoStatic | null;
static Http: typeof IHttp;
static Transports: Array<(connectionManager: typeof ConnectionManager) => Transport>;
static Defaults: IDefaults;
Expand Down
8 changes: 8 additions & 0 deletions src/common/types/ICryptoStatic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,11 @@ export default interface ICryptoStatic<IV, InputPlaintext, OutputCiphertext, Inp
params: IGetCipherParams<IV>
): IGetCipherReturnValue<ICipher<InputPlaintext, OutputCiphertext, InputCiphertext, OutputPlaintext>>;
}

/*
A less strongly typed version of ICryptoStatic to use until we
can make Platform a generic type (see comment there).
*/
export interface IUntypedCryptoStatic extends API.Types.Crypto {
getCipher(params: any): any;
}
Loading

0 comments on commit 8dbfcc7

Please sign in to comment.