Skip to content

Commit

Permalink
Make crypto functionality tree-shakable
Browse files Browse the repository at this point in the history
We move the crypto functionality into a tree-shakable Crypto module.

Then, we split the decodeMessage* functions introduced in ed90ac8 in
two; we introduce new decodeEncryptedMessage* variants which import the
Crypto module and are hence able to decrypt encrypted messages, and we
change the existing functions to not import the Crypto module and to
fail if they are given cipher options.

Resolves #1396.
  • Loading branch information
lawrence-forooghian committed Aug 21, 2023
1 parent d2c43d4 commit 8e99076
Show file tree
Hide file tree
Showing 12 changed files with 176 additions and 31 deletions.
42 changes: 38 additions & 4 deletions scripts/moduleReport.js
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
const esbuild = require('esbuild');

// List of all modules accepted in ModulesMap
const moduleNames = ['Rest'];
const moduleNames = ['Rest', 'Crypto'];

// List of all free-standing functions exported by the library
const functionNames = ['generateRandomKey', 'getDefaultCryptoParams', 'decodeMessage', 'decodeMessages'];
// List of all free-standing functions exported by the library along with the
// ModulesMap entries that we expect them to transitively import
const functions = [
{ name: 'generateRandomKey', transitiveImports: ['Crypto'] },
{ name: 'getDefaultCryptoParams', transitiveImports: ['Crypto'] },
{ name: 'decodeMessage', transitiveImports: [] },
{ name: 'decodeEncryptedMessage', transitiveImports: ['Crypto'] },
{ name: 'decodeMessages', transitiveImports: [] },
{ name: 'decodeEncryptedMessages', transitiveImports: ['Crypto'] },
];

function formatBytes(bytes) {
const kibibytes = bytes / 1024;
Expand Down Expand Up @@ -39,7 +47,7 @@ const errors = [];
console.log(`${baseClient}: ${formatBytes(baseClientSize)}`);

// Then display the size of each export together with the base client
[...moduleNames, ...functionNames].forEach((exportName) => {
[...moduleNames, ...Object.values(functions).map((functionData) => functionData.name)].forEach((exportName) => {
const size = getImportSize([baseClient, exportName]);
console.log(`${baseClient} + ${exportName}: ${formatBytes(size)}`);

Expand All @@ -51,6 +59,32 @@ const errors = [];
});
});

for (const functionData of functions) {
const { name: functionName, transitiveImports } = functionData;

// First display the size of the function
const standaloneSize = getImportSize([functionName]);
console.log(`${functionName}: ${formatBytes(standaloneSize)}`);

// Then display the size of the function together with the modules we expect
// it to transitively import
if (transitiveImports.length > 0) {
const withTransitiveImportsSize = getImportSize([functionName, ...transitiveImports]);
console.log(`${functionName} + ${transitiveImports.join(' + ')}: ${formatBytes(withTransitiveImportsSize)}`);

if (withTransitiveImportsSize > standaloneSize) {
// Emit an error if the bundle size is increased by adding the modules
// that we expect this function to have transitively imported anyway.
// This seemed like a useful sense check, but it might need tweaking in
// the future if we make future optimisations that mean that the
// standalone functions don’t necessarily import the whole module.
errors.push(
new Error(`Adding ${transitiveImports.join(' + ')} to ${functionName} unexpectedly increases the bundle size.`)
);
}
}
}

if (errors.length > 0) {
for (const error of errors) {
console.log(error.message);
Expand Down
6 changes: 5 additions & 1 deletion src/common/lib/client/baseclient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import Platform from '../../platform';
import PresenceMessage from '../types/presencemessage';
import { ModulesMap } from './modulesmap';
import { Rest } from './rest';
import { IUntypedCryptoStatic } from 'common/types/ICryptoStatic';
import { throwMissingModuleError } from '../util/utils';

/**
`BaseClient` acts as the base class for all of the client classes exported by the SDK. It is an implementation detail and this class is not advertised publicly.
Expand All @@ -27,6 +29,7 @@ class BaseClient {
auth: Auth;

private readonly _rest: Rest | null;
readonly _Crypto: IUntypedCryptoStatic | null;

constructor(options: ClientOptions | string, modules: ModulesMap) {
if (!options) {
Expand Down Expand Up @@ -77,11 +80,12 @@ class BaseClient {
this.auth = new Auth(this, normalOptions);

this._rest = modules.Rest ? new modules.Rest(this) : null;
this._Crypto = modules.Crypto ?? null;
}

private get rest(): Rest {
if (!this._rest) {
throw new ErrorInfo('Rest module not provided', 400, 40000);
throwMissingModuleError('Crypto');
}
return this._rest;
}
Expand Down
7 changes: 3 additions & 4 deletions src/common/lib/client/channel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import { ChannelOptions } from '../../types/channel';
import { PaginatedResultCallback, StandardCallback } from '../../types/utils';
import BaseClient from './baseclient';
import * as API from '../../../../ably';
import Platform from 'common/platform';
import Defaults from '../util/defaults';
import { IUntypedCryptoStatic } from 'common/types/ICryptoStatic';

Expand All @@ -34,7 +33,7 @@ function allEmptyIds(messages: Array<Message>) {
function normaliseChannelOptions(Crypto: IUntypedCryptoStatic | null, options?: ChannelOptions) {
const channelOptions = options || {};
if (channelOptions.cipher) {
if (!Crypto) throw new Error('Encryption not enabled; use ably.encryption.js instead');
if (!Crypto) Utils.throwMissingModuleError('Crypto');
const cipher = Crypto.getCipher(channelOptions.cipher);
channelOptions.cipher = cipher.cipherParams;
channelOptions.channelCipher = cipher.cipher;
Expand All @@ -61,11 +60,11 @@ class Channel extends EventEmitter {
this.name = name;
this.basePath = '/channels/' + encodeURIComponent(name);
this.presence = new Presence(this);
this.channelOptions = normaliseChannelOptions(Platform.Crypto, channelOptions);
this.channelOptions = normaliseChannelOptions(client._Crypto ?? null, channelOptions);
}

setOptions(options?: ChannelOptions): void {
this.channelOptions = normaliseChannelOptions(Platform.Crypto, options);
this.channelOptions = normaliseChannelOptions(this.client._Crypto ?? null, options);
}

history(
Expand Down
2 changes: 1 addition & 1 deletion src/common/lib/client/defaultrealtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import { DefaultMessage } from '../types/defaultmessage';
*/
export class DefaultRealtime extends BaseRealtime {
constructor(options: ClientOptions) {
super(options, allCommonModules);
super(options, { ...allCommonModules, Crypto: DefaultRealtime.Crypto ?? undefined });
}

static Utils = Utils;
Expand Down
2 changes: 1 addition & 1 deletion src/common/lib/client/defaultrest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { DefaultMessage } from '../types/defaultmessage';
*/
export class DefaultRest extends BaseRest {
constructor(options: ClientOptions | string) {
super(options, allCommonModules);
super(options, { ...allCommonModules, Crypto: DefaultRest.Crypto ?? undefined });
}

private static _Crypto: typeof Platform.Crypto = null;
Expand Down
2 changes: 2 additions & 0 deletions src/common/lib/client/modulesmap.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import { Rest } from './rest';
import { IUntypedCryptoStatic } from '../../types/ICryptoStatic';

export interface ModulesMap {
Rest?: typeof Rest;
Crypto?: IUntypedCryptoStatic;
}

export const allCommonModules: ModulesMap = { Rest };
2 changes: 1 addition & 1 deletion src/common/lib/types/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ function normalizeCipherOptions(
options: API.Types.ChannelOptions | null
): ChannelOptions {
if (options && options.cipher) {
if (!Crypto) throw new Error('Encryption not enabled; use ably.encryption.js instead');
if (!Crypto) Utils.throwMissingModuleError('Crypto');
const cipher = Crypto.getCipher(options.cipher);
return {
cipher: cipher.cipherParams,
Expand Down
4 changes: 4 additions & 0 deletions src/common/lib/util/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -528,3 +528,7 @@ export function toBase64(str: string) {
const textBuffer = bufferUtils.utf8Encode(str);
return bufferUtils.base64Encode(textBuffer);
}

export function throwMissingModuleError(moduleName: string): never {
throw new ErrorInfo(`${moduleName} module not provided`, 400, 40000);
}
4 changes: 0 additions & 4 deletions src/platform/web/modules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import Platform from '../../common/platform';
// Platform Specific
import BufferUtils from './lib/util/bufferutils';
// @ts-ignore
import { createCryptoClass } from './lib/util/crypto';
import Http from './lib/util/http';
import Config from './config';
// @ts-ignore
Expand All @@ -16,9 +15,6 @@ import { getDefaults } from '../../common/lib/util/defaults';
import WebStorage from './lib/util/webstorage';
import PlatformDefaults from './lib/util/defaults';

const Crypto = createCryptoClass(Config, BufferUtils);

Platform.Crypto = Crypto;
Platform.BufferUtils = BufferUtils;
Platform.Http = Http;
Platform.Config = Config;
Expand Down
10 changes: 7 additions & 3 deletions src/platform/web/modules/crypto.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import BufferUtils from '../lib/util/bufferutils';
import { createCryptoClass } from '../lib/util/crypto';
import Config from '../config';
import * as API from '../../../../ably';
import Platform from 'common/platform';

export const Crypto = /* @__PURE__@ */ createCryptoClass(Config, BufferUtils);

export const generateRandomKey: API.Types.Crypto['generateRandomKey'] = (keyLength) => {
return Platform.Crypto!.generateRandomKey(keyLength);
return Crypto.generateRandomKey(keyLength);
};

export const getDefaultCryptoParams: API.Types.Crypto['getDefaultParams'] = (params) => {
return Platform.Crypto!.getDefaultParams(params);
return Crypto.getDefaultParams(params);
};
14 changes: 11 additions & 3 deletions src/platform/web/modules/message.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
import * as API from '../../../../ably';
import Platform from 'common/platform';
import { Crypto } from './crypto';
import { fromEncoded, fromEncodedArray } from '../../../common/lib/types/message';

// The type assertions for the decode* functions below are due to https://github.com/ably/ably-js/issues/1421

export const decodeMessage = ((obj, options) => {
return fromEncoded(Platform.Crypto, obj, options);
return fromEncoded(null, obj, options);
}) as API.Types.MessageStatic['fromEncoded'];

export const decodeEncryptedMessage = ((obj, options) => {
return fromEncoded(Crypto, obj, options);
}) as API.Types.MessageStatic['fromEncoded'];

export const decodeMessages = ((obj, options) => {
return fromEncodedArray(Platform.Crypto, obj, options);
return fromEncodedArray(null, obj, options);
}) as API.Types.MessageStatic['fromEncodedArray'];

export const decodeEncryptedMessages = ((obj, options) => {
return fromEncodedArray(Crypto, obj, options);
}) as API.Types.MessageStatic['fromEncodedArray'];
112 changes: 103 additions & 9 deletions test/browser/modules.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ import {
generateRandomKey,
getDefaultCryptoParams,
decodeMessage,
decodeEncryptedMessage,
decodeMessages,
decodeEncryptedMessages,
Crypto,
} from '../../build/modules/index.js';

describe('browser/modules', function () {
Expand Down Expand Up @@ -80,14 +83,32 @@ describe('browser/modules', function () {
});

describe('Message standalone functions', () => {
async function testDecodesMessageData(functionUnderTest) {
const testData = await loadTestData(testResourcesPath + 'crypto-data-128.json');

const item = testData.items[1];
const decoded = await functionUnderTest(item.encoded);

expect(decoded.data).to.be.an('ArrayBuffer');
}

describe('decodeMessage', () => {
it('decodes a message’s data', async () => {
testDecodesMessageData(decodeMessage);
});

it('throws an error when given channel options with a cipher', async () => {
const testData = await loadTestData(testResourcesPath + 'crypto-data-128.json');
const key = BufferUtils.base64Decode(testData.key);
const iv = BufferUtils.base64Decode(testData.iv);

const item = testData.items[1];
const decoded = await decodeMessage(item.encoded);
expect(() => decodeMessage(testData.items[0].encrypted, { cipher: { key, iv } })).to.throw;
});
});

expect(decoded.data).to.be.an('ArrayBuffer');
describe('decodeEncryptedMessage', async () => {
it('decodes a message’s data', async () => {
testDecodesMessageData(decodeEncryptedMessage);
});

it('decrypts a message', async () => {
Expand All @@ -99,23 +120,46 @@ describe('browser/modules', function () {
for (const item of testData.items) {
const [decodedFromEncoded, decodedFromEncrypted] = await Promise.all([
decodeMessage(item.encoded),
decodeMessage(item.encrypted, { cipher: { key, iv } }),
decodeEncryptedMessage(item.encrypted, { cipher: { key, iv } }),
]);

testMessageEquality(decodedFromEncoded, decodedFromEncrypted);
}
});
});

async function testDecodesMessagesData(functionUnderTest) {
const testData = await loadTestData(testResourcesPath + 'crypto-data-128.json');

const items = [testData.items[1], testData.items[3]];
const decoded = await functionUnderTest(items.map((item) => item.encoded));

expect(decoded[0].data).to.be.an('ArrayBuffer');
expect(decoded[1].data).to.be.an('array');
}

describe('decodeMessages', () => {
it('decodes messages’ data', async () => {
testDecodesMessagesData(decodeMessages);
});

it('throws an error when given channel options with a cipher', async () => {
const testData = await loadTestData(testResourcesPath + 'crypto-data-128.json');
const key = BufferUtils.base64Decode(testData.key);
const iv = BufferUtils.base64Decode(testData.iv);

const items = [testData.items[1], testData.items[3]];
const decoded = await decodeMessages(items.map((item) => item.encoded));
expect(() =>
decodeMessages(
items.map((item) => item.encrypted),
{ cipher: { key, iv } }
)
).to.throw;
});
});

expect(decoded[0].data).to.be.an('ArrayBuffer');
expect(decoded[1].data).to.be.an('array');
describe('decodeEncryptedMessages', () => {
it('decodes messages’ data', async () => {
testDecodesMessagesData(decodeEncryptedMessages);
});

it('decrypts messages', async () => {
Expand All @@ -126,7 +170,7 @@ describe('browser/modules', function () {

const [decodedFromEncoded, decodedFromEncrypted] = await Promise.all([
decodeMessages(testData.items.map((item) => item.encoded)),
decodeMessages(
decodeEncryptedMessages(
testData.items.map((item) => item.encrypted),
{ cipher: { key, iv } }
),
Expand All @@ -138,4 +182,54 @@ describe('browser/modules', function () {
});
});
});

describe('Crypto', () => {
describe('without Crypto', () => {
for (const clientClass of [BaseRest, BaseRealtime]) {
describe(clientClass.name, () => {
it('throws an error when given channel options with a cipher', async () => {
const client = new clientClass(ablyClientOptions(), {});
const key = await generateRandomKey();
expect(() => client.channels.get('channel', { cipher: { key } })).to.throw;
});
});
}
});

describe('with Crypto', () => {
for (const clientClass of [BaseRest, BaseRealtime]) {
describe(clientClass.name, () => {
it('is able to publish encrypted messages', async () => {
const clientOptions = ablyClientOptions();

const key = await generateRandomKey();

// Publish the message on a channel configured to use encryption, and receive it on one not configured to use encryption

const rxClient = new BaseRealtime(clientOptions, {});
const rxChannel = rxClient.channels.get('channel');
await rxChannel.attach();

const rxMessagePromise = new Promise((resolve, _) => rxChannel.subscribe((message) => resolve(message)));

const encryptionChannelOptions = { cipher: { key } };

const txMessage = { name: 'message', data: 'data' };
const txClient = new clientClass(clientOptions, { Crypto });
const txChannel = txClient.channels.get('channel', encryptionChannelOptions);
await txChannel.publish(txMessage);

const rxMessage = await rxMessagePromise;

// Verify that the message was published with encryption
expect(rxMessage.encoding).to.equal('utf-8/cipher+aes-256-cbc');

// Verify that the message was correctly encrypted
const rxMessageDecrypted = await decodeEncryptedMessage(rxMessage, encryptionChannelOptions);
testMessageEquality(rxMessageDecrypted, txMessage);
});
});
}
});
});
});

0 comments on commit 8e99076

Please sign in to comment.