diff --git a/programs/svm-spoke/src/constraints.rs b/programs/svm-spoke/src/constraints.rs index 2c23d93cb..a1a685b32 100644 --- a/programs/svm-spoke/src/constraints.rs +++ b/programs/svm-spoke/src/constraints.rs @@ -1,4 +1,8 @@ use anchor_lang::prelude::*; +use anchor_spl::{ + associated_token::get_associated_token_address_with_program_id, + token_interface::{Mint, TokenAccount, TokenInterface}, +}; use crate::{ state::State, @@ -13,3 +17,17 @@ pub fn is_local_or_remote_owner(signer: &Signer, state: &Account) -> bool pub fn is_relay_hash_valid(relay_hash: &[u8; 32], relay_data: &V3RelayData, state: &Account) -> bool { relay_hash == &get_v3_relay_hash(relay_data, state.chain_id) } + +// Implements the same underlying logic as in Anchor's associated_token constraint macro, except for token_program_check +// as that would duplicate Anchor's token constraint macro that the caller already uses. +// https://github.com/coral-xyz/anchor/blob/e6d7dafe12da661a36ad1b4f3b5970e8986e5321/lang/syn/src/codegen/accounts/constraints.rs#L1132 +pub fn is_valid_associated_token_account( + token_account: &InterfaceAccount, + mint: &InterfaceAccount, + token_program: &Interface, + authority: &Pubkey, +) -> bool { + &token_account.owner == authority + && token_account.key() + == get_associated_token_address_with_program_id(authority, &mint.key(), &token_program.key()) +} diff --git a/programs/svm-spoke/src/error.rs b/programs/svm-spoke/src/error.rs index cca8fab68..76250aafa 100644 --- a/programs/svm-spoke/src/error.rs +++ b/programs/svm-spoke/src/error.rs @@ -72,6 +72,8 @@ pub enum SvmError { NonZeroRefundClaim, #[msg("Invalid claim initializer!")] InvalidClaimInitializer, + #[msg("Invalid refund token account!")] + InvalidRefundTokenAccount, #[msg("Seed must be 0 in production!")] InvalidProductionSeed, #[msg("Invalid remaining accounts for ATA creation!")] diff --git a/programs/svm-spoke/src/instructions/refund_claims.rs b/programs/svm-spoke/src/instructions/refund_claims.rs index 7fa215380..3a48cb6d2 100644 --- a/programs/svm-spoke/src/instructions/refund_claims.rs +++ b/programs/svm-spoke/src/instructions/refund_claims.rs @@ -3,6 +3,7 @@ use anchor_spl::token_interface::{transfer_checked, Mint, TokenAccount, TokenInt use crate::{ constants::DISCRIMINATOR_SIZE, + constraints::is_valid_associated_token_account, error::SvmError, event::ClaimedRelayerRefund, state::{ClaimAccount, State}, @@ -57,15 +58,25 @@ pub struct ClaimRelayerRefund<'info> { #[account(mint::token_program = token_program)] pub mint: InterfaceAccount<'info, Mint>, - // This method allows relayer to claim refunds on any custom token account. - #[account(mut, token::mint = mint, token::token_program = token_program)] + /// CHECK: This is used for claim_account PDA derivation and it is up to the caller to ensure it is valid. + pub refund_address: UncheckedAccount<'info>, + + // If refund_address is the same as signer this method allows relayer to claim refunds on any custom token account. + // Otherwise this must be the associated token account of the provided refund_address. + #[account( + mut, + token::mint = mint, + token::token_program = token_program, + constraint = refund_address.key().eq(&signer.key()) + || is_valid_associated_token_account(&token_account, &mint, &token_program, &refund_address.key()) + @ SvmError::InvalidRefundTokenAccount + )] pub token_account: InterfaceAccount<'info, TokenAccount>, - // Only relayer can claim the refund with this method as the claim account is derived from the relayer's address. #[account( mut, close = initializer, - seeds = [b"claim_account", mint.key().as_ref(), signer.key().as_ref()], + seeds = [b"claim_account", mint.key().as_ref(), refund_address.key().as_ref()], bump )] pub claim_account: Account<'info, ClaimAccount>, @@ -99,84 +110,12 @@ pub fn claim_relayer_refund(ctx: Context) -> Result<()> { emit_cpi!(ClaimedRelayerRefund { l2_token_address: ctx.accounts.mint.key(), claim_amount, - refund_address: ctx.accounts.signer.key(), + refund_address: ctx.accounts.refund_address.key(), }); Ok(()) // There is no need to reset the claim amount as the account will be closed at the end of instruction. } -#[event_cpi] -#[derive(Accounts)] -#[instruction(refund_address: Pubkey)] -pub struct ClaimRelayerRefundFor<'info> { - pub signer: Signer<'info>, - - /// CHECK: We don't need any additional checks as long as this is the same account that initialized the claim account. - #[account(mut, address = claim_account.initializer @ SvmError::InvalidClaimInitializer)] - pub initializer: UncheckedAccount<'info>, - - #[account(seeds = [b"state", state.seed.to_le_bytes().as_ref()], bump)] - pub state: Account<'info, State>, - - #[account( - mut, - associated_token::mint = mint, - associated_token::authority = state, - associated_token::token_program = token_program - )] - pub vault: InterfaceAccount<'info, TokenAccount>, - - // Mint address has been checked when executing the relayer refund leaf and it is part of claim account derivation. - #[account(mint::token_program = token_program)] - pub mint: InterfaceAccount<'info, Mint>, - - #[account( - mut, - associated_token::mint = mint, - associated_token::authority = refund_address, - associated_token::token_program = token_program - )] - pub token_account: InterfaceAccount<'info, TokenAccount>, - - #[account( - mut, - close = initializer, - seeds = [b"claim_account", mint.key().as_ref(), refund_address.as_ref()], - bump - )] - pub claim_account: Account<'info, ClaimAccount>, - - pub token_program: Interface<'info, TokenInterface>, -} - -pub fn claim_relayer_refund_for(ctx: Context, refund_address: Pubkey) -> Result<()> { - // Ensure the claim account holds a non-zero amount. - let claim_amount = ctx.accounts.claim_account.amount; - if claim_amount == 0 { - return err!(SvmError::ZeroRefundClaim); - } - - // Derive the signer seeds for the state required for the transfer form vault. - let state_seed_bytes = ctx.accounts.state.seed.to_le_bytes(); - let seeds = &[b"state", state_seed_bytes.as_ref(), &[ctx.bumps.state]]; - let signer_seeds = &[&seeds[..]]; - - // Transfer the claim amount from the vault to the relayer token account. - let transfer_accounts = TransferChecked { - from: ctx.accounts.vault.to_account_info(), - mint: ctx.accounts.mint.to_account_info(), - to: ctx.accounts.token_account.to_account_info(), - authority: ctx.accounts.state.to_account_info(), - }; - let cpi_context = - CpiContext::new_with_signer(ctx.accounts.token_program.to_account_info(), transfer_accounts, signer_seeds); - transfer_checked(cpi_context, claim_amount, ctx.accounts.mint.decimals)?; - - emit_cpi!(ClaimedRelayerRefund { l2_token_address: ctx.accounts.mint.key(), claim_amount, refund_address }); - - Ok(()) // There is no need to reset the claim amount as the account will be closed at the end of instruction. -} - // Though claim accounts are being closed automatically when claiming the refund, there might be a scenario where // relayer refunds were executed with ATA after initializing the claim account. In such cases, the initializer should be // able to close the claim account manually. diff --git a/programs/svm-spoke/src/lib.rs b/programs/svm-spoke/src/lib.rs index eea54150d..1d5830554 100644 --- a/programs/svm-spoke/src/lib.rs +++ b/programs/svm-spoke/src/lib.rs @@ -455,18 +455,15 @@ pub mod svm_spoke { /// - state (Account): Spoke state PDA. Seed: ["state",state.seed] where seed is 0 on mainnet. /// - vault (InterfaceAccount): The ATA for the refunded mint. Authority must be the state. /// - mint (InterfaceAccount): The mint account for the token being refunded. - /// - token_account (InterfaceAccount): The ATA for the token being refunded to. + /// - refund_address: token account authority receiving the refund. + /// - token_account (InterfaceAccount): The receiving token account for the refund. When refund_address is different + /// from the signer, this must match its ATA. /// - claim_account (Account): The claim account PDA. Seed: ["claim_account",mint,refund_address]. /// - token_program (Interface): The token program. pub fn claim_relayer_refund(ctx: Context) -> Result<()> { instructions::claim_relayer_refund(ctx) } - /// Functionally identical to claim_relayer_refund() except the refund is sent to a specified refund address. - pub fn claim_relayer_refund_for(ctx: Context, refund_address: Pubkey) -> Result<()> { - instructions::claim_relayer_refund_for(ctx, refund_address) - } - /// Creates token accounts in batch for a set of addresses. /// /// This helper function allows the caller to pass in a set of remaining accounts to create a batch of Associated diff --git a/test/svm/SvmSpoke.RefundClaims.ts b/test/svm/SvmSpoke.RefundClaims.ts index 01c35ff09..dcb0e4890 100644 --- a/test/svm/SvmSpoke.RefundClaims.ts +++ b/test/svm/SvmSpoke.RefundClaims.ts @@ -4,7 +4,14 @@ import { Keypair, PublicKey } from "@solana/web3.js"; import { assert } from "chai"; import { common } from "./SvmSpoke.common"; import { MerkleTree } from "@uma/common/dist/MerkleTree"; -import { createMint, getOrCreateAssociatedTokenAccount, mintTo, TOKEN_PROGRAM_ID } from "@solana/spl-token"; +import { + AuthorityType, + createMint, + getOrCreateAssociatedTokenAccount, + mintTo, + setAuthority, + TOKEN_PROGRAM_ID, +} from "@solana/spl-token"; import { RelayerRefundLeafSolana, RelayerRefundLeafType } from "../../src/types/svm"; import { loadExecuteRelayerRefundLeafParams, readEventsUntilFound, relayerRefundHashFn } from "../../src/svm"; @@ -31,6 +38,7 @@ describe("svm_spoke.refund_claims", () => { state: PublicKey; vault: PublicKey; mint: PublicKey; + refundAddress: PublicKey; tokenAccount: PublicKey; claimAccount: PublicKey; tokenProgram: PublicKey; @@ -142,6 +150,7 @@ describe("svm_spoke.refund_claims", () => { state, vault, mint, + refundAddress: relayer.publicKey, tokenAccount, claimAccount, tokenProgram: TOKEN_PROGRAM_ID, @@ -169,10 +178,7 @@ describe("svm_spoke.refund_claims", () => { const iRelayerBal = (await connection.getTokenAccountBalance(tokenAccount)).value.amount; // Claim refund for the relayer. - const tx = await program.methods - .claimRelayerRefundFor(relayer.publicKey) - .accounts(claimRelayerRefundAccounts) - .rpc(); + const tx = await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).rpc(); // The relayer should have received funds from the vault. const fVaultBal = (await connection.getTokenAccountBalance(vault)).value.amount; @@ -194,11 +200,11 @@ describe("svm_spoke.refund_claims", () => { await executeRelayerRefundToClaim(relayerRefund); // Claim refund for the relayer. - await program.methods.claimRelayerRefundFor(relayer.publicKey).accounts(claimRelayerRefundAccounts).rpc(); + await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).rpc(); // The claim account should have been automatically closed, so repeated claim should fail. try { - await program.methods.claimRelayerRefundFor(relayer.publicKey).accounts(claimRelayerRefundAccounts).rpc(); + await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).rpc(); assert.fail("Claiming refund from closed account should fail"); } catch (error: any) { assert.instanceOf(error, AnchorError); @@ -212,7 +218,7 @@ describe("svm_spoke.refund_claims", () => { // After reinitalizing the claim account, the repeated claim should still fail. await initializeClaimAccount(); try { - await program.methods.claimRelayerRefundFor(relayer.publicKey).accounts(claimRelayerRefundAccounts).rpc(); + await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).rpc(); assert.fail("Claiming refund from reinitalized account should fail"); } catch (error: any) { assert.instanceOf(error, AnchorError); @@ -231,7 +237,7 @@ describe("svm_spoke.refund_claims", () => { const iRelayerBal = (await connection.getTokenAccountBalance(tokenAccount)).value.amount; // Claim refund for the relayer. - await await program.methods.claimRelayerRefundFor(relayer.publicKey).accounts(claimRelayerRefundAccounts).rpc(); + await await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).rpc(); // The relayer should have received both refunds. const fVaultBal = (await connection.getTokenAccountBalance(vault)).value.amount; @@ -256,7 +262,7 @@ describe("svm_spoke.refund_claims", () => { // Claiming with default initializer should fail. try { - await program.methods.claimRelayerRefundFor(relayer.publicKey).accounts(claimRelayerRefundAccounts).rpc(); + await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).rpc(); } catch (error: any) { assert.instanceOf(error, AnchorError); assert.strictEqual( @@ -268,7 +274,7 @@ describe("svm_spoke.refund_claims", () => { // Claim refund for the relayer passing the correct initializer account. claimRelayerRefundAccounts.initializer = anotherInitializer.publicKey; - await program.methods.claimRelayerRefundFor(relayer.publicKey).accounts(claimRelayerRefundAccounts).rpc(); + await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).rpc(); // The relayer should have received funds from the vault. const fVaultBal = (await connection.getTokenAccountBalance(vault)).value.amount; @@ -329,25 +335,50 @@ describe("svm_spoke.refund_claims", () => { } }); - it("Cannot claim refund on behalf of relayer to wrong token account", async () => { + it("Cannot claim refund on behalf of relayer to wrongly owned token account", async () => { // Execute relayer refund using claim account. const relayerRefund = new BN(500000); await executeRelayerRefundToClaim(relayerRefund); - // Claim refund for the relayer to a custom token account. + // Claim refund for the relayer to a custom token account owned by another authority. const wrongOwner = Keypair.generate().publicKey; const wrongTokenAccount = (await getOrCreateAssociatedTokenAccount(connection, payer, mint, wrongOwner)).address; claimRelayerRefundAccounts.tokenAccount = wrongTokenAccount; try { - await program.methods.claimRelayerRefundFor(relayer.publicKey).accounts(claimRelayerRefundAccounts).rpc(); + await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).rpc(); + assert.fail("Claiming refund to custom token account should fail"); + } catch (error: any) { + assert.instanceOf(error, AnchorError); + assert.strictEqual( + error.error.errorCode.code, + "InvalidRefundTokenAccount", + "Expected error code InvalidRefundTokenAccount" + ); + } + }); + + it("Cannot claim refund on behalf of relayer to wrong associated token account", async () => { + // Execute relayer refund using claim account. + const relayerRefund = new BN(500000); + await executeRelayerRefundToClaim(relayerRefund); + + // Claim refund for the relayer to a custom token account owned by the relayer, but not being its associated token account. + const wrongOwner = Keypair.generate(); + const wrongTokenAccount = (await getOrCreateAssociatedTokenAccount(connection, payer, mint, wrongOwner.publicKey)) + .address; + claimRelayerRefundAccounts.tokenAccount = wrongTokenAccount; + await setAuthority(connection, payer, wrongTokenAccount, wrongOwner, AuthorityType.AccountOwner, relayer.publicKey); + + try { + await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).rpc(); assert.fail("Claiming refund to custom token account should fail"); } catch (error: any) { assert.instanceOf(error, AnchorError); assert.strictEqual( error.error.errorCode.code, - "ConstraintTokenOwner", - "Expected error code ConstraintTokenOwner" + "InvalidRefundTokenAccount", + "Expected error code InvalidRefundTokenAccount" ); } }); @@ -389,6 +420,7 @@ describe("svm_spoke.refund_claims", () => { await executeRelayerRefundToClaim(relayerRefund); // Claim refund for the relayer with the default signer should fail as relayer address is part of claim account derivation. + claimRelayerRefundAccounts.refundAddress = owner; try { await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).rpc(); assert.fail("Claiming refund with wrong signer should fail");