diff --git a/contracts/ChainlinkPriceFeed.sol b/contracts/ChainlinkPriceFeed.sol index b72d5bb..21b27ed 100644 --- a/contracts/ChainlinkPriceFeed.sol +++ b/contracts/ChainlinkPriceFeed.sol @@ -4,10 +4,11 @@ pragma solidity 0.7.6; import { Address } from "@openzeppelin/contracts/utils/Address.sol"; import { SafeMath } from "@openzeppelin/contracts/math/SafeMath.sol"; import { AggregatorV3Interface } from "@chainlink/contracts/src/v0.6/interfaces/AggregatorV3Interface.sol"; +import { IChainlinkPriceFeed } from "./interface/IChainlinkPriceFeed.sol"; import { IPriceFeed } from "./interface/IPriceFeed.sol"; import { BlockContext } from "./base/BlockContext.sol"; -contract ChainlinkPriceFeed is IPriceFeed, BlockContext { +contract ChainlinkPriceFeed is IChainlinkPriceFeed, IPriceFeed, BlockContext { using SafeMath for uint256; using Address for address; @@ -24,10 +25,25 @@ contract ChainlinkPriceFeed is IPriceFeed, BlockContext { return _aggregator.decimals(); } - function getAggregator() external view returns (address) { + function getAggregator() external view override returns (address) { return address(_aggregator); } + function getRoundData(uint80 roundId) external view override returns (uint256, uint256) { + // NOTE: aggregator will revert if roundId is invalid (but there might not be a revert message sometimes) + // will return (roundId, 0, 0, 0, roundId) if round is not complete (not existed yet) + // https://docs.chain.link/docs/historical-price-data/ + (, int256 price, , uint256 updatedAt, ) = _aggregator.getRoundData(roundId); + + // CPF_IP: Invalid Price + require(price > 0, "CPF_IP"); + + // CPF_RINC: Round Is Not Complete + require(updatedAt > 0, "CPF_RINC"); + + return (uint256(price), updatedAt); + } + function getPrice(uint256 interval) external view override returns (uint256) { // there are 3 timestamps: base(our target), previous & current // base: now - _interval @@ -85,21 +101,6 @@ contract ChainlinkPriceFeed is IPriceFeed, BlockContext { return weightedPrice == 0 ? latestPrice : weightedPrice.div(interval); } - function getRoundData(uint80 roundId) external view returns (uint256, uint256) { - // NOTE: aggregator will revert if roundId is invalid (but there might not be a revert message sometimes) - // will return (roundId, 0, 0, 0, roundId) if round is not complete (not existed yet) - // https://docs.chain.link/docs/historical-price-data/ - (, int256 price, , uint256 updatedAt, ) = _aggregator.getRoundData(roundId); - - // CPF_IP: Invalid Price - require(price > 0, "CPF_IP"); - - // CPF_RINC: Round Is Not Complete - require(updatedAt > 0, "CPF_RINC"); - - return (uint256(price), updatedAt); - } - function _getLatestRoundData() private view diff --git a/contracts/ChainlinkPriceFeedV2.sol b/contracts/ChainlinkPriceFeedV2.sol index 9d81585..67e8d4c 100644 --- a/contracts/ChainlinkPriceFeedV2.sol +++ b/contracts/ChainlinkPriceFeedV2.sol @@ -3,11 +3,12 @@ pragma solidity 0.7.6; import { Address } from "@openzeppelin/contracts/utils/Address.sol"; import { AggregatorV3Interface } from "@chainlink/contracts/src/v0.6/interfaces/AggregatorV3Interface.sol"; +import { IChainlinkPriceFeed } from "./interface/IChainlinkPriceFeed.sol"; import { IPriceFeedV2 } from "./interface/IPriceFeedV2.sol"; import { BlockContext } from "./base/BlockContext.sol"; import { CachedTwap } from "./twap/CachedTwap.sol"; -contract ChainlinkPriceFeedV2 is IPriceFeedV2, BlockContext, CachedTwap { +contract ChainlinkPriceFeedV2 is IChainlinkPriceFeed, IPriceFeedV2, BlockContext, CachedTwap { using Address for address; AggregatorV3Interface private immutable _aggregator; @@ -38,6 +39,25 @@ contract ChainlinkPriceFeedV2 is IPriceFeedV2, BlockContext, CachedTwap { return _aggregator.decimals(); } + function getAggregator() external view override returns (address) { + return address(_aggregator); + } + + function getRoundData(uint80 roundId) external view override returns (uint256, uint256) { + // NOTE: aggregator will revert if roundId is invalid (but there might not be a revert message sometimes) + // will return (roundId, 0, 0, 0, roundId) if round is not complete (not existed yet) + // https://docs.chain.link/docs/historical-price-data/ + (, int256 price, , uint256 updatedAt, ) = _aggregator.getRoundData(roundId); + + // CPF_IP: Invalid Price + require(price > 0, "CPF_IP"); + + // CPF_RINC: Round Is Not Complete + require(updatedAt > 0, "CPF_RINC"); + + return (uint256(price), updatedAt); + } + function getPrice(uint256 interval) external view override returns (uint256) { (uint80 round, uint256 latestPrice, uint256 latestTimestamp) = _getLatestRoundData(); diff --git a/contracts/interface/IChainlinkPriceFeed.sol b/contracts/interface/IChainlinkPriceFeed.sol new file mode 100644 index 0000000..f380635 --- /dev/null +++ b/contracts/interface/IChainlinkPriceFeed.sol @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: GPL-3.0-or-later +pragma solidity 0.7.6; + +interface IChainlinkPriceFeed { + function getAggregator() external view returns (address); + + /// @param roundId The roundId that fed into Chainlink aggregator. + function getRoundData(uint80 roundId) external view returns (uint256, uint256); +} diff --git a/test/ChainlinkPriceFeed.spec.ts b/test/ChainlinkPriceFeed.spec.ts index 5a0407f..2b65094 100644 --- a/test/ChainlinkPriceFeed.spec.ts +++ b/test/ChainlinkPriceFeed.spec.ts @@ -4,6 +4,7 @@ import { BigNumber } from "ethers" import { parseEther, parseUnits } from "ethers/lib/utils" import { ethers, waffle } from "hardhat" import { ChainlinkPriceFeed, TestAggregatorV3, TestAggregatorV3__factory } from "../typechain" +import { computeRoundId } from "./shared/chainlink" interface ChainlinkPriceFeedFixture { chainlinkPriceFeed: ChainlinkPriceFeed @@ -30,12 +31,6 @@ async function chainlinkPriceFeedFixture(): Promise { return { chainlinkPriceFeed, aggregator, chainlinkPriceFeed2, aggregator2 } } -// https://docs.chain.link/docs/historical-price-data/#roundid-in-proxy -function computeRoundId(phaseId: number, aggregatorRoundId: number): string { - const roundId = (BigInt(phaseId) << BigInt("64")) | BigInt(aggregatorRoundId) - return roundId.toString() -} - describe("ChainlinkPriceFeed Spec", () => { const [admin] = waffle.provider.getWallets() const loadFixture: ReturnType = waffle.createFixtureLoader([admin]) diff --git a/test/ChainlinkPriceFeedV2.spec.ts b/test/ChainlinkPriceFeedV2.spec.ts index f353b7c..f699e3a 100644 --- a/test/ChainlinkPriceFeedV2.spec.ts +++ b/test/ChainlinkPriceFeedV2.spec.ts @@ -1,12 +1,16 @@ import { MockContract, smock } from "@defi-wonderland/smock" import { expect } from "chai" -import { parseEther } from "ethers/lib/utils" +import { BigNumber } from "ethers" +import { parseEther, parseUnits } from "ethers/lib/utils" import { ethers, waffle } from "hardhat" import { ChainlinkPriceFeedV2, TestAggregatorV3, TestAggregatorV3__factory } from "../typechain" +import { computeRoundId } from "./shared/chainlink" interface ChainlinkPriceFeedFixture { chainlinkPriceFeed: ChainlinkPriceFeedV2 aggregator: MockContract + chainlinkPriceFeed2: ChainlinkPriceFeedV2 + aggregator2: MockContract } async function chainlinkPriceFeedFixture(): Promise { @@ -17,7 +21,17 @@ async function chainlinkPriceFeedFixture(): Promise { const chainlinkPriceFeedFactory = await ethers.getContractFactory("ChainlinkPriceFeedV2") const chainlinkPriceFeed = (await chainlinkPriceFeedFactory.deploy(aggregator.address, 900)) as ChainlinkPriceFeedV2 - return { chainlinkPriceFeed, aggregator } + const aggregatorFactory2 = await smock.mock("TestAggregatorV3") + const aggregator2 = await aggregatorFactory2.deploy() + aggregator2.decimals.returns(() => 8) + + const chainlinkPriceFeedFactory2 = await ethers.getContractFactory("ChainlinkPriceFeedV2") + const chainlinkPriceFeed2 = (await chainlinkPriceFeedFactory2.deploy( + aggregator2.address, + 900, + )) as ChainlinkPriceFeedV2 + + return { chainlinkPriceFeed, aggregator, chainlinkPriceFeed2, aggregator2 } } describe("ChainlinkPriceFeedV2 Spec", () => { @@ -25,30 +39,39 @@ describe("ChainlinkPriceFeedV2 Spec", () => { const loadFixture: ReturnType = waffle.createFixtureLoader([admin]) let chainlinkPriceFeed: ChainlinkPriceFeedV2 let aggregator: MockContract - let currentTime: number - let roundData: any[] - - async function updatePrice(index: number, price: number, forward: boolean = true): Promise { - roundData.push([index, parseEther(price.toString()), currentTime, currentTime, index]) - aggregator.latestRoundData.returns(() => { - return roundData[roundData.length - 1] - }) - await chainlinkPriceFeed.update() - - if (forward) { - currentTime += 15 - await ethers.provider.send("evm_setNextBlockTimestamp", [currentTime]) - await ethers.provider.send("evm_mine", []) - } - } + let priceFeedDecimals: number + let chainlinkPriceFeed2: ChainlinkPriceFeedV2 + let aggregator2: MockContract + let priceFeedDecimals2: number beforeEach(async () => { const _fixture = await loadFixture(chainlinkPriceFeedFixture) chainlinkPriceFeed = _fixture.chainlinkPriceFeed aggregator = _fixture.aggregator + priceFeedDecimals = await chainlinkPriceFeed.decimals() + chainlinkPriceFeed2 = _fixture.chainlinkPriceFeed2 + aggregator2 = _fixture.aggregator2 + priceFeedDecimals2 = await chainlinkPriceFeed2.decimals() }) describe("edge cases, have the same timestamp for several rounds", () => { + let currentTime: number + let roundData: any[] + + async function updatePrice(index: number, price: number, forward: boolean = true): Promise { + roundData.push([index, parseEther(price.toString()), currentTime, currentTime, index]) + aggregator.latestRoundData.returns(() => { + return roundData[roundData.length - 1] + }) + await chainlinkPriceFeed.update() + + if (forward) { + currentTime += 15 + await ethers.provider.send("evm_setNextBlockTimestamp", [currentTime]) + await ethers.provider.send("evm_mine", []) + } + } + it("force error, can't update if timestamp is the same", async () => { currentTime = (await waffle.provider.getBlock("latest")).timestamp roundData = [ @@ -65,4 +88,91 @@ describe("ChainlinkPriceFeedV2 Spec", () => { await expect(chainlinkPriceFeed.update()).to.be.revertedWith("CT_IT") }) }) + + describe("getRoundData", async () => { + let currentTime: number + + beforeEach(async () => { + currentTime = (await waffle.provider.getBlock("latest")).timestamp + + await aggregator2.setRoundData( + computeRoundId(1, 1), + parseUnits("1800", priceFeedDecimals2), + BigNumber.from(currentTime), + BigNumber.from(currentTime), + computeRoundId(1, 1), + ) + await aggregator2.setRoundData( + computeRoundId(1, 2), + parseUnits("1900", priceFeedDecimals2), + BigNumber.from(currentTime + 15), + BigNumber.from(currentTime + 15), + computeRoundId(1, 2), + ) + await aggregator2.setRoundData( + computeRoundId(2, 10000), + parseUnits("1700", priceFeedDecimals2), + BigNumber.from(currentTime + 30), + BigNumber.from(currentTime + 30), + computeRoundId(2, 10000), + ) + + // updatedAt is 0 means the round is not complete and should not be used + await aggregator2.setRoundData( + computeRoundId(2, 20000), + parseUnits("-0.1", priceFeedDecimals2), + BigNumber.from(currentTime + 45), + BigNumber.from(0), + computeRoundId(2, 20000), + ) + + // updatedAt is 0 means the round is not complete and should not be used + await aggregator2.setRoundData( + computeRoundId(2, 20001), + parseUnits("5000", priceFeedDecimals2), + BigNumber.from(currentTime + 45), + BigNumber.from(0), + computeRoundId(2, 20001), + ) + }) + + it("computeRoundId", async () => { + expect(computeRoundId(1, 1)).to.be.eq(await aggregator2.computeRoundId(1, 1)) + expect(computeRoundId(1, 2)).to.be.eq(await aggregator2.computeRoundId(1, 2)) + expect(computeRoundId(2, 10000)).to.be.eq(await aggregator2.computeRoundId(2, 10000)) + }) + + it("getRoundData with valid roundId", async () => { + expect(await chainlinkPriceFeed2.getRoundData(computeRoundId(1, 1))).to.be.deep.eq([ + parseUnits("1800", priceFeedDecimals2), + BigNumber.from(currentTime), + ]) + + expect(await chainlinkPriceFeed2.getRoundData(computeRoundId(1, 2))).to.be.deep.eq([ + parseUnits("1900", priceFeedDecimals2), + BigNumber.from(currentTime + 15), + ]) + + expect(await chainlinkPriceFeed2.getRoundData(computeRoundId(2, 10000))).to.be.deep.eq([ + parseUnits("1700", priceFeedDecimals2), + BigNumber.from(currentTime + 30), + ]) + }) + + it("force error, getRoundData when price <= 0", async () => { + // price < 0 + await expect(chainlinkPriceFeed2.getRoundData(computeRoundId(2, 20000))).to.be.revertedWith("CPF_IP") + + // price = 0 + await expect(chainlinkPriceFeed2.getRoundData("123")).to.be.revertedWith("CPF_IP") + }) + + it("force error, getRoundData when round is not complete", async () => { + await expect(chainlinkPriceFeed2.getRoundData(computeRoundId(2, 20001))).to.be.revertedWith("CPF_RINC") + }) + }) + + it("getAggregator", async () => { + expect(await chainlinkPriceFeed2.getAggregator()).to.be.eq(aggregator2.address) + }) }) diff --git a/test/shared/chainlink.ts b/test/shared/chainlink.ts new file mode 100644 index 0000000..71dd1f4 --- /dev/null +++ b/test/shared/chainlink.ts @@ -0,0 +1,5 @@ +// https://docs.chain.link/docs/historical-price-data/#roundid-in-proxy +export function computeRoundId(phaseId: number, aggregatorRoundId: number): string { + const roundId = (BigInt(phaseId) << BigInt("64")) | BigInt(aggregatorRoundId) + return roundId.toString() +}