diff --git a/packages/signer/signer-btc/src/internal/app-binder/command/SignPsbtCommand.test.ts b/packages/signer/signer-btc/src/internal/app-binder/command/SignPsbtCommand.test.ts index 0e2888ebb..dbc5dc74a 100644 --- a/packages/signer/signer-btc/src/internal/app-binder/command/SignPsbtCommand.test.ts +++ b/packages/signer/signer-btc/src/internal/app-binder/command/SignPsbtCommand.test.ts @@ -8,7 +8,7 @@ import { type SignPsbtCommandArgs, } from "@internal/app-binder/command/SignPsbtCommand"; -const GLOBAL_COMMITMENTS = Uint8Array.from([ +const GLOBAL_COMMITMENT = Uint8Array.from([ 0x05, 0x51, 0x9b, 0x38, 0xda, 0xe7, 0x44, 0x47, 0xb7, 0x21, 0x51, 0xf3, 0x54, 0xcb, 0x13, 0x8c, 0xa3, 0x59, 0x1a, 0x5f, 0xf8, 0xac, 0x81, 0x32, 0x89, 0xb1, 0x8a, 0x00, 0x4e, 0x31, 0x32, 0x16, 0x20, 0x3a, 0x22, 0x1f, 0x4b, 0xb9, 0x5e, @@ -16,13 +16,13 @@ const GLOBAL_COMMITMENTS = Uint8Array.from([ 0xa3, 0x43, 0x51, 0x65, 0xd3, 0xdf, 0xb7, 0x35, 0xce, 0x2d, 0xf5, 0xf5, 0x8f, ]); -const INPUTS_COMMITMENTS = Uint8Array.from([ +const INPUTS_ROOT = Uint8Array.from([ 0x01, 0x2a, 0xc8, 0xcd, 0xbc, 0x6f, 0xd6, 0x43, 0x70, 0x05, 0x56, 0x63, 0xf9, 0x50, 0x2f, 0xe3, 0x66, 0xed, 0xf8, 0x49, 0x70, 0xcc, 0x7d, 0x7e, 0xe8, 0xf6, 0xba, 0x47, 0x59, 0x9f, 0x11, 0x05, 0xc2, ]); -const OUTPUTS_COMMITMENTS = Uint8Array.from([ +const OUTPUTS_ROOT = Uint8Array.from([ 0x01, 0xd9, 0x35, 0x14, 0xd4, 0x29, 0x68, 0x8d, 0x76, 0x57, 0xc9, 0xaf, 0x0a, 0x08, 0x86, 0xac, 0x74, 0x4b, 0xd0, 0x88, 0x1c, 0x4a, 0x19, 0x10, 0xb5, 0x37, 0xfa, 0xba, 0x28, 0xcd, 0xca, 0x2e, 0x11, @@ -43,22 +43,22 @@ const SIGN_PSBT_APDU = Uint8Array.from([ 0x00, 0x01, 0xc5, - ...GLOBAL_COMMITMENTS, + ...GLOBAL_COMMITMENT, 0x01, - ...INPUTS_COMMITMENTS, + ...INPUTS_ROOT, 0x01, - ...OUTPUTS_COMMITMENTS, + ...OUTPUTS_ROOT, ...WALLET_ID, ...WALLET_HMAC, ]); describe("SignPsbtCommand", () => { const args: SignPsbtCommandArgs = { - globalCommitments: GLOBAL_COMMITMENTS, + globalCommitment: GLOBAL_COMMITMENT, inputsCount: 1, - inputsCommitments: INPUTS_COMMITMENTS, + inputsRoot: INPUTS_ROOT, outputsCount: 1, - outputsCommitments: OUTPUTS_COMMITMENTS, + outputsRoot: OUTPUTS_ROOT, walletId: WALLET_ID, walletHmac: WALLET_HMAC, }; diff --git a/packages/signer/signer-btc/src/internal/app-binder/command/SignPsbtCommand.ts b/packages/signer/signer-btc/src/internal/app-binder/command/SignPsbtCommand.ts index 93f0b5ea6..f63d970f5 100644 --- a/packages/signer/signer-btc/src/internal/app-binder/command/SignPsbtCommand.ts +++ b/packages/signer/signer-btc/src/internal/app-binder/command/SignPsbtCommand.ts @@ -18,11 +18,11 @@ import { PROTOCOL_VERSION } from "@internal/app-binder/command/utils/constants"; import { BtcCommandUtils } from "@internal/utils/BtcCommandUtils"; export type SignPsbtCommandArgs = { - globalCommitments: Uint8Array; + globalCommitment: Uint8Array; inputsCount: number; - inputsCommitments: Uint8Array; + inputsRoot: Uint8Array; outputsCount: number; - outputsCommitments: Uint8Array; + outputsRoot: Uint8Array; walletId: Uint8Array; walletHmac: Uint8Array; }; @@ -53,21 +53,21 @@ export class SignPsbtCommand p2: PROTOCOL_VERSION, }); const { - globalCommitments, + globalCommitment, inputsCount, - inputsCommitments, + inputsRoot, outputsCount, - outputsCommitments, + outputsRoot, walletHmac, walletId, } = this._args; return builder - .addBufferToData(globalCommitments) + .addBufferToData(globalCommitment) .add8BitUIntToData(inputsCount) - .addBufferToData(inputsCommitments) + .addBufferToData(inputsRoot) .add8BitUIntToData(outputsCount) - .addBufferToData(outputsCommitments) + .addBufferToData(outputsRoot) .addBufferToData(walletId) .addBufferToData(walletHmac) .build(); diff --git a/packages/signer/signer-btc/src/internal/app-binder/device-action/SignMessage/SignMessageDeviceAction.test.ts b/packages/signer/signer-btc/src/internal/app-binder/device-action/SignMessage/SignMessageDeviceAction.test.ts index d2b50da47..bcdabe776 100644 --- a/packages/signer/signer-btc/src/internal/app-binder/device-action/SignMessage/SignMessageDeviceAction.test.ts +++ b/packages/signer/signer-btc/src/internal/app-binder/device-action/SignMessage/SignMessageDeviceAction.test.ts @@ -11,6 +11,7 @@ import { type SignMessageDAState } from "@api/index"; import { makeDeviceActionInternalApiMock } from "@internal/app-binder/device-action/__test-utils__/makeInternalApi"; import { setupOpenAppDAMock } from "@internal/app-binder/device-action/__test-utils__/setupOpenAppDAMock"; import { testDeviceActionStates } from "@internal/app-binder/device-action/__test-utils__/testDeviceActionStates"; +import { type DataStoreService } from "@internal/data-store/service/DataStoreService"; import { SignMessageDeviceAction } from "./SignMessageDeviceAction"; @@ -34,7 +35,7 @@ describe("SignMessageDeviceAction", () => { }; } - beforeEach(() => { + afterEach(() => { jest.resetAllMocks(); }); @@ -46,6 +47,7 @@ describe("SignMessageDeviceAction", () => { input: { derivationPath: "44'/60'/0'/0/0", message: "Hello world", + dataStoreService: "DataStoreService" as unknown as DataStoreService, }, }); @@ -99,7 +101,6 @@ describe("SignMessageDeviceAction", () => { deviceAction, expectedStates, makeDeviceActionInternalApiMock(), - done, ); // Verify mocks calls parameters @@ -110,9 +111,11 @@ describe("SignMessageDeviceAction", () => { input: { derivationPath: "44'/60'/0'/0/0", message: "Hello world", + dataStoreService: "DataStoreService", }, }), ); + done(); }, }); }); @@ -145,6 +148,7 @@ describe("SignMessageDeviceAction", () => { input: { derivationPath: "44'/60'/0'/0/0", message: "Hello world", + dataStoreService: "DataStoreService" as unknown as DataStoreService, }, }); @@ -163,6 +167,7 @@ describe("SignMessageDeviceAction", () => { input: { derivationPath: "44'/60'/0'/0/0", message: "Hello world", + dataStoreService: "DataStoreService" as unknown as DataStoreService, }, }); @@ -217,6 +222,7 @@ describe("SignMessageDeviceAction", () => { input: { derivationPath: "44'/60'/0'/0/0", message: "Hello world", + dataStoreService: "DataStoreService" as unknown as DataStoreService, }, }); @@ -269,6 +275,7 @@ describe("SignMessageDeviceAction", () => { input: { derivationPath: "44'/60'/0'/0/0", message: "Hello world", + dataStoreService: "DataStoreService" as unknown as DataStoreService, }, }); @@ -323,6 +330,7 @@ describe("SignMessageDeviceAction", () => { input: { derivationPath: "44'/60'/0'/0/0", message: "Hello world", + dataStoreService: "DataStoreService" as unknown as DataStoreService, }, }); diff --git a/packages/signer/signer-btc/src/internal/app-binder/device-action/SignMessage/SignMessageDeviceAction.ts b/packages/signer/signer-btc/src/internal/app-binder/device-action/SignMessage/SignMessageDeviceAction.ts index 9590a15fe..6e57795d3 100644 --- a/packages/signer/signer-btc/src/internal/app-binder/device-action/SignMessage/SignMessageDeviceAction.ts +++ b/packages/signer/signer-btc/src/internal/app-binder/device-action/SignMessage/SignMessageDeviceAction.ts @@ -25,10 +25,11 @@ import { SendSignMessageTask, type SendSignMessageTaskArgs, } from "@internal/app-binder/task/SignMessageTask"; +import { type DataStoreService } from "@internal/data-store/service/DataStoreService"; export type MachineDependencies = { readonly signMessage: (arg0: { - input: SendSignMessageTaskArgs; + input: SendSignMessageTaskArgs & { dataStoreService: DataStoreService }; }) => Promise>; }; @@ -161,6 +162,7 @@ export class SignMessageDeviceAction extends XStateDeviceAction< input: ({ context }) => ({ derivationPath: context.input.derivationPath, message: context.input.message, + dataStoreService: context.input.dataStoreService, }), onDone: { target: "SignMessageResultCheck", @@ -215,8 +217,18 @@ export class SignMessageDeviceAction extends XStateDeviceAction< input: { derivationPath: string; message: string; + dataStoreService: DataStoreService; }; - }) => new SendSignMessageTask(internalApi, arg0.input).run(); + }) => { + const { + input: { derivationPath, message, dataStoreService }, + } = arg0; + return new SendSignMessageTask( + internalApi, + { derivationPath, message }, + dataStoreService, + ).run(); + }; return { signMessage, diff --git a/packages/signer/signer-btc/src/internal/app-binder/device-action/SignPsbt/SignPsbtDeviceAction.test.ts b/packages/signer/signer-btc/src/internal/app-binder/device-action/SignPsbt/SignPsbtDeviceAction.test.ts index a4b39b68b..eab133ab6 100644 --- a/packages/signer/signer-btc/src/internal/app-binder/device-action/SignPsbt/SignPsbtDeviceAction.test.ts +++ b/packages/signer/signer-btc/src/internal/app-binder/device-action/SignPsbt/SignPsbtDeviceAction.test.ts @@ -12,6 +12,13 @@ import { type RegisteredWallet } from "@api/model/Wallet"; import { makeDeviceActionInternalApiMock } from "@internal/app-binder/device-action/__test-utils__/makeInternalApi"; import { setupOpenAppDAMock } from "@internal/app-binder/device-action/__test-utils__/setupOpenAppDAMock"; import { testDeviceActionStates } from "@internal/app-binder/device-action/__test-utils__/testDeviceActionStates"; +import { type BuildPsbtTaskResult } from "@internal/app-binder/task/BuildPsbtTask"; +import { type DataStoreService } from "@internal/data-store/service/DataStoreService"; +import { type PsbtMapper } from "@internal/psbt/service/psbt/PsbtMapper"; +import { type ValueParser } from "@internal/psbt/service/value/ValueParser"; +import { type Wallet } from "@internal/wallet/model/Wallet"; +import { type WalletBuilder } from "@internal/wallet/service/WalletBuilder"; +import { type WalletSerializer } from "@internal/wallet/service/WalletSerializer"; import { SignPsbtDeviceAction } from "./SignPsbtDeviceAction"; @@ -27,26 +34,31 @@ jest.mock( ); describe("SignPsbtDeviceAction", () => { - const signPersonalPsbtMock = jest.fn(); + const signPsbtMock = jest.fn(); + const prepareWalletPolicyMock = jest.fn(); + const buildPsbtMock = jest.fn(); function extractDependenciesMock() { return { - signPsbt: signPersonalPsbtMock, + signPsbt: signPsbtMock, + prepareWalletPolicy: prepareWalletPolicyMock, + buildPsbt: buildPsbtMock, }; } - beforeEach(() => { - jest.resetAllMocks(); - }); - describe("Success case", () => { it("should call external dependencies with the correct parameters", (done) => { setupOpenAppDAMock(); const deviceAction = new SignPsbtDeviceAction({ input: { - wallet: {} as unknown as RegisteredWallet, + wallet: "ApiWallet" as unknown as RegisteredWallet, psbt: "Hello world", + walletBuilder: "WalletBuilder" as unknown as WalletBuilder, + walletSerializer: "WalletSerializer" as unknown as WalletSerializer, + dataStoreService: "DataStoreService" as unknown as DataStoreService, + psbtMapper: "PsbtMapper" as unknown as PsbtMapper, + valueParser: "ValueParser" as unknown as ValueParser, }, }); @@ -54,14 +66,30 @@ describe("SignPsbtDeviceAction", () => { jest .spyOn(deviceAction, "extractDependencies") .mockReturnValue(extractDependenciesMock()); - signPersonalPsbtMock.mockResolvedValueOnce( + prepareWalletPolicyMock.mockResolvedValueOnce( + CommandResultFactory({ + data: "Wallet" as unknown as Wallet, + }), + ); + buildPsbtMock.mockResolvedValueOnce( CommandResultFactory({ - data: [Uint8Array.from([0x01, 0x02, 0x03])], + data: "BuildPsbtResult" as unknown as BuildPsbtTaskResult, + }), + ); + signPsbtMock.mockResolvedValueOnce( + CommandResultFactory({ + data: [ + { + inputIndex: 0, + pubKeyAugmented: Uint8Array.from([0x04, 0x05, 0x06]), + signature: Uint8Array.from([0x01, 0x02, 0x03]), + }, + ], }), ); // Expected intermediate values for the following state sequence: - // Initial -> OpenApp -> BuildContext -> ProvideContext -> SignTypedData + // Initial -> OpenApp -> PrepareWalletPolicy -> BuildPsbt -> SignPsbt const expectedStates: Array = [ { intermediateValue: { @@ -75,6 +103,18 @@ describe("SignPsbtDeviceAction", () => { }, status: DeviceActionStatus.Pending, }, + { + intermediateValue: { + requiredUserInteraction: UserInteractionRequired.None, + }, + status: DeviceActionStatus.Pending, + }, + { + intermediateValue: { + requiredUserInteraction: UserInteractionRequired.None, + }, + status: DeviceActionStatus.Pending, + }, { intermediateValue: { requiredUserInteraction: UserInteractionRequired.SignTransaction, @@ -82,7 +122,13 @@ describe("SignPsbtDeviceAction", () => { status: DeviceActionStatus.Pending, }, { - output: [Uint8Array.from([0x01, 0x02, 0x03])], + output: [ + { + inputIndex: 0, + pubKeyAugmented: Uint8Array.from([0x04, 0x05, 0x06]), + signature: Uint8Array.from([0x01, 0x02, 0x03]), + }, + ], status: DeviceActionStatus.Completed, }, ]; @@ -91,27 +137,47 @@ describe("SignPsbtDeviceAction", () => { deviceAction, expectedStates, makeDeviceActionInternalApiMock(), - done, ); // Verify mocks calls parameters observable.subscribe({ complete: () => { - expect(signPersonalPsbtMock).toHaveBeenCalledWith( + expect(prepareWalletPolicyMock).toHaveBeenCalledWith( + expect.objectContaining({ + input: { wallet: "ApiWallet", walletBuilder: "WalletBuilder" }, + }), + ); + expect(buildPsbtMock).toHaveBeenCalledWith( expect.objectContaining({ input: { - wallet: {} as unknown as RegisteredWallet, psbt: "Hello world", + wallet: "Wallet", + dataStoreService: "DataStoreService", + psbtMapper: "PsbtMapper", }, }), ); + expect(signPsbtMock).toHaveBeenCalledWith( + expect.objectContaining({ + input: { + wallet: "Wallet", + buildPsbtResult: "BuildPsbtResult", + walletSerializer: "WalletSerializer", + valueParser: "ValueParser", + }, + }), + ); + done(); }, }); }); }); describe("error cases", () => { - it("Error if the open app fails", (done) => { + beforeEach(() => { + jest.resetAllMocks(); + }); + it("Error if open app fails", (done) => { setupOpenAppDAMock(new UnknownDeviceExchangeError("Mocked error")); const expectedStates: Array = [ @@ -137,6 +203,11 @@ describe("SignPsbtDeviceAction", () => { input: { wallet: {} as unknown as RegisteredWallet, psbt: "Hello world", + walletBuilder: {} as WalletBuilder, + walletSerializer: {} as WalletSerializer, + dataStoreService: {} as DataStoreService, + psbtMapper: {} as PsbtMapper, + valueParser: {} as ValueParser, }, }); @@ -148,13 +219,18 @@ describe("SignPsbtDeviceAction", () => { ); }); - it("Error if the signPsbt fails", (done) => { + it("Error if prepareWallet fails", (done) => { setupOpenAppDAMock(); const deviceAction = new SignPsbtDeviceAction({ input: { wallet: {} as unknown as RegisteredWallet, psbt: "Hello world", + walletBuilder: {} as WalletBuilder, + walletSerializer: {} as WalletSerializer, + dataStoreService: {} as DataStoreService, + psbtMapper: {} as PsbtMapper, + valueParser: {} as ValueParser, }, }); @@ -162,7 +238,7 @@ describe("SignPsbtDeviceAction", () => { jest .spyOn(deviceAction, "extractDependencies") .mockReturnValue(extractDependenciesMock()); - signPersonalPsbtMock.mockResolvedValueOnce( + prepareWalletPolicyMock.mockResolvedValueOnce( CommandResultFactory({ error: new UnknownDeviceExchangeError("Mocked error"), }), @@ -184,7 +260,7 @@ describe("SignPsbtDeviceAction", () => { { status: DeviceActionStatus.Pending, intermediateValue: { - requiredUserInteraction: UserInteractionRequired.SignTransaction, + requiredUserInteraction: UserInteractionRequired.None, }, }, { @@ -201,13 +277,18 @@ describe("SignPsbtDeviceAction", () => { ); }); - it("Error if the signPsbt throws an exception", (done) => { + it("Error if buildPsbt fails", (done) => { setupOpenAppDAMock(); const deviceAction = new SignPsbtDeviceAction({ input: { wallet: {} as unknown as RegisteredWallet, psbt: "Hello world", + walletBuilder: {} as WalletBuilder, + walletSerializer: {} as WalletSerializer, + dataStoreService: {} as DataStoreService, + psbtMapper: {} as PsbtMapper, + valueParser: {} as ValueParser, }, }); @@ -215,8 +296,15 @@ describe("SignPsbtDeviceAction", () => { jest .spyOn(deviceAction, "extractDependencies") .mockReturnValue(extractDependenciesMock()); - signPersonalPsbtMock.mockRejectedValueOnce( - new InvalidStatusWordError("Mocked error"), + prepareWalletPolicyMock.mockResolvedValueOnce( + CommandResultFactory({ + data: {} as Wallet, + }), + ); + buildPsbtMock.mockResolvedValueOnce( + CommandResultFactory({ + error: new UnknownDeviceExchangeError("Mocked error"), + }), ); const expectedStates: Array = [ @@ -235,12 +323,18 @@ describe("SignPsbtDeviceAction", () => { { status: DeviceActionStatus.Pending, intermediateValue: { - requiredUserInteraction: UserInteractionRequired.SignTransaction, + requiredUserInteraction: UserInteractionRequired.None, + }, + }, + { + status: DeviceActionStatus.Pending, + intermediateValue: { + requiredUserInteraction: UserInteractionRequired.None, }, }, { status: DeviceActionStatus.Error, - error: new InvalidStatusWordError("Mocked error"), + error: new UnknownDeviceExchangeError("Mocked error"), }, ]; @@ -252,13 +346,18 @@ describe("SignPsbtDeviceAction", () => { ); }); - it("Error if signPsbt return an error", (done) => { + it("Error if signPsbt fails", (done) => { setupOpenAppDAMock(); const deviceAction = new SignPsbtDeviceAction({ input: { wallet: {} as unknown as RegisteredWallet, psbt: "Hello world", + walletBuilder: {} as WalletBuilder, + walletSerializer: {} as WalletSerializer, + dataStoreService: {} as DataStoreService, + psbtMapper: {} as PsbtMapper, + valueParser: {} as ValueParser, }, }); @@ -266,7 +365,17 @@ describe("SignPsbtDeviceAction", () => { jest .spyOn(deviceAction, "extractDependencies") .mockReturnValue(extractDependenciesMock()); - signPersonalPsbtMock.mockResolvedValueOnce( + prepareWalletPolicyMock.mockResolvedValueOnce( + CommandResultFactory({ + data: {} as Wallet, + }), + ); + buildPsbtMock.mockResolvedValueOnce( + CommandResultFactory({ + data: {} as BuildPsbtTaskResult, + }), + ); + signPsbtMock.mockResolvedValueOnce( CommandResultFactory({ error: new UnknownDeviceExchangeError("Mocked error"), }), @@ -285,6 +394,18 @@ describe("SignPsbtDeviceAction", () => { requiredUserInteraction: UserInteractionRequired.ConfirmOpenApp, }, }, + { + status: DeviceActionStatus.Pending, + intermediateValue: { + requiredUserInteraction: UserInteractionRequired.None, + }, + }, + { + status: DeviceActionStatus.Pending, + intermediateValue: { + requiredUserInteraction: UserInteractionRequired.None, + }, + }, { status: DeviceActionStatus.Pending, intermediateValue: { @@ -305,6 +426,80 @@ describe("SignPsbtDeviceAction", () => { ); }); + it("Error if signPsbt throws an exception", (done) => { + setupOpenAppDAMock(); + + const deviceAction = new SignPsbtDeviceAction({ + input: { + wallet: {} as unknown as RegisteredWallet, + psbt: "Hello world", + walletBuilder: {} as WalletBuilder, + walletSerializer: {} as WalletSerializer, + dataStoreService: {} as DataStoreService, + psbtMapper: {} as PsbtMapper, + valueParser: {} as ValueParser, + }, + }); + + // Mock the dependencies to return some sample data + jest + .spyOn(deviceAction, "extractDependencies") + .mockReturnValue(extractDependenciesMock()); + prepareWalletPolicyMock.mockResolvedValueOnce( + CommandResultFactory({ data: {} as Wallet }), + ); + buildPsbtMock.mockResolvedValueOnce( + CommandResultFactory({ data: {} as BuildPsbtTaskResult }), + ); + signPsbtMock.mockRejectedValueOnce( + new InvalidStatusWordError("Mocked error"), + ); + + const expectedStates: Array = [ + { + status: DeviceActionStatus.Pending, + intermediateValue: { + requiredUserInteraction: UserInteractionRequired.None, + }, + }, + { + status: DeviceActionStatus.Pending, + intermediateValue: { + requiredUserInteraction: UserInteractionRequired.ConfirmOpenApp, + }, + }, + { + status: DeviceActionStatus.Pending, + intermediateValue: { + requiredUserInteraction: UserInteractionRequired.None, + }, + }, + { + status: DeviceActionStatus.Pending, + intermediateValue: { + requiredUserInteraction: UserInteractionRequired.None, + }, + }, + { + status: DeviceActionStatus.Pending, + intermediateValue: { + requiredUserInteraction: UserInteractionRequired.SignTransaction, + }, + }, + { + status: DeviceActionStatus.Error, + error: new InvalidStatusWordError("Mocked error"), + }, + ]; + + testDeviceActionStates( + deviceAction, + expectedStates, + makeDeviceActionInternalApiMock(), + done, + ); + }); + it("Return a Left if the final state has no signature", (done) => { setupOpenAppDAMock(); @@ -312,6 +507,11 @@ describe("SignPsbtDeviceAction", () => { input: { wallet: {} as unknown as RegisteredWallet, psbt: "Hello world", + walletBuilder: {} as WalletBuilder, + walletSerializer: {} as WalletSerializer, + dataStoreService: {} as DataStoreService, + psbtMapper: {} as PsbtMapper, + valueParser: {} as ValueParser, }, }); @@ -319,7 +519,17 @@ describe("SignPsbtDeviceAction", () => { jest .spyOn(deviceAction, "extractDependencies") .mockReturnValue(extractDependenciesMock()); - signPersonalPsbtMock.mockResolvedValueOnce( + prepareWalletPolicyMock.mockResolvedValueOnce( + CommandResultFactory({ + data: {} as Wallet, + }), + ); + buildPsbtMock.mockResolvedValueOnce( + CommandResultFactory({ + data: {} as BuildPsbtTaskResult, + }), + ); + signPsbtMock.mockResolvedValueOnce( CommandResultFactory({ data: undefined, }), @@ -338,6 +548,18 @@ describe("SignPsbtDeviceAction", () => { requiredUserInteraction: UserInteractionRequired.ConfirmOpenApp, }, }, + { + status: DeviceActionStatus.Pending, + intermediateValue: { + requiredUserInteraction: UserInteractionRequired.None, + }, + }, + { + status: DeviceActionStatus.Pending, + intermediateValue: { + requiredUserInteraction: UserInteractionRequired.None, + }, + }, { status: DeviceActionStatus.Pending, intermediateValue: { diff --git a/packages/signer/signer-btc/src/internal/app-binder/device-action/SignPsbt/SignPsbtDeviceAction.ts b/packages/signer/signer-btc/src/internal/app-binder/device-action/SignPsbt/SignPsbtDeviceAction.ts index 53aac18e5..1216b1c26 100644 --- a/packages/signer/signer-btc/src/internal/app-binder/device-action/SignPsbt/SignPsbtDeviceAction.ts +++ b/packages/signer/signer-btc/src/internal/app-binder/device-action/SignPsbt/SignPsbtDeviceAction.ts @@ -19,15 +19,48 @@ import { type SignPsbtDAInternalState, type SignPsbtDAOutput, } from "@api/app-binder/SignPsbtDeviceActionTypes"; -import { type Psbt } from "@api/model/Psbt"; +import { type Psbt as ApiPsbt } from "@api/model/Psbt"; import { type Wallet as ApiWallet } from "@api/model/Wallet"; import { type BtcErrorCodes } from "@internal/app-binder/command/utils/bitcoinAppErrors"; -import { SignPsbtTask } from "@internal/app-binder/task/SignPsbtTask"; +import { + BuildPsbtTask, + type BuildPsbtTaskResult, +} from "@internal/app-binder/task/BuildPsbtTask"; +import { PrepareWalletPolicyTask } from "@internal/app-binder/task/PrepareWalletPolicyTask"; +import { + type PsbtSignature, + SignPsbtTask, +} from "@internal/app-binder/task/SignPsbtTask"; +import type { DataStoreService } from "@internal/data-store/service/DataStoreService"; +import type { PsbtMapper } from "@internal/psbt/service/psbt/PsbtMapper"; +import type { ValueParser } from "@internal/psbt/service/value/ValueParser"; +import { type Wallet as InternalWallet } from "@internal/wallet/model/Wallet"; +import { type WalletBuilder } from "@internal/wallet/service/WalletBuilder"; +import { type WalletSerializer } from "@internal/wallet/service/WalletSerializer"; export type MachineDependencies = { + readonly prepareWalletPolicy: (arg0: { + input: { + wallet: ApiWallet; + walletBuilder: WalletBuilder; + }; + }) => Promise>; + readonly buildPsbt: (arg0: { + input: { + psbt: ApiPsbt; + wallet: InternalWallet; + dataStoreService: DataStoreService; + psbtMapper: PsbtMapper; + }; + }) => Promise>; readonly signPsbt: (arg0: { - input: { wallet: ApiWallet; psbt: Psbt }; - }) => Promise>; + input: { + wallet: InternalWallet; + buildPsbtResult: BuildPsbtTaskResult; + walletSerializer: WalletSerializer; + valueParser: ValueParser; + }; + }) => Promise>; }; export type ExtractMachineDependencies = ( @@ -61,7 +94,8 @@ export class SignPsbtDeviceAction extends XStateDeviceAction< SignPsbtDAInternalState >; - const { signPsbt } = this.extractDependencies(internalApi); + const { signPsbt, prepareWalletPolicy, buildPsbt } = + this.extractDependencies(internalApi); return setup({ types: { @@ -74,6 +108,8 @@ export class SignPsbtDeviceAction extends XStateDeviceAction< openAppStateMachine: new OpenAppDeviceAction({ input: { appName: "Bitcoin" }, }).makeStateMachine(internalApi), + prepareWalletPolicy: fromPromise(prepareWalletPolicy), + buildPsbt: fromPromise(buildPsbt), signPsbt: fromPromise(signPsbt), }, guards: { @@ -99,8 +135,10 @@ export class SignPsbtDeviceAction extends XStateDeviceAction< }, _internalState: { error: null, - signature: null, wallet: null, + buildPsbtResult: null, + signatures: null, + signedPsbt: null, }, }; }, @@ -140,12 +178,92 @@ export class SignPsbtDeviceAction extends XStateDeviceAction< CheckOpenAppDeviceActionResult: { always: [ { - target: "SignPsbt", + target: "PrepareWalletPolicy", guard: "noInternalError", }, "Error", ], }, + PrepareWalletPolicy: { + invoke: { + id: "prepareWalletPolicy", + src: "prepareWalletPolicy", + input: ({ context }) => ({ + wallet: context.input.wallet, + walletBuilder: context.input.walletBuilder, + }), + onDone: { + target: "PrepareWalletPolicyResultCheck", + actions: [ + assign({ + _internalState: ({ event, context }) => { + if (isSuccessCommandResult(event.output)) { + return { + ...context._internalState, + wallet: event.output.data, + }; + } + return { + ...context._internalState, + error: event.output.error, + }; + }, + }), + ], + }, + onError: { + target: "Error", + actions: "assignErrorFromEvent", + }, + }, + }, + PrepareWalletPolicyResultCheck: { + always: [ + { guard: "noInternalError", target: "BuildPsbt" }, + { target: "Error" }, + ], + }, + BuildPsbt: { + invoke: { + id: "buildPsbt", + src: "buildPsbt", + input: ({ context }) => ({ + psbt: context.input.psbt, + wallet: context._internalState.wallet!, + dataStoreService: context.input.dataStoreService, + psbtMapper: context.input.psbtMapper, + }), + onDone: { + target: "BuildPsbtResultCheck", + actions: [ + assign({ + _internalState: ({ event, context }) => { + if (isSuccessCommandResult(event.output)) { + return { + ...context._internalState, + buildPsbtResult: event.output.data, + }; + } + return { + ...context._internalState, + error: event.output.error, + }; + }, + }), + ], + }, + onError: { + target: "Error", + actions: "assignErrorFromEvent", + }, + }, + }, + BuildPsbtResultCheck: { + always: [ + { guard: "noInternalError", target: "SignPsbt" }, + { target: "Error" }, + ], + }, SignPsbt: { entry: assign({ intermediateValue: { @@ -161,8 +279,10 @@ export class SignPsbtDeviceAction extends XStateDeviceAction< id: "signPsbt", src: "signPsbt", input: ({ context }) => ({ - psbt: context.input.psbt, - wallet: context.input.wallet, + walletSerializer: context.input.walletSerializer, + valueParser: context.input.valueParser, + buildPsbtResult: context._internalState.buildPsbtResult!, + wallet: context._internalState.wallet!, }), onDone: { target: "SignPsbtResultCheck", @@ -172,7 +292,7 @@ export class SignPsbtDeviceAction extends XStateDeviceAction< if (isSuccessCommandResult(event.output)) { return { ...context._internalState, - signature: event.output.data, + signatures: event.output.data, }; } return { @@ -202,23 +322,69 @@ export class SignPsbtDeviceAction extends XStateDeviceAction< type: "final", }, }, - output: ({ context }) => - context._internalState.signature - ? Right(context._internalState.signature) - : Left( - context._internalState.error || - new UnknownDAError("No error in final state"), - ), + output: ({ + context: { + _internalState: { signatures, error }, + }, + }) => + signatures + ? Right(signatures) + : Left(error || new UnknownDAError("No error in final state")), }); } extractDependencies(internalApi: InternalApi): MachineDependencies { + const prepareWalletPolicy = async (arg0: { + input: { wallet: ApiWallet; walletBuilder: WalletBuilder }; + }): Promise> => { + const { + input: { walletBuilder, wallet }, + } = arg0; + return await new PrepareWalletPolicyTask( + internalApi, + { wallet }, + walletBuilder, + ).run(); + }; + const buildPsbt = async (arg0: { + input: { + psbt: ApiPsbt; + wallet: InternalWallet; + dataStoreService: DataStoreService; + psbtMapper: PsbtMapper; + }; + }): Promise> => { + const { + input: { psbt, wallet, dataStoreService, psbtMapper }, + } = arg0; + return new BuildPsbtTask( + { psbt, wallet }, + dataStoreService, + psbtMapper, + ).run(); + }; const signPsbt = async (arg0: { - input: { wallet: ApiWallet; psbt: Psbt }; - }): Promise> => { - return await new SignPsbtTask(internalApi, arg0.input).run(); + input: { + wallet: InternalWallet; + buildPsbtResult: BuildPsbtTaskResult; + walletSerializer: WalletSerializer; + valueParser: ValueParser; + }; + }): Promise> => { + const { + input: { wallet, buildPsbtResult, walletSerializer, valueParser }, + } = arg0; + return await new SignPsbtTask( + internalApi, + { wallet, ...buildPsbtResult }, + walletSerializer, + valueParser, + ).run(); }; + return { + prepareWalletPolicy, + buildPsbt, signPsbt, }; } diff --git a/packages/signer/signer-btc/src/internal/app-binder/task/BuildPsbtTask.test.ts b/packages/signer/signer-btc/src/internal/app-binder/task/BuildPsbtTask.test.ts index b8657e3ed..c6d0f377c 100644 --- a/packages/signer/signer-btc/src/internal/app-binder/task/BuildPsbtTask.test.ts +++ b/packages/signer/signer-btc/src/internal/app-binder/task/BuildPsbtTask.test.ts @@ -1,13 +1,16 @@ import { CommandResultFactory, - isSuccessCommandResult, UnknownDeviceExchangeError, } from "@ledgerhq/device-management-kit"; -import { Left, Nothing, Right } from "purify-ts"; +import { Left, Right } from "purify-ts"; import { type Psbt } from "@api/model/Psbt"; import { BuildPsbtTask } from "@internal/app-binder/task/BuildPsbtTask"; -import { type PsbtCommitment } from "@internal/data-store/service/DataStoreService"; +import { DataStore } from "@internal/data-store/model/DataStore"; +import { + type DataStoreService, + type PsbtCommitment, +} from "@internal/data-store/service/DataStoreService"; import { type Psbt as InternalPsbt } from "@internal/psbt/model/Psbt"; import { type Wallet } from "@internal/wallet/model/Wallet"; @@ -15,31 +18,45 @@ describe("BuildPsbtTask", () => { it("should build psbt and fill datastore", async () => { // given const psbtMapper = { - map: jest.fn(() => - Right({ - getGlobalValue: jest.fn(() => Nothing), - } as unknown as InternalPsbt), - ), + map: jest.fn(() => Right("InternalPsbt" as unknown as InternalPsbt)), }; const dataStoreService = { merklizeWallet: jest.fn(), - merklizePsbt: jest.fn(() => Right({} as PsbtCommitment)), - merklizeChunks: jest.fn(), - }; + merklizePsbt: jest.fn(() => + Right("PsbtCommitment" as unknown as PsbtCommitment), + ), + } as unknown as DataStoreService; + const dataStore = new DataStore(); const task = new BuildPsbtTask( { - wallet: {} as unknown as Wallet, - psbt: { - getGlobalValue: jest.fn(), - } as unknown as Psbt, + wallet: "Wallet" as unknown as Wallet, + psbt: "ApiPsbt" as unknown as Psbt, }, - psbtMapper, dataStoreService, + psbtMapper, + () => dataStore, ); // when const result = await task.run(); // then - expect(isSuccessCommandResult(result)).toBe(true); + expect(psbtMapper.map).toHaveBeenCalledWith("ApiPsbt"); + expect(dataStoreService.merklizePsbt).toHaveBeenCalledWith( + dataStore, + "InternalPsbt", + ); + expect(dataStoreService.merklizeWallet).toHaveBeenCalledWith( + dataStore, + "Wallet", + ); + expect(result).toStrictEqual( + CommandResultFactory({ + data: { + psbtCommitment: "PsbtCommitment", + dataStore, + psbt: "InternalPsbt", + }, + }), + ); }); it("should return an error if datastore fails", async () => { // given @@ -57,8 +74,8 @@ describe("BuildPsbtTask", () => { wallet: {} as unknown as Wallet, psbt: {} as unknown as Psbt, }, - psbtMapper, dataStoreService, + psbtMapper, ); // when const result = await task.run(); @@ -85,8 +102,8 @@ describe("BuildPsbtTask", () => { wallet: {} as unknown as Wallet, psbt: {} as unknown as Psbt, }, - psbtMapper, dataStoreService, + psbtMapper, ); // when const result = await task.run(); diff --git a/packages/signer/signer-btc/src/internal/app-binder/task/BuildPsbtTask.ts b/packages/signer/signer-btc/src/internal/app-binder/task/BuildPsbtTask.ts index 1b97bf821..198132eef 100644 --- a/packages/signer/signer-btc/src/internal/app-binder/task/BuildPsbtTask.ts +++ b/packages/signer/signer-btc/src/internal/app-binder/task/BuildPsbtTask.ts @@ -12,75 +12,29 @@ import { type DataStoreService, type PsbtCommitment, } from "@internal/data-store/service/DataStoreService"; -import { DefaultDataStoreService } from "@internal/data-store/service/DefaultDataStoreService"; -import { MerkleMapBuilder } from "@internal/merkle-tree/service/MerkleMapBuilder"; -import { MerkleTreeBuilder } from "@internal/merkle-tree/service/MerkleTreeBuilder"; -import { Sha256HasherService } from "@internal/merkle-tree/service/Sha256HasherService"; -import { - type Psbt as InternalPsbt, - PsbtGlobal, -} from "@internal/psbt/model/Psbt"; -import { DefaultKeySerializer } from "@internal/psbt/service/key/DefaultKeySerializer"; -import { DefaultKeyPairSerializer } from "@internal/psbt/service/key-pair/DefaultKeyPairSerializer"; -import { DefaultPsbtMapper } from "@internal/psbt/service/psbt/DefaultPsbtMapper"; -import { DefaultPsbtSerializer } from "@internal/psbt/service/psbt/DefaultPsbtSerializer"; -import { DefaultPsbtV2Normalizer } from "@internal/psbt/service/psbt/DefaultPsbtV2Normalizer"; +import { type Psbt as InternalPsbt } from "@internal/psbt/model/Psbt"; import type { PsbtMapper } from "@internal/psbt/service/psbt/PsbtMapper"; -import { DefaultValueFactory } from "@internal/psbt/service/value/DefaultValueFactory"; -import { DefaultValueParser } from "@internal/psbt/service/value/DefaultValueParser"; -import { type ValueParser } from "@internal/psbt/service/value/ValueParser"; import { type Wallet } from "@internal/wallet/model/Wallet"; -import { DefaultWalletSerializer } from "@internal/wallet/service/DefaultWalletSerializer"; -type BuildPsbtTaskResponse = { +export type BuildPsbtTaskResult = { psbtCommitment: PsbtCommitment; dataStore: DataStore; - inputsCount: number; - outputsCount: number; + psbt: InternalPsbt; }; export class BuildPsbtTask { - private readonly _dataStoreService: DataStoreService; - private readonly _psbtMapper: PsbtMapper; - private readonly _valueParser: ValueParser; - constructor( private readonly _args: { wallet: Wallet; psbt: Psbt; }, - psbtMapper?: PsbtMapper, - dataStoreService?: DataStoreService, - ) { - this._valueParser = new DefaultValueParser(); - const merkleTreeBuilder = new MerkleTreeBuilder(new Sha256HasherService()); - const merkleMapBuilder = new MerkleMapBuilder(merkleTreeBuilder); - const hasher = new Sha256HasherService(); - - this._psbtMapper = - psbtMapper || - new DefaultPsbtMapper( - new DefaultPsbtSerializer( - this._valueParser, - new DefaultKeyPairSerializer(new DefaultKeySerializer()), - ), - new DefaultPsbtV2Normalizer( - this._valueParser, - new DefaultValueFactory(), - ), - ); - this._dataStoreService = - dataStoreService || - new DefaultDataStoreService( - merkleTreeBuilder, - merkleMapBuilder, - new DefaultWalletSerializer(hasher), - hasher, - ); - } + private readonly _dataStoreService: DataStoreService, + private readonly _psbtMapper: PsbtMapper, + private readonly _dataStoreFactory = () => new DataStore(), + ) {} - async run(): Promise> { - const dataStore = new DataStore(); + async run(): Promise> { + const dataStore = this._dataStoreFactory(); let psbt: InternalPsbt; return await EitherAsync(async ({ liftEither }) => { // map the input PSBT (V1 or V2, string or byte array) into a normalized and parsed PSBTv2 @@ -99,18 +53,7 @@ export class BuildPsbtTask { data: { psbtCommitment, dataStore, - inputsCount: psbt - .getGlobalValue(PsbtGlobal.INPUT_COUNT) - .mapOrDefault( - (value) => this._valueParser.getVarint(value.data).orDefault(0), - 0, - ), - outputsCount: psbt - .getGlobalValue(PsbtGlobal.OUTPUT_COUNT) - .mapOrDefault( - (value) => this._valueParser.getVarint(value.data).orDefault(0), - 0, - ), + psbt, }, }); }, diff --git a/packages/signer/signer-btc/src/internal/app-binder/task/ContinueTask.ts b/packages/signer/signer-btc/src/internal/app-binder/task/ContinueTask.ts index a155a86e6..c1a462c97 100644 --- a/packages/signer/signer-btc/src/internal/app-binder/task/ContinueTask.ts +++ b/packages/signer/signer-btc/src/internal/app-binder/task/ContinueTask.ts @@ -19,21 +19,18 @@ import { type DataStore } from "@internal/data-store/model/DataStore"; import { BtcCommandUtils } from "@internal/utils/BtcCommandUtils"; export class ContinueTask { - private readonly _clientCommandInterpreter: ClientCommandInterpreter; private readonly _context: ClientCommandContext; constructor( private readonly _api: InternalApi, dataStore: DataStore, - clientCommandInterpreter?: ClientCommandInterpreter, + private readonly _clientCommandInterpreter = new ClientCommandInterpreter(), ) { this._context = { dataStore, queue: [], yieldedResults: [], }; - this._clientCommandInterpreter = - clientCommandInterpreter || new ClientCommandInterpreter(); } async run( diff --git a/packages/signer/signer-btc/src/internal/app-binder/task/PrepareWalletPolicyTask.ts b/packages/signer/signer-btc/src/internal/app-binder/task/PrepareWalletPolicyTask.ts index 898f7641c..d127d90df 100644 --- a/packages/signer/signer-btc/src/internal/app-binder/task/PrepareWalletPolicyTask.ts +++ b/packages/signer/signer-btc/src/internal/app-binder/task/PrepareWalletPolicyTask.ts @@ -11,27 +11,17 @@ import { import { GetExtendedPublicKeyCommand } from "@internal/app-binder/command/GetExtendedPublicKeyCommand"; import { GetMasterFingerprintCommand } from "@internal/app-binder/command/GetMasterFingerprintCommand"; import { type BtcErrorCodes } from "@internal/app-binder/command/utils/bitcoinAppErrors"; -import { MerkleTreeBuilder } from "@internal/merkle-tree/service/MerkleTreeBuilder"; -import { Sha256HasherService } from "@internal/merkle-tree/service/Sha256HasherService"; import { type Wallet as InternalWallet } from "@internal/wallet/model/Wallet"; -import { DefaultWalletBuilder } from "@internal/wallet/service/DefaultWalletBuilder"; import { type WalletBuilder } from "@internal/wallet/service/WalletBuilder"; export type PrepareWalletPolicyTaskArgs = { wallet: ApiWallet }; export class PrepareWalletPolicyTask { - private readonly _walletBuilder: WalletBuilder; constructor( private readonly _api: InternalApi, private readonly _args: PrepareWalletPolicyTaskArgs, - walletBuilder?: WalletBuilder, - ) { - this._walletBuilder = - walletBuilder || - new DefaultWalletBuilder( - new MerkleTreeBuilder(new Sha256HasherService()), - ); - } + private readonly _walletBuilder: WalletBuilder, + ) {} private isDefaultWallet(wallet: ApiWallet): wallet is DefaultWallet { return "derivationPath" in wallet; diff --git a/packages/signer/signer-btc/src/internal/app-binder/task/SignMessageTask.test.ts b/packages/signer/signer-btc/src/internal/app-binder/task/SignMessageTask.test.ts index 0474ca51e..7585fb3ef 100644 --- a/packages/signer/signer-btc/src/internal/app-binder/task/SignMessageTask.test.ts +++ b/packages/signer/signer-btc/src/internal/app-binder/task/SignMessageTask.test.ts @@ -1,32 +1,24 @@ import { ApduResponse, CommandResultFactory, - CommandResultStatus, type InternalApi, InvalidStatusWordError, } from "@ledgerhq/device-management-kit"; -import { Left, Right } from "purify-ts"; import { type Signature } from "@api/model/Signature"; -import type { ClientCommandContext } from "@internal/app-binder/command/client-command-handlers/ClientCommandHandlersTypes"; -import { ClientCommandHandlerError } from "@internal/app-binder/command/client-command-handlers/Errors"; -import { ContinueCommand } from "@internal/app-binder/command/ContinueCommand"; -import { ClientCommandInterpreter } from "@internal/app-binder/command/service/ClientCommandInterpreter"; -import { SignMessageCommand } from "@internal/app-binder/command/SignMessageCommand"; import { CHUNK_SIZE, - ClientCommandCodes, SHA256_SIZE, - SW_INTERRUPTED_EXECUTION, } from "@internal/app-binder/command/utils/constants"; -import { DefaultDataStoreService } from "@internal/data-store/service/DefaultDataStoreService"; +import { type ContinueTask } from "@internal/app-binder/task/ContinueTask"; +import { DataStore } from "@internal/data-store/model/DataStore"; +import { type DataStoreService } from "@internal/data-store/service/DataStoreService"; import { SendSignMessageTask } from "./SignMessageTask"; const EXACT_ONE_CHUNK_MESSAGE = "a".repeat(CHUNK_SIZE); const EXACT_TWO_CHUNKS_MESSAGE = "a".repeat(CHUNK_SIZE * 2); const DERIVATION_PATH = "44'/0'/0'/0/0"; -const PREIMAGE = new Uint8Array([1, 2, 3, 4]); const MERKLE_ROOT = new Uint8Array(SHA256_SIZE).fill(0x01); const SIGNATURE: Signature = { @@ -35,11 +27,6 @@ const SIGNATURE: Signature = { s: "0x6950d02e74e9c102c164a225533082cabdd890efc463f67f60cefe8c3f87cfce", }; -const APDU_RESPONSE_YELD: ApduResponse = { - statusCode: SW_INTERRUPTED_EXECUTION, - data: new Uint8Array([ClientCommandCodes.YIELD]), -}; - const SIGNATURE_APDU = new Uint8Array([ 0x1b, 0x97, 0xa4, 0xca, 0x8f, 0x69, 0x46, 0x33, 0x59, 0x26, 0x01, 0xf5, 0xa2, 0x3e, 0x0b, 0xcc, 0x55, 0x3c, 0x9d, 0x0a, 0x90, 0xd3, 0xa3, 0x42, 0x2d, 0x57, @@ -71,27 +58,28 @@ describe("SignMessageTask", () => { message: EXACT_ONE_CHUNK_MESSAGE, }; - jest - .spyOn(DefaultDataStoreService.prototype, "merklizeChunks") - .mockImplementation((_, chunks) => { - expect(chunks.length).toBe(1); - return MERKLE_ROOT; - }); + const dataStoreService = { + merklizeChunks: jest.fn().mockReturnValue(MERKLE_ROOT), + } as unknown as DataStoreService; - (apiMock.sendCommand as jest.Mock).mockResolvedValueOnce( - CommandResultFactory({ - data: new ApduResponse({ - data: SIGNATURE_APDU, - statusCode: new Uint8Array([0x90, 0x00]), - }), - }), - ); + const continueTaskFactory = () => + ({ + run: jest.fn().mockReturnValue(signatureResult), + }) as unknown as ContinueTask; // WHEN - const result = await new SendSignMessageTask(apiMock, args).run(); + const result = await new SendSignMessageTask( + apiMock, + args, + dataStoreService, + continueTaskFactory, + ).run(); // THEN - expect(apiMock.sendCommand).toHaveBeenCalledTimes(1); + expect(dataStoreService.merklizeChunks).toHaveBeenCalledWith( + expect.any(DataStore), + [Uint8Array.from(new Array(64).fill(0x61))], + ); expect(result).toStrictEqual(CommandResultFactory({ data: SIGNATURE })); }); @@ -102,122 +90,32 @@ describe("SignMessageTask", () => { message: EXACT_TWO_CHUNKS_MESSAGE, }; - jest - .spyOn(DefaultDataStoreService.prototype, "merklizeChunks") - .mockImplementation((_, chunks) => { - expect(chunks.length).toBe(2); - return MERKLE_ROOT; - }); - - (apiMock.sendCommand as jest.Mock).mockResolvedValueOnce(signatureResult); - - // WHEN - const result = await new SendSignMessageTask(apiMock, args).run(); - - // THEN - expect(apiMock.sendCommand).toHaveBeenCalledTimes(1); - expect(result).toStrictEqual(CommandResultFactory({ data: SIGNATURE })); - }); - - it("should handle interrupted execution with interactive commands", async () => { - // GIVEN - const args = { - derivationPath: DERIVATION_PATH, - message: EXACT_TWO_CHUNKS_MESSAGE, - }; - - jest - .spyOn(DefaultDataStoreService.prototype, "merklizeChunks") - .mockImplementation((_, chunks) => { - expect(chunks.length).toBe(2); - return MERKLE_ROOT; - }); + const dataStoreService = { + merklizeChunks: jest.fn().mockReturnValue(MERKLE_ROOT), + } as unknown as DataStoreService; - (apiMock.sendCommand as jest.Mock) - .mockResolvedValueOnce( - CommandResultFactory({ - data: APDU_RESPONSE_YELD, - }), - ) - .mockResolvedValueOnce( - CommandResultFactory({ - data: { - statusCode: SW_INTERRUPTED_EXECUTION, - data: new Uint8Array([ClientCommandCodes.GET_PREIMAGE]), - }, - }), - ) - .mockResolvedValueOnce(signatureResult); - - const getClientCommandPayloadMock = jest - .spyOn(ClientCommandInterpreter.prototype, "getClientCommandPayload") - - .mockImplementation( - (request: Uint8Array, context: ClientCommandContext) => { - const commandCode = request[0]; - if (commandCode === ClientCommandCodes.YIELD) { - // simulate YIELD command - context.yieldedResults.push(new Uint8Array([])); - return Right(new Uint8Array([0x00])); - } - if (commandCode === ClientCommandCodes.GET_PREIMAGE) { - // simulate GET_PREIMAGE command - return Right(PREIMAGE); - } - return Left(new ClientCommandHandlerError("error")); - }, - ); + const continueTaskFactory = () => + ({ + run: jest.fn().mockReturnValue(signatureResult), + }) as unknown as ContinueTask; // WHEN - const result = await new SendSignMessageTask(apiMock, args).run(); + const result = await new SendSignMessageTask( + apiMock, + args, + dataStoreService, + continueTaskFactory, + ).run(); // THEN - // expected number of sendCommand calls: - // 1. SignMessageCommand - // 2. ContinueCommand after YIELD - // 3. ContinueCommand after GET_PREIMAGE - expect(apiMock.sendCommand).toHaveBeenCalledTimes(3); - - // check that sendCommand was called with the correct commands - expect(apiMock.sendCommand).toHaveBeenNthCalledWith( - 1, - new SignMessageCommand({ - derivationPath: DERIVATION_PATH, - messageLength: new TextEncoder().encode(EXACT_TWO_CHUNKS_MESSAGE) - .length, - messageMerkleRoot: MERKLE_ROOT, - }), + expect(dataStoreService.merklizeChunks).toHaveBeenCalledWith( + expect.any(DataStore), + [ + Uint8Array.from(new Array(64).fill(0x61)), + Uint8Array.from(new Array(64).fill(0x61)), + ], ); - - expect(apiMock.sendCommand).toHaveBeenNthCalledWith( - 2, - new ContinueCommand({ - payload: new Uint8Array([0x00]), - }), - ); - - expect(apiMock.sendCommand).toHaveBeenNthCalledWith( - 3, - new ContinueCommand({ - payload: PREIMAGE, - }), - ); - - // check the final result expect(result).toStrictEqual(CommandResultFactory({ data: SIGNATURE })); - - // check that getClientCommandPayload was called correctly - expect(getClientCommandPayloadMock).toHaveBeenCalledTimes(2); - expect(getClientCommandPayloadMock).toHaveBeenNthCalledWith( - 1, - new Uint8Array([ClientCommandCodes.YIELD]), - expect.any(Object), - ); - expect(getClientCommandPayloadMock).toHaveBeenNthCalledWith( - 2, - new Uint8Array([ClientCommandCodes.GET_PREIMAGE]), - expect.any(Object), - ); }); it("should return an error if the initial SignMessageCommand fails", async () => { @@ -230,110 +128,29 @@ describe("SignMessageTask", () => { const resultError = CommandResultFactory({ error: new InvalidStatusWordError("error"), }); + const dataStoreService = { + merklizeChunks: jest.fn().mockReturnValue(MERKLE_ROOT), + } as unknown as DataStoreService; - jest - .spyOn(DefaultDataStoreService.prototype, "merklizeChunks") - .mockImplementation((_, chunks) => { - expect(chunks.length).toBe(1); - return MERKLE_ROOT; - }); - - (apiMock.sendCommand as jest.Mock).mockResolvedValueOnce(resultError); + const continueTaskFactory = () => + ({ + run: jest.fn().mockReturnValue(resultError), + }) as unknown as ContinueTask; // WHEN - const result = await new SendSignMessageTask(apiMock, args).run(); + const result = await new SendSignMessageTask( + apiMock, + args, + dataStoreService, + continueTaskFactory, + ).run(); // THEN - expect(apiMock.sendCommand).toHaveBeenCalledTimes(1); - expect(result.status).toBe(CommandResultStatus.Error); - if (result.status === CommandResultStatus.Error) { - expect(result.error).toBeInstanceOf(InvalidStatusWordError); - } - }); - - it("should return an error if a ContinueCommand fails during interactive execution", async () => { - // GIVEN - const args = { - derivationPath: DERIVATION_PATH, - message: EXACT_TWO_CHUNKS_MESSAGE, - }; - - jest - .spyOn(DefaultDataStoreService.prototype, "merklizeChunks") - .mockImplementation((_, chunks) => { - expect(chunks.length).toBe(2); - return MERKLE_ROOT; - }); - - const resultError = CommandResultFactory({ - error: new InvalidStatusWordError("error"), - }); - - (apiMock.sendCommand as jest.Mock) - .mockResolvedValueOnce( - CommandResultFactory({ - data: APDU_RESPONSE_YELD, - }), - ) - .mockResolvedValueOnce(resultError); - - const getClientCommandPayloadMock = jest - .spyOn(ClientCommandInterpreter.prototype, "getClientCommandPayload") - - .mockImplementation( - (request: Uint8Array, context: ClientCommandContext) => { - const commandCode = request[0]; - if (commandCode === ClientCommandCodes.YIELD) { - // simulate YIELD command - context.yieldedResults.push(new Uint8Array([])); - return Right(new Uint8Array([0x00])); - } - // no need GET_PREIMAGE since as it should fail before - return Left(new ClientCommandHandlerError("error")); - }, - ); - - // WHEN - const result = await new SendSignMessageTask(apiMock, args).run(); - - // THEN - // expected number of sendCommand calls: - // 1. SignMessageCommand - // 2. ContinueCommand after YIELD (which fails) - expect(apiMock.sendCommand).toHaveBeenCalledTimes(2); - - // check that sendCommand was called with the correct commands - expect(apiMock.sendCommand).toHaveBeenNthCalledWith( - 1, - new SignMessageCommand({ - derivationPath: DERIVATION_PATH, - messageLength: new TextEncoder().encode(EXACT_TWO_CHUNKS_MESSAGE) - .length, - messageMerkleRoot: MERKLE_ROOT, - }), - ); - - expect(apiMock.sendCommand).toHaveBeenNthCalledWith( - 2, - new ContinueCommand({ - payload: new Uint8Array([0x00]), - }), - ); - - // check the final result expect(result).toStrictEqual( CommandResultFactory({ - error: new InvalidStatusWordError("Invalid response from the device"), + error: new InvalidStatusWordError("error"), }), ); - - // check that getClientCommandPayload was called correctly - expect(getClientCommandPayloadMock).toHaveBeenCalledTimes(1); - expect(getClientCommandPayloadMock).toHaveBeenNthCalledWith( - 1, - new Uint8Array([ClientCommandCodes.YIELD]), - expect.any(Object), - ); }); }); }); diff --git a/packages/signer/signer-btc/src/internal/app-binder/task/SignMessageTask.ts b/packages/signer/signer-btc/src/internal/app-binder/task/SignMessageTask.ts index 8199db8fe..312e9d773 100644 --- a/packages/signer/signer-btc/src/internal/app-binder/task/SignMessageTask.ts +++ b/packages/signer/signer-btc/src/internal/app-binder/task/SignMessageTask.ts @@ -1,8 +1,6 @@ import { type CommandResult, - CommandResultFactory, type InternalApi, - InvalidStatusWordError, isSuccessCommandResult, } from "@ledgerhq/device-management-kit"; @@ -13,12 +11,7 @@ import { CHUNK_SIZE } from "@internal/app-binder/command/utils/constants"; import { ContinueTask } from "@internal/app-binder/task/ContinueTask"; import { DataStore } from "@internal/data-store/model/DataStore"; import { type DataStoreService } from "@internal/data-store/service/DataStoreService"; -import { DefaultDataStoreService } from "@internal/data-store/service/DefaultDataStoreService"; -import { MerkleMapBuilder } from "@internal/merkle-tree/service/MerkleMapBuilder"; -import { MerkleTreeBuilder } from "@internal/merkle-tree/service/MerkleTreeBuilder"; -import { Sha256HasherService } from "@internal/merkle-tree/service/Sha256HasherService"; import { BtcCommandUtils } from "@internal/utils/BtcCommandUtils"; -import { DefaultWalletSerializer } from "@internal/wallet/service/DefaultWalletSerializer"; export type SendSignMessageTaskArgs = { derivationPath: string; @@ -26,28 +19,18 @@ export type SendSignMessageTaskArgs = { }; export class SendSignMessageTask { - private dataStoreService: DataStoreService; - constructor( - private api: InternalApi, - private args: SendSignMessageTaskArgs, - ) { - const merkleTreeBuilder = new MerkleTreeBuilder(new Sha256HasherService()); - const merkleMapBuilder = new MerkleMapBuilder(merkleTreeBuilder); - const walletSerializer = new DefaultWalletSerializer( - new Sha256HasherService(), - ); - - this.dataStoreService = new DefaultDataStoreService( - merkleTreeBuilder, - merkleMapBuilder, - walletSerializer, - new Sha256HasherService(), - ); - } + private readonly _api: InternalApi, + private readonly _args: SendSignMessageTaskArgs, + private readonly _dataStoreService: DataStoreService, + private readonly _continueTaskFactory = ( + api: InternalApi, + dataStore: DataStore, + ) => new ContinueTask(api, dataStore), + ) {} async run(): Promise> { - const { derivationPath, message } = this.args; + const { derivationPath, message } = this._args; const dataStore = new DataStore(); @@ -57,23 +40,21 @@ export class SendSignMessageTask { chunks.push(messageBuffer.subarray(i, i + CHUNK_SIZE)); } - const merkleRoot = this.dataStoreService.merklizeChunks(dataStore, chunks); + const merkleRoot = this._dataStoreService.merklizeChunks(dataStore, chunks); - const signMessageFirstCommandResponse = await this.api.sendCommand( + const signMessageFirstCommandResponse = await this._api.sendCommand( new SignMessageCommand({ derivationPath, messageLength: messageBuffer.length, messageMerkleRoot: merkleRoot, }), ); - const response = await new ContinueTask(this.api, dataStore).run( + const response = await this._continueTaskFactory(this._api, dataStore).run( signMessageFirstCommandResponse, ); if (isSuccessCommandResult(response)) { return BtcCommandUtils.getSignature(response); } - return CommandResultFactory({ - error: new InvalidStatusWordError("Invalid response from the device"), - }); + return response; } } diff --git a/packages/signer/signer-btc/src/internal/app-binder/task/SignPsbtTask.test.ts b/packages/signer/signer-btc/src/internal/app-binder/task/SignPsbtTask.test.ts index 35f23b895..611091f04 100644 --- a/packages/signer/signer-btc/src/internal/app-binder/task/SignPsbtTask.test.ts +++ b/packages/signer/signer-btc/src/internal/app-binder/task/SignPsbtTask.test.ts @@ -1,199 +1,155 @@ import { - ApduResponse, CommandResultFactory, type InternalApi, UnknownDeviceExchangeError, } from "@ledgerhq/device-management-kit"; +import { Maybe, Nothing } from "purify-ts"; -import { type DefaultWallet } from "@api/model/Wallet"; import { SignPsbtCommand } from "@internal/app-binder/command/SignPsbtCommand"; -import { BuildPsbtTask } from "@internal/app-binder/task/BuildPsbtTask"; -import { ContinueTask } from "@internal/app-binder/task/ContinueTask"; -import { PrepareWalletPolicyTask } from "@internal/app-binder/task/PrepareWalletPolicyTask"; +import { type ContinueTask } from "@internal/app-binder/task/ContinueTask"; import { SignPsbtTask } from "@internal/app-binder/task/SignPsbtTask"; +import { type DataStore } from "@internal/data-store/model/DataStore"; +import type { PsbtCommitment } from "@internal/data-store/service/DataStoreService"; +import { type Psbt } from "@internal/psbt/model/Psbt"; +import { type ValueParser } from "@internal/psbt/service/value/ValueParser"; +import { type Wallet } from "@internal/wallet/model/Wallet"; import { type WalletSerializer } from "@internal/wallet/service/WalletSerializer"; -const mockRunBuildPsbt = jest.fn(); -const mockRunPrepareWallet = jest.fn(); -const mockRunContinue = jest.fn(() => - CommandResultFactory({ - data: new ApduResponse({ - statusCode: Uint8Array.from([0xe0, 0x00]), - data: Uint8Array.from([]), - }), - }), -); - -jest.mock("@internal/app-binder/task/BuildPsbtTask", () => ({ - BuildPsbtTask: jest.fn().mockImplementation(() => ({ - run: mockRunBuildPsbt, - })), -})); -jest.mock("@internal/app-binder/task/ContinueTask", () => ({ - ContinueTask: jest.fn().mockImplementation(() => ({ - run: mockRunContinue, - getYieldedResults: jest.fn(() => []), - })), -})); -jest.mock("@internal/app-binder/task/PrepareWalletPolicyTask", () => ({ - PrepareWalletPolicyTask: jest.fn().mockImplementation(() => ({ - run: mockRunPrepareWallet, - })), -})); +const SIGN_PSBT_YIELD_RESULT = Uint8Array.from([ + 0x00, 0x20, 0xf1, 0xe8, 0x42, 0x44, 0x7f, 0xae, 0x7b, 0x1c, 0x6e, 0xb7, 0xa8, + 0xa7, 0x85, 0xf7, 0x76, 0xfa, 0x19, 0xa9, 0x3a, 0xb9, 0x6c, 0xc1, 0xee, 0xee, + 0xe9, 0x47, 0xc1, 0x71, 0x13, 0x38, 0x5f, 0x5f, 0x12, 0x4d, 0x63, 0x5c, 0xf2, + 0x52, 0xae, 0x26, 0xa6, 0x7b, 0xe2, 0x77, 0x71, 0x2e, 0xad, 0x07, 0xb4, 0x48, + 0x96, 0xdf, 0xb0, 0x16, 0xfc, 0x9d, 0x03, 0xa3, 0xe9, 0x22, 0xbd, 0x9a, 0x01, + 0x66, 0x3c, 0x59, 0x59, 0x41, 0x13, 0xe5, 0x71, 0x00, 0x06, 0x3d, 0x9d, 0xcc, + 0xd7, 0x8f, 0xb3, 0x93, 0x82, 0xdb, 0xf8, 0x0a, 0x8f, 0x11, 0x50, 0xfd, 0x59, + 0xd9, 0xfe, 0xb7, 0x9e, 0x25, 0x3b, 0xd2, +]); describe("SignPsbtTask", () => { describe("run", () => { - it("should call all tasks", async () => { + it("should return signatures", async () => { // given const api = { sendCommand: jest.fn(), } as unknown as InternalApi; - const psbt = ""; - const wallet = {} as DefaultWallet; + const psbt = { + getGlobalValue: () => Maybe.of(Uint8Array.from([0x03])), + } as unknown as Psbt; + const wallet = { + hmac: Uint8Array.from([0x04]), + } as Wallet; + const psbtCommitment = { + globalCommitment: Uint8Array.from([0x03]), + inputsRoot: Uint8Array.from([0x01]), + outputsRoot: Uint8Array.from([0x02]), + } as PsbtCommitment; + const dataStore = {} as DataStore; const walletSerializer = { getId: jest.fn(() => Uint8Array.from([0x05])), } as unknown as WalletSerializer; - mockRunBuildPsbt.mockReturnValue( - CommandResultFactory({ - data: { - psbtCommitment: { - inputsRoot: Uint8Array.from([0x01]), - outputsRoot: Uint8Array.from([0x02]), - globalCommitment: Uint8Array.from([0x03]), - }, - inputsCount: 42, - outputsCount: 42, - }, - }), - ); - mockRunPrepareWallet.mockReturnValue( - CommandResultFactory({ - data: { - hmac: Uint8Array.from([0x04]), - }, - }), - ); + const valueParser = { + getVarint: jest.fn(() => Maybe.of(42)), + } as unknown as ValueParser; + const continueTaskFactory = () => + ({ + run: jest.fn().mockResolvedValue( + CommandResultFactory({ + data: [], + }), + ), + getYieldedResults: () => [SIGN_PSBT_YIELD_RESULT], + }) as unknown as ContinueTask; // when - await new SignPsbtTask( + const signatures = await new SignPsbtTask( api, { psbt, wallet, + psbtCommitment, + dataStore, }, walletSerializer, + valueParser, + continueTaskFactory, ).run(); // then - expect(BuildPsbtTask).toHaveBeenCalled(); - expect(ContinueTask).toHaveBeenCalled(); - expect(PrepareWalletPolicyTask).toHaveBeenCalled(); expect(api.sendCommand).toHaveBeenCalledWith( new SignPsbtCommand({ - globalCommitments: Uint8Array.from([0x03]), + globalCommitment: Uint8Array.from([0x03]), inputsCount: 42, - inputsCommitments: Uint8Array.from([0x01]), + inputsRoot: Uint8Array.from([0x01]), outputsCount: 42, - outputsCommitments: Uint8Array.from([0x02]), + outputsRoot: Uint8Array.from([0x02]), walletId: Uint8Array.from([0x05]), walletHmac: Uint8Array.from([0x04]), }), ); - }); - }); - describe("errors", () => { - it("should return an error if build psbt fails", async () => { - // given - const api = { - sendCommand: jest.fn(), - } as unknown as InternalApi; - const psbt = ""; - const wallet = {} as DefaultWallet; - mockRunPrepareWallet.mockReturnValue( - CommandResultFactory({ - data: {}, - }), - ); - mockRunBuildPsbt.mockReturnValue( - CommandResultFactory({ - error: new UnknownDeviceExchangeError("Failed"), - }), - ); - // when - const result = await new SignPsbtTask(api, { - psbt, - wallet, - }).run(); - // then - expect(result).toStrictEqual( - CommandResultFactory({ - error: new UnknownDeviceExchangeError("Failed"), - }), - ); - }); - it("should return an error if prepare wallet fails", async () => { - // given - const api = { - sendCommand: jest.fn(), - } as unknown as InternalApi; - const psbt = ""; - const wallet = {} as DefaultWallet; - mockRunBuildPsbt.mockReturnValue( - CommandResultFactory({ - data: {}, - }), - ); - mockRunPrepareWallet.mockReturnValue( - CommandResultFactory({ - error: new UnknownDeviceExchangeError("Failed"), - }), - ); - // when - const result = await new SignPsbtTask(api, { - psbt, - wallet, - }).run(); - // then - expect(result).toStrictEqual( + expect(signatures).toStrictEqual( CommandResultFactory({ - error: new UnknownDeviceExchangeError("Failed"), + data: [ + { + inputIndex: 0, + pubKeyAugmented: Uint8Array.from([ + 0xf1, 0xe8, 0x42, 0x44, 0x7f, 0xae, 0x7b, 0x1c, 0x6e, 0xb7, + 0xa8, 0xa7, 0x85, 0xf7, 0x76, 0xfa, 0x19, 0xa9, 0x3a, 0xb9, + 0x6c, 0xc1, 0xee, 0xee, 0xe9, 0x47, 0xc1, 0x71, 0x13, 0x38, + 0x5f, 0x5f, + ]), + signature: Uint8Array.from([ + 0x12, 0x4d, 0x63, 0x5c, 0xf2, 0x52, 0xae, 0x26, 0xa6, 0x7b, + 0xe2, 0x77, 0x71, 0x2e, 0xad, 0x07, 0xb4, 0x48, 0x96, 0xdf, + 0xb0, 0x16, 0xfc, 0x9d, 0x03, 0xa3, 0xe9, 0x22, 0xbd, 0x9a, + 0x01, 0x66, 0x3c, 0x59, 0x59, 0x41, 0x13, 0xe5, 0x71, 0x00, + 0x06, 0x3d, 0x9d, 0xcc, 0xd7, 0x8f, 0xb3, 0x93, 0x82, 0xdb, + 0xf8, 0x0a, 0x8f, 0x11, 0x50, 0xfd, 0x59, 0xd9, 0xfe, 0xb7, + 0x9e, 0x25, 0x3b, 0xd2, + ]), + }, + ], }), ); }); + }); + describe("errors", () => { it("should return an error if continue task fails", async () => { // given const api = { sendCommand: jest.fn(), } as unknown as InternalApi; - const psbt = ""; - const wallet = {} as DefaultWallet; + const psbt = { + getGlobalValue: jest.fn(() => Nothing), + } as unknown as Psbt; + const wallet = {} as Wallet; + const psbtCommitment = {} as PsbtCommitment; + const dataStore = {} as DataStore; const walletSerializer = { getId: jest.fn(() => Uint8Array.from([0x05])), } as unknown as WalletSerializer; - mockRunContinue.mockReturnValue( - CommandResultFactory({ - error: new UnknownDeviceExchangeError("Failed"), - }), - ); - mockRunPrepareWallet.mockReturnValue( - CommandResultFactory({ - data: {}, - }), - ); - mockRunBuildPsbt.mockReturnValue( - CommandResultFactory({ - data: { - psbtCommitment: {}, - }, - }), - ); + const valueParser = { + getVarint: jest.fn(() => Maybe.of(42)), + } as unknown as ValueParser; + const continueTaskFactory = () => + ({ + run: jest.fn().mockResolvedValue( + CommandResultFactory({ + error: new UnknownDeviceExchangeError("Failed"), + }), + ), + }) as unknown as ContinueTask; // when const result = await new SignPsbtTask( api, { psbt, wallet, + psbtCommitment, + dataStore, }, walletSerializer, + valueParser, + continueTaskFactory, ).run(); // then expect(result).toStrictEqual( diff --git a/packages/signer/signer-btc/src/internal/app-binder/task/SignPsbtTask.ts b/packages/signer/signer-btc/src/internal/app-binder/task/SignPsbtTask.ts index 7e652eeee..aa9f4255b 100644 --- a/packages/signer/signer-btc/src/internal/app-binder/task/SignPsbtTask.ts +++ b/packages/signer/signer-btc/src/internal/app-binder/task/SignPsbtTask.ts @@ -1,96 +1,98 @@ import { - CommandResult, + ByteArrayParser, + type CommandResult, CommandResultFactory, type InternalApi, isSuccessCommandResult, } from "@ledgerhq/device-management-kit"; -import { injectable } from "inversify"; +import { Maybe } from "purify-ts"; -import { Psbt } from "@api/model/Psbt"; -import { Wallet as ApiWallet } from "@api/model/Wallet"; import { SignPsbtCommand } from "@internal/app-binder/command/SignPsbtCommand"; -import { BtcErrorCodes } from "@internal/app-binder/command/utils/bitcoinAppErrors"; -import { BuildPsbtTask } from "@internal/app-binder/task/BuildPsbtTask"; +import { type BtcErrorCodes } from "@internal/app-binder/command/utils/bitcoinAppErrors"; +import { type BuildPsbtTaskResult } from "@internal/app-binder/task/BuildPsbtTask"; import { ContinueTask } from "@internal/app-binder/task/ContinueTask"; -import { PrepareWalletPolicyTask } from "@internal/app-binder/task/PrepareWalletPolicyTask"; -import { DataStore } from "@internal/data-store/model/DataStore"; -import { PsbtCommitment } from "@internal/data-store/service/DataStoreService"; -import { Sha256HasherService } from "@internal/merkle-tree/service/Sha256HasherService"; -import { Wallet as InternalWallet } from "@internal/wallet/model/Wallet"; -import { DefaultWalletSerializer } from "@internal/wallet/service/DefaultWalletSerializer"; +import { type DataStore } from "@internal/data-store/model/DataStore"; +import { PsbtGlobal } from "@internal/psbt/model/Psbt"; +import type { ValueParser } from "@internal/psbt/service/value/ValueParser"; +import { extractVarint } from "@internal/utils/Varint"; +import { type Wallet as InternalWallet } from "@internal/wallet/model/Wallet"; import type { WalletSerializer } from "@internal/wallet/service/WalletSerializer"; -export type SignPsbtTaskArgs = { - psbt: Psbt; - wallet: ApiWallet; +export type SignPsbtTaskArgs = BuildPsbtTaskResult & { + wallet: InternalWallet; }; -@injectable() +export type PsbtSignature = { + inputIndex: number; + pubKeyAugmented: Uint8Array; + signature: Uint8Array; +}; + +export type SignPsbtTaskResult = CommandResult; + export class SignPsbtTask { - private readonly _walletSerializer: WalletSerializer; constructor( private readonly _api: InternalApi, private readonly _args: SignPsbtTaskArgs, - walletSerializer?: WalletSerializer, - ) { - const hasher = new Sha256HasherService(); - this._walletSerializer = - walletSerializer || new DefaultWalletSerializer(hasher); - } - - private async runPrepareWalletPolicy() { - return new PrepareWalletPolicyTask(this._api, { - wallet: this._args.wallet, - }).run(); - } - private async runBuildPsbt(wallet: InternalWallet) { - return new BuildPsbtTask({ wallet, psbt: this._args.psbt }).run(); - } + private readonly _walletSerializer: WalletSerializer, + private readonly _valueParser: ValueParser, + private readonly _continueTaskFactory = ( + api: InternalApi, + dataStore: DataStore, + ) => new ContinueTask(api, dataStore), + ) {} - private async runSignPsbt( - psbtCommitment: PsbtCommitment, - dataStore: DataStore, - inputsCount: number, - outputsCount: number, - wallet: InternalWallet, - ): Promise> { + async run(): Promise { + const { + psbtCommitment: { globalCommitment, inputsRoot, outputsRoot }, + psbt, + wallet, + dataStore, + } = this._args; const signPsbtCommandResult = await this._api.sendCommand( new SignPsbtCommand({ - globalCommitments: psbtCommitment.globalCommitment, - inputsCount, - inputsCommitments: psbtCommitment.inputsRoot, - outputsCount, - outputsCommitments: psbtCommitment.outputsRoot, + globalCommitment, + inputsRoot, + outputsRoot, + inputsCount: psbt + .getGlobalValue(PsbtGlobal.INPUT_COUNT) + .chain((value) => this._valueParser.getVarint(value.data)) + .orDefault(0), + outputsCount: psbt + .getGlobalValue(PsbtGlobal.OUTPUT_COUNT) + .chain((value) => this._valueParser.getVarint(value.data)) + .orDefault(0), walletId: this._walletSerializer.getId(wallet), walletHmac: wallet.hmac, }), ); - const continueTask = new ContinueTask(this._api, dataStore); + const continueTask = this._continueTaskFactory(this._api, dataStore); const result = await continueTask.run(signPsbtCommandResult); if (isSuccessCommandResult(result)) { const signatureList = continueTask.getYieldedResults(); - return CommandResultFactory({ data: signatureList }); + const signatures = signatureList.map((sig) => { + const parser = new ByteArrayParser(sig); + const inputIndex = extractVarint(parser).mapOrDefault( + (val) => val.value, + 0, + ); + const pubKeyAugmentedLength = Maybe.fromNullable( + parser.extract8BitUInt(), + ).orDefault(0); + const pubKeyAugmented = Maybe.fromNullable( + parser.extractFieldByLength(pubKeyAugmentedLength), + ).orDefault(Uint8Array.from([])); + const signature = Maybe.fromNullable( + parser.extractFieldByLength(parser.getUnparsedRemainingLength()), + ).orDefault(Uint8Array.from([])); + return { signature, inputIndex, pubKeyAugmented }; + }); + return CommandResultFactory({ + data: signatures, + }); } return result; } - - async run() { - const walletResult = await this.runPrepareWalletPolicy(); - if (!isSuccessCommandResult(walletResult)) { - return walletResult; - } - const psbtResult = await this.runBuildPsbt(walletResult.data); - if (!isSuccessCommandResult(psbtResult)) { - return psbtResult; - } - return await this.runSignPsbt( - psbtResult.data.psbtCommitment, - psbtResult.data.dataStore, - psbtResult.data.inputsCount, - psbtResult.data.outputsCount, - walletResult.data, - ); - } }