Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enable get-starknet to watch account change #505

Merged
merged 8 commits into from
Feb 6, 2025
22 changes: 18 additions & 4 deletions packages/get-starknet/src/__tests__/helper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,25 @@ export class MockProvider implements MetaMaskProvider {
request = jest.fn();
}

export class MockMetaMaskSnapWallet extends MetaMaskSnapWallet {
public pollingDelayMs = 0;

public pollingTimeoutMs = 0;

public startPolling(): void {
super.startPolling();
}

public stopPolling(): void {
super.stopPolling();
}
}

/**
* Create a wallet instance.
*/
export function createWallet() {
return new MetaMaskSnapWallet(new MockProvider());
return new MockMetaMaskSnapWallet(new MockProvider());
}

/**
Expand All @@ -101,17 +115,17 @@ export function mockWalletInit({
}) {
const installSpy = jest.spyOn(MetaMaskSnap.prototype, 'installIfNot');
const getCurrentNetworkSpy = jest.spyOn(MetaMaskSnap.prototype, 'getCurrentNetwork');
const recoverDefaultAccountSpy = jest.spyOn(MetaMaskSnap.prototype, 'recoverDefaultAccount');
const getCurrentAccountSpy = jest.spyOn(MetaMaskSnap.prototype, 'getCurrentAccount');
const initSpy = jest.spyOn(MetaMaskSnapWallet.prototype, 'init');

installSpy.mockResolvedValue(install);
getCurrentNetworkSpy.mockResolvedValue(currentNetwork);
recoverDefaultAccountSpy.mockResolvedValue(generateAccount({ address }));
getCurrentAccountSpy.mockResolvedValue(generateAccount({ address }));

return {
initSpy,
installSpy,
getCurrentNetworkSpy,
recoverDefaultAccountSpy,
getCurrentAccountSpy,
};
}
34 changes: 8 additions & 26 deletions packages/get-starknet/src/snap.ts
Original file line number Diff line number Diff line change
Expand Up @@ -229,42 +229,20 @@ export class MetaMaskSnap {
return network;
}

async recoverDefaultAccount(chainId: string): Promise<AccContract> {
const result = await this.recoverAccounts({
chainId,
startScanIndex: 0,
maxScanned: 1,
maxMissed: 1,
});
return result[0];
}

async recoverAccounts({
chainId,
startScanIndex = 0,
maxScanned = 1,
maxMissed = 1,
}: {
chainId?: string;
startScanIndex?: number;
maxScanned?: number;
maxMissed?: number;
}): Promise<AccContract[]> {
async getCurrentAccount({ chainId, fromState }: { chainId?: string; fromState?: boolean }): Promise<AccContract> {
return (await this.#provider.request({
method: 'wallet_invokeSnap',
params: {
snapId: this.#snapId,
request: {
method: 'starkNet_recoverAccounts',
method: 'starkNet_getCurrentAccount',
params: await this.#getSnapParams({
startScanIndex,
maxScanned,
maxMissed,
chainId,
fromState,
}),
},
},
})) as AccContract[];
})) as AccContract;
}

async switchNetwork(chainId: string): Promise<boolean> {
Expand Down Expand Up @@ -424,6 +402,10 @@ export class MetaMaskSnap {
}

async installIfNot(): Promise<boolean> {
// if the snap is already installed, return true, to bypass the prompt
if (await this.isInstalled()) {
return true;
}
const response = (await this.#provider.request({
method: 'wallet_requestSnaps',
params: {
Expand Down
124 changes: 109 additions & 15 deletions packages/get-starknet/src/wallet.test.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,39 @@
import { Mutex } from 'async-mutex';
import type { WalletEventHandlers } from 'get-starknet-core';
import { Provider } from 'starknet';

import { SepoliaNetwork, mockWalletInit, createWallet } from './__tests__/helper';
import {
SepoliaNetwork,
mockWalletInit,
createWallet,
generateAccount,
MainnetNetwork,
MockMetaMaskSnapWallet,
} from './__tests__/helper';
import { MetaMaskAccount } from './accounts';
import { WalletSupportedSpecs } from './rpcs';
import type { AccContract, Network } from './type';
import type { Network } from './type';

describe('MetaMaskSnapWallet', () => {
const setupEventTest = async (eventName: keyof WalletEventHandlers) => {
const handlers = [jest.fn(), jest.fn()];
const wallet = createWallet();

for (const handler of handlers) {
wallet.on(eventName, handler);
}

// Having a delay to make sure the polling is done
await new Promise((resolve) => setTimeout(resolve, 100));

// Due to polling operation is a endless loop, we need to stop it manually,
for (const handler of handlers) {
wallet.off(eventName, handler);
}

return { handlers, wallet };
};

describe('enable', () => {
it('returns an account address', async () => {
const expectedAccountAddress = '0x04882a372da3dfe1c53170ad75893832469bf87b62b13e84662565c4a88f25cd'; // in hex
Expand All @@ -19,15 +46,6 @@ describe('MetaMaskSnapWallet', () => {

expect(address).toStrictEqual(expectedAccountAddress);
});

it('throws `Unable to recover accounts` error if the account address not return from the Snap', async () => {
const { recoverDefaultAccountSpy } = mockWalletInit({});
recoverDefaultAccountSpy.mockResolvedValue({} as unknown as AccContract);

const wallet = createWallet();

await expect(wallet.enable()).rejects.toThrow('Unable to recover accounts');
});
});

describe('init', () => {
Expand Down Expand Up @@ -137,18 +155,94 @@ describe('MetaMaskSnapWallet', () => {
});

describe('on', () => {
it('does nothing and not throw any error', async () => {
it('adds an event handler and starts polling if not already started', async () => {
mockWalletInit({});
const addHandlerSpy = jest.spyOn(Set.prototype, 'add');
const startPollingSpy = jest.spyOn(MockMetaMaskSnapWallet.prototype, 'startPolling');

const { handlers } = await setupEventTest('accountsChanged');

expect(addHandlerSpy).toHaveBeenCalledTimes(handlers.length);
for (let i = 0; i < handlers.length; i++) {
expect(addHandlerSpy).toHaveBeenNthCalledWith(i + 1, handlers[i]);
}
expect(startPollingSpy).toHaveBeenCalledTimes(1);
});

it('throws an error for unsupported events', () => {
const wallet = createWallet();

expect(() => wallet.on('accountsChanged', jest.fn())).not.toThrow();
expect(() => wallet.on('unsupportedEvent' as any, jest.fn())).toThrow('Unsupported event: unsupportedEvent');
});
});

describe('off', () => {
it('does nothing and not throw any error', async () => {
it('removes an event handler and stops polling if no handlers remain', async () => {
mockWalletInit({});
const deleteHandlerSpy = jest.spyOn(Set.prototype, 'delete');
const stopPollingSpy = jest.spyOn(MockMetaMaskSnapWallet.prototype, 'stopPolling');

const { handlers } = await setupEventTest('accountsChanged');

expect(deleteHandlerSpy).toHaveBeenCalledTimes(handlers.length);
for (let i = 0; i < handlers.length; i++) {
expect(deleteHandlerSpy).toHaveBeenNthCalledWith(i + 1, handlers[i]);
}
expect(stopPollingSpy).toHaveBeenCalledTimes(1);
});

it('throws an error for unsupported events', () => {
const wallet = createWallet();
expect(() => wallet.off('unsupportedEvent' as any, jest.fn())).toThrow('Unsupported event: unsupportedEvent');
});
});

describe('event handling', () => {
it('dispatchs a `accountsChanged` event', async () => {
const { address: initialAddress } = generateAccount({ address: '0xInitialAddress' });
const { address: newAddress } = generateAccount({ address: '0xNewAddress' });

// The code simulates a scenario where the initial address is the default account address.
// Later, the address is changed to a new address and remains unchanged.
// - `mockResolvedValueOnce` sets the initial address as the default account address.
// - `mockResolvedValue` from `mockWalletInit` sets the new address as the new default.
const { getCurrentAccountSpy } = mockWalletInit({ address: newAddress });
getCurrentAccountSpy.mockResolvedValueOnce(generateAccount({ address: initialAddress }));

expect(() => wallet.off('accountsChanged', jest.fn())).not.toThrow();
const { handlers } = await setupEventTest('accountsChanged');

for (const handler of handlers) {
expect(handler).toHaveBeenCalledWith([newAddress], undefined);
}
});

it('dispatchs a `networkChanged` event', async () => {
// The code simulates a scenario where the MainnetNetwork is the default network.
// Later, the network is changed to SepoliaNetwork and remains unchanged.
// - `mockResolvedValueOnce` sets the MainnetNetwork as the default network.
// - `mockResolvedValue` from `mockWalletInit` sets the SepoliaNetwork as the new network.
const { address } = generateAccount({});
const { getCurrentNetworkSpy } = mockWalletInit({ currentNetwork: SepoliaNetwork, address });
getCurrentNetworkSpy.mockResolvedValueOnce(MainnetNetwork);

const { handlers } = await setupEventTest('networkChanged');

for (const handler of handlers) {
expect(handler).toHaveBeenCalledWith(SepoliaNetwork.chainId, [address]);
}
});

it.each(['accountsChanged', 'networkChanged'])(
'does not dispatchs a %s event if the wallet object is not initialized yet',
async (event: keyof WalletEventHandlers) => {
mockWalletInit({});

const { handlers } = await setupEventTest(event);

for (const handler of handlers) {
expect(handler).toHaveBeenCalledTimes(0);
}
},
);
});
});
Loading