Skip to content

Commit

Permalink
feat: enable get-starknet to watch account change (#505)
Browse files Browse the repository at this point in the history
* refactor: add account discover service

* feat: support multiple account in SNAP

* feat: watch the SNAP current account

* fix: lint

---------

Co-authored-by: Florin Dzeladini <[email protected]>
  • Loading branch information
stanleyyconsensys and khanti42 authored Feb 6, 2025
1 parent 6124ed3 commit c129f1c
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 66 deletions.
22 changes: 18 additions & 4 deletions packages/get-starknet/src/__tests__/helper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,25 @@ export class MockProvider implements MetaMaskProvider {
request = jest.fn();
}

export class MockMetaMaskSnapWallet extends MetaMaskSnapWallet {
public pollingDelayMs = 0;

public pollingTimeoutMs = 0;

public startPolling(): void {
super.startPolling();
}

public stopPolling(): void {
super.stopPolling();
}
}

/**
* Create a wallet instance.
*/
export function createWallet() {
return new MetaMaskSnapWallet(new MockProvider());
return new MockMetaMaskSnapWallet(new MockProvider());
}

/**
Expand All @@ -101,17 +115,17 @@ export function mockWalletInit({
}) {
const installSpy = jest.spyOn(MetaMaskSnap.prototype, 'installIfNot');
const getCurrentNetworkSpy = jest.spyOn(MetaMaskSnap.prototype, 'getCurrentNetwork');
const recoverDefaultAccountSpy = jest.spyOn(MetaMaskSnap.prototype, 'recoverDefaultAccount');
const getCurrentAccountSpy = jest.spyOn(MetaMaskSnap.prototype, 'getCurrentAccount');
const initSpy = jest.spyOn(MetaMaskSnapWallet.prototype, 'init');

installSpy.mockResolvedValue(install);
getCurrentNetworkSpy.mockResolvedValue(currentNetwork);
recoverDefaultAccountSpy.mockResolvedValue(generateAccount({ address }));
getCurrentAccountSpy.mockResolvedValue(generateAccount({ address }));

return {
initSpy,
installSpy,
getCurrentNetworkSpy,
recoverDefaultAccountSpy,
getCurrentAccountSpy,
};
}
34 changes: 8 additions & 26 deletions packages/get-starknet/src/snap.ts
Original file line number Diff line number Diff line change
Expand Up @@ -229,42 +229,20 @@ export class MetaMaskSnap {
return network;
}

async recoverDefaultAccount(chainId: string): Promise<AccContract> {
const result = await this.recoverAccounts({
chainId,
startScanIndex: 0,
maxScanned: 1,
maxMissed: 1,
});
return result[0];
}

async recoverAccounts({
chainId,
startScanIndex = 0,
maxScanned = 1,
maxMissed = 1,
}: {
chainId?: string;
startScanIndex?: number;
maxScanned?: number;
maxMissed?: number;
}): Promise<AccContract[]> {
async getCurrentAccount({ chainId, fromState }: { chainId?: string; fromState?: boolean }): Promise<AccContract> {
return (await this.#provider.request({
method: 'wallet_invokeSnap',
params: {
snapId: this.#snapId,
request: {
method: 'starkNet_recoverAccounts',
method: 'starkNet_getCurrentAccount',
params: await this.#getSnapParams({
startScanIndex,
maxScanned,
maxMissed,
chainId,
fromState,
}),
},
},
})) as AccContract[];
})) as AccContract;
}

async switchNetwork(chainId: string): Promise<boolean> {
Expand Down Expand Up @@ -424,6 +402,10 @@ export class MetaMaskSnap {
}

async installIfNot(): Promise<boolean> {
// if the snap is already installed, return true, to bypass the prompt
if (await this.isInstalled()) {
return true;
}
const response = (await this.#provider.request({
method: 'wallet_requestSnaps',
params: {
Expand Down
124 changes: 109 additions & 15 deletions packages/get-starknet/src/wallet.test.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,39 @@
import { Mutex } from 'async-mutex';
import type { WalletEventHandlers } from 'get-starknet-core';
import { Provider } from 'starknet';

import { SepoliaNetwork, mockWalletInit, createWallet } from './__tests__/helper';
import {
SepoliaNetwork,
mockWalletInit,
createWallet,
generateAccount,
MainnetNetwork,
MockMetaMaskSnapWallet,
} from './__tests__/helper';
import { MetaMaskAccount } from './accounts';
import { WalletSupportedSpecs } from './rpcs';
import type { AccContract, Network } from './type';
import type { Network } from './type';

describe('MetaMaskSnapWallet', () => {
const setupEventTest = async (eventName: keyof WalletEventHandlers) => {
const handlers = [jest.fn(), jest.fn()];
const wallet = createWallet();

for (const handler of handlers) {
wallet.on(eventName, handler);
}

// Having a delay to make sure the polling is done
await new Promise((resolve) => setTimeout(resolve, 100));

// Due to polling operation is a endless loop, we need to stop it manually,
for (const handler of handlers) {
wallet.off(eventName, handler);
}

return { handlers, wallet };
};

describe('enable', () => {
it('returns an account address', async () => {
const expectedAccountAddress = '0x04882a372da3dfe1c53170ad75893832469bf87b62b13e84662565c4a88f25cd'; // in hex
Expand All @@ -19,15 +46,6 @@ describe('MetaMaskSnapWallet', () => {

expect(address).toStrictEqual(expectedAccountAddress);
});

it('throws `Unable to recover accounts` error if the account address not return from the Snap', async () => {
const { recoverDefaultAccountSpy } = mockWalletInit({});
recoverDefaultAccountSpy.mockResolvedValue({} as unknown as AccContract);

const wallet = createWallet();

await expect(wallet.enable()).rejects.toThrow('Unable to recover accounts');
});
});

describe('init', () => {
Expand Down Expand Up @@ -137,18 +155,94 @@ describe('MetaMaskSnapWallet', () => {
});

describe('on', () => {
it('does nothing and not throw any error', async () => {
it('adds an event handler and starts polling if not already started', async () => {
mockWalletInit({});
const addHandlerSpy = jest.spyOn(Set.prototype, 'add');
const startPollingSpy = jest.spyOn(MockMetaMaskSnapWallet.prototype, 'startPolling');

const { handlers } = await setupEventTest('accountsChanged');

expect(addHandlerSpy).toHaveBeenCalledTimes(handlers.length);
for (let i = 0; i < handlers.length; i++) {
expect(addHandlerSpy).toHaveBeenNthCalledWith(i + 1, handlers[i]);
}
expect(startPollingSpy).toHaveBeenCalledTimes(1);
});

it('throws an error for unsupported events', () => {
const wallet = createWallet();

expect(() => wallet.on('accountsChanged', jest.fn())).not.toThrow();
expect(() => wallet.on('unsupportedEvent' as any, jest.fn())).toThrow('Unsupported event: unsupportedEvent');
});
});

describe('off', () => {
it('does nothing and not throw any error', async () => {
it('removes an event handler and stops polling if no handlers remain', async () => {
mockWalletInit({});
const deleteHandlerSpy = jest.spyOn(Set.prototype, 'delete');
const stopPollingSpy = jest.spyOn(MockMetaMaskSnapWallet.prototype, 'stopPolling');

const { handlers } = await setupEventTest('accountsChanged');

expect(deleteHandlerSpy).toHaveBeenCalledTimes(handlers.length);
for (let i = 0; i < handlers.length; i++) {
expect(deleteHandlerSpy).toHaveBeenNthCalledWith(i + 1, handlers[i]);
}
expect(stopPollingSpy).toHaveBeenCalledTimes(1);
});

it('throws an error for unsupported events', () => {
const wallet = createWallet();
expect(() => wallet.off('unsupportedEvent' as any, jest.fn())).toThrow('Unsupported event: unsupportedEvent');
});
});

describe('event handling', () => {
it('dispatchs a `accountsChanged` event', async () => {
const { address: initialAddress } = generateAccount({ address: '0xInitialAddress' });
const { address: newAddress } = generateAccount({ address: '0xNewAddress' });

// The code simulates a scenario where the initial address is the default account address.
// Later, the address is changed to a new address and remains unchanged.
// - `mockResolvedValueOnce` sets the initial address as the default account address.
// - `mockResolvedValue` from `mockWalletInit` sets the new address as the new default.
const { getCurrentAccountSpy } = mockWalletInit({ address: newAddress });
getCurrentAccountSpy.mockResolvedValueOnce(generateAccount({ address: initialAddress }));

expect(() => wallet.off('accountsChanged', jest.fn())).not.toThrow();
const { handlers } = await setupEventTest('accountsChanged');

for (const handler of handlers) {
expect(handler).toHaveBeenCalledWith([newAddress], undefined);
}
});

it('dispatchs a `networkChanged` event', async () => {
// The code simulates a scenario where the MainnetNetwork is the default network.
// Later, the network is changed to SepoliaNetwork and remains unchanged.
// - `mockResolvedValueOnce` sets the MainnetNetwork as the default network.
// - `mockResolvedValue` from `mockWalletInit` sets the SepoliaNetwork as the new network.
const { address } = generateAccount({});
const { getCurrentNetworkSpy } = mockWalletInit({ currentNetwork: SepoliaNetwork, address });
getCurrentNetworkSpy.mockResolvedValueOnce(MainnetNetwork);

const { handlers } = await setupEventTest('networkChanged');

for (const handler of handlers) {
expect(handler).toHaveBeenCalledWith(SepoliaNetwork.chainId, [address]);
}
});

it.each(['accountsChanged', 'networkChanged'])(
'does not dispatchs a %s event if the wallet object is not initialized yet',
async (event: keyof WalletEventHandlers) => {
mockWalletInit({});

const { handlers } = await setupEventTest(event);

for (const handler of handlers) {
expect(handler).toHaveBeenCalledTimes(0);
}
},
);
});
});
Loading

0 comments on commit c129f1c

Please sign in to comment.