Skip to content

Commit

Permalink
Merge pull request #39 from Darlington02/main
Browse files Browse the repository at this point in the history
fix: implement fixes for audit findings, update toolings and tests
  • Loading branch information
Darlington02 authored Apr 2, 2024
2 parents c41c299 + b102906 commit b65044e
Show file tree
Hide file tree
Showing 15 changed files with 267 additions and 124 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_contracts.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ jobs:
- uses: software-mansion/setup-scarb@v1
- uses: foundry-rs/setup-snfoundry@v3
with:
starknet-foundry-version: 0.19.0
starknet-foundry-version: 0.20.1
- name: Run cairo tests
run: snforge test
2 changes: 1 addition & 1 deletion .tool-versions
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
scarb 2.6.0
starknet-foundry 0.19.0
starknet-foundry 0.20.1
4 changes: 2 additions & 2 deletions Scarb.lock
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ version = 1

[[package]]
name = "snforge_std"
version = "0.19.0"
source = "git+https://github.com/foundry-rs/starknet-foundry.git?tag=v0.19.0#a3391dce5bdda51c63237032e6cfc64fb7a346d4"
version = "0.20.1"
source = "git+https://github.com/foundry-rs/starknet-foundry.git?tag=v0.20.1#fea2db8f2b20148cc15ee34b08de12028eb42942"

[[package]]
name = "token_bound_accounts"
Expand Down
3 changes: 2 additions & 1 deletion Scarb.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[package]
name = "token_bound_accounts"
version = "0.3.0"
edition = "2023_10"

# See more keys and their definitions at https://docs.swmansion.com/scarb/docs/reference/manifest

Expand All @@ -10,7 +11,7 @@ casm = true

[dependencies]
starknet = "2.6.0"
snforge_std = { git = "https://github.com/foundry-rs/starknet-foundry.git", tag = "v0.19.0" }
snforge_std = { git = "https://github.com/foundry-rs/starknet-foundry.git", tag = "v0.20.1" }

[tool.snforge]
# exit_first = true
114 changes: 67 additions & 47 deletions src/account/account.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,21 @@
////////////////////////////////
#[starknet::component]
mod AccountComponent {
use core::num::traits::zero::Zero;
use starknet::{
get_tx_info, get_caller_address, get_contract_address, get_block_timestamp, ContractAddress,
account::Call, call_contract_syscall, replace_class_syscall, ClassHash, SyscallResultTrait
};
use ecdsa::check_ecdsa_signature;
use array::{SpanTrait, ArrayTrait};
use box::BoxTrait;
use option::OptionTrait;
use zeroable::Zeroable;
use token_bound_accounts::interfaces::IERC721::{IERC721DispatcherTrait, IERC721Dispatcher};
use token_bound_accounts::interfaces::IAccount::IAccount;
use token_bound_accounts::interfaces::IAccount::{TBA_INTERFACE_ID};
use token_bound_accounts::interfaces::IAccount::{
IAccount, IAccountDispatcherTrait, IAccountDispatcher, TBA_INTERFACE_ID
};

#[storage]
struct Storage {
Account_token_contract: ContractAddress, // contract address of NFT
Account_token_id: u256, // token ID of NFT
Account_unlock_timestamp: u64, // time to unlock account when locked
account_token_contract: ContractAddress, // contract address of NFT
account_token_id: u256, // token ID of NFT
account_unlock_timestamp: u64, // time to unlock account when locked
}

#[event]
Expand Down Expand Up @@ -82,6 +79,18 @@ mod AccountComponent {
self._is_valid_signature(hash, signature)
}

/// @notice used to validate signer
/// @param signer address to be validated
fn is_valid_signer(
self: @ComponentState<TContractState>, signer: ContractAddress
) -> felt252 {
if self._is_valid_signer(signer) {
return starknet::VALIDATED;
} else {
return 0;
}
}

fn __validate_deploy__(
self: @ComponentState<TContractState>,
class_hash: felt252,
Expand Down Expand Up @@ -109,7 +118,9 @@ mod AccountComponent {
fn __execute__(
ref self: ComponentState<TContractState>, mut calls: Array<Call>
) -> Array<Span<felt252>> {
self._assert_only_owner();
let caller = get_caller_address();
assert(self._is_valid_signer(caller), Errors::UNAUTHORIZED);

let (lock_status, _) = self._is_locked();
assert(!lock_status, Errors::LOCKED_ACCOUNT);

Expand All @@ -123,29 +134,32 @@ mod AccountComponent {
retdata
}

/// @notice gets the token bound NFT owner
/// @notice gets the NFT owner
/// @param token_contract the contract address of the NFT
/// @param token_id the token ID of the NFT
fn owner(
self: @ComponentState<TContractState>, token_contract: ContractAddress, token_id: u256
) -> ContractAddress {
fn owner(self: @ComponentState<TContractState>) -> ContractAddress {
let token_contract = self.account_token_contract.read();
let token_id = self.account_token_id.read();
self._get_owner(token_contract, token_id)
}

/// @notice returns the contract address and token ID of the NFT
/// @notice returns the contract address and token ID of the associated NFT
fn token(self: @ComponentState<TContractState>) -> (ContractAddress, u256) {
self._get_token()
}

// @notice protection mechanism for selling token bound accounts. can't execute when account is locked
// @param duration for which to lock account
fn lock(ref self: ComponentState<TContractState>, duration: u64) {
self._assert_only_owner();
let caller = get_caller_address();
assert(self._is_valid_signer(caller), Errors::UNAUTHORIZED);

let (lock_status, _) = self._is_locked();
assert(!lock_status, Errors::LOCKED_ACCOUNT);

let current_timestamp = get_block_timestamp();
let unlock_time = current_timestamp + duration;
self.Account_unlock_timestamp.write(unlock_time);
self.account_unlock_timestamp.write(unlock_time);
self
.emit(
AccountLocked {
Expand Down Expand Up @@ -176,27 +190,20 @@ mod AccountComponent {
impl InternalImpl<
TContractState, +HasComponent<TContractState>, +Drop<TContractState>
> of InternalTrait<TContractState> {
/// @notice initializes the account by setting the initial token conrtact and token id
/// @notice initializes the account by setting the initial token contract and token id
fn initializer(
ref self: ComponentState<TContractState>,
token_contract: ContractAddress,
token_id: u256
) {
self.Account_token_contract.write(token_contract);
self.Account_token_id.write(token_id);

let owner = self._get_owner(token_contract, token_id);
assert(owner.is_non_zero(), Errors::UNAUTHORIZED);
// initialize account
self.account_token_contract.write(token_contract);
self.account_token_id.write(token_id);
self.emit(AccountCreated { owner });
}

/// @notice check that caller is the token bound account
fn _assert_only_owner(ref self: ComponentState<TContractState>) {
let caller = get_caller_address();
let owner = self
._get_owner(self.Account_token_contract.read(), self.Account_token_id.read());
assert(caller == owner, Errors::UNAUTHORIZED);
}

/// @notice internal function for getting NFT owner
/// @param token_contract contract address of NFT
// @param token_id token ID of NFT
Expand All @@ -218,14 +225,14 @@ mod AccountComponent {

/// @notice internal transaction for returning the contract address and token ID of the NFT
fn _get_token(self: @ComponentState<TContractState>) -> (ContractAddress, u256) {
let contract = self.Account_token_contract.read();
let tokenId = self.Account_token_id.read();
let contract = self.account_token_contract.read();
let tokenId = self.account_token_id.read();
(contract, tokenId)
}

// @notice protection mechanism for TBA trading. Returns the lock-status (true or false), and the remaning time till account unlocks.
fn _is_locked(self: @ComponentState<TContractState>) -> (bool, u64) {
let unlock_timestamp = self.Account_unlock_timestamp.read();
let unlock_timestamp = self.account_unlock_timestamp.read();
let current_time = get_block_timestamp();
if (current_time < unlock_timestamp) {
let time_until_unlocks = unlock_timestamp - current_time;
Expand All @@ -235,16 +242,17 @@ mod AccountComponent {
}
}

/// @notice internal function for tx validation
fn _validate_transaction(self: @ComponentState<TContractState>) -> felt252 {
let tx_info = get_tx_info().unbox();
let tx_hash = tx_info.transaction_hash;
let signature = tx_info.signature;
assert(
self._is_valid_signature(tx_hash, signature) == starknet::VALIDATED,
Errors::INV_SIGNATURE
);
starknet::VALIDATED
// @notice internal function for validating signer
fn _is_valid_signer(
self: @ComponentState<TContractState>, signer: ContractAddress
) -> bool {
let owner = self
._get_owner(self.account_token_contract.read(), self.account_token_id.read());
if (signer == owner) {
return true;
} else {
return false;
}
}

/// @notice internal function for signature validation
Expand All @@ -254,16 +262,28 @@ mod AccountComponent {
let signature_length = signature.len();
assert(signature_length == 2_u32, Errors::INV_SIG_LEN);

let caller = get_caller_address();
let owner = self
._get_owner(self.Account_token_contract.read(), self.Account_token_id.read());
if (caller == owner) {
._get_owner(self.account_token_contract.read(), self.account_token_id.read());
let account = IAccountDispatcher { contract_address: owner };
if (account.is_valid_signature(hash, signature) == starknet::VALIDATED) {
return starknet::VALIDATED;
} else {
return 0;
}
}

/// @notice internal function for tx validation
fn _validate_transaction(self: @ComponentState<TContractState>) -> felt252 {
let tx_info = get_tx_info().unbox();
let tx_hash = tx_info.transaction_hash;
let signature = tx_info.signature;
assert(
self._is_valid_signature(tx_hash, signature) == starknet::VALIDATED,
Errors::INV_SIGNATURE
);
starknet::VALIDATED
}

/// @notice internal function for executing transactions
/// @param calls An array of transactions to be executed
fn _execute_calls(
Expand All @@ -277,7 +297,7 @@ mod AccountComponent {
Option::Some(call) => {
match call_contract_syscall(call.to, call.selector, call.calldata) {
Result::Ok(mut retdata) => { result.append(retdata); },
Result::Err(_) => { panic_with_felt252('multicall_failed'); }
Result::Err(_) => { panic(array!['multicall_failed']); }
}
},
Option::None(_) => { break (); }
Expand Down
7 changes: 3 additions & 4 deletions src/interfaces/IAccount.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,22 @@ use starknet::ClassHash;
use starknet::account::Call;

// SRC5 interface for token bound accounts
const TBA_INTERFACE_ID: felt252 = 0x539036932a2ab9c4734fbfd9872a1f7791a3f577e45477336ae0fd0a00c9ff;
const TBA_INTERFACE_ID: felt252 = 0xd050d1042482f6e9a28d0c039d0a8428266bf4fd59fe95cee66d8e0e8b3b2e;

#[starknet::interface]
trait IAccount<TContractState> {
fn is_valid_signature(
self: @TContractState, hash: felt252, signature: Span<felt252>
) -> felt252;
fn is_valid_signer(self: @TContractState, signer: ContractAddress) -> felt252;
fn __validate__(ref self: TContractState, calls: Array<Call>) -> felt252;
fn __validate_declare__(self: @TContractState, class_hash: felt252) -> felt252;
fn __validate_deploy__(
self: @TContractState, class_hash: felt252, contract_address_salt: felt252
) -> felt252;
fn __execute__(ref self: TContractState, calls: Array<Call>) -> Array<Span<felt252>>;
fn token(self: @TContractState) -> (ContractAddress, u256);
fn owner(
self: @TContractState, token_contract: ContractAddress, token_id: u256
) -> ContractAddress;
fn owner(self: @TContractState) -> ContractAddress;
fn lock(ref self: TContractState, duration: u64);
fn is_locked(self: @TContractState) -> (bool, u64);
fn supports_interface(self: @TContractState, interface_id: felt252) -> bool;
Expand Down
5 changes: 3 additions & 2 deletions src/presets/account.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
////////////////////////////////
#[starknet::contract(account)]
mod Account {
use starknet::ContractAddress;
use starknet::{ContractAddress, get_caller_address};
use starknet::ClassHash;
use token_bound_accounts::account::AccountComponent;
use token_bound_accounts::upgradeable::UpgradeableComponent;
Expand Down Expand Up @@ -45,7 +45,8 @@ mod Account {
#[abi(embed_v0)]
impl UpgradeableImpl of IUpgradeable<ContractState> {
fn upgrade(ref self: ContractState, new_class_hash: ClassHash) {
self.account._assert_only_owner();
let caller = get_caller_address();
assert(self.account._is_valid_signer(caller), AccountComponent::Errors::UNAUTHORIZED);
let (lock_status, _) = self.account._is_locked();
assert(!lock_status, AccountComponent::Errors::LOCKED_ACCOUNT);
self.upgradeable._upgrade(new_class_hash);
Expand Down
15 changes: 5 additions & 10 deletions src/registry/registry.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,17 @@
mod Registry {
use core::result::ResultTrait;
use core::hash::HashStateTrait;
use core::pedersen::PedersenTrait;
use starknet::{
ContractAddress, get_caller_address, syscalls::call_contract_syscall, class_hash::ClassHash,
class_hash::Felt252TryIntoClassHash, syscalls::deploy_syscall, SyscallResultTrait
};
use zeroable::Zeroable;
use traits::{Into, TryInto};
use option::OptionTrait;
use array::{ArrayTrait, SpanTrait};
use pedersen::PedersenTrait;

use token_bound_accounts::interfaces::IERC721::{IERC721DispatcherTrait, IERC721Dispatcher};
use token_bound_accounts::interfaces::IRegistry::IRegistry;

#[storage]
struct Storage {
Registry_deployed_accounts: LegacyMap<
registry_deployed_accounts: LegacyMap<
(ContractAddress, u256), u8
>, // tracks no. of deployed accounts by registry for an NFT
}
Expand Down Expand Up @@ -72,10 +67,10 @@ mod Registry {
let (account_address, _) = result.unwrap_syscall();

let new_deployment_index: u8 = self
.Registry_deployed_accounts
.registry_deployed_accounts
.read((token_contract, token_id))
+ 1_u8;
self.Registry_deployed_accounts.write((token_contract, token_id), new_deployment_index);
self.registry_deployed_accounts.write((token_contract, token_id), new_deployment_index);

self.emit(AccountCreated { account_address, token_contract, token_id, });

Expand Down Expand Up @@ -120,7 +115,7 @@ mod Registry {
fn total_deployed_accounts(
self: @ContractState, token_contract: ContractAddress, token_id: u256
) -> u8 {
self.Registry_deployed_accounts.read((token_contract, token_id))
self.registry_deployed_accounts.read((token_contract, token_id))
}
}

Expand Down
1 change: 1 addition & 0 deletions src/test_helper.cairo
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod hello_starknet;
mod account_upgrade;
mod erc721_helper;
mod simple_account;
10 changes: 3 additions & 7 deletions src/test_helper/account_upgrade.cairo
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use array::{ArrayTrait, SpanTrait};
use starknet::{account::Call, ContractAddress, ClassHash};

#[starknet::interface]
Expand Down Expand Up @@ -55,11 +54,8 @@ mod UpgradedAccount {
get_tx_info, get_caller_address, get_contract_address, ContractAddress, account::Call,
call_contract_syscall, replace_class_syscall, ClassHash, SyscallResultTrait
};
use ecdsa::check_ecdsa_signature;
use array::{SpanTrait, ArrayTrait};
use box::BoxTrait;
use option::OptionTrait;
use zeroable::Zeroable;
use core::ecdsa::check_ecdsa_signature;
use core::zeroable::Zeroable;
use super::{IERC721DispatcherTrait, IERC721Dispatcher};

#[storage]
Expand Down Expand Up @@ -211,7 +207,7 @@ mod UpgradedAccount {
Option::Some(call) => {
match call_contract_syscall(call.to, call.selector, call.calldata) {
Result::Ok(mut retdata) => { result.append(retdata); },
Result::Err(_) => { panic_with_felt252('multicall_failed'); }
Result::Err(_) => { panic(array!['multicall_failed']); }
}
},
Option::None(_) => { break (); }
Expand Down
Loading

0 comments on commit b65044e

Please sign in to comment.