Skip to content

Commit

Permalink
solana-ibc: add support for witness account (#388)
Browse files Browse the repository at this point in the history
Co-authored-by: dhruvja <[email protected]>
  • Loading branch information
mina86 and dhruvja authored Oct 18, 2024
1 parent a21cd4b commit ba1046e
Show file tree
Hide file tree
Showing 10 changed files with 314 additions and 83 deletions.
1 change: 1 addition & 0 deletions solana/restaking/programs/restaking/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ no-entrypoint = []
no-idl = []
no-log-ix-name = []
cpi = ["no-entrypoint"]
witness = ["solana-ibc/witness"]
default = []

[dependencies]
Expand Down
46 changes: 31 additions & 15 deletions solana/restaking/programs/restaking/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ pub mod restaking {
/// sent in the same order as given below
/// - Chain Data
/// - trie
/// - witness (if compiled with `witness` Cargo feature)
/// - Guest blockchain program ID
pub fn deposit<'a, 'info>(
ctx: Context<'a, 'a, 'a, 'info, Deposit<'info>>,
Expand Down Expand Up @@ -106,8 +107,12 @@ pub mod restaking {
let validator_key = match service {
Service::GuestChain { validator } => validator,
};
let remaining_accounts = validation::validate_remaining_accounts(
ctx.remaining_accounts,
&guest_chain_program_id,
)?;
let borrowed_chain_data =
ctx.remaining_accounts[0].data.try_borrow().unwrap();
remaining_accounts.chain.try_borrow_data().unwrap();
let mut chain_data: &[u8] = &borrowed_chain_data;
let chain =
solana_ibc::chain::ChainData::try_deserialize(&mut chain_data)
Expand All @@ -118,20 +123,19 @@ pub mod restaking {
let amount = validator.map_or(u128::from(amount), |val| {
u128::from(val.stake) + u128::from(amount)
});
validation::validate_remaining_accounts(
ctx.remaining_accounts,
&guest_chain_program_id,
)?;
core::mem::drop(borrowed_chain_data);

let cpi_accounts = SetStake {
sender: ctx.accounts.depositor.to_account_info(),
chain: ctx.remaining_accounts[0].clone(),
trie: ctx.remaining_accounts[1].clone(),
chain: remaining_accounts.chain.clone(),
trie: remaining_accounts.trie.clone(),
#[cfg(feature = "witness")]
witness: remaining_accounts.witness.clone(),
system_program: ctx.accounts.system_program.to_account_info(),
instruction: ctx.accounts.instruction.to_account_info(),
};
let cpi_program = ctx.remaining_accounts[2].clone();
let cpi_ctx = CpiContext::new(cpi_program, cpi_accounts);
let cpi_ctx =
CpiContext::new(remaining_accounts.program.clone(), cpi_accounts);
solana_ibc::cpi::set_stake(cpi_ctx, validator_key, amount)
}

Expand Down Expand Up @@ -359,6 +363,8 @@ pub mod restaking {
sender: ctx.accounts.withdrawer.to_account_info(),
chain: chain.to_account_info(),
trie: ctx.accounts.trie.to_account_info(),
#[cfg(feature = "witness")]
witness: ctx.accounts.witness.to_account_info(),
system_program: ctx.accounts.system_program.to_account_info(),
instruction: validation::check_instructions_sysvar(
&ctx.accounts.instruction,
Expand Down Expand Up @@ -550,7 +556,6 @@ pub mod restaking {
) -> Result<()> {
let vault_params = &mut ctx.accounts.vault_params;
let staking_params = &mut ctx.accounts.staking_params;
let guest_chain = &ctx.remaining_accounts[0];

let token_account = &ctx.accounts.receipt_token_account;
if token_account.amount < 1 {
Expand All @@ -561,6 +566,10 @@ pub mod restaking {
Some(id) => id,
None => return Err(error!(ErrorCodes::OperationNotAllowed)),
};
let remaining_accounts = validation::validate_remaining_accounts(
ctx.remaining_accounts,
&guest_chain_program_id,
)?;
if vault_params.service.is_some() {
return Err(error!(ErrorCodes::ServiceAlreadySet));
}
Expand All @@ -577,7 +586,8 @@ pub mod restaking {
let validator_key = match service {
Service::GuestChain { validator } => validator,
};
let borrowed_chain_data = guest_chain.data.try_borrow().unwrap();
let borrowed_chain_data =
remaining_accounts.chain.try_borrow_data().unwrap();
let mut chain_data: &[u8] = &borrowed_chain_data;
let chain =
solana_ibc::chain::ChainData::try_deserialize(&mut chain_data)
Expand All @@ -593,15 +603,17 @@ pub mod restaking {

let cpi_accounts = SetStake {
sender: ctx.accounts.depositor.to_account_info(),
chain: guest_chain.to_account_info(),
trie: ctx.remaining_accounts[1].clone(),
chain: remaining_accounts.chain.clone(),
trie: remaining_accounts.trie.clone(),
#[cfg(feature = "witness")]
witness: remaining_accounts.witness.clone(),
system_program: ctx.accounts.system_program.to_account_info(),
instruction: validation::check_instructions_sysvar(
&ctx.accounts.instruction,
)?,
};
let cpi_program = ctx.remaining_accounts[2].clone();
let cpi_ctx = CpiContext::new(cpi_program, cpi_accounts);
let cpi_ctx =
CpiContext::new(remaining_accounts.program.clone(), cpi_accounts);
solana_ibc::cpi::set_stake(cpi_ctx, validator_key, amount)
}

Expand Down Expand Up @@ -865,6 +877,10 @@ pub struct Withdraw<'info> {
#[account(mut, seeds = [TRIE_SEED], bump, seeds::program = guest_chain_program.key())]
/// CHECK:
pub trie: AccountInfo<'info>,
#[cfg(feature = "witness")]
#[account(mut, seeds = [solana_ibc::WITNESS_SEED, trie.key().as_ref()], bump)]
/// CHECK:
pub witness: AccountInfo<'info>,

pub token_mint: Box<Account<'info, Mint>>,
#[account(mut, token::mint = token_mint)]
Expand Down
98 changes: 74 additions & 24 deletions solana/restaking/programs/restaking/src/validation.rs
Original file line number Diff line number Diff line change
@@ -1,51 +1,101 @@
use anchor_lang::prelude::*;
use solana_ibc::{CHAIN_SEED, TRIE_SEED};

use crate::ErrorCodes;

pub(crate) struct RemainingAccounts<'a, 'info> {
pub chain: &'a AccountInfo<'info>,
pub trie: &'a AccountInfo<'info>,
#[cfg(feature = "witness")]
pub witness: &'a AccountInfo<'info>,
pub program: &'a AccountInfo<'info>,
}

/// Validates accounts needed for CPI call to the guest chain.
///
/// Right now, this method would only validate accounts for calling `set_stake`
/// method in the guest chain. Later when we expand to other services, we could
/// extend this method below to do the validation for those accounts as well.
///
/// Accounts needed for calling `set_stake`
/// - chain: PDA with seeds ["chain"]. Should be writable
/// - trie: PDA with seeds ["trie"]
/// - chain: PDA with seeds ["chain"]. Must be writable.
/// - trie: PDA with seeds ["trie"]. Must be writable.
/// - witness: Only if compiled with `witness` Cargo feature. PDA with seeds
/// `["witness", trie.key()]`. Must be writable.
/// - guest chain program ID: Should match the expected guest chain program ID
///
/// Note: The accounts should be sent in above order.
pub(crate) fn validate_remaining_accounts(
accounts: &[AccountInfo<'_>],
pub(crate) fn validate_remaining_accounts<'a, 'info>(
accounts: &'a [AccountInfo<'info>],
expected_guest_chain_program_id: &Pubkey,
) -> Result<()> {
) -> Result<RemainingAccounts<'a, 'info>> {
let accounts = &mut accounts.iter();

// Chain account
let seeds = [CHAIN_SEED];
let seeds = seeds.as_ref();
let chain = next_pda_account(
accounts,
[solana_ibc::CHAIN_SEED].as_ref(),
expected_guest_chain_program_id,
true,
"chain",
)?;

let (storage_account, _bump) =
Pubkey::find_program_address(seeds, expected_guest_chain_program_id);
if &storage_account != accounts[0].key && accounts[0].is_writable {
return Err(error!(ErrorCodes::AccountValidationFailedForCPI));
}
// Trie account
let seeds = [TRIE_SEED];
let seeds = seeds.as_ref();
let trie = next_pda_account(
accounts,
[solana_ibc::TRIE_SEED].as_ref(),
expected_guest_chain_program_id,
true,
"trie",
)?;

let (storage_account, _bump) =
Pubkey::find_program_address(seeds, expected_guest_chain_program_id);
if &storage_account != accounts[1].key && accounts[1].is_writable {
return Err(error!(ErrorCodes::AccountValidationFailedForCPI));
}
// Trie account
#[cfg(feature = "witness")]
let witness = next_pda_account(
accounts,
[solana_ibc::WITNESS_SEED, trie.key().as_ref()].as_ref(),
expected_guest_chain_program_id,
true,
"witness",
)?;

// Guest chain program ID
if expected_guest_chain_program_id != accounts[2].key {
return Err(error!(ErrorCodes::AccountValidationFailedForCPI));
}
let program = next_account_info(accounts)
.ok()
.filter(|info| expected_guest_chain_program_id == info.key)
.ok_or_else(|| error!(ErrorCodes::AccountValidationFailedForCPI))?;

Ok(())
Ok(RemainingAccounts {
chain,
trie,
program,
#[cfg(feature = "witness")]
witness,
})
}

fn next_pda_account<'a, 'info>(
accounts: &mut impl core::iter::Iterator<Item = &'a AccountInfo<'info>>,
seeds: &[&[u8]],
program_id: &Pubkey,
must_be_mut: bool,
account_name: &str,
) -> Result<&'a AccountInfo<'info>> {
(|| {
let info = next_account_info(accounts).ok()?;
let addr = Pubkey::try_find_program_address(seeds, program_id)?.0;
if &addr == info.key && (!must_be_mut || info.is_writable) {
Some(info)
} else {
None
}
})()
.ok_or_else(|| {
error!(ErrorCodes::AccountValidationFailedForCPI)
.with_account_name(account_name)
})
}


/// Verifies that given account is the Instruction sysvars and returns it if it
/// is.
pub(crate) fn check_instructions_sysvar<'info>(
Expand Down
78 changes: 49 additions & 29 deletions solana/restaking/tests/instructions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
guestChainProgramID,
restakingProgramID,
} from "./helper";
import { Transaction, TransactionInstruction } from "@solana/web3.js";

export const depositInstruction = async (
program: anchor.Program<Restaking>,
Expand Down Expand Up @@ -159,36 +160,55 @@ export const withdrawInstruction = async (
withdrawer
);

const tx = await program.methods
.withdraw()
.preInstructions([
let instruction = new TransactionInstruction({
keys: [
{ pubkey: withdrawer, isSigner: true, isWritable: true },
{ pubkey: withdrawer, isSigner: false, isWritable: true },
{ pubkey: vaultParamsPDA, isSigner: false, isWritable: true },
{ pubkey: stakingParamsPDA, isSigner: false, isWritable: true },
{ pubkey: guestChainPDA, isSigner: false, isWritable: true },
{ pubkey: triePDA, isSigner: false, isWritable: true },
{ pubkey: stakedTokenMint, isSigner: false, isWritable: true },
{
pubkey: withdrawerStakedTokenAccount,
isSigner: false,
isWritable: true,
},
{ pubkey: vaultTokenAccountPDA, isSigner: false, isWritable: true },
{ pubkey: receiptTokenMint, isSigner: false, isWritable: true },
{ pubkey: escrowReceiptTokenPDA, isSigner: false, isWritable: true },
{ pubkey: guestChainProgramID, isSigner: false, isWritable: true },
{ pubkey: spl.TOKEN_PROGRAM_ID, isSigner: false, isWritable: true },
{
pubkey: anchor.web3.SystemProgram.programId,
isSigner: false,
isWritable: true,
},
{
pubkey: new anchor.web3.PublicKey(mpl.MPL_TOKEN_METADATA_PROGRAM_ID),
isSigner: false,
isWritable: true,
},
{ pubkey: anchor.web3.SYSVAR_RENT_PUBKEY, isSigner: false, isWritable: true },
{ pubkey: masterEditionPDA, isSigner: false, isWritable: true },
{ pubkey: nftMetadataPDA, isSigner: false, isWritable: true },
{
pubkey: anchor.web3.SYSVAR_INSTRUCTIONS_PUBKEY,
isSigner: false,
isWritable: true,
},
],
programId: restakingProgramID,
data: Buffer.from([183, 18, 70, 156, 148, 109, 161, 34]),
});

let tx = new Transaction()
.add(
anchor.web3.ComputeBudgetProgram.setComputeUnitLimit({
units: 1000000,
}),
])
.accounts({
signer: withdrawer,
withdrawer,
vaultParams: vaultParamsPDA,
stakingParams: stakingParamsPDA,
guestChain: guestChainPDA,
trie: triePDA,
tokenMint: stakedTokenMint,
withdrawerTokenAccount: withdrawerStakedTokenAccount,
vaultTokenAccount: vaultTokenAccountPDA,
receiptTokenMint,
escrowReceiptTokenAccount: escrowReceiptTokenPDA,
guestChainProgram: guestChainProgramID,
tokenProgram: spl.TOKEN_PROGRAM_ID,
masterEditionAccount: masterEditionPDA,
nftMetadata: nftMetadataPDA,
systemProgram: anchor.web3.SystemProgram.programId,
metadataProgram: new anchor.web3.PublicKey(
mpl.MPL_TOKEN_METADATA_PROGRAM_ID
),
instruction: anchor.web3.SYSVAR_INSTRUCTIONS_PUBKEY,
})
.transaction();
})
)
.add(instruction);

return tx;
};
Expand Down Expand Up @@ -315,7 +335,7 @@ export const setServiceInstruction = async (
validator: anchor.web3.PublicKey,
receiptTokenMint: anchor.web3.PublicKey,
/// Token which is staked
stakeTokenMint: anchor.web3.PublicKey,
stakeTokenMint: anchor.web3.PublicKey
) => {
const { vaultParamsPDA } = getVaultParamsPDA(receiptTokenMint);
const { stakingParamsPDA } = getStakingParamsPDA();
Expand Down
Loading

0 comments on commit ba1046e

Please sign in to comment.