Skip to content

Commit

Permalink
chore: add AddRewardTokensTest
Browse files Browse the repository at this point in the history
  • Loading branch information
pyk committed May 30, 2024
1 parent 3b21c5d commit 1797d09
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 162 deletions.
95 changes: 48 additions & 47 deletions src/LlamaLocker.sol
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ contract LlamaLocker is ERC721Holder, Ownable2Step {
using Math for uint256;

IERC721 public nft;
IERC20[] public tokens; // yield token
address[] public tokens; // reward tokens
uint256 public constant EPOCH_DURATION = 7 days;
uint256 public constant LOCK_DURATION_IN_EPOCH = 4; // 4 epochs
uint256 public totalLockedNFT;
Expand All @@ -46,7 +46,7 @@ contract LlamaLocker is ERC721Holder, Ownable2Step {
uint256 amount;
}

struct YieldInfo {
struct RewardTokenInfo {
uint208 amountPerSecond;
uint48 epochEndAt;
uint208 amountPerNFTStored;
Expand All @@ -60,7 +60,7 @@ contract LlamaLocker is ERC721Holder, Ownable2Step {

Epoch[] public epochs;
mapping(uint256 nftId => LockInfo info) public locks;
mapping(IERC20 token => YieldInfo info) private yields;
mapping(address token => RewardTokenInfo info) public rewards;
mapping(uint256 nftId => mapping(IERC20 token => NFTYield)) private nftYields;

// TODO(pyk): check variables below
Expand All @@ -76,8 +76,10 @@ contract LlamaLocker is ERC721Holder, Ownable2Step {
error InvalidYieldRecipient();
error InvalidLockOwner();
error InvalidUnlockWindow();
error InvalidRewardTokenCount();
error InvalidRewardToken();

event TokenAdded(IERC20 token);
event NewRewardToken(address token);
event RewardDistributed(IERC20 token, uint256 amount);
event Locked(address owner, uint256 nftId);
event Unlocked(address recipient, uint256 nftId);
Expand Down Expand Up @@ -147,6 +149,7 @@ contract LlamaLocker is ERC721Holder, Ownable2Step {
LockInfo memory lockInfo = locks[tokenId_];
if (unlocker_ != lockInfo.owner) revert InvalidLockOwner();
// TODO(pyk): mark lockInfo as unlocked here
// TODO(pyk): claim unclaimed rewards here

uint256 lockedDurationInEpoch = currendEpochIndex_ - lockInfo.lockedAtEpochIndex;
if (lockedDurationInEpoch == 0) revert InvalidUnlockWindow();
Expand All @@ -172,61 +175,59 @@ contract LlamaLocker is ERC721Holder, Ownable2Step {
}
}

//************************************************************//
// Yield //
//************************************************************//

function addTokens(IERC20[] memory _tokens) external onlyOwner {
uint256 tokenCount = _tokens.length;
if (tokenCount == 0) revert Empty();

for (uint256 i = 0; i < _tokens.length; i++) {
IERC20 token = tokens[i];

if (yields[token].updatedAt > 0) revert TokenExists();
tokens.push(token);
yields[token].updatedAt = block.timestamp.toUint48();
yields[token].epochEndAt = block.timestamp.toUint48();
function _addRewardToken(address token_) private {
if (address(token_) == address(0)) revert InvalidRewardToken();
if (rewards[token_].updatedAt > 0) revert InvalidRewardToken();
tokens.push(token_);
rewards[token_].updatedAt = block.timestamp.toUint48();
rewards[token_].epochEndAt = block.timestamp.toUint48();
emit NewRewardToken(token_);
}

emit TokenAdded(token);
function addRewardTokens(address[] memory tokens_) external onlyOwner {
uint256 tokenCount = tokens_.length;
if (tokenCount == 0) revert InvalidRewardTokenCount();
_backfillEpochs();
for (uint256 i = 0; i < tokens_.length; i++) {
_addRewardToken(tokens_[i]);
}
}

function getTokenCount() external view returns (uint256 count) {
function getRewardTokenCount() external view returns (uint256 count) {
count = tokens.length;
}

function getYieldInfo(IERC20 _token) external view returns (YieldInfo memory info) {
info = yields[_token];
}
// function getYieldInfo(IERC20 _token) external view returns (YieldInfo memory info) {
// info = yields[_token];
// }

/// @dev Calculate yield amount per NFT
function _yieldAmountPerNFT(IERC20 _token) internal view returns (uint256) {
if (totalLockedNFT == 0) return yields[_token].amountPerNFTStored;

YieldInfo memory yieldInfo = yields[_token];
uint256 prevYieldAmountPerNFT = uint256(yieldInfo.amountPerNFTStored);
uint256 epochEndAt = Math.min(uint256(yieldInfo.epochEndAt), block.timestamp);
uint256 timeDelta = epochEndAt - uint256(yieldInfo.updatedAt);
uint256 yieldAmountPerNFT = (timeDelta * yieldInfo.amountPerSecond) / totalLockedNFT;
return prevYieldAmountPerNFT + yieldAmountPerNFT;
}
// function _yieldAmountPerNFT(IERC20 _token) internal view returns (uint256) {
// if (totalLockedNFT == 0) return yields[_token].amountPerNFTStored;

// YieldInfo memory yieldInfo = yields[_token];
// uint256 prevYieldAmountPerNFT = uint256(yieldInfo.amountPerNFTStored);
// uint256 epochEndAt = Math.min(uint256(yieldInfo.epochEndAt), block.timestamp);
// uint256 timeDelta = epochEndAt - uint256(yieldInfo.updatedAt);
// uint256 yieldAmountPerNFT = (timeDelta * yieldInfo.amountPerSecond) / totalLockedNFT;
// return prevYieldAmountPerNFT + yieldAmountPerNFT;
// }

function _getClaimableYield(uint256 _nftId, IERC20 _token) internal view returns (uint256) {
NFTYield memory nftYield = nftYields[_nftId][_token];
uint256 amountPerNFT = _yieldAmountPerNFT(_token);
return (amountPerNFT - nftYield.paidAmount) + nftYield.amount;
}
// function _getClaimableYield(uint256 _nftId, IERC20 _token) internal view returns (uint256) {
// NFTYield memory nftYield = nftYields[_nftId][_token];
// uint256 amountPerNFT = _yieldAmountPerNFT(_token);
// return (amountPerNFT - nftYield.paidAmount) + nftYield.amount;
// }

/// @dev Get claimable yield of nftId
function getClaimableYield(uint256 _nftId) external view returns (Claimable[] memory claimables) {
claimables = new Claimable[](tokens.length);
for (uint256 i = 0; i < tokens.length; i++) {
IERC20 token = tokens[i];
claimables[i].token = token;
claimables[i].amount = _getClaimableYield(_nftId, token);
}
}
// function getClaimableYield(uint256 _nftId) external view returns (Claimable[] memory claimables) {
// claimables = new Claimable[](tokens.length);
// for (uint256 i = 0; i < tokens.length; i++) {
// IERC20 token = tokens[i];
// claimables[i].token = token;
// claimables[i].amount = _getClaimableYield(_nftId, token);
// }
// }

// TODO(pyk): review functions below

Expand Down
97 changes: 97 additions & 0 deletions test/AddRewardTokens.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// SPDX-License-Identifier: MIT
pragma solidity 0.8.23;

import {Ownable} from "@openzeppelin/contracts/access/Ownable.sol";
import {Test} from "forge-std/Test.sol";

import {LlamaLocker} from "../src/LlamaLocker.sol";
import {MockNFT} from "./MockNFT.sol";
import {MockToken} from "./MockToken.sol";

contract AddRewardTokensTest is Test {
MockNFT public nft;
MockToken public token0;
MockToken public token1;
LlamaLocker public locker;

address public admin = makeAddr("admin");
address public alice = makeAddr("alice");
address public bob = makeAddr("bob");

function setUp() public {
token0 = new MockToken();
token1 = new MockToken();
nft = new MockNFT();
locker = new LlamaLocker(admin, address(nft));
}

function test_addRewardTokens_Valid() public {
vm.warp(1714608000);
LlamaLocker llama = new LlamaLocker(admin, makeAddr("nft"));

vm.startPrank(admin);

address[] memory rewardTokens = new address[](2);
rewardTokens[0] = address(token0);
rewardTokens[1] = address(token1);

vm.warp(1717026037);
llama.addRewardTokens(rewardTokens);

// addRewardTokens() should backfill epochs
assertEq(llama.epochs(0), 1714608000, "invalid epoch 0");
assertEq(llama.epochs(1), 1715212800, "invalid epoch 1");
assertEq(llama.epochs(2), 1715817600, "invalid epoch 2");
assertEq(llama.epochs(3), 1716422400, "invalid epoch 3");

vm.expectRevert();
llama.epochs(4);

// addRewardTokens() should increase token count
assertEq(llama.getRewardTokenCount(), 2, "invalid reward token count");

// addRewardTokens() should set initial values
(uint208 amountPerSecond, uint48 epochEndAt, uint208 amountPerNFTStored, uint48 updatedAt) =
llama.rewards(address(token0));
assertEq(amountPerSecond, 0, "invalid token0 amountPerSecond");
assertEq(epochEndAt, block.timestamp, "invalid token0 epochEndAt");
assertEq(amountPerNFTStored, 0, "invalid token0 amountPerNFTStored");
assertEq(updatedAt, block.timestamp, "invalid token0 updatedAt");

(amountPerSecond, epochEndAt, amountPerNFTStored, updatedAt) = llama.rewards(address(token1));
assertEq(amountPerSecond, 0, "invalid token1 amountPerSecond");
assertEq(epochEndAt, block.timestamp, "invalid token1 epochEndAt");
assertEq(amountPerNFTStored, 0, "invalid token1 amountPerNFTStored");
assertEq(updatedAt, block.timestamp, "invalid token1 updatedAt");
}

function test_addRewardTokens_Unauthorized() public {
address[] memory rewardTokens = new address[](0);
vm.startPrank(alice);
vm.expectRevert(abi.encodeWithSelector(Ownable.OwnableUnauthorizedAccount.selector, alice));
locker.addRewardTokens(rewardTokens);
}

function test_addRewardTokens_InvalidRewardTokenCount() public {
address[] memory rewardTokens = new address[](0);
vm.startPrank(admin);
vm.expectRevert(abi.encodeWithSelector(LlamaLocker.InvalidRewardTokenCount.selector));
locker.addRewardTokens(rewardTokens);
}

function test_addRewardTokens_InvalidRewardToken() public {
vm.startPrank(admin);

address[] memory rewardTokens = new address[](1);
rewardTokens[0] = address(0);
vm.expectRevert(abi.encodeWithSelector(LlamaLocker.InvalidRewardToken.selector));
locker.addRewardTokens(rewardTokens);

rewardTokens = new address[](1);
rewardTokens[0] = address(token0);
locker.addRewardTokens(rewardTokens);

vm.expectRevert(abi.encodeWithSelector(LlamaLocker.InvalidRewardToken.selector));
locker.addRewardTokens(rewardTokens);
}
}
Loading

0 comments on commit 1797d09

Please sign in to comment.