From f8f31940a007d7c9b924abcdd1450a51824cf0e5 Mon Sep 17 00:00:00 2001 From: vincent-merkl Date: Fri, 14 Feb 2025 09:40:03 +0100 Subject: [PATCH] Feat : referral program (#103) --- contracts/ReferralRegistry.sol | 319 +++++++++++++++++++++++++++ contracts/utils/Errors.sol | 2 + scripts/deployReferralRegistry.s.sol | 40 ++++ test/unit/ReferralRegistry.t.sol | 209 ++++++++++++++++++ 4 files changed, 570 insertions(+) create mode 100644 contracts/ReferralRegistry.sol create mode 100644 scripts/deployReferralRegistry.s.sol create mode 100644 test/unit/ReferralRegistry.t.sol diff --git a/contracts/ReferralRegistry.sol b/contracts/ReferralRegistry.sol new file mode 100644 index 0000000..2db7f75 --- /dev/null +++ b/contracts/ReferralRegistry.sol @@ -0,0 +1,319 @@ +// SPDX-License-Identifier: BUSL-1.1 + +pragma solidity ^0.8.17; + +import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; +import { SafeERC20 } from "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol"; + +import { UUPSHelper } from "./utils/UUPSHelper.sol"; +import { IAccessControlManager } from "./interfaces/IAccessControlManager.sol"; +import { Errors } from "./utils/Errors.sol"; + +/// @title ReferralRegistry +/// @notice Allows to manage referral programs and claim rewards distributed through Merkl +/// @dev This contract uses UUPS upgradeability pattern and ReentrancyGuard for security +contract ReferralRegistry is UUPSHelper { + using SafeERC20 for IERC20; + struct ReferralProgram { + address owner; + bool requiresAuthorization; + bool requiresRefererToBeSet; + uint256 cost; + address paymentToken; + } + enum ReferralStatus { + NotAllowed, + Allowed, + Set + } + + /// @notice Address to receive fees + address public feeRecipient; + + /// @notice `AccessControlManager` contract handling access control + IAccessControlManager public accessControlManager; + + /// @notice Whether the contract has been made non upgradeable or not + uint128 public upgradeabilityDeactivated; + + /// @notice Cost to create a referral program + uint256 public costReferralProgram; + + /// @notice List of string keys that are currently in a referral program + string[] public referralKeys; + + /// @notice Mapping to store referral program details + mapping(string => ReferralProgram) public referralPrograms; + + /// @notice Mapping to determine if a user is allowed to be a referrer + mapping(string => mapping(address => ReferralStatus)) public refererStatus; + + /// @notice Mapping to store referrer codes + mapping(string => mapping(address => string)) public referrerCodeMapping; + + /// @notice Mapping to store referrer addresses by code + mapping(string => mapping(string => address)) public codeToReferrer; + + /// @notice Mapping to store user to referrer relationships + mapping(string => mapping(address => address)) public keyToUserToReferrer; + + /// @notice Adds a new referral key to the list + /// @param key The referral key to add + /// @param _cost The cost of the referral program + /// @param _requiresRefererToBeSet Whether the referral program requires a referrer to be set + /// @param _owner The owner of the referral program + /// @param _requiresAuthorization Whether the referral program requires authorization + /// @param _paymentToken The token used for payment in the referral program + function addReferralKey( + string calldata key, + uint256 _cost, + bool _requiresRefererToBeSet, + address _owner, + bool _requiresAuthorization, + address _paymentToken + ) external payable { + if (referralPrograms[key].owner != address(0)) revert Errors.KeyAlreadyUsed(); + if (msg.value != costReferralProgram) revert Errors.NotEnoughPayment(); + require( + _cost == 0 || (_cost > 0 && _requiresRefererToBeSet), + "Cost must be set if requiresRefererToBeSet is true" + ); + referralKeys.push(key); + referralPrograms[key] = ReferralProgram({ + owner: _owner, + requiresAuthorization: _requiresAuthorization, + cost: _cost, + requiresRefererToBeSet: _requiresRefererToBeSet, + paymentToken: _paymentToken + }); + if (costReferralProgram > 0) { + (bool sent, ) = feeRecipient.call{ value: msg.value }(""); + require(sent, "Failed to send Ether"); + } + emit ReferralKeyAdded(key); + } + + /// @notice Edits the parameters of a referral program + /// @param key The referral key to edit + /// @param newCost The new cost of the referral program + /// @param newRequiresAuthorization Whether the referral program requires authorization + /// @param newRequiresRefererToBeSet Whether the referral program requires a referrer to be set + /// @param newPaymentToken The new payment token of the referral program + function editReferralProgram( + string calldata key, + uint256 newCost, + bool newRequiresAuthorization, + bool newRequiresRefererToBeSet, + address newPaymentToken + ) external { + if (referralPrograms[key].owner != msg.sender) revert Errors.NotAllowed(); + referralPrograms[key] = ReferralProgram({ + owner: referralPrograms[key].owner, + requiresAuthorization: newRequiresAuthorization, + cost: newCost, + requiresRefererToBeSet: newRequiresRefererToBeSet, + paymentToken: newPaymentToken + }); + emit ReferralProgramModified( + key, + newCost, + newRequiresAuthorization, + newRequiresRefererToBeSet, + newPaymentToken + ); + } + + /// @notice Marks an address as allowed to be a referrer for a specific referral key + /// @param key The referral key for which the address is allowed + /// @param user The address to be marked as allowed + function allowReferrer(string calldata key, address user) external { + if (referralPrograms[key].owner != msg.sender) revert Errors.NotAllowed(); + refererStatus[key][user] = ReferralStatus.Allowed; + emit ReferrerAdded(key, user); + } + + /// @notice Allows a user to become a referrer for a specific referral key + /// @param key The referral key for which the user wants to become a referrer + /// @param referrerCode The code of the referrer + function becomeReferrer(string calldata key, string calldata referrerCode) external payable { + if (referralPrograms[key].owner == address(0)) revert Errors.NotAllowed(); + require(codeToReferrer[key][referrerCode] == address(0), "Referrer code already in use"); + ReferralProgram storage program = referralPrograms[key]; + if (program.requiresAuthorization) { + if (refererStatus[key][msg.sender] != ReferralStatus.Allowed) revert Errors.NotAllowed(); + } + refererStatus[key][msg.sender] = ReferralStatus.Set; + referrerCodeMapping[key][msg.sender] = referrerCode; + codeToReferrer[key][referrerCode] = msg.sender; + if (program.cost > 0) { + if (address(program.paymentToken) == address(0)) { + if (msg.value < program.cost) revert Errors.NotEnoughPayment(); + (bool sent, ) = program.owner.call{ value: msg.value }(""); + require(sent, "Failed to send Ether"); + } else { + IERC20(program.paymentToken).safeTransferFrom(msg.sender, program.owner, program.cost); + } + } + emit ReferrerAdded(key, msg.sender); + } + + /// @notice Allows a user to acknowledge that they are referred by a referrer + /// @param key The referral key for which the user is acknowledging the referrer + /// @param referrer The address of the referrer + function acknowledgeReferrer(string calldata key, address referrer) public { + if (referralPrograms[key].requiresRefererToBeSet) { + require(refererStatus[key][referrer] == ReferralStatus.Set, "Referrer has not created a referral link"); + } + keyToUserToReferrer[key][msg.sender] = referrer; + emit ReferrerAcknowledged(key, msg.sender, referrer); + } + + /// @notice Allows a user to acknowledge that they are referred by a referrer using a referrer code + /// @param key The referral key for which the user is acknowledging the referrer + /// @param referrerCode The code of the referrer + function acknowledgeReferrerByKey(string calldata key, string calldata referrerCode) external { + address referrer = codeToReferrer[key][referrerCode]; + acknowledgeReferrer(key, referrer); + } + + /// @notice Sets the cost of the referral program + /// @param _costReferralProgram The new cost of the referral program + function setCostReferralProgram(uint256 _costReferralProgram) external onlyGovernor { + costReferralProgram = _costReferralProgram; + emit CostReferralProgramSet(_costReferralProgram); + } + + /// @notice Receive function to accept ETH payments + receive() external payable { + // Custom logic for receiving ETH can be added here + } + + /*////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + EVENTS + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////*/ + event CostReferralProgramSet(uint256 newCost); + event ReferrerAcknowledged(string indexed key, address indexed user, address indexed referrer); + event ReferrerAdded(string indexed key, address indexed referrer); + event ReferralProgramModified( + string indexed key, + uint256 newCost, + bool newRequiresAuthorization, + bool newRequiresRefererToBeSet, + address newPaymentToken + ); + event ReferralKeyAdded(string indexed key); + event ReferralKeyRemoved(uint256 index); + event UpgradeabilityRevoked(); + + event Claimed(address indexed user, address indexed token, uint256 amount); + + /*////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + MODIFIERS + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////*/ + + /// @notice Checks whether the `msg.sender` has the governor role + modifier onlyGovernor() { + if (!accessControlManager.isGovernor(msg.sender)) revert Errors.NotGovernor(); + _; + } + + /// @notice Checks whether the contract is upgradeable or whether the caller is allowed to upgrade the contract + modifier onlyUpgradeableInstance() { + if (upgradeabilityDeactivated == 1) revert Errors.NotUpgradeable(); + else if (!accessControlManager.isGovernor(msg.sender)) revert Errors.NotGovernor(); + _; + } + + /*////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + CONSTRUCTOR + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////*/ + + constructor() initializer {} + function initialize( + IAccessControlManager _accessControlManager, + uint256 _costReferralProgram, + address _feeRecipient + ) external initializer { + if (address(_accessControlManager) == address(0)) revert Errors.ZeroAddress(); + accessControlManager = _accessControlManager; + costReferralProgram = _costReferralProgram; + feeRecipient = _feeRecipient; + } + + /// @inheritdoc UUPSHelper + function _authorizeUpgrade(address) internal view override onlyUpgradeableInstance {} + + /// @notice Prevents future contract upgrades + function revokeUpgradeability() external onlyGovernor { + upgradeabilityDeactivated = 1; + emit UpgradeabilityRevoked(); + } + + /*////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + VIEW FUNCTIONS + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////*/ + + /// @notice Gets the list of referral keys + /// @return The list of referral keys + function getReferralKeys() external view returns (string[] memory) { + return referralKeys; + } + /// @notice Gets the details of a referral program + /// @param key The referral key to get details for + /// @return The details of the referral program + function getReferralProgram(string calldata key) external view returns (ReferralProgram memory) { + return referralPrograms[key]; + } + + /// @notice Gets the referrer status for a specific user and referral key + /// @param key The referral key to check + /// @param user The user to check the referrer status for + /// @return The referrer status of the user for the given key + function getReferrerStatus(string calldata key, address user) external view returns (ReferralStatus) { + return refererStatus[key][user]; + } + + /// @notice Gets the referrer for a specific user and referral key + /// @param key The referral key to check + /// @param user The user to check the referrer for + /// @return The referrer of the user for the given key + function getReferrer(string calldata key, address user) external view returns (address) { + return keyToUserToReferrer[key][user]; + } + + /// @notice Gets the cost of a referral for a specific key + /// @param key The referral key to check + /// @return The cost of the referral for the given key + function getCostOfReferral(string calldata key) external view returns (uint256) { + return referralPrograms[key].cost; + } + + /// @notice Gets the payment token of a referral program + /// @param key The referral key to check + /// @return The payment token of the referral program + function getPaymentToken(string calldata key) external view returns (address) { + return referralPrograms[key].paymentToken; + } + + /// @notice Checks if a referral program requires authorization + /// @param key The referral key to check + /// @return True if the referral program requires authorization, false otherwise + function requiresAuthorization(string calldata key) external view returns (bool) { + return referralPrograms[key].requiresAuthorization; + } + + /// @notice Checks if a referral program requires a referrer to be set + /// @param key The referral key to check + /// @return True if the referral program requires a referrer to be set, false otherwise + function requiresRefererToBeSet(string calldata key) external view returns (bool) { + return referralPrograms[key].requiresRefererToBeSet; + } + + /// @notice Gets the status of a referrer for a specific referral key + /// @param key The referral key to check + /// @param referrer The referrer to check the status for + /// @return The status of the referrer for the given key + function getReferrerStatusByKey(string calldata key, address referrer) external view returns (ReferralStatus) { + return refererStatus[key][referrer]; + } +} diff --git a/contracts/utils/Errors.sol b/contracts/utils/Errors.sol index 19c3e1c..d79f2a5 100644 --- a/contracts/utils/Errors.sol +++ b/contracts/utils/Errors.sol @@ -19,9 +19,11 @@ library Errors { error InvalidReturnMessage(); error InvalidReward(); error InvalidSignature(); + error KeyAlreadyUsed(); error NoDispute(); error NoOverrideForCampaign(); error NotAllowed(); + error NotEnoughPayment(); error NotGovernor(); error NotGovernorOrGuardian(); error NotSigned(); diff --git a/scripts/deployReferralRegistry.s.sol b/scripts/deployReferralRegistry.s.sol new file mode 100644 index 0000000..5a861ff --- /dev/null +++ b/scripts/deployReferralRegistry.s.sol @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: BUSL-1.1 +pragma solidity ^0.8.17; + +import { console } from "forge-std/console.sol"; +import { BaseScript } from "./utils/Base.s.sol"; +import { ERC1967Proxy } from "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; +import { ReferralRegistry } from "../contracts/ReferralRegistry.sol"; +import { DistributionCreator } from "../contracts/DistributionCreator.sol"; +import { IAccessControlManager } from "../contracts/interfaces/IAccessControlManager.sol"; +interface IDistributionCreator { + function distributor() external view returns (address); + + function feeRecipient() external view returns (address); + + function accessControlManager() external view returns (IAccessControlManager); +} + +contract DeployReferralRegistry is BaseScript { + function run() public { + uint256 deployerPrivateKey = vm.envUint("DEPLOYER_PRIVATE_KEY"); + vm.startBroadcast(deployerPrivateKey); + uint256 feeSetup = 0; + // uint32 cliffDuration = 1 weeks; + IDistributionCreator distributionCreator = IDistributionCreator(0x8BB4C975Ff3c250e0ceEA271728547f3802B36Fd); + address feeRecipient = distributionCreator.feeRecipient(); + IAccessControlManager accessControlManager = distributionCreator.accessControlManager(); + + // Deploy implementation + address implementation = address(new ReferralRegistry()); + console.log("ReferralRegistry Implementation:", implementation); + + // Deploy proxy + ERC1967Proxy proxy = new ERC1967Proxy(implementation, ""); + console.log("ReferralRegistry Proxy:", address(proxy)); + + // Initialize + ReferralRegistry(payable(address(proxy))).initialize(accessControlManager, feeSetup, feeRecipient); + vm.stopBroadcast(); + } +} diff --git a/test/unit/ReferralRegistry.t.sol b/test/unit/ReferralRegistry.t.sol new file mode 100644 index 0000000..0391201 --- /dev/null +++ b/test/unit/ReferralRegistry.t.sol @@ -0,0 +1,209 @@ +// SPDX-License-Identifier: BUSL-1.1 + +pragma solidity ^0.8.17; + +import "forge-std/Test.sol"; +import "../../contracts/ReferralRegistry.sol"; +import "../../contracts/interfaces/IAccessControlManager.sol"; +import "@openzeppelin/contracts/token/ERC20/IERC20.sol"; + +import { ERC1967Proxy } from "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; + + +contract ReferralRegistryTest is Test { + ReferralRegistry referralRegistry; + ReferralRegistry referralRegistryImple; + IAccessControlManager accessControlManager; + address paymentToken; + + address owner = vm.addr(1); + address user = vm.addr(2); + address referrer = vm.addr(3); + address feeRecipient = vm.addr(4); + + string referralKey = "testKey"; + uint256 cost = 1000; + uint256 feeSetup = 100; + bool requiresRefererToBeSet = true; + bool requiresAuthorization = false; + + function deployUUPS(address implementation, bytes memory data) public returns (address) { + return address(new ERC1967Proxy(implementation, data)); + } + + function setUp() public { + accessControlManager = IAccessControlManager(address(new MockAccessControlManager())); + referralRegistryImple = new ReferralRegistry(); + paymentToken = address(new MockERC20()); + referralRegistry = ReferralRegistry(payable(deployUUPS(address(referralRegistryImple), hex""))); + referralRegistry.initialize(accessControlManager, feeSetup, feeRecipient); + } + + function testAddReferralKeyCostZero() public { + vm.prank(owner); + referralRegistry.setCostReferralProgram(0); + referralRegistry.addReferralKey(referralKey, cost, requiresRefererToBeSet, owner, requiresAuthorization, paymentToken); + + ReferralRegistry.ReferralProgram memory program = referralRegistry.getReferralProgram(referralKey); + assertEq(program.owner, owner); + assertEq(program.cost, cost); + assertEq(program.requiresRefererToBeSet, requiresRefererToBeSet); + assertEq(program.requiresAuthorization, requiresAuthorization); + assertEq(address(program.paymentToken), address(paymentToken)); + } + + function testAddReferralKey() public { + vm.prank(owner); + uint256 fee = referralRegistry.costReferralProgram(); + referralRegistry.addReferralKey{value: fee}(referralKey, cost, requiresRefererToBeSet, owner, requiresAuthorization, paymentToken); + + ReferralRegistry.ReferralProgram memory program = referralRegistry.getReferralProgram(referralKey); + assertEq(program.owner, owner); + assertEq(program.cost, cost); + assertEq(program.requiresRefererToBeSet, requiresRefererToBeSet); + assertEq(program.requiresAuthorization, requiresAuthorization); + assertEq(address(program.paymentToken), address(paymentToken)); + } + + function testEditReferralProgram() public { + vm.prank(owner); + uint256 fee = referralRegistry.costReferralProgram(); + + referralRegistry.addReferralKey{value: fee}(referralKey, cost, requiresRefererToBeSet, owner, requiresAuthorization, paymentToken); + + uint256 newCost = 2000; + bool newRequiresRefererToBeSet = false; + bool newRequiresAuthorization = false; + address newPaymentToken = address(new MockERC20()); + + vm.prank(owner); + referralRegistry.editReferralProgram(referralKey, newCost, newRequiresAuthorization, newRequiresRefererToBeSet, newPaymentToken); + + ReferralRegistry.ReferralProgram memory program = referralRegistry.getReferralProgram(referralKey); + assertEq(program.cost, newCost); + assertEq(program.requiresRefererToBeSet, newRequiresRefererToBeSet); + assertEq(program.requiresAuthorization, newRequiresAuthorization); + assertEq(address(program.paymentToken), address(newPaymentToken)); + } + + function testBecomeReferrer() public { + vm.prank(owner); + uint256 fee = referralRegistry.costReferralProgram(); + + referralRegistry.addReferralKey{value: fee}(referralKey, cost, requiresRefererToBeSet, owner, requiresAuthorization, paymentToken); + + string memory referrerCode = "referrerCode"; + vm.startPrank(referrer); + IERC20(paymentToken).approve(address(referralRegistry), cost); + referralRegistry.becomeReferrer(referralKey, referrerCode); + + ReferralRegistry.ReferralStatus status = referralRegistry.getReferrerStatus(referralKey, referrer); + assertEq(uint(status), uint(ReferralRegistry.ReferralStatus.Set)); + + string memory storedReferrerCode = referralRegistry.referrerCodeMapping(referralKey, referrer); + assertEq(storedReferrerCode, referrerCode); + + address storedReferrer = referralRegistry.codeToReferrer(referralKey, referrerCode); + assertEq(storedReferrer, referrer); + } + + function testAcknowledgeReferrer() public { + vm.prank(owner); + uint256 fee = referralRegistry.costReferralProgram(); + + referralRegistry.addReferralKey{value: fee}(referralKey, cost, requiresRefererToBeSet, owner, requiresAuthorization, paymentToken); + + string memory referrerCode = "referrerCode"; + vm.startPrank(referrer); + IERC20(paymentToken).approve(address(referralRegistry), cost); + referralRegistry.becomeReferrer(referralKey, referrerCode); + vm.stopPrank(); + vm.prank(user); + referralRegistry.acknowledgeReferrer(referralKey, referrer); + + address referrerOnChain = referralRegistry.getReferrer(referralKey, user); + assertEq(referrer, referrerOnChain); + } + + function testAcknowledgeReferrerByKey() public { + vm.prank(owner); + uint256 fee = referralRegistry.costReferralProgram(); + + referralRegistry.addReferralKey{value: fee}(referralKey, cost, requiresRefererToBeSet, owner, requiresAuthorization, paymentToken); + + string memory referrerCode = "referrerCode"; + vm.startPrank(referrer); + IERC20(paymentToken).approve(address(referralRegistry), cost); + referralRegistry.becomeReferrer(referralKey, referrerCode); + vm.stopPrank(); + vm.prank(user); + referralRegistry.acknowledgeReferrerByKey(referralKey, referrerCode); + + address referrerOnChain = referralRegistry.getReferrer(referralKey, user); + assertEq(referrer, referrerOnChain); + } + + function testAcknowledgeReferrerByKeyWithoutCost() public { + vm.prank(owner); + uint256 fee = referralRegistry.costReferralProgram(); + referralRegistry.addReferralKey{value: fee}(referralKey, 0, false, owner, false, address(0)); + + string memory referrerCode = "referrerCode"; + vm.startPrank(referrer); + referralRegistry.becomeReferrer(referralKey, referrerCode); + vm.stopPrank(); + vm.prank(user); + referralRegistry.acknowledgeReferrerByKey(referralKey, referrerCode); + + address referrerOnChain = referralRegistry.getReferrer(referralKey, user); + assertEq(referrer, referrerOnChain); + } +} + +contract MockAccessControlManager is IAccessControlManager { + function isGovernor(address) external pure returns (bool) { + return true; + } + + function isGovernorOrGuardian(address) external pure returns (bool) { + return true; + } +} + +contract MockERC20 is IERC20 { + function totalSupply() external pure returns (uint256) { + return 1000000; + } + + function balanceOf(address) external pure returns (uint256) { + return 1000000; + } + + function transfer(address, uint256) external pure returns (bool) { + return true; + } + + function allowance(address, address) external pure returns (uint256) { + return 1000000; + } + + function approve(address, uint256) external pure returns (bool) { + return true; + } + + function transferFrom(address, address, uint256) external pure returns (bool) { + return true; + } + + function name() external pure returns (string memory) { + return "MockERC20"; + } + + function symbol() external pure returns (string memory) { + return "MERC20"; + } + + function decimals() external pure returns (uint8) { + return 18; + } +}