From 54ddebf2416d935a90b38d2df378a3c872e1bab6 Mon Sep 17 00:00:00 2001 From: Alex Donesky Date: Fri, 27 Sep 2024 11:12:18 -0500 Subject: [PATCH] Update `QueuedRequestController` to support Multichain API (#4718) ## Explanation The `QueuedRequestController` previously batched and processed requests based solely on their `origin`. This approach doesn't account for scenarios where a dapp interacts with multiple networks simultaneously from the same origin as will be the case on the multichain API. This PR updates the `QueuedRequestController` to batch and process requests based on both `origin` and `networkClientId`. This ensures that: - Requests from the same origin but targeting different networks are processed in separate batches. - Network switches occur appropriately between batches when the `networkClientId` changes. - Each request is processed on the correct network. Key changes: - **Batching Logic**: Modified the batching mechanism to consider both `origin` and `networkClientId`, for the legacy API this shouldn't change anything, but it allows us to cleanly batch for the multichain API until we remove queueing altogther. - **Request Validation**: Added validation to throw an error if a request doesn't include a `networkClientId`. - **Dependency Removal**: Removed reliance on `SelectedNetworkController` for determining `networkClientId`. Now `NetworkClientId` is expected on every request. Which should always be the case if the middleware ordering is correct on both APIs. - **Test Enhancements**: Added and updated tests to cover new scenarios involving multiple `origin`s and `networkClientId`s. Draft PR/Branch off caip-multichain feature branch with preview build: https://github.com/MetaMask/metamask-extension/pull/27408 ## Videos demonstrating functionality w/ multichain API: ### 2 requests same network different dapps: https://github.com/user-attachments/assets/b7087322-093b-4229-86f9-d643540a5f5e ### 2 requests different network same dapp https://github.com/user-attachments/assets/80b9e820-30fd-4bd7-a82e-588243fa3a44 ### 2 requests same network same dapp https://github.com/user-attachments/assets/91e2e35d-2ab6-4863-a9e4-0b955815354e ### More complex flows: https://drive.google.com/file/d/1PvCmmCQbxXqglsFLUlyT_7P_znukLUxI/view?usp=drive_link https://drive.google.com/file/d/18R8zEPz_zsfhsw-DOkRFwkYRI1URynis/view?usp=sharing ### Queueing working same as before on legacy API [Branch/PR with preview builds on develop used to test ](https://github.com/MetaMask/metamask-extension/pull/27430) https://github.com/user-attachments/assets/485d0a28-7075-4c65-be4b-b9022b6ef28f ## References - N/A ## Changelog ### `@metamask/queued-request-controller` - **CHANGED**: Batch processing now considers both `origin` and `networkClientId`, ensuring requests targeting different networks are processed separately. - **REMOVED**: Dependency on `SelectedNetworkController`; the controller no longer uses `SelectedNetworkController:getNetworkClientIdForDomain`. - **CHANGED**: Incoming requests to `enqueueRequest` now **must** include a `networkClientId`; an error is thrown if it's missing. This was previously a [required part of the type](https://github.com/MetaMask/core/blob/66c94ae4f0a54764ca890c927ffdbe3c2d6cd846/packages/queued-request-controller/src/types.ts#L5) but since consumers like the extension do not have extensive typescript coverage this isn't definitively enforced. ## Checklist - [ ] I've updated the test suite for new or updated code as appropriate - [ ] I've updated documentation (JSDoc, Markdown, etc.) for new or updated code as appropriate - [ ] I've highlighted breaking changes using the "BREAKING" category above as appropriate - [ ] I've prepared draft pull requests for clients and consumer packages to resolve any breaking changes --- .../src/QueuedRequestController.test.ts | 438 +++++++++++++++--- .../src/QueuedRequestController.ts | 85 ++-- 2 files changed, 422 insertions(+), 101 deletions(-) diff --git a/packages/queued-request-controller/src/QueuedRequestController.test.ts b/packages/queued-request-controller/src/QueuedRequestController.test.ts index 239ae82885..30c8ffbdfc 100644 --- a/packages/queued-request-controller/src/QueuedRequestController.test.ts +++ b/packages/queued-request-controller/src/QueuedRequestController.test.ts @@ -4,7 +4,6 @@ import { type NetworkControllerGetStateAction, type NetworkControllerSetActiveNetworkAction, } from '@metamask/network-controller'; -import type { SelectedNetworkControllerGetNetworkClientIdForDomainAction } from '@metamask/selected-network-controller'; import { createDeferredPromise } from '@metamask/utils'; import type { @@ -35,6 +34,24 @@ describe('QueuedRequestController', () => { }); describe('enqueueRequest', () => { + it('throws an error if networkClientId is not provided', async () => { + const controller = buildQueuedRequestController(); + await expect(() => + controller.enqueueRequest( + // @ts-expect-error: networkClientId is intentionally not provided + { + method: 'doesnt matter', + id: 'doesnt matter', + jsonrpc: '2.0' as const, + origin: 'example.metamask.io', + }, + () => new Promise((resolve) => setTimeout(resolve, 10)), + ), + ).rejects.toThrow( + 'Error while attempting to enqueue request: networkClientId is required.', + ); + }); + it('skips the queue if the queue is empty and no request is being processed', async () => { const controller = buildQueuedRequestController(); @@ -70,9 +87,6 @@ describe('QueuedRequestController', () => { selectedNetworkClientId: 'selectedNetworkClientId', }), networkControllerSetActiveNetwork: mockSetActiveNetwork, - selectedNetworkControllerGetNetworkClientIdForDomain: jest - .fn() - .mockImplementation((_origin) => 'differentNetworkClientId'), }); const onNetworkSwitched = jest.fn(); messenger.subscribe( @@ -87,7 +101,11 @@ describe('QueuedRequestController', () => { }); await controller.enqueueRequest( - { ...buildRequest(), method: 'method_requiring_network_switch' }, + { + ...buildRequest(), + networkClientId: 'differentNetworkClientId', + method: 'method_requiring_network_switch', + }, () => new Promise((resolve) => setTimeout(resolve, 10)), ); @@ -107,9 +125,6 @@ describe('QueuedRequestController', () => { selectedNetworkClientId: 'selectedNetworkClientId', }), networkControllerSetActiveNetwork: mockSetActiveNetwork, - selectedNetworkControllerGetNetworkClientIdForDomain: jest - .fn() - .mockImplementation((_origin) => 'differentNetworkClientId'), }); const onNetworkSwitched = jest.fn(); messenger.subscribe( @@ -139,9 +154,6 @@ describe('QueuedRequestController', () => { selectedNetworkClientId: 'selectedNetworkClientId', }), networkControllerSetActiveNetwork: mockSetActiveNetwork, - selectedNetworkControllerGetNetworkClientIdForDomain: jest - .fn() - .mockImplementation((_origin) => 'selectedNetworkClientId'), }); const onNetworkSwitched = jest.fn(); messenger.subscribe( @@ -354,6 +366,79 @@ describe('QueuedRequestController', () => { expect(controller.state.queuedRequestCount).toBe(0); }); + it('processes queued requests on same origin but different network clientId', async () => { + const controller = buildQueuedRequestController(); + const executionOrder: string[] = []; + + const firstRequest = controller.enqueueRequest( + { + ...buildRequest(), + origin: 'https://example.metamask.io', + networkClientId: 'network1', + }, + async () => { + executionOrder.push('Request 1 (network1)'); + await new Promise((resolve) => setTimeout(resolve, 10)); + }, + ); + + // Ensure first request skips queue + expect(controller.state.queuedRequestCount).toBe(0); + + const secondRequest = controller.enqueueRequest( + { + ...buildRequest(), + origin: 'https://example.metamask.io', + networkClientId: 'network2', + }, + async () => { + executionOrder.push('Request 2 (network2)'); + await new Promise((resolve) => setTimeout(resolve, 10)); + }, + ); + + const thirdRequest = controller.enqueueRequest( + { + ...buildRequest(), + origin: 'https://example.metamask.io', + networkClientId: 'network1', + }, + async () => { + executionOrder.push('Request 3 (network1)'); + await new Promise((resolve) => setTimeout(resolve, 10)); + }, + ); + + const fourthRequest = controller.enqueueRequest( + { + ...buildRequest(), + origin: 'https://example.metamask.io', + networkClientId: 'network2', + }, + async () => { + executionOrder.push('Request 4 (network2)'); + await new Promise((resolve) => setTimeout(resolve, 10)); + }, + ); + + expect(controller.state.queuedRequestCount).toBe(3); + + await Promise.all([ + firstRequest, + secondRequest, + thirdRequest, + fourthRequest, + ]); + + expect(controller.state.queuedRequestCount).toBe(0); + expect(executionOrder).toStrictEqual([ + 'Request 1 (network1)', + 'Request 2 (network2)', + 'Request 3 (network1)', + 'Request 4 (network2)', + ]); + }); + it('preserves request order within each batch', async () => { const controller = buildQueuedRequestController(); const executionOrder: string[] = []; @@ -455,13 +540,6 @@ describe('QueuedRequestController', () => { selectedNetworkClientId: 'selectedNetworkClientId', }), networkControllerSetActiveNetwork: mockSetActiveNetwork, - selectedNetworkControllerGetNetworkClientIdForDomain: jest - .fn() - .mockImplementation((origin) => - origin === 'https://secondorigin.metamask.io' - ? 'differentNetworkClientId' - : 'selectedNetworkClientId', - ), }); const onNetworkSwitched = jest.fn(); messenger.subscribe( @@ -483,7 +561,11 @@ describe('QueuedRequestController', () => { () => new Promise((resolve) => setTimeout(resolve, 100)), ); const secondRequest = controller.enqueueRequest( - { ...buildRequest(), origin: 'https://secondorigin.metamask.io' }, + { + ...buildRequest(), + networkClientId: 'differentNetworkClientId', + origin: 'https://secondorigin.metamask.io', + }, secondRequestNext, ); // ensure test starts with one request queued up @@ -503,16 +585,14 @@ describe('QueuedRequestController', () => { }); it('does not switch networks if a new batch has the same network client', async () => { + const networkClientId = 'selectedNetworkClientId'; const mockSetActiveNetwork = jest.fn(); const { messenger } = buildControllerMessenger({ networkControllerGetState: jest.fn().mockReturnValue({ ...getDefaultNetworkControllerState(), - selectedNetworkClientId: 'selectedNetworkClientId', + selectedNetworkClientId: networkClientId, }), networkControllerSetActiveNetwork: mockSetActiveNetwork, - selectedNetworkControllerGetNetworkClientIdForDomain: jest - .fn() - .mockImplementation(() => 'selectedNetworkClientId'), }); const onNetworkSwitched = jest.fn(); messenger.subscribe( @@ -534,7 +614,11 @@ describe('QueuedRequestController', () => { () => new Promise((resolve) => setTimeout(resolve, 100)), ); const secondRequest = controller.enqueueRequest( - { ...buildRequest(), origin: 'https://secondorigin.metamask.io' }, + { + ...buildRequest(), + networkClientId, + origin: 'https://secondorigin.metamask.io', + }, secondRequestNext, ); // ensure test starts with one request queued up @@ -548,6 +632,244 @@ describe('QueuedRequestController', () => { expect(onNetworkSwitched).not.toHaveBeenCalled(); }); + it('queues request if a request from the same origin but different networkClientId is being processed', async () => { + const controller = buildQueuedRequestController(); + // Trigger first request + const firstRequest = controller.enqueueRequest( + { + ...buildRequest(), + origin: 'https://example.metamask.io', + networkClientId: 'network1', + }, + () => new Promise((resolve) => setTimeout(resolve, 10)), + ); + // ensure first request skips queue + expect(controller.state.queuedRequestCount).toBe(0); + + const secondRequestNext = jest.fn(); + const secondRequest = controller.enqueueRequest( + { + ...buildRequest(), + origin: 'https://example.metamask.io', + networkClientId: 'network2', + }, + secondRequestNext, + ); + + expect(controller.state.queuedRequestCount).toBe(1); + expect(secondRequestNext).not.toHaveBeenCalled(); + + await firstRequest; + await secondRequest; + }); + + it('processes requests from different origins but same networkClientId in separate batches without network switch', async () => { + const mockSetActiveNetwork = jest.fn(); + const { messenger } = buildControllerMessenger({ + networkControllerGetState: jest.fn().mockReturnValue({ + ...getDefaultNetworkControllerState(), + selectedNetworkClientId: 'network1', + }), + networkControllerSetActiveNetwork: mockSetActiveNetwork, + }); + const controller = buildQueuedRequestController({ + messenger: buildQueuedRequestControllerMessenger(messenger), + }); + + // Trigger first request + const firstRequest = controller.enqueueRequest( + { + ...buildRequest(), + origin: 'https://firstorigin.metamask.io', + networkClientId: 'network1', + }, + () => new Promise((resolve) => setTimeout(resolve, 10)), + ); + // Ensure first request skips queue + expect(controller.state.queuedRequestCount).toBe(0); + + const secondRequestNext = jest.fn(); + const secondRequest = controller.enqueueRequest( + { + ...buildRequest(), + origin: 'https://secondorigin.metamask.io', + networkClientId: 'network1', + }, + secondRequestNext, + ); + + expect(controller.state.queuedRequestCount).toBe(1); + expect(secondRequestNext).not.toHaveBeenCalled(); + + await firstRequest; + await secondRequest; + + expect(mockSetActiveNetwork).not.toHaveBeenCalled(); + }); + + it('switches networks between batches with different networkClientIds', async () => { + const mockSetActiveNetwork = jest.fn(); + const { messenger } = buildControllerMessenger({ + networkControllerGetState: jest.fn().mockReturnValue({ + ...getDefaultNetworkControllerState(), + selectedNetworkClientId: 'network1', + }), + networkControllerSetActiveNetwork: mockSetActiveNetwork, + }); + + const controller = buildQueuedRequestController({ + messenger: buildQueuedRequestControllerMessenger(messenger), + }); + + const firstRequest = controller.enqueueRequest( + { + ...buildRequest(), + origin: 'https://firstorigin.metamask.io', + networkClientId: 'network1', + }, + () => new Promise((resolve) => setTimeout(resolve, 10)), + ); + + expect(controller.state.queuedRequestCount).toBe(0); + + const secondRequestNext = jest.fn(); + const secondRequest = controller.enqueueRequest( + { + ...buildRequest(), + origin: 'https://secondorigin.metamask.io', + networkClientId: 'network2', + }, + secondRequestNext, + ); + + expect(controller.state.queuedRequestCount).toBe(1); + expect(secondRequestNext).not.toHaveBeenCalled(); + + await firstRequest; + + expect(mockSetActiveNetwork).toHaveBeenCalledWith('network2'); + + await secondRequest; + + expect(controller.state.queuedRequestCount).toBe(0); + + expect(secondRequestNext).toHaveBeenCalled(); + }); + + it('processes complex interleaved requests from multiple origins and networkClientIds correctly', async () => { + const events: string[] = []; + + const mockSetActiveNetwork = jest.fn((networkClientId: string) => { + events.push(`network switched to ${networkClientId}`); + return Promise.resolve(); + }); + + const { messenger } = buildControllerMessenger({ + networkControllerGetState: jest + .fn() + .mockReturnValueOnce({ + ...getDefaultNetworkControllerState(), + selectedNetworkClientId: 'NetworkClientId1', + }) + .mockReturnValueOnce({ + ...getDefaultNetworkControllerState(), + selectedNetworkClientId: 'NetworkClientId2', + }) + .mockReturnValueOnce({ + ...getDefaultNetworkControllerState(), + selectedNetworkClientId: 'NetworkClientId2', + }) + .mockReturnValueOnce({ + ...getDefaultNetworkControllerState(), + selectedNetworkClientId: 'NetworkClientId1', + }) + .mockReturnValueOnce({ + ...getDefaultNetworkControllerState(), + selectedNetworkClientId: 'NetworkClientId3', + }), + networkControllerSetActiveNetwork: mockSetActiveNetwork, + }); + + const controller = buildQueuedRequestController({ + messenger: buildQueuedRequestControllerMessenger(messenger), + }); + + const createRequestNext = (requestName: string) => + jest.fn(() => { + events.push(`${requestName} processed`); + return Promise.resolve(); + }); + + const request1Next = createRequestNext('request1'); + const request2Next = createRequestNext('request2'); + const request3Next = createRequestNext('request3'); + const request4Next = createRequestNext('request4'); + const request5Next = createRequestNext('request5'); + + const enqueueRequest = ( + origin: string, + networkClientId: string, + next: jest.Mock, + ) => + controller.enqueueRequest( + { + ...buildRequest(), + origin, + networkClientId, + }, + () => Promise.resolve(next()), + ); + + const request1Promise = enqueueRequest( + 'https://origin1.metamask.io', + 'NetworkClientId1', + request1Next, + ); + const request2Promise = enqueueRequest( + 'https://origin1.metamask.io', + 'NetworkClientId2', + request2Next, + ); + const request3Promise = enqueueRequest( + 'https://origin2.metamask.io', + 'NetworkClientId2', + request3Next, + ); + const request4Promise = enqueueRequest( + 'https://origin2.metamask.io', + 'NetworkClientId1', + request4Next, + ); + const request5Promise = enqueueRequest( + 'https://origin1.metamask.io', + 'NetworkClientId3', + request5Next, + ); + + expect(controller.state.queuedRequestCount).toBe(4); + + await request1Promise; + await request2Promise; + await request3Promise; + await request4Promise; + await request5Promise; + + expect(events).toStrictEqual([ + 'request1 processed', + 'network switched to NetworkClientId2', + 'request2 processed', + 'request3 processed', + 'network switched to NetworkClientId1', + 'request4 processed', + 'network switched to NetworkClientId3', + 'request5 processed', + ]); + + expect(mockSetActiveNetwork).toHaveBeenCalledTimes(3); + + expect(controller.state.queuedRequestCount).toBe(0); + }); + describe('when the network switch for a single request fails', () => { it('throws error', async () => { const switchError = new Error('switch error'); @@ -559,9 +881,6 @@ describe('QueuedRequestController', () => { networkControllerSetActiveNetwork: jest .fn() .mockRejectedValue(switchError), - selectedNetworkControllerGetNetworkClientIdForDomain: jest - .fn() - .mockImplementation((_origin) => 'differentNetworkClientId'), }); const controller = buildQueuedRequestController({ messenger: buildQueuedRequestControllerMessenger(messenger), @@ -573,6 +892,7 @@ describe('QueuedRequestController', () => { controller.enqueueRequest( { ...buildRequest(), + networkClientId: 'differentNetworkClientId', method: 'method_requiring_network_switch', origin: 'https://example.metamask.io', }, @@ -590,35 +910,26 @@ describe('QueuedRequestController', () => { }), networkControllerSetActiveNetwork: jest .fn() - .mockRejectedValue(switchError), - selectedNetworkControllerGetNetworkClientIdForDomain: jest - .fn() - .mockImplementation((origin) => - origin === 'https://firstorigin.metamask.io' - ? 'differentNetworkClientId' - : 'selectedNetworkClientId', - ), + .mockRejectedValueOnce(switchError), }); const controller = buildQueuedRequestController({ messenger: buildQueuedRequestControllerMessenger(messenger), shouldRequestSwitchNetwork: ({ method }) => method === 'method_requiring_network_switch', }); + const firstRequest = controller.enqueueRequest( { ...buildRequest(), + networkClientId: 'differentNetworkClientId', method: 'method_requiring_network_switch', origin: 'https://firstorigin.metamask.io', }, () => new Promise((resolve) => setTimeout(resolve, 10)), ); - // ensure first request skips queue expect(controller.state.queuedRequestCount).toBe(0); - const secondRequestNext = jest - .fn() - .mockImplementation( - () => new Promise((resolve) => setTimeout(resolve, 100)), - ); + + const secondRequestNext = jest.fn().mockResolvedValue(undefined); const secondRequest = controller.enqueueRequest( { ...buildRequest(), @@ -628,7 +939,7 @@ describe('QueuedRequestController', () => { secondRequestNext, ); - await expect(firstRequest).rejects.toThrow(switchError); + await expect(firstRequest).rejects.toThrow('switch error'); await secondRequest; expect(secondRequestNext).toHaveBeenCalled(); @@ -638,27 +949,23 @@ describe('QueuedRequestController', () => { describe('when the network switch for a batch fails', () => { it('throws error', async () => { const switchError = new Error('switch error'); + const { messenger } = buildControllerMessenger({ networkControllerGetState: jest.fn().mockReturnValue({ ...getDefaultNetworkControllerState(), - selectedNetworkClientId: 'selectedNetworkClientId', + selectedNetworkClientId: 'mainnet', }), networkControllerSetActiveNetwork: jest .fn() - .mockRejectedValue(switchError), - selectedNetworkControllerGetNetworkClientIdForDomain: jest - .fn() - .mockImplementation((origin) => - origin === 'https://secondorigin.metamask.io' - ? 'differentNetworkClientId' - : 'selectedNetworkClientId', - ), + .mockRejectedValueOnce(switchError), }); const controller = buildQueuedRequestController({ messenger: buildQueuedRequestControllerMessenger(messenger), shouldRequestSwitchNetwork: ({ method }) => method === 'method_requiring_network_switch', }); + + // no switch required const firstRequest = controller.enqueueRequest( { ...buildRequest(), @@ -677,6 +984,7 @@ describe('QueuedRequestController', () => { const secondRequest = controller.enqueueRequest( { ...buildRequest(), + networkClientId: 'differentNetworkClientId', method: 'method_requiring_network_switch', origin: 'https://secondorigin.metamask.io', }, @@ -695,18 +1003,11 @@ describe('QueuedRequestController', () => { const { messenger } = buildControllerMessenger({ networkControllerGetState: jest.fn().mockReturnValue({ ...getDefaultNetworkControllerState(), - selectedNetworkClientId: 'selectedNetworkClientId', + selectedNetworkClientId: 'mainnet', }), networkControllerSetActiveNetwork: jest .fn() - .mockRejectedValue(switchError), - selectedNetworkControllerGetNetworkClientIdForDomain: jest - .fn() - .mockImplementation((origin) => - origin === 'https://secondorigin.metamask.io' - ? 'differentNetworkClientId' - : 'selectedNetworkClientId', - ), + .mockRejectedValueOnce(switchError), }); const controller = buildQueuedRequestController({ messenger: buildQueuedRequestControllerMessenger(messenger), @@ -731,6 +1032,7 @@ describe('QueuedRequestController', () => { const secondRequest = controller.enqueueRequest( { ...buildRequest(), + networkClientId: 'differentNetworkClientId', method: 'method_requiring_network_switch', origin: 'https://secondorigin.metamask.io', }, @@ -969,19 +1271,15 @@ describe('QueuedRequestController', () => { * action. * @param options.networkControllerSetActiveNetwork - A handler for the * `NetworkController:setActiveNetwork` action. - * @param options.selectedNetworkControllerGetNetworkClientIdForDomain - A handler for the - * `SelectedNetworkController:getNetworkClientIdForDomain` action. * @returns A controller messenger with QueuedRequestController types, and * mocks for all allowed actions. */ function buildControllerMessenger({ networkControllerGetState, networkControllerSetActiveNetwork, - selectedNetworkControllerGetNetworkClientIdForDomain, }: { networkControllerGetState?: NetworkControllerGetStateAction['handler']; networkControllerSetActiveNetwork?: NetworkControllerSetActiveNetworkAction['handler']; - selectedNetworkControllerGetNetworkClientIdForDomain?: SelectedNetworkControllerGetNetworkClientIdForDomainAction['handler']; } = {}): { messenger: ControllerMessenger< QueuedRequestControllerActions | AllowedActions, @@ -993,9 +1291,6 @@ function buildControllerMessenger({ mockNetworkControllerSetActiveNetwork: jest.Mocked< NetworkControllerSetActiveNetworkAction['handler'] >; - mockSelectedNetworkControllerGetNetworkClientIdForDomain: jest.Mocked< - SelectedNetworkControllerGetNetworkClientIdForDomainAction['handler'] - >; } { const messenger = new ControllerMessenger< QueuedRequestControllerActions | AllowedActions, @@ -1018,17 +1313,11 @@ function buildControllerMessenger({ 'NetworkController:setActiveNetwork', mockNetworkControllerSetActiveNetwork, ); - const mockSelectedNetworkControllerGetNetworkClientIdForDomain = - selectedNetworkControllerGetNetworkClientIdForDomain ?? jest.fn(); - messenger.registerActionHandler( - 'SelectedNetworkController:getNetworkClientIdForDomain', - mockSelectedNetworkControllerGetNetworkClientIdForDomain, - ); + return { messenger, mockNetworkControllerGetState, mockNetworkControllerSetActiveNetwork, - mockSelectedNetworkControllerGetNetworkClientIdForDomain, }; } @@ -1046,7 +1335,6 @@ function buildQueuedRequestControllerMessenger( allowedActions: [ 'NetworkController:getState', 'NetworkController:setActiveNetwork', - 'SelectedNetworkController:getNetworkClientIdForDomain', ], allowedEvents: ['SelectedNetworkController:stateChange'], }); diff --git a/packages/queued-request-controller/src/QueuedRequestController.ts b/packages/queued-request-controller/src/QueuedRequestController.ts index 712caa5b32..0f35ffdfdd 100644 --- a/packages/queued-request-controller/src/QueuedRequestController.ts +++ b/packages/queued-request-controller/src/QueuedRequestController.ts @@ -5,13 +5,11 @@ import type { } from '@metamask/base-controller'; import { BaseController } from '@metamask/base-controller'; import type { + NetworkClientId, NetworkControllerGetStateAction, NetworkControllerSetActiveNetworkAction, } from '@metamask/network-controller'; -import type { - SelectedNetworkControllerGetNetworkClientIdForDomainAction, - SelectedNetworkControllerStateChangeEvent, -} from '@metamask/selected-network-controller'; +import type { SelectedNetworkControllerStateChangeEvent } from '@metamask/selected-network-controller'; import { SelectedNetworkControllerEventTypes } from '@metamask/selected-network-controller'; import { createDeferredPromise } from '@metamask/utils'; @@ -64,8 +62,7 @@ export type QueuedRequestControllerActions = export type AllowedActions = | NetworkControllerGetStateAction - | NetworkControllerSetActiveNetworkAction - | SelectedNetworkControllerGetNetworkClientIdForDomainAction; + | NetworkControllerSetActiveNetworkAction; export type AllowedEvents = SelectedNetworkControllerStateChangeEvent; @@ -94,6 +91,12 @@ type QueuedRequest = { * The origin of the queued request. */ origin: string; + + /** + * The networkClientId of the queuedRequest. + */ + networkClientId: NetworkClientId; + /** * A callback used to continue processing the request, called when the request is dequeued. */ @@ -125,6 +128,12 @@ export class QueuedRequestController extends BaseController< */ #originOfCurrentBatch: string | undefined; + /** + * The networkClientId of the current batch of requests being processed, or `undefined` if there are no + * requests currently being processed. + */ + #networkClientIdOfCurrentBatch?: NetworkClientId; + /** * The list of all queued requests, in chronological order. */ @@ -224,6 +233,9 @@ export class QueuedRequestController extends BaseController< ); } + // Note: since we're using queueing for multichain requests to start, this flush could incorrectly flush + // multichain requests if the user switches networks on a dapp while multichain request is in the queue. + // we intend to remove queueing for multichain requests in the future, so for now we have to live with this. #flushQueueForOrigin(flushOrigin: string) { this.#requestQueue .filter(({ origin }) => origin === flushOrigin) @@ -252,17 +264,24 @@ export class QueuedRequestController extends BaseController< async #processNextBatch() { const firstRequest = this.#requestQueue.shift() as QueuedRequest; this.#originOfCurrentBatch = firstRequest.origin; + this.#networkClientIdOfCurrentBatch = firstRequest.networkClientId; const batch = [firstRequest.processRequest]; - while (this.#requestQueue[0]?.origin === this.#originOfCurrentBatch) { + + // alternatively we could still batch by only origin but switch networks in batches by + // adding the network clientId to the values in the batch array + while ( + this.#requestQueue[0]?.networkClientId === + this.#networkClientIdOfCurrentBatch && + this.#requestQueue[0]?.origin === this.#originOfCurrentBatch + ) { const nextEntry = this.#requestQueue.shift() as QueuedRequest; batch.push(nextEntry.processRequest); } - // If globally selected network is different from origin selected network, // switch network before processing batch let networkSwitchError: unknown; try { - await this.#switchNetworkIfNecessary(); + await this.#switchNetworkIfNecessary(firstRequest.networkClientId); } catch (error: unknown) { networkSwitchError = error; } @@ -277,34 +296,27 @@ export class QueuedRequestController extends BaseController< * Switch the globally selected network client to match the network * client of the current batch. * + * @param requestNetworkClientId - the networkClientId of the next request to process. * @throws Throws an error if the current selected `networkClientId` or the * `networkClientId` on the request are invalid. */ - async #switchNetworkIfNecessary() { - // This branch is unreachable; it's just here for type reasons. - /* istanbul ignore next */ - if (!this.#originOfCurrentBatch) { - throw new Error('Current batch origin must be initialized first'); - } - const originNetworkClientId = this.messagingSystem.call( - 'SelectedNetworkController:getNetworkClientIdForDomain', - this.#originOfCurrentBatch, - ); + async #switchNetworkIfNecessary(requestNetworkClientId: NetworkClientId) { const { selectedNetworkClientId } = this.messagingSystem.call( 'NetworkController:getState', ); - if (originNetworkClientId === selectedNetworkClientId) { + + if (requestNetworkClientId === selectedNetworkClientId) { return; } await this.messagingSystem.call( 'NetworkController:setActiveNetwork', - originNetworkClientId, + requestNetworkClientId, ); this.messagingSystem.publish( 'QueuedRequestController:networkSwitched', - originNetworkClientId, + requestNetworkClientId, ); } @@ -317,12 +329,19 @@ export class QueuedRequestController extends BaseController< }); } - async #waitForDequeue(origin: string): Promise { + async #waitForDequeue({ + origin, + networkClientId, + }: { + origin: string; + networkClientId: NetworkClientId; + }): Promise { const { promise, reject, resolve } = createDeferredPromise({ suppressUnhandledRejection: true, }); this.#requestQueue.push({ origin, + networkClientId, processRequest: (error: unknown) => { if (error) { reject(error); @@ -354,23 +373,36 @@ export class QueuedRequestController extends BaseController< request: QueuedRequestMiddlewareJsonRpcRequest, requestNext: () => Promise, ): Promise { + if (request.networkClientId === undefined) { + // This error will occur if selectedNetworkMiddleware does not precede queuedRequestMiddleware in the middleware stack + throw new Error( + 'Error while attempting to enqueue request: networkClientId is required.', + ); + } if (this.#originOfCurrentBatch === undefined) { this.#originOfCurrentBatch = request.origin; } + if (this.#networkClientIdOfCurrentBatch === undefined) { + this.#networkClientIdOfCurrentBatch = request.networkClientId; + } try { // Queue request for later processing // Network switch is handled when this batch is processed if ( this.state.queuedRequestCount > 0 || - this.#originOfCurrentBatch !== request.origin + this.#originOfCurrentBatch !== request.origin || + this.#networkClientIdOfCurrentBatch !== request.networkClientId ) { this.#showApprovalRequest(); - await this.#waitForDequeue(request.origin); + await this.#waitForDequeue({ + origin: request.origin, + networkClientId: request.networkClientId, + }); } else if (this.#shouldRequestSwitchNetwork(request)) { // Process request immediately // Requires switching network now if necessary - await this.#switchNetworkIfNecessary(); + await this.#switchNetworkIfNecessary(request.networkClientId); } this.#processingRequestCount += 1; try { @@ -382,6 +414,7 @@ export class QueuedRequestController extends BaseController< } finally { if (this.#processingRequestCount === 0) { this.#originOfCurrentBatch = undefined; + this.#networkClientIdOfCurrentBatch = undefined; if (this.#requestQueue.length > 0) { // The next batch is triggered here. We intentionally omit the `await` because we don't // want the next batch to block resolution of the current request.