Skip to content

Commit

Permalink
Merge pull request #1311 from ably/1293-make-ICipher.decrypt-async
Browse files Browse the repository at this point in the history
Make `ICipher.decrypt` async
  • Loading branch information
lawrence-forooghian authored Jun 5, 2023
2 parents 333872a + c955702 commit f7dd898
Show file tree
Hide file tree
Showing 18 changed files with 411 additions and 295 deletions.
24 changes: 12 additions & 12 deletions ably.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2758,17 +2758,17 @@ declare namespace Types {
*
* @param JsonObject - A `Message`-like deserialized object.
* @param channelOptions - A {@link ChannelOptions} object. If you have an encrypted channel, use this to allow the library to decrypt the data.
* @returns A `Message` object.
* @returns A promise which will be fulfilled with a `Message` object.
*/
static fromEncoded: (JsonObject: any, channelOptions?: ChannelOptions) => Message;
static fromEncoded: (JsonObject: any, channelOptions?: ChannelOptions) => Promise<Message>;
/**
* A static factory method to create an array of `Message` objects from an array of deserialized Message-like object encoded using Ably's wire protocol.
*
* @param JsonArray - An array of `Message`-like deserialized objects.
* @param channelOptions - A {@link ChannelOptions} object. If you have an encrypted channel, use this to allow the library to decrypt the data.
* @returns An array of {@link Message} objects.
* @returns A promise which will be fulfilled with an array of {@link Message} objects.
*/
static fromEncodedArray: (JsonArray: any[], channelOptions?: ChannelOptions) => Message[];
static fromEncodedArray: (JsonArray: any[], channelOptions?: ChannelOptions) => Promise<Message[]>;
/**
* The client ID of the publisher of this message.
*/
Expand Down Expand Up @@ -2812,17 +2812,17 @@ declare namespace Types {
*
* @param JsonObject - A `Message`-like deserialized object.
* @param channelOptions - A {@link ChannelOptions} object. If you have an encrypted channel, use this to allow the library to decrypt the data.
* @returns A `Message` object.
* @returns A promise which will be fulfilled with a `Message` object.
*/
fromEncoded: (JsonObject: any, channelOptions?: ChannelOptions) => Message;
fromEncoded: (JsonObject: any, channelOptions?: ChannelOptions) => Promise<Message>;
/**
* A static factory method to create an array of `Message` objects from an array of deserialized Message-like object encoded using Ably's wire protocol.
*
* @param JsonArray - An array of `Message`-like deserialized objects.
* @param channelOptions - A {@link ChannelOptions} object. If you have an encrypted channel, use this to allow the library to decrypt the data.
* @returns An array of {@link Message} objects.
* @returns A promise which will be fulfilled with an array of {@link Message} objects.
*/
fromEncodedArray: (JsonArray: any[], channelOptions?: ChannelOptions) => Message[];
fromEncodedArray: (JsonArray: any[], channelOptions?: ChannelOptions) => Promise<Message[]>;
}

/**
Expand All @@ -2841,14 +2841,14 @@ declare namespace Types {
* @param JsonObject - The deserialized `PresenceMessage`-like object to decode and decrypt.
* @param channelOptions - A {@link ChannelOptions} object containing the cipher.
*/
static fromEncoded: (JsonObject: any, channelOptions?: ChannelOptions) => PresenceMessage;
static fromEncoded: (JsonObject: any, channelOptions?: ChannelOptions) => Promise<PresenceMessage>;
/**
* Decodes and decrypts an array of deserialized `PresenceMessage`-like object using the cipher in {@link ChannelOptions}. Any residual transforms that cannot be decoded or decrypted will be in the `encoding` property. Intended for users receiving messages from a source other than a REST or Realtime channel (for example a queue) to avoid having to parse the encoding string.
*
* @param JsonArray - An array of deserialized `PresenceMessage`-like objects to decode and decrypt.
* @param channelOptions - A {@link ChannelOptions} object containing the cipher.
*/
static fromEncodedArray: (JsonArray: any[], channelOptions?: ChannelOptions) => PresenceMessage[];
static fromEncodedArray: (JsonArray: any[], channelOptions?: ChannelOptions) => Promise<PresenceMessage[]>;
/**
* The type of {@link PresenceAction} the `PresenceMessage` is for.
*/
Expand Down Expand Up @@ -2889,14 +2889,14 @@ declare namespace Types {
* @param JsonObject - The deserialized `PresenceMessage`-like object to decode and decrypt.
* @param channelOptions - A {@link ChannelOptions} object containing the cipher.
*/
fromEncoded: (JsonObject: any, channelOptions?: ChannelOptions) => PresenceMessage;
fromEncoded: (JsonObject: any, channelOptions?: ChannelOptions) => Promise<PresenceMessage>;
/**
* Decodes and decrypts an array of deserialized `PresenceMessage`-like object using the cipher in {@link ChannelOptions}. Any residual transforms that cannot be decoded or decrypted will be in the `encoding` property. Intended for users receiving messages from a source other than a REST or Realtime channel (for example a queue) to avoid having to parse the encoding string.
*
* @param JsonArray - An array of deserialized `PresenceMessage`-like objects to decode and decrypt.
* @param channelOptions - A {@link ChannelOptions} object containing the cipher.
*/
fromEncodedArray: (JsonArray: any[], channelOptions?: ChannelOptions) => PresenceMessage[];
fromEncodedArray: (JsonArray: any[], channelOptions?: ChannelOptions) => Promise<PresenceMessage[]>;
}

/**
Expand Down
4 changes: 2 additions & 2 deletions src/common/lib/client/channel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,12 @@ class Channel extends EventEmitter {
Utils.mixin(headers, rest.options.headers);

const options = this.channelOptions;
new PaginatedResource(rest, this.basePath + '/messages', headers, envelope, function (
new PaginatedResource(rest, this.basePath + '/messages', headers, envelope, async function (
body: any,
headers: Record<string, string>,
unpacked?: boolean
) {
return Message.fromResponseBody(body, options, unpacked ? undefined : format);
return await Message.fromResponseBody(body, options, unpacked ? undefined : format);
}).get(params as Record<string, unknown>, callback);
}

Expand Down
43 changes: 25 additions & 18 deletions src/common/lib/client/paginatedresource.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import ErrorInfo, { IPartialErrorInfo } from '../types/errorinfo';
import { PaginatedResultCallback } from '../../types/utils';
import Rest from './rest';

export type BodyHandler = (body: unknown, headers: Record<string, string>, packed?: boolean) => any;
export type BodyHandler = (body: unknown, headers: Record<string, string>, packed?: boolean) => Promise<any>;

function getRelParams(linkUrl: string) {
const urlMatch = linkUrl.match(/^\.\/(\w+)\?(.*)$/);
Expand Down Expand Up @@ -149,25 +149,32 @@ class PaginatedResource {
callback?.(err);
return;
}
let items, linkHeader, relParams;
try {
items = this.bodyHandler(body, headers || {}, unpacked);
} catch (e) {
/* If we got an error, the failure to parse the body is almost certainly
* due to that, so callback with that in preference over the parse error */
callback?.(err || e);
return;
}

if (headers && (linkHeader = headers['Link'] || headers['link'])) {
relParams = parseRelLinks(linkHeader);
}
const handleBody = async () => {
let items, linkHeader, relParams;

if (this.useHttpPaginatedResponse) {
callback(null, new HttpPaginatedResponse(this, items, headers || {}, statusCode as number, relParams, err));
} else {
callback(null, new PaginatedResult(this, items, relParams));
}
try {
items = await this.bodyHandler(body, headers || {}, unpacked);
} catch (e) {
/* If we got an error, the failure to parse the body is almost certainly
* due to that, so throw that in preference over the parse error */
throw err || e;
}

if (headers && (linkHeader = headers['Link'] || headers['link'])) {
relParams = parseRelLinks(linkHeader);
}

if (this.useHttpPaginatedResponse) {
return new HttpPaginatedResponse(this, items, headers || {}, statusCode as number, relParams, err);
} else {
return new PaginatedResult(this, items, relParams);
}
};

handleBody()
.then((result) => callback(null, result))
.catch((err) => callback(err, null));
}
}

Expand Down
8 changes: 4 additions & 4 deletions src/common/lib/client/presence.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ class Presence extends EventEmitter {
Utils.mixin(headers, rest.options.headers);

const options = this.channel.channelOptions;
new PaginatedResource(rest, this.basePath, headers, envelope, function (
new PaginatedResource(rest, this.basePath, headers, envelope, async function (
body: any,
headers: Record<string, string>,
unpacked?: boolean
) {
return PresenceMessage.fromResponseBody(body, options as CipherOptions, unpacked ? undefined : format);
return await PresenceMessage.fromResponseBody(body, options as CipherOptions, unpacked ? undefined : format);
}).get(params, callback);
}

Expand Down Expand Up @@ -84,12 +84,12 @@ class Presence extends EventEmitter {
Utils.mixin(headers, rest.options.headers);

const options = this.channel.channelOptions;
new PaginatedResource(rest, this.basePath + '/history', headers, envelope, function (
new PaginatedResource(rest, this.basePath + '/history', headers, envelope, async function (
body: any,
headers: Record<string, string>,
unpacked?: boolean
) {
return PresenceMessage.fromResponseBody(body, options as CipherOptions, unpacked ? undefined : format);
return await PresenceMessage.fromResponseBody(body, options as CipherOptions, unpacked ? undefined : format);
}).get(params, callback);
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/common/lib/client/push.ts
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class DeviceRegistrations {

Utils.mixin(headers, rest.options.headers);

new PaginatedResource(rest, '/push/deviceRegistrations', headers, envelope, function (
new PaginatedResource(rest, '/push/deviceRegistrations', headers, envelope, async function (
body: any,
headers: Record<string, string>,
unpacked?: boolean
Expand Down Expand Up @@ -286,7 +286,7 @@ class ChannelSubscriptions {

Utils.mixin(headers, rest.options.headers);

new PaginatedResource(rest, '/push/channelSubscriptions', headers, envelope, function (
new PaginatedResource(rest, '/push/channelSubscriptions', headers, envelope, async function (
body: any,
headers: Record<string, string>,
unpacked?: boolean
Expand Down Expand Up @@ -334,7 +334,7 @@ class ChannelSubscriptions {

if (rest.options.pushFullWait) Utils.mixin(params, { fullWait: 'true' });

new PaginatedResource(rest, '/push/channels', headers, envelope, function (
new PaginatedResource(rest, '/push/channels', headers, envelope, async function (
body: unknown,
headers: Record<string, string>,
unpacked?: boolean
Expand Down
9 changes: 5 additions & 4 deletions src/common/lib/client/realtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,13 @@ class Channels extends EventEmitter {
}
}

onChannelMessage(msg: ProtocolMessage) {
// Access to this method is synchronised by ConnectionManager#processChannelMessage.
async processChannelMessage(msg: ProtocolMessage) {
const channelName = msg.channel;
if (channelName === undefined) {
Logger.logAction(
Logger.LOG_ERROR,
'Channels.onChannelMessage()',
'Channels.processChannelMessage()',
'received event unspecified channel, action = ' + msg.action
);
return;
Expand All @@ -97,12 +98,12 @@ class Channels extends EventEmitter {
if (!channel) {
Logger.logAction(
Logger.LOG_ERROR,
'Channels.onChannelMessage()',
'Channels.processChannelMessage()',
'received event for non-existent channel: ' + channelName
);
return;
}
channel.onMessage(msg);
await channel.processMessage(msg);
}

/* called when a transport becomes connected; reattempt attach/detach
Expand Down
23 changes: 14 additions & 9 deletions src/common/lib/client/realtimechannel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,8 @@ class RealtimeChannel extends Channel {
this.sendMessage(msg, callback);
}

onMessage(message: ProtocolMessage): void {
// Access to this method is synchronised by ConnectionManager#processChannelMessage, in order to synchronise access to the state stored in _decodingContext.
async processMessage(message: ProtocolMessage): Promise<void> {
if (
message.action === actions.ATTACHED ||
message.action === actions.MESSAGE ||
Expand Down Expand Up @@ -656,12 +657,12 @@ class RealtimeChannel extends Channel {
for (let i = 0; i < presence.length; i++) {
try {
presenceMsg = presence[i];
PresenceMessage.decode(presenceMsg, options);
await PresenceMessage.decode(presenceMsg, options);
if (!presenceMsg.connectionId) presenceMsg.connectionId = connectionId;
if (!presenceMsg.timestamp) presenceMsg.timestamp = timestamp;
if (!presenceMsg.id) presenceMsg.id = id + ':' + i;
} catch (e) {
Logger.logAction(Logger.LOG_ERROR, 'RealtimeChannel.onMessage()', (e as Error).toString());
Logger.logAction(Logger.LOG_ERROR, 'RealtimeChannel.processMessage()', (e as Error).toString());
}
}
this.presence.setPresence(presence, isSync, syncChannelSerial as any);
Expand All @@ -672,7 +673,7 @@ class RealtimeChannel extends Channel {
if (this.state !== 'attached') {
Logger.logAction(
Logger.LOG_MAJOR,
'RealtimeChannel.onMessage()',
'RealtimeChannel.processMessage()',
'Message "' +
message.id +
'" skipped as this channel "' +
Expand Down Expand Up @@ -702,18 +703,18 @@ class RealtimeChannel extends Channel {
'" on this channel "' +
this.name +
'".';
Logger.logAction(Logger.LOG_ERROR, 'RealtimeChannel.onMessage()', msg);
Logger.logAction(Logger.LOG_ERROR, 'RealtimeChannel.processMessage()', msg);
this._startDecodeFailureRecovery(new ErrorInfo(msg, 40018, 400));
break;
}

for (let i = 0; i < messages.length; i++) {
const msg = messages[i];
try {
Message.decode(msg, this._decodingContext);
await Message.decode(msg, this._decodingContext);
} catch (e) {
/* decrypt failed .. the most likely cause is that we have the wrong key */
Logger.logAction(Logger.LOG_ERROR, 'RealtimeChannel.onMessage()', (e as Error).toString());
Logger.logAction(Logger.LOG_ERROR, 'RealtimeChannel.processMessage()', (e as Error).toString());
switch ((e as ErrorInfo).code) {
case 40018:
/* decode failure */
Expand Down Expand Up @@ -753,7 +754,7 @@ class RealtimeChannel extends Channel {
default:
Logger.logAction(
Logger.LOG_ERROR,
'RealtimeChannel.onMessage()',
'RealtimeChannel.processMessage()',
'Fatal protocol error: unrecognised action (' + message.action + ')'
);
this.connectionManager.abort(ConnectionErrors.unknownChannelErr());
Expand All @@ -762,7 +763,11 @@ class RealtimeChannel extends Channel {

_startDecodeFailureRecovery(reason: ErrorInfo): void {
if (!this._lastPayload.decodeFailureRecoveryInProgress) {
Logger.logAction(Logger.LOG_MAJOR, 'RealtimeChannel.onMessage()', 'Starting decode failure recovery process.');
Logger.logAction(
Logger.LOG_MAJOR,
'RealtimeChannel.processMessage()',
'Starting decode failure recovery process.'
);
this._lastPayload.decodeFailureRecoveryInProgress = true;
this._attach(true, reason, () => {
this._lastPayload.decodeFailureRecoveryInProgress = false;
Expand Down
2 changes: 1 addition & 1 deletion src/common/lib/client/rest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ class Rest {
path,
headers,
envelope,
function (resbody: unknown, headers: Record<string, string>, unpacked?: boolean) {
async function (resbody: unknown, headers: Record<string, string>, unpacked?: boolean) {
return Utils.ensureArray(unpacked ? resbody : decoder(resbody as string & Buffer));
},
/* useHttpPaginatedResponse: */ true
Expand Down
38 changes: 36 additions & 2 deletions src/common/lib/transport/connectionmanager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,12 @@ class ConnectionManager extends EventEmitter {
suspendTimer?: number | NodeJS.Timeout | null;
retryTimer?: number | NodeJS.Timeout | null;
disconnectedRetryCount: number = 0;
pendingChannelMessagesState: {
// Whether a message is currently being processed
isProcessing: boolean;
// The messages remaining to be processed (excluding any message currently being processed)
queue: { message: ProtocolMessage; transport: Transport }[];
} = { isProcessing: false, queue: [] };

constructor(realtime: Realtime, options: ClientOptions) {
super();
Expand Down Expand Up @@ -1966,20 +1972,48 @@ class ConnectionManager extends EventEmitter {
}

onChannelMessage(message: ProtocolMessage, transport: Transport): void {
this.pendingChannelMessagesState.queue.push({ message, transport });

if (!this.pendingChannelMessagesState.isProcessing) {
this.processNextPendingChannelMessage();
}
}

private processNextPendingChannelMessage() {
if (this.pendingChannelMessagesState.queue.length > 0) {
this.pendingChannelMessagesState.isProcessing = true;

const pendingChannelMessage = this.pendingChannelMessagesState.queue.shift()!;
this.processChannelMessage(pendingChannelMessage.message, pendingChannelMessage.transport)
.catch((err) => {
Logger.logAction(
Logger.LOG_ERROR,
'ConnectionManager.processNextPendingChannelMessage() received error ',
err
);
})
.finally(() => {
this.pendingChannelMessagesState.isProcessing = false;
this.processNextPendingChannelMessage();
});
}
}

private async processChannelMessage(message: ProtocolMessage, transport: Transport) {
const onActiveTransport = this.activeProtocol && transport === this.activeProtocol.getTransport(),
onUpgradeTransport = Utils.arrIn(this.pendingTransports, transport) && this.state == this.states.synchronizing;

/* As the lib now has a period where the upgrade transport is synced but
* before it's become active (while waiting for the old one to become
* idle), message can validly arrive on it even though it isn't active */
if (onActiveTransport || onUpgradeTransport) {
this.realtime.channels.onChannelMessage(message);
await this.realtime.channels.processChannelMessage(message);
} else {
// Message came in on a defunct transport. Allow only acks, nacks, & errors for outstanding
// messages, no new messages (as sync has been sent on new transport so new messages will
// be resent there, or connection has been closed so don't want new messages)
if (Utils.arrIndexOf([actions.ACK, actions.NACK, actions.ERROR], message.action) > -1) {
this.realtime.channels.onChannelMessage(message);
await this.realtime.channels.processChannelMessage(message);
} else {
Logger.logAction(
Logger.LOG_MICRO,
Expand Down
Loading

0 comments on commit f7dd898

Please sign in to comment.