diff --git a/src/components/upgradeable/upgradeable.cairo b/src/components/upgradeable/upgradeable.cairo index 0353c69..acf5056 100644 --- a/src/components/upgradeable/upgradeable.cairo +++ b/src/components/upgradeable/upgradeable.cairo @@ -6,7 +6,9 @@ pub mod UpgradeableComponent { // ************************************************************************* // IMPORTS // ************************************************************************* - use starknet::{ClassHash, SyscallResultTrait, get_caller_address}; + use starknet::{ + ClassHash, SyscallResultTrait, get_caller_address, get_contract_address, ContractAddress + }; use core::num::traits::zero::Zero; use token_bound_accounts::components::account::account::AccountComponent; use token_bound_accounts::components::account::account::AccountComponent::InternalImpl; @@ -23,13 +25,14 @@ pub mod UpgradeableComponent { #[event] #[derive(Drop, starknet::Event)] pub enum Event { - Upgraded: Upgraded + TBAUpgraded: TBAUpgraded } /// @notice Emitted when the contract is upgraded. /// @param class_hash implementation hash to be upgraded to #[derive(Drop, starknet::Event)] - pub struct Upgraded { + pub struct TBAUpgraded { + pub account_address: ContractAddress, pub class_hash: ClassHash } @@ -68,7 +71,12 @@ pub mod UpgradeableComponent { // upgrade account starknet::syscalls::replace_class_syscall(new_class_hash).unwrap_syscall(); - self.emit(Upgraded { class_hash: new_class_hash }); + self + .emit( + TBAUpgraded { + account_address: get_contract_address(), class_hash: new_class_hash + } + ); } } } diff --git a/tests/test_upgradeable.cairo b/tests/test_upgradeable.cairo index bc11a06..62205cd 100644 --- a/tests/test_upgradeable.cairo +++ b/tests/test_upgradeable.cairo @@ -16,7 +16,7 @@ use token_bound_accounts::interfaces::IUpgradeable::{ IUpgradeableDispatcher, IUpgradeableDispatcherTrait }; use token_bound_accounts::components::presets::account_preset::AccountPreset; -use token_bound_accounts::components::account::account::AccountComponent; +use token_bound_accounts::components::upgradeable::upgradeable::UpgradeableComponent; use token_bound_accounts::test_helper::{ erc721_helper::{IERC721Dispatcher, IERC721DispatcherTrait, ERC721}, @@ -66,7 +66,6 @@ fn __setup__() -> (ContractAddress, ContractAddress) { #[test] fn test_upgrade() { let (contract_address, erc721_contract_address) = __setup__(); - let new_class_hash = declare("UpgradedAccount").unwrap().class_hash; // get token owner @@ -89,7 +88,6 @@ fn test_upgrade() { #[should_panic(expected: ('Account: unauthorized',))] fn test_upgrade_with_unauthorized() { let (contract_address, _) = __setup__(); - let new_class_hash = declare("UpgradedAccount").unwrap().class_hash; // call upgrade function with an unauthorized address @@ -97,3 +95,36 @@ fn test_upgrade_with_unauthorized() { let safe_upgrade_dispatcher = IUpgradeableDispatcher { contract_address }; safe_upgrade_dispatcher.upgrade(new_class_hash); } + +#[test] +fn test_upgrade_emits_event() { + let (contract_address, erc721_contract_address) = __setup__(); + let new_class_hash = declare("UpgradedAccount").unwrap().class_hash; + + // get token owner + let token_dispatcher = IERC721Dispatcher { contract_address: erc721_contract_address }; + let token_owner = token_dispatcher.ownerOf(1.try_into().unwrap()); + + // spy on emitted events + let mut spy = spy_events(); + + // call the upgrade function + let dispatcher = IUpgradeableDispatcher { contract_address }; + start_cheat_caller_address(contract_address, token_owner); + dispatcher.upgrade(new_class_hash); + + // check events are emitted + spy + .assert_emitted( + @array![ + ( + contract_address, + UpgradeableComponent::Event::TBAUpgraded( + UpgradeableComponent::TBAUpgraded { + account_address: contract_address, class_hash: new_class_hash + } + ) + ) + ] + ); +}