diff --git a/contracts/uint128x2.sol b/contracts/uint128x2.sol index a6eac64..0ee82ea 100644 --- a/contracts/uint128x2.sol +++ b/contracts/uint128x2.sol @@ -4,15 +4,15 @@ pragma solidity ^0.8.20; type uint128x2 is uint256; -function uint128x2From(uint256 hi, uint256 lo) pure returns (uint128x2) { - return uint128x2.wrap(hi << 128 | lo); +function uint128x2From(uint256 high, uint256 low) pure returns (uint128x2) { + return uint128x2.wrap(high << 128 | low); } -function hi128(uint128x2 n) pure returns (uint256) { +function hi(uint128x2 n) pure returns (uint256) { return uint128x2.unwrap(n) >> 128; } -function lo128(uint128x2 n) pure returns (uint256) { +function lo(uint128x2 n) pure returns (uint256) { return uint128(uint128x2.unwrap(n)); } @@ -24,14 +24,18 @@ function decLo(uint128x2 self, uint256 delta) pure returns (uint128x2) { function inc(uint128x2 self, uint256 delta) pure returns (uint128x2) { unchecked { - return uint128x2.wrap(uint128x2.unwrap(self) + (delta | delta << 128)); + return uint128x2.wrap(uint128x2.unwrap(self) + (delta | (delta << 128))); } } function dec(uint128x2 self, uint256 delta) pure returns (uint128x2) { unchecked { - return uint128x2.wrap(uint128x2.unwrap(self) - (delta | delta << 128)); + return uint128x2.wrap(uint128x2.unwrap(self) - (delta | (delta << 128))); } } -using {hi128, lo128, decLo, inc, dec} for uint128x2 global; +function equal(uint128x2 self, uint128x2 other) pure returns (bool) { + return uint128x2.unwrap(self) == uint128x2.unwrap(other); +} + +using {hi, lo, decLo, inc, dec, equal as ==} for uint128x2 global; diff --git a/test/uint128x2.t.sol b/test/uint128x2.t.sol index ee2dc6d..beee2d2 100644 --- a/test/uint128x2.t.sol +++ b/test/uint128x2.t.sol @@ -2,14 +2,33 @@ pragma solidity ^0.8.0; -import {Test} from "forge-std/Test.sol"; import {uint128x2, uint128x2From} from "contracts/uint128x2.sol"; +import {Test} from "forge-std/Test.sol"; -contract integersTest is Test { +contract uint128x2Test is Test { function testAccessors() public pure { uint128x2 pair = uint128x2From(1, 2); - assertEq(pair.lo128(), 2); - assertEq(pair.hi128(), 1); + assertEq(pair.lo(), 2); + assertEq(pair.hi(), 1); + } + + function testIncDec() public pure { + uint128x2 pair = uint128x2From(300, 500); + pair = pair.inc(200); + + assertEq(pair.hi(), 500); + assertEq(pair.lo(), 700); + + pair = pair.dec(200); + + assertEq(pair.hi(), 300); + assertEq(pair.lo(), 500); + + uint128x2 pair2 = uint128x2From(type(uint128).max, 10); + uint128x2 pair3 = pair2.dec(10); + pair3 = pair3.inc(10); + + assert(pair2 == pair3); } }