Skip to content

Commit

Permalink
Extracted staker storage into separate file due to large amount of bo…
Browse files Browse the repository at this point in the history
…ilerplate.

Implemented custom packing and custom data SubPointers.
  • Loading branch information
baitcode committed Jan 6, 2025
1 parent 1a2d6fc commit b18d702
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 73 deletions.
1 change: 1 addition & 0 deletions src/lib.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub mod governor;
mod governor_test;

pub mod staker;
pub mod staker_storage;
#[cfg(test)]
mod staker_test;

Expand Down
34 changes: 2 additions & 32 deletions src/staker.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ pub mod Staker {
Vec, VecTrait, MutableVecTrait,
};
use crate::utils::fp::{UFixedPoint, UFixedPointZero};
use crate::staker_storage::{StakingLogRecord};

use starknet::{
get_block_timestamp, get_caller_address, get_contract_address,
Expand Down Expand Up @@ -107,41 +108,10 @@ pub mod Staker {
}
}

#[derive(Drop, Serde)]
struct StakingLogRecord {
timestamp: u64,
total_staked: u128,
cumulative_seconds_per_total_staked: UFixedPoint,
}

pub impl StakingLogRecordStorePacking of StorePacking<StakingLogRecord, (felt252, felt252)> {
fn pack(value: StakingLogRecord) -> (felt252, felt252) {
let first: felt252 = u256 {
high: value.timestamp.into(),
low: value.total_staked,
}.try_into().unwrap();

let second: felt252 = value.cumulative_seconds_per_total_staked
.try_into()
.unwrap();

(first, second)
}

fn unpack(value: (felt252, felt252)) -> StakingLogRecord {
let (packed_ts_total_staked, cumulative_seconds_per_total_staked) = value;
let medium: u256 = packed_ts_total_staked.into();
StakingLogRecord {
timestamp: medium.high.try_into().unwrap(),
total_staked: medium.low,
cumulative_seconds_per_total_staked: cumulative_seconds_per_total_staked.try_into().unwrap(),
}
}
}

#[storage]
struct Storage {
token: IERC20Dispatcher,

// owner, delegate => amount
staked: Map<(ContractAddress, ContractAddress), u128>,
amount_delegated: Map<ContractAddress, u128>,
Expand Down
118 changes: 118 additions & 0 deletions src/staker_storage.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
use starknet::{ContractAddress, Store};
use starknet::storage_access::{StorePacking};
use starknet::storage::{StoragePointer, SubPointers, SubPointersMut, Mutable};
use crate::utils::fp::{UFixedPoint};


#[derive(Drop, Serde)]
pub(crate) struct StakingLogRecord {
pub(crate) timestamp: u64,
pub(crate) total_staked: u128,
pub(crate) cumulative_seconds_per_total_staked: UFixedPoint,
}

pub(crate) impl StakingLogRecordStorePacking of StorePacking<StakingLogRecord, (felt252, felt252)> {
fn pack(value: StakingLogRecord) -> (felt252, felt252) {
let first: felt252 = u256 {
high: value.timestamp.into(),
low: value.total_staked,
}.try_into().unwrap();

let second: felt252 = value.cumulative_seconds_per_total_staked
.try_into()
.unwrap();

(first, second)
}

fn unpack(value: (felt252, felt252)) -> StakingLogRecord {
let (packed_ts_total_staked, cumulative_seconds_per_total_staked) = value;
let medium: u256 = packed_ts_total_staked.into();
StakingLogRecord {
timestamp: medium.high.try_into().unwrap(),
total_staked: medium.low,
cumulative_seconds_per_total_staked: cumulative_seconds_per_total_staked.try_into().unwrap(),
}
}
}


#[derive(Drop, Copy)]
pub(crate) struct StakingLogRecordSubPointers {
pub(crate) timestamp: StoragePointer<u64>,
pub(crate) total_staked: StoragePointer<u128>,
pub(crate) cumulative_seconds_per_total_staked: StoragePointer<UFixedPoint>,
}

pub(crate) impl StakingLogRecordSubPointersImpl of SubPointers<StakingLogRecord> {

type SubPointersType = StakingLogRecordSubPointers;

fn sub_pointers(self: StoragePointer<StakingLogRecord>) -> StakingLogRecordSubPointers {
let base_address = self.__storage_pointer_address__;

let mut current_offset = self.__storage_pointer_offset__;
let __packed_low_128__ = StoragePointer {
__storage_pointer_address__: base_address,
__storage_pointer_offset__: current_offset,
};

let __packed_high_124__ = StoragePointer {
__storage_pointer_address__: base_address,
__storage_pointer_offset__: current_offset + Store::<u128>::size(),
};

current_offset = current_offset + Store::<felt252>::size();
let __packed_felt2__ = StoragePointer {
__storage_pointer_address__: base_address,
__storage_pointer_offset__: current_offset,
};

StakingLogRecordSubPointers {
timestamp: __packed_high_124__,
total_staked: __packed_low_128__,
cumulative_seconds_per_total_staked: __packed_felt2__,
}
}
}

#[derive(Drop, Copy)]
pub(crate) struct StakingLogRecordSubPointersMut {
pub(crate) timestamp: StoragePointer<Mutable<u64>>,
pub(crate) total_staked: StoragePointer<Mutable<u128>>,
pub(crate) cumulative_seconds_per_total_staked: StoragePointer<Mutable<UFixedPoint>>,
}

pub(crate) impl StakingLogRecordSubPointersMutImpl of SubPointersMut<StakingLogRecord> {

type SubPointersType = StakingLogRecordSubPointersMut;

fn sub_pointers_mut(
self: StoragePointer<Mutable<StakingLogRecord>>,
) -> StakingLogRecordSubPointersMut {
let base_address = self.__storage_pointer_address__;

let mut current_offset = self.__storage_pointer_offset__;
let __packed_low_128__ = StoragePointer {
__storage_pointer_address__: base_address,
__storage_pointer_offset__: current_offset,
};

let __packed_high_124__ = StoragePointer {
__storage_pointer_address__: base_address,
__storage_pointer_offset__: current_offset + Store::<u128>::size(),
};

current_offset = current_offset + Store::<felt252>::size();
let __packed_felt2__ = StoragePointer {
__storage_pointer_address__: base_address,
__storage_pointer_offset__: current_offset,
};

StakingLogRecordSubPointersMut {
timestamp: __packed_high_124__,
total_staked: __packed_low_128__,
cumulative_seconds_per_total_staked: __packed_felt2__,
}
}
}
4 changes: 2 additions & 2 deletions src/staker_test.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,8 @@ mod staker_staked_seconds_per_total_staked_calculation {


fn assert_fp(value: UFixedPoint, integer: u128, fractional: u128) {
assert(value.get_integer() == integer, 'Integer part is not correct');
assert(value.get_fractional() == fractional, 'Fractional part is not correct');
assert_eq!(value.get_integer(), integer);
assert_eq!(value.get_fractional(), fractional);
}


Expand Down
78 changes: 39 additions & 39 deletions src/utils/fp_test.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@ fn test_add() {
let f2 : UFixedPoint = 1_u64.into();
let res = f1 + f2;
let z: u256 = res.try_into().unwrap();
assert(z.low == 0, 'low 0');
assert(z.high == 18446744073709551616, 'high 18446744073709551616');
assert_eq!(z.low, 0);
assert_eq!(z.high, 18446744073709551616);
}

#[test]
fn test_fp_value_mapping() {
let f1 : UFixedPoint = 7_u64.into();
assert(f1.value.limb0 == 0x0, 'limb0 == 0');
assert(f1.value.limb1 == 0x7, 'limb1 == 7');
assert_eq!(f1.value.limb0, 0x0);
assert_eq!(f1.value.limb1, 0x7);

let val: u256 = f1.try_into().unwrap();
assert(val == 7_u256*0x100000000000000000000000000000000, 'val has to be 128 bit shifted');
assert_eq!(val, 7_u256*0x100000000000000000000000000000000);
}


Expand All @@ -34,35 +34,35 @@ fn test_mul() {

let expected = (7_u256*SCALE_FACTOR).wide_mul(7_u256*SCALE_FACTOR);

assert(expected.limb0 == 0, 'limb0==0');
assert(expected.limb1 == 0, 'limb1==0');
assert(expected.limb2 == 49, 'limb2==0');
assert(expected.limb3 == 0, 'limb3==0');
assert_eq!(expected.limb0, 0);
assert_eq!(expected.limb1, 0);
assert_eq!(expected.limb2, 49);
assert_eq!(expected.limb3, 0);

let res: u256 = (f1 * f2).try_into().unwrap();
assert(res.high == 49, 'high 49');
assert(res.low == 0, 'low 0');
assert_eq!(res.high, 49);
assert_eq!(res.low, 0);
}

#[test]
fn test_multiplication() {
let f1 : UFixedPoint = 9223372036854775808_u128.into();
assert(f1.value.limb0 == 0, 'f1.limb0 0= 0');
assert(f1.value.limb1 == 9223372036854775808_u128, 'f1.limb1 != 0');
assert(f1.value.limb2 == 0, 'f1.limb2 == 0');
assert(f1.value.limb3 == 0, 'f1.limb3 == 0');
assert_eq!(f1.value.limb0, 0);
assert_eq!(f1.value.limb1, 9223372036854775808_u128);
assert_eq!(f1.value.limb2, 0);
assert_eq!(f1.value.limb3, 0);

let res = f1 * f1;

assert(res.value.limb0 == 0, 'res.limb0 != 0');
assert(res.value.limb1 == 0x40000000000000000000000000000000, 'res.limb1 != 0');
assert(res.value.limb2 == 0, 'res.limb2 == 0');
assert(res.value.limb3 == 0, 'res.limb3 == 0');
assert_eq!(res.value.limb0, 0);
assert_eq!(res.value.limb1, 0x40000000000000000000000000000000);
assert_eq!(res.value.limb2, 0);
assert_eq!(res.value.limb3, 0);

let expected = 9223372036854775808_u128.wide_mul(9223372036854775808_u128) * SCALE_FACTOR;

assert(expected.low == 0, 'low == 0');
assert(expected.high == 0x40000000000000000000000000000000, 'high != 0');
assert_eq!(expected.low, 0);
assert_eq!(expected.high, 0x40000000000000000000000000000000);

let result: u256 = res.try_into().unwrap();
assert(result == expected, 'unexpected mult result');
Expand All @@ -72,51 +72,51 @@ fn test_multiplication() {
fn test_u256_conversion() {
let f: u256 = 0x0123456789ABCDEFFEDCBA987654321000112233445566778899AABBCCDDEEFF_u256;

assert(f.low == 0x00112233445566778899AABBCCDDEEFF, 'low');
assert(f.high == 0x0123456789ABCDEFFEDCBA9876543210, 'high');
assert_eq!(f.low, 0x00112233445566778899AABBCCDDEEFF);
assert_eq!(f.high, 0x0123456789ABCDEFFEDCBA9876543210);

// BITSHIFT DOWN
let fp: UFixedPoint = f.into();
assert(fp.get_integer() == f.high, 'integer == f.high');
assert(fp.get_fractional() == f.low, 'fractional == f.low');
assert_eq!(fp.get_integer(), f.high);
assert_eq!(fp.get_fractional(), f.low);

let fp = fp.bitshift_128_down();
assert(fp.get_integer() == 0, 'integer==0 bs_down');
assert(fp.get_fractional() == f.high, 'fractional == f.low bs_down');
assert_eq!(fp.get_integer(), 0);
assert_eq!(fp.get_fractional(), f.high);

let fp = fp.bitshift_128_down();
assert(fp.get_integer() == 0, 'integer==0 bs_down 2');
assert(fp.get_fractional() == 0, 'fractional == 0 bs_down 2');
assert_eq!(fp.get_integer(), 0);
assert_eq!(fp.get_fractional(), 0);

// BITSHIFT UP
let fp: UFixedPoint = f.into();
assert(fp.get_integer() == f.high, 'integer == f.high');
assert(fp.get_fractional() == f.low, 'fractional == f.low');
assert_eq!(fp.get_integer(), f.high);
assert_eq!(fp.get_fractional(), f.low);

let fp = fp.bitshift_128_up();
assert(fp.get_integer() == f.low, 'integer == f.high bs_up');
assert(fp.get_fractional() == 0, 'fractional == f.low bs_up');
assert_eq!(fp.get_integer(), f.low);
assert_eq!(fp.get_fractional(), 0);

let fp = fp.bitshift_128_up();
assert(fp.get_integer() == 0, 'integer == f.high bs_up');
assert(fp.get_fractional() == 0, 'fractional == f.low bs_up');
assert_eq!(fp.get_integer(), 0);
assert_eq!(fp.get_fractional(), 0);
}

fn run_division_test(left: u128, right: u128, expected_int: u128, expected_frac: u128) {
let f1 : UFixedPoint = left.into();
let f2 : UFixedPoint = right.into();
let res = f1 / f2;
assert(res.get_integer() == expected_int, 'integer');
assert(res.get_fractional() == expected_frac, 'fractional');
assert_eq!(res.get_integer(), expected_int);
assert_eq!(res.get_fractional(), expected_frac);
}

fn run_division_and_multiplication_test(numenator: u128, divisor: u128, mult: u128, expected_int: u128, expected_frac: u128) {
let f1 : UFixedPoint = numenator.into();
let f2 : UFixedPoint = divisor.into();
let f3 : UFixedPoint = mult.into();
let res = f1 / f2 * f3;
assert(res.get_integer() == expected_int, 'integer');
assert(res.get_fractional() == expected_frac, 'fractional');
assert_eq!(res.get_integer(), expected_int);
assert_eq!(res.get_fractional(), expected_frac);
}


Expand Down

0 comments on commit b18d702

Please sign in to comment.