diff --git a/src/LBPair.sol b/src/LBPair.sol index 94885341..a5a815f2 100644 --- a/src/LBPair.sol +++ b/src/LBPair.sol @@ -1112,7 +1112,7 @@ contract LBPair is LBToken, ReentrancyGuardUpgradeable, ILBPair { uint128 _protocolFees ) private { assembly { - sstore(_pairFees.slot, and(shl(_OFFSET_PROTOCOL_FEE, _protocolFees), _totalFees)) + sstore(_pairFees.slot, or(shl(_OFFSET_PROTOCOL_FEE, _protocolFees), _totalFees)) } } diff --git a/test/LBPair.Fees.t.sol b/test/LBPair.Fees.t.sol index c8e33e97..17e5ebd8 100644 --- a/test/LBPair.Fees.t.sol +++ b/test/LBPair.Fees.t.sol @@ -167,44 +167,120 @@ contract LiquidityBinPairFeesTest is TestHelper { ); } - function testClaimProtocolFees() public { - uint256 amountYInLiquidity = 100e18; - uint256 amountXOutForSwap = 1e6; - uint256 amountYOutForSwap = 1e6; - uint24 startId = ID_ONE; + struct FeeInfo { + uint256 feeXTotal; + uint256 feeYTotal; + uint256 feeXProtocol; + uint256 feeYProtocol; + } - addLiquidity(amountYInLiquidity, startId, 5, 0); + function _getGlobalFees() internal view returns (FeeInfo memory) { + (uint256 feesXTotal, uint256 feesYTotal, uint256 feesXProtocol, uint256 feesYProtocol) = pair.getGlobalFees(); + return FeeInfo(feesXTotal, feesYTotal, feesXProtocol, feesYProtocol); + } - (uint256 amountYInForSwap, uint256 feesFromGetSwapIn) = router.getSwapIn(pair, amountXOutForSwap, false); + function testClaimProtocolFees() public { + addLiquidity(100e18, ID_ONE, 5, 0); + + // Add Y fees + (uint256 amountIn, uint256 feesIn) = router.getSwapIn(pair, 1e6, false); + token18D.mint(address(pair), amountIn); - token18D.mint(address(pair), amountYInForSwap); vm.prank(ALICE); pair.swap(false, DEV); - (, uint256 feesYTotal, , uint256 feesYProtocol) = pair.getGlobalFees(); - assertEq(feesFromGetSwapIn, feesYTotal); - assertGt(feesYTotal, 0); + // Claiming rewards for Y + FeeInfo memory feesBefore = _getGlobalFees(); - address protocolFeesReceiver = factory.feeRecipient(); + assertEq(feesBefore.feeXTotal, 0); + assertEq(feesBefore.feeXProtocol, 0); + + assertGt(feesBefore.feeYTotal, 0); + assertEq(feesBefore.feeYTotal, feesIn); + + assertGt(feesBefore.feeYProtocol, 0); + assertLt(feesBefore.feeYProtocol, feesBefore.feeYTotal); + address protocolFeesReceiver = factory.feeRecipient(); uint256 balanceBefore = token18D.balanceOf(protocolFeesReceiver); + + vm.prank(protocolFeesReceiver); pair.collectProtocolFees(); - assertEq(token18D.balanceOf(protocolFeesReceiver) - balanceBefore, feesYProtocol - 1); - // Claiming twice + assertEq(token18D.balanceOf(protocolFeesReceiver) - balanceBefore, feesBefore.feeYProtocol - 1); + FeeInfo memory feesAfter = _getGlobalFees(); + + assertEq(feesAfter.feeXTotal, 0); + assertEq(feesAfter.feeXProtocol, 0); + + assertGt(feesAfter.feeYTotal, 0); + assertEq(feesAfter.feeYTotal, feesBefore.feeYTotal - (feesBefore.feeYProtocol - 1)); + assertEq(feesAfter.feeYProtocol, 1); + + // Claiming twice pair.collectProtocolFees(); - assertEq(token18D.balanceOf(protocolFeesReceiver) - balanceBefore, feesYProtocol - 1); + assertEq(token18D.balanceOf(protocolFeesReceiver) - balanceBefore, feesBefore.feeYProtocol - 1); - //Claiming rewards for X - (uint256 amountXInForSwap, ) = router.getSwapIn(pair, amountXOutForSwap, true); + FeeInfo memory feesAfter2 = _getGlobalFees(); - token6D.mint(address(pair), amountXInForSwap); - vm.prank(BOB); + assertEq(feesAfter2.feeXTotal, feesAfter.feeXTotal); + assertEq(feesAfter2.feeXProtocol, feesAfter.feeXProtocol); + + assertEq(feesAfter2.feeYTotal, feesAfter.feeYTotal); + assertEq(feesAfter2.feeYProtocol, feesAfter.feeYProtocol); + + // Add X fees + (amountIn, feesIn) = router.getSwapIn(pair, 1e18, true); + token6D.mint(address(pair), amountIn); + + vm.prank(ALICE); pair.swap(true, DEV); + + // Claiming rewards for X + feesBefore = _getGlobalFees(); + + assertEq(feesBefore.feeYTotal, feesAfter2.feeYTotal); + assertEq(feesBefore.feeYProtocol, feesAfter2.feeYProtocol); + + assertGt(feesBefore.feeXTotal, 0); + assertEq(feesBefore.feeXTotal, feesIn); + + assertGt(feesBefore.feeXProtocol, 0); + assertLt(feesBefore.feeXProtocol, feesBefore.feeXTotal); + balanceBefore = token6D.balanceOf(protocolFeesReceiver); + + vm.prank(protocolFeesReceiver); + pair.collectProtocolFees(); + + assertEq(token6D.balanceOf(protocolFeesReceiver) - balanceBefore, feesBefore.feeXProtocol - 1); + + feesAfter = _getGlobalFees(); + + assertEq(feesAfter.feeYTotal, feesBefore.feeYTotal); + assertEq(feesAfter.feeYProtocol, feesBefore.feeYProtocol); + + assertGt(feesAfter.feeXTotal, 0); + assertEq(feesAfter.feeXTotal, feesBefore.feeXTotal - (feesBefore.feeXProtocol - 1)); + + assertGt(feesAfter.feeXProtocol, 0); + assertEq(feesAfter.feeXProtocol, 1); + + // Claiming twice pair.collectProtocolFees(); - assertEq(token6D.balanceOf(protocolFeesReceiver) - balanceBefore, feesYProtocol - 1); + assertEq(token6D.balanceOf(protocolFeesReceiver) - balanceBefore, feesBefore.feeXProtocol - 1); + + FeeInfo memory feesAfter2X = _getGlobalFees(); + + assertEq(feesAfter2X.feeXTotal, feesAfter.feeXTotal); + assertEq(feesAfter2X.feeXProtocol, feesAfter.feeXProtocol); + + assertEq(feesAfter2X.feeXTotal, feesAfter.feeXTotal); + assertEq(feesAfter2X.feeXProtocol, feesAfter.feeXProtocol); + + assertEq(feesAfter2X.feeYTotal, feesAfter.feeYTotal); + assertEq(feesAfter2X.feeYProtocol, feesAfter.feeYProtocol); } function testForceDecay() public {