diff --git a/staking/app/StakeConnection.ts b/staking/app/StakeConnection.ts index 08444821..c9273539 100644 --- a/staking/app/StakeConnection.ts +++ b/staking/app/StakeConnection.ts @@ -535,15 +535,6 @@ export class StakeConnection { * Locks all unvested tokens in governance */ public async lockAllUnvested(stakeAccount: StakeAccount) { - const balanceSummary = stakeAccount.getBalanceSummary(await this.getTime()); - - await this.lockTokens(stakeAccount, balanceSummary.unvested.unlocked); - } - - /** - * Locks the specified amount of tokens in governance. - */ - public async lockTokens(stakeAccount: StakeAccount, amount: PythBalance) { const vestingAccountState = stakeAccount.getVestingAccountState( await this.getTime() ); @@ -554,6 +545,15 @@ export class StakeConnection { ) { throw Error(`Unexpected account state ${vestingAccountState}`); } + + const balanceSummary = stakeAccount.getBalanceSummary(await this.getTime()); + await this.lockTokens(stakeAccount, balanceSummary.unvested.unlocked); + } + + /** + * Locks the specified amount of tokens in governance. + */ + public async lockTokens(stakeAccount: StakeAccount, amount: PythBalance) { const owner: PublicKey = stakeAccount.stakeAccountMetadata.owner; const amountBN = amount.toBN(); diff --git a/staking/programs/staking/src/lib.rs b/staking/programs/staking/src/lib.rs index 7ca914e9..e5f77803 100644 --- a/staking/programs/staking/src/lib.rs +++ b/staking/programs/staking/src/lib.rs @@ -585,7 +585,7 @@ pub mod staking { *ctx.bumps.get("new_custody_authority").unwrap(), *ctx.bumps.get("new_voter_record").unwrap(), &split_request.recipient, - Some(current_epoch), + None, ); let new_stake_account_positions = @@ -599,7 +599,10 @@ pub mod staking { let source_stake_account_positions = &mut ctx.accounts.source_stake_account_positions.load_mut()?; - // Pre-check + // Pre-check invariants + // Note that the accept operation requires the positions account to be empty, which should trivially + // pass this invariant check. However, we explicitly check invariants everywhere else, so may + // as well check in this operation also. utils::risk::validate( source_stake_account_positions, ctx.accounts.source_stake_account_custody.amount, @@ -625,7 +628,7 @@ pub mod staking { require!(split_request.amount > 0, ErrorCode::SplitZeroTokens); // Split vesting account - let (source_vesting_account, new_vesting_account) = ctx + let (source_vesting_schedule, new_vesting_schedule) = ctx .accounts .source_stake_account_metadata .lock @@ -635,22 +638,20 @@ pub mod staking { )?; ctx.accounts .source_stake_account_metadata - .set_lock(source_vesting_account); + .set_lock(source_vesting_schedule); ctx.accounts .new_stake_account_metadata - .set_lock(new_vesting_account); + .set_lock(new_vesting_schedule); - { - transfer( - CpiContext::from(&*ctx.accounts).with_signer(&[&[ - AUTHORITY_SEED.as_bytes(), - ctx.accounts.source_stake_account_positions.key().as_ref(), - &[ctx.accounts.source_stake_account_metadata.authority_bump], - ]]), - split_request.amount, - )?; - } + transfer( + CpiContext::from(&*ctx.accounts).with_signer(&[&[ + AUTHORITY_SEED.as_bytes(), + ctx.accounts.source_stake_account_positions.key().as_ref(), + &[ctx.accounts.source_stake_account_metadata.authority_bump], + ]]), + split_request.amount, + )?; ctx.accounts.source_stake_account_custody.reload()?; ctx.accounts.new_stake_account_custody.reload()?; diff --git a/staking/programs/staking/src/state/stake_account.rs b/staking/programs/staking/src/state/stake_account.rs index 20a0de97..e1dacbc3 100644 --- a/staking/programs/staking/src/state/stake_account.rs +++ b/staking/programs/staking/src/state/stake_account.rs @@ -47,14 +47,14 @@ impl StakeAccountMetadataV2 { metadata_bump: u8, custody_bump: u8, authority_bump: u8, - voter_record_bump: u8, + voter_bump: u8, owner: &Pubkey, transfer_epoch: Option, ) { self.metadata_bump = metadata_bump; self.custody_bump = custody_bump; self.authority_bump = authority_bump; - self.voter_bump = voter_record_bump; + self.voter_bump = voter_bump; self.owner = *owner; self.next_index = 0; self.transfer_epoch = transfer_epoch; diff --git a/staking/tests/split_vesting_account.ts b/staking/tests/split_vesting_account.ts index cb859271..9ab56ad4 100644 --- a/staking/tests/split_vesting_account.ts +++ b/staking/tests/split_vesting_account.ts @@ -16,8 +16,7 @@ import { OptionalBalanceSummary, } from "./utils/api_utils"; import assert from "assert"; -import { blob } from "stream/consumers"; -import { Key } from "@metaplex-foundation/mpl-token-metadata"; +import { expectFailWithCode } from "./utils/utils"; const ONE_MONTH = new BN(3600 * 24 * 30.5); const portNumber = getPortNumber(path.basename(__filename)); @@ -25,13 +24,10 @@ const portNumber = getPortNumber(path.basename(__filename)); describe("split vesting account", async () => { const pythMintAccount = new Keypair(); const pythMintAuthority = new Keypair(); - let EPOCH_DURATION: BN; let stakeConnection: StakeConnection; let controller: CustomAbortController; - let owner: PublicKey; - let pdaAuthority = new Keypair(); let pdaConnection: StakeConnection; @@ -49,9 +45,6 @@ describe("split vesting account", async () => { ) )); - EPOCH_DURATION = stakeConnection.config.epochDuration; - owner = stakeConnection.provider.wallet.publicKey; - pdaConnection = await connect(pdaAuthority); }); @@ -172,26 +165,6 @@ describe("split vesting account", async () => { ); } - async function assertFailsWithErrorCode( - thunk: () => Promise, - errorCode: string - ) { - let actualErrorCode: string | undefined = undefined; - try { - await thunk(); - } catch (err) { - if (err instanceof AnchorError) { - actualErrorCode = err.error.errorCode.code; - } - } - - assert.equal( - actualErrorCode, - errorCode, - `Call did not fail with the expected error code.` - ); - } - it("split/accept flow success", async () => { let [samConnection, aliceConnection] = await setupSplit("100", "100", "0"); @@ -244,13 +217,12 @@ describe("split vesting account", async () => { aliceConnection.userPublicKey() ); - await assertFailsWithErrorCode( - () => - pdaConnection.acceptSplit( - stakeAccount, - PythBalance.fromString("33"), - aliceConnection.userPublicKey() - ), + await expectFailWithCode( + pdaConnection.acceptSplit( + stakeAccount, + PythBalance.fromString("33"), + aliceConnection.userPublicKey() + ), "SplitWithStake" ); @@ -280,37 +252,41 @@ describe("split vesting account", async () => { ); // wrong balance - await assertFailsWithErrorCode( - () => - pdaConnection.acceptSplit( - stakeAccount, - PythBalance.fromString("34"), - aliceConnection.userPublicKey() - ), + await expectFailWithCode( + pdaConnection.acceptSplit( + stakeAccount, + PythBalance.fromString("34"), + aliceConnection.userPublicKey() + ), "InvalidApproval" ); // wrong recipient - await assertFailsWithErrorCode( - () => - pdaConnection.acceptSplit( - stakeAccount, - PythBalance.fromString("33"), - samConnection.userPublicKey() - ), + await expectFailWithCode( + pdaConnection.acceptSplit( + stakeAccount, + PythBalance.fromString("33"), + samConnection.userPublicKey() + ), "InvalidApproval" ); // wrong signer - await assertFailsWithErrorCode( - () => - aliceConnection.acceptSplit( - stakeAccount, - PythBalance.fromString("33"), - aliceConnection.userPublicKey() - ), + await expectFailWithCode( + aliceConnection.acceptSplit( + stakeAccount, + PythBalance.fromString("33"), + aliceConnection.userPublicKey() + ), "ConstraintAddress" ); + + // Passing the correct arguments should succeed + await pdaConnection.acceptSplit( + stakeAccount, + PythBalance.fromString("33"), + aliceConnection.userPublicKey() + ); }); after(async () => { diff --git a/staking/tests/utils/utils.ts b/staking/tests/utils/utils.ts index b6554d2d..18b3c7be 100644 --- a/staking/tests/utils/utils.ts +++ b/staking/tests/utils/utils.ts @@ -112,3 +112,28 @@ export async function expectFailApi(promise: Promise, error: string) { assert.equal(err.message, error); } } + +/** + * Awaits the api request and checks whether the error message matches the provided string + * @param promise : api promise + * @param errorCode : expected string + */ +export async function expectFailWithCode( + promise: Promise, + errorCode: string +) { + let actualErrorCode: string | undefined = undefined; + try { + await promise; + assert(false, "Operation should fail"); + } catch (err) { + if (err instanceof AnchorError) { + actualErrorCode = err.error.errorCode.code; + } + } + assert.equal( + actualErrorCode, + errorCode, + `Call did not fail with the expected error code.` + ); +}