diff --git a/packages/get-starknet/src/__tests__/helper.ts b/packages/get-starknet/src/__tests__/helper.ts index 8f949f76..524fe281 100644 --- a/packages/get-starknet/src/__tests__/helper.ts +++ b/packages/get-starknet/src/__tests__/helper.ts @@ -74,11 +74,25 @@ export class MockProvider implements MetaMaskProvider { request = jest.fn(); } +export class MockMetaMaskSnapWallet extends MetaMaskSnapWallet { + public pollingDelayMs = 0; + + public pollingTimeoutMs = 0; + + public startPolling(): void { + super.startPolling(); + } + + public stopPolling(): void { + super.stopPolling(); + } +} + /** * Create a wallet instance. */ export function createWallet() { - return new MetaMaskSnapWallet(new MockProvider()); + return new MockMetaMaskSnapWallet(new MockProvider()); } /** @@ -101,17 +115,17 @@ export function mockWalletInit({ }) { const installSpy = jest.spyOn(MetaMaskSnap.prototype, 'installIfNot'); const getCurrentNetworkSpy = jest.spyOn(MetaMaskSnap.prototype, 'getCurrentNetwork'); - const recoverDefaultAccountSpy = jest.spyOn(MetaMaskSnap.prototype, 'recoverDefaultAccount'); + const getCurrentAccountSpy = jest.spyOn(MetaMaskSnap.prototype, 'getCurrentAccount'); const initSpy = jest.spyOn(MetaMaskSnapWallet.prototype, 'init'); installSpy.mockResolvedValue(install); getCurrentNetworkSpy.mockResolvedValue(currentNetwork); - recoverDefaultAccountSpy.mockResolvedValue(generateAccount({ address })); + getCurrentAccountSpy.mockResolvedValue(generateAccount({ address })); return { initSpy, installSpy, getCurrentNetworkSpy, - recoverDefaultAccountSpy, + getCurrentAccountSpy, }; } diff --git a/packages/get-starknet/src/snap.ts b/packages/get-starknet/src/snap.ts index 17149863..4798514b 100644 --- a/packages/get-starknet/src/snap.ts +++ b/packages/get-starknet/src/snap.ts @@ -229,42 +229,20 @@ export class MetaMaskSnap { return network; } - async recoverDefaultAccount(chainId: string): Promise { - const result = await this.recoverAccounts({ - chainId, - startScanIndex: 0, - maxScanned: 1, - maxMissed: 1, - }); - return result[0]; - } - - async recoverAccounts({ - chainId, - startScanIndex = 0, - maxScanned = 1, - maxMissed = 1, - }: { - chainId?: string; - startScanIndex?: number; - maxScanned?: number; - maxMissed?: number; - }): Promise { + async getCurrentAccount({ chainId, fromState }: { chainId?: string; fromState?: boolean }): Promise { return (await this.#provider.request({ method: 'wallet_invokeSnap', params: { snapId: this.#snapId, request: { - method: 'starkNet_recoverAccounts', + method: 'starkNet_getCurrentAccount', params: await this.#getSnapParams({ - startScanIndex, - maxScanned, - maxMissed, chainId, + fromState, }), }, }, - })) as AccContract[]; + })) as AccContract; } async switchNetwork(chainId: string): Promise { @@ -424,6 +402,10 @@ export class MetaMaskSnap { } async installIfNot(): Promise { + // if the snap is already installed, return true, to bypass the prompt + if (await this.isInstalled()) { + return true; + } const response = (await this.#provider.request({ method: 'wallet_requestSnaps', params: { diff --git a/packages/get-starknet/src/wallet.test.ts b/packages/get-starknet/src/wallet.test.ts index 135f4bb0..13f14ba2 100644 --- a/packages/get-starknet/src/wallet.test.ts +++ b/packages/get-starknet/src/wallet.test.ts @@ -1,12 +1,39 @@ import { Mutex } from 'async-mutex'; +import type { WalletEventHandlers } from 'get-starknet-core'; import { Provider } from 'starknet'; -import { SepoliaNetwork, mockWalletInit, createWallet } from './__tests__/helper'; +import { + SepoliaNetwork, + mockWalletInit, + createWallet, + generateAccount, + MainnetNetwork, + MockMetaMaskSnapWallet, +} from './__tests__/helper'; import { MetaMaskAccount } from './accounts'; import { WalletSupportedSpecs } from './rpcs'; -import type { AccContract, Network } from './type'; +import type { Network } from './type'; describe('MetaMaskSnapWallet', () => { + const setupEventTest = async (eventName: keyof WalletEventHandlers) => { + const handlers = [jest.fn(), jest.fn()]; + const wallet = createWallet(); + + for (const handler of handlers) { + wallet.on(eventName, handler); + } + + // Having a delay to make sure the polling is done + await new Promise((resolve) => setTimeout(resolve, 100)); + + // Due to polling operation is a endless loop, we need to stop it manually, + for (const handler of handlers) { + wallet.off(eventName, handler); + } + + return { handlers, wallet }; + }; + describe('enable', () => { it('returns an account address', async () => { const expectedAccountAddress = '0x04882a372da3dfe1c53170ad75893832469bf87b62b13e84662565c4a88f25cd'; // in hex @@ -19,15 +46,6 @@ describe('MetaMaskSnapWallet', () => { expect(address).toStrictEqual(expectedAccountAddress); }); - - it('throws `Unable to recover accounts` error if the account address not return from the Snap', async () => { - const { recoverDefaultAccountSpy } = mockWalletInit({}); - recoverDefaultAccountSpy.mockResolvedValue({} as unknown as AccContract); - - const wallet = createWallet(); - - await expect(wallet.enable()).rejects.toThrow('Unable to recover accounts'); - }); }); describe('init', () => { @@ -137,18 +155,94 @@ describe('MetaMaskSnapWallet', () => { }); describe('on', () => { - it('does nothing and not throw any error', async () => { + it('adds an event handler and starts polling if not already started', async () => { + mockWalletInit({}); + const addHandlerSpy = jest.spyOn(Set.prototype, 'add'); + const startPollingSpy = jest.spyOn(MockMetaMaskSnapWallet.prototype, 'startPolling'); + + const { handlers } = await setupEventTest('accountsChanged'); + + expect(addHandlerSpy).toHaveBeenCalledTimes(handlers.length); + for (let i = 0; i < handlers.length; i++) { + expect(addHandlerSpy).toHaveBeenNthCalledWith(i + 1, handlers[i]); + } + expect(startPollingSpy).toHaveBeenCalledTimes(1); + }); + + it('throws an error for unsupported events', () => { const wallet = createWallet(); - expect(() => wallet.on('accountsChanged', jest.fn())).not.toThrow(); + expect(() => wallet.on('unsupportedEvent' as any, jest.fn())).toThrow('Unsupported event: unsupportedEvent'); }); }); describe('off', () => { - it('does nothing and not throw any error', async () => { + it('removes an event handler and stops polling if no handlers remain', async () => { + mockWalletInit({}); + const deleteHandlerSpy = jest.spyOn(Set.prototype, 'delete'); + const stopPollingSpy = jest.spyOn(MockMetaMaskSnapWallet.prototype, 'stopPolling'); + + const { handlers } = await setupEventTest('accountsChanged'); + + expect(deleteHandlerSpy).toHaveBeenCalledTimes(handlers.length); + for (let i = 0; i < handlers.length; i++) { + expect(deleteHandlerSpy).toHaveBeenNthCalledWith(i + 1, handlers[i]); + } + expect(stopPollingSpy).toHaveBeenCalledTimes(1); + }); + + it('throws an error for unsupported events', () => { const wallet = createWallet(); + expect(() => wallet.off('unsupportedEvent' as any, jest.fn())).toThrow('Unsupported event: unsupportedEvent'); + }); + }); + + describe('event handling', () => { + it('dispatchs a `accountsChanged` event', async () => { + const { address: initialAddress } = generateAccount({ address: '0xInitialAddress' }); + const { address: newAddress } = generateAccount({ address: '0xNewAddress' }); + + // The code simulates a scenario where the initial address is the default account address. + // Later, the address is changed to a new address and remains unchanged. + // - `mockResolvedValueOnce` sets the initial address as the default account address. + // - `mockResolvedValue` from `mockWalletInit` sets the new address as the new default. + const { getCurrentAccountSpy } = mockWalletInit({ address: newAddress }); + getCurrentAccountSpy.mockResolvedValueOnce(generateAccount({ address: initialAddress })); - expect(() => wallet.off('accountsChanged', jest.fn())).not.toThrow(); + const { handlers } = await setupEventTest('accountsChanged'); + + for (const handler of handlers) { + expect(handler).toHaveBeenCalledWith([newAddress], undefined); + } + }); + + it('dispatchs a `networkChanged` event', async () => { + // The code simulates a scenario where the MainnetNetwork is the default network. + // Later, the network is changed to SepoliaNetwork and remains unchanged. + // - `mockResolvedValueOnce` sets the MainnetNetwork as the default network. + // - `mockResolvedValue` from `mockWalletInit` sets the SepoliaNetwork as the new network. + const { address } = generateAccount({}); + const { getCurrentNetworkSpy } = mockWalletInit({ currentNetwork: SepoliaNetwork, address }); + getCurrentNetworkSpy.mockResolvedValueOnce(MainnetNetwork); + + const { handlers } = await setupEventTest('networkChanged'); + + for (const handler of handlers) { + expect(handler).toHaveBeenCalledWith(SepoliaNetwork.chainId, [address]); + } }); + + it.each(['accountsChanged', 'networkChanged'])( + 'does not dispatchs a %s event if the wallet object is not initialized yet', + async (event: keyof WalletEventHandlers) => { + mockWalletInit({}); + + const { handlers } = await setupEventTest(event); + + for (const handler of handlers) { + expect(handler).toHaveBeenCalledTimes(0); + } + }, + ); }); }); diff --git a/packages/get-starknet/src/wallet.ts b/packages/get-starknet/src/wallet.ts index ef00e58d..28547afc 100644 --- a/packages/get-starknet/src/wallet.ts +++ b/packages/get-starknet/src/wallet.ts @@ -69,13 +69,15 @@ export class MetaMaskSnapWallet implements StarknetWindowObject { #pollingController: AbortController | undefined; + // eslint-disable-next-line no-restricted-syntax #accountChangeHandlers: Set = new Set(); + // eslint-disable-next-line no-restricted-syntax #networkChangeHandlers: Set = new Set(); - static readonly pollingDelayMs = 100; + protected pollingDelayMs = 100; - static readonly pollingTimeoutMs = 5000; + protected pollingTimeoutMs = 5000; // eslint-disable-next-line @typescript-eslint/naming-convention, no-restricted-globals static readonly snapId = process.env.SNAP_ID ?? 'npm:@consensys/starknet-snap'; @@ -141,10 +143,10 @@ export class MetaMaskSnapWallet implements StarknetWindowObject { } async #getWalletAddress(chainId: string) { - const accountResponse = await this.snap.recoverDefaultAccount(chainId); + const accountResponse = await this.snap.getCurrentAccount({ chainId, fromState: true }); if (!accountResponse?.address) { - throw new Error('Unable to recover accounts'); + throw new Error('Unable to retrieve the wallet account'); } return accountResponse.address; @@ -227,19 +229,23 @@ export class MetaMaskSnapWallet implements StarknetWindowObject { throw new Error('Unable to find the selected network'); } + const address = await this.#getWalletAddress(network.chainId); if (!this.#network || network.chainId !== this.#network.chainId) { - // address is depends on network, if network changes, address will update - this.#selectedAddress = await this.#getWalletAddress(network.chainId); // provider is depends on network.nodeUrl, if network changes, set provider to undefine for reinitialization this.#provider = undefined; // account is depends on address and provider, if network changes, address will update, // hence set account to undefine for reinitialization - // TODO : This should be removed. The walletAccount is created with the SWO as input. - // This means account is not managed from within the SWO but from outside. - // Event handling helps ensure that the correct address is set. this.#account = undefined; } + if (address !== this.#selectedAddress) { + // account is depend on address, + // hence set account to undefine for reinitialization + this.#account = undefined; + } + + this.#selectedAddress = address; + this.#network = network; this.#chainId = network.chainId; this.isConnected = true; @@ -281,7 +287,7 @@ export class MetaMaskSnapWallet implements StarknetWindowObject { throw new Error(`Unsupported event: ${String(event)}`); } if (!this.#pollingController) { - this.#startPolling(); + this.startPolling(); } } @@ -300,7 +306,7 @@ export class MetaMaskSnapWallet implements StarknetWindowObject { throw new Error(`Unsupported event: ${String(event)}`); } if (this.#accountChangeHandlers.size + this.#networkChangeHandlers.size === 0) { - this.#stopPolling(); + this.stopPolling(); } } @@ -322,7 +328,7 @@ export class MetaMaskSnapWallet implements StarknetWindowObject { while (!signal.aborted) { // Early exit if there are no handlers left if (this.#accountChangeHandlers.size + this.#networkChangeHandlers.size === 0) { - this.#stopPolling(); + this.stopPolling(); return; } @@ -337,13 +343,15 @@ export class MetaMaskSnapWallet implements StarknetWindowObject { // Fetch network, assign address and chainId for thread safe. this.#init(), new Promise((_, reject) => - // Timeout after `MetaMaskSnapWallet.pollingTimeoutMs`. - setTimeout(() => reject(new Error('Polling timeout exceeded')), MetaMaskSnapWallet.pollingTimeoutMs), + // Timeout after `this.pollingTimeoutMs`. + setTimeout(() => reject(new Error('Polling timeout exceeded')), this.pollingTimeoutMs), ), ]); - // Check for network change - if (previousNetwork !== this.#chainId) { + // By checking the previous network is undefined + // it will not sending event to client when the wallet object initialized first time + if (previousNetwork !== this.#chainId && previousNetwork !== undefined) { + // With `Promise.allSettled`, we can handle all promises and continue even if some fail. await Promise.allSettled( Array.from(this.#networkChangeHandlers).map(async (callback) => resolver(callback, this.#chainId, [this.#selectedAddress]), @@ -351,8 +359,10 @@ export class MetaMaskSnapWallet implements StarknetWindowObject { ); } - // Check for account change - if (previousAddress !== this.#selectedAddress) { + // By checking the previous address is undefined + // it will not sending event to client when the wallet object initialized first tim + if (previousAddress !== this.#selectedAddress && previousAddress !== undefined) { + // With `Promise.allSettled`, we can handle all promises and continue even if some fail. await Promise.allSettled( Array.from(this.#accountChangeHandlers).map(async (callback) => resolver(callback, [this.#selectedAddress]), @@ -363,17 +373,17 @@ export class MetaMaskSnapWallet implements StarknetWindowObject { // Silently handle errors to avoid breaking the loop } - await new Promise((resolve) => setTimeout(resolve, MetaMaskSnapWallet.pollingDelayMs)); + await new Promise((resolve) => setTimeout(resolve, this.pollingDelayMs)); } }; - #startPolling(): void { + protected startPolling(): void { this.#pollingController = new AbortController(); // eslint-disable-next-line @typescript-eslint/no-floating-promises this.#pollingFunction(); } - #stopPolling(): void { + protected stopPolling(): void { if (this.#pollingController) { this.#pollingController.abort(); this.#pollingController = undefined;