diff --git a/packages/drift/src/adapter/MockAdapter.test.ts b/packages/drift/src/adapter/MockAdapter.test.ts index 24e1255d..8dbd37b0 100644 --- a/packages/drift/src/adapter/MockAdapter.test.ts +++ b/packages/drift/src/adapter/MockAdapter.test.ts @@ -1,46 +1,549 @@ import { MockAdapter } from "src/adapter/MockAdapter"; +import type { Event } from "src/adapter/contract/types/event"; +import type { Block } from "src/adapter/network/types/Block"; +import type { + Transaction, + TransactionReceipt, +} from "src/adapter/network/types/Transaction"; +import type { + DecodeFunctionDataParams, + EncodeFunctionDataParams, + GetEventsParams, + ReadParams, + WriteParams, +} from "src/adapter/types"; import { IERC20 } from "src/utils/testing/IERC20"; import { describe, expect, it } from "vitest"; describe("MockAdapter", () => { - it("Includes a mock network", async () => { - const adapter = new MockAdapter(); - const blockStub = { - blockNumber: 100n, - timestamp: 200n, - }; - adapter.network.stubGetBlock({ - value: blockStub, + describe("getBalance", () => { + it("Resolves to a default value", async () => { + const adapter = new MockAdapter(); + expect(await adapter.getBalance("0x0")).toBeTypeOf("bigint"); + }); + + it("Can be stubbed", async () => { + const adapter = new MockAdapter(); + adapter.onGetBalance().resolves(123n); + expect(await adapter.getBalance("0x")).toBe(123n); + }); + + it("Can be stubbed with specific args", async () => { + const adapter = new MockAdapter(); + const defaultValue = await adapter.getBalance("0x"); + adapter.onGetBalance("0xAlice").resolves(defaultValue + 1n); + adapter.onGetBalance("0xBob").resolves(defaultValue + 2n); + expect(await adapter.getBalance("0xAlice")).toBe(defaultValue + 1n); + expect(await adapter.getBalance("0xBob")).toBe(defaultValue + 2n); + }); + }); + + describe("getBlock", () => { + it("Resolves to a default value", async () => { + const adapter = new MockAdapter(); + expect(adapter.getBlock()).resolves.toEqual({ + blockNumber: expect.any(BigInt), + timestamp: expect.any(BigInt), + }); + }); + + it("Can be stubbed", async () => { + const adapter = new MockAdapter(); + const block: Block = { + blockNumber: 123n, + timestamp: 123n, + }; + adapter.onGetBlock().resolves(block); + expect(await adapter.getBlock()).toBe(block); + }); + + it("Can be stubbed with specific args", async () => { + const adapter = new MockAdapter(); + const { blockNumber, timestamp } = await adapter.getBlock(); + const block1: Block = { + blockNumber: blockNumber ?? 0n + 1n, + timestamp: timestamp + 1n, + }; + const block2: Block = { + blockNumber: blockNumber ?? 0n + 2n, + timestamp: timestamp + 2n, + }; + adapter.onGetBlock({ blockNumber: 1n }).resolves(block1); + adapter.onGetBlock({ blockNumber: 2n }).resolves(block2); + expect(await adapter.getBlock({ blockNumber: 1n })).toBe(block1); + expect(await adapter.getBlock({ blockNumber: 2n })).toBe(block2); + }); + }); + + describe("getChainId", () => { + it("Resolves to a default value", async () => { + const adapter = new MockAdapter(); + expect(await adapter.getChainId()).toBeTypeOf("number"); + }); + + it("Can be stubbed", async () => { + const adapter = new MockAdapter(); + adapter.onGetChainId().resolves(123); + expect(await adapter.getChainId()).toBe(123); + }); + }); + + describe("getTransaction", () => { + it("Resolves to undefined by default", async () => { + const adapter = new MockAdapter(); + expect(await adapter.getTransaction("0x")).toBeUndefined(); + }); + + it("Can be stubbed", async () => { + const adapter = new MockAdapter(); + const transaction: Transaction = { + blockNumber: 123n, + gas: 123n, + gasPrice: 123n, + input: "0x", + nonce: 123, + type: "0x123", + value: 123n, + }; + adapter.onGetTransaction().resolves(transaction); + expect(await adapter.getTransaction("0x")).toBe(transaction); + }); + + it("Can be stubbed with specific args", async () => { + const adapter = new MockAdapter(); + const transaction1: Transaction = { + input: "0x1", + blockNumber: 123n, + gas: 123n, + gasPrice: 123n, + nonce: 123, + type: "0x123", + value: 123n, + }; + const transaction2: Transaction = { + ...transaction1, + input: "0x2", + }; + adapter.onGetTransaction("0x1").resolves(transaction1); + adapter.onGetTransaction("0x2").resolves(transaction2); + expect(await adapter.getTransaction("0x1")).toBe(transaction1); + expect(await adapter.getTransaction("0x2")).toBe(transaction2); }); - const block = await adapter.network.getBlock(); - expect(block).toBe(blockStub); }); - it("Stubs the signer address", async () => { - const adapter = new MockAdapter(); - const signer = await adapter.getSignerAddress(); - expect(signer).toBeTypeOf("string"); + describe("waitForTransaction", () => { + it("Resolves to undefined by default", async () => { + const adapter = new MockAdapter(); + expect(adapter.waitForTransaction("0x")).resolves.toBeUndefined(); + }); + + it("Can be stubbed", async () => { + const adapter = new MockAdapter(); + const receipt: TransactionReceipt = { + blockNumber: 123n, + blockHash: "0x", + cumulativeGasUsed: 123n, + effectiveGasPrice: 123n, + from: "0x", + gasUsed: 123n, + logsBloom: "0x", + status: "success", + to: "0x", + transactionHash: "0x", + transactionIndex: 123, + }; + adapter.onWaitForTransaction().resolves(receipt); + expect(await adapter.waitForTransaction("0x")).toBe(receipt); + }); + + it("Can be stubbed with specific args", async () => { + const adapter = new MockAdapter(); + const receipt1: TransactionReceipt = { + transactionHash: "0x1", + blockNumber: 123n, + blockHash: "0x", + cumulativeGasUsed: 123n, + effectiveGasPrice: 123n, + from: "0x", + gasUsed: 123n, + logsBloom: "0x", + status: "success", + to: "0x", + transactionIndex: 123, + }; + const receipt2: TransactionReceipt = { + ...receipt1, + transactionHash: "0x2", + }; + adapter.onWaitForTransaction("0x1").resolves(receipt1); + adapter.onWaitForTransaction("0x2").resolves(receipt2); + expect(await adapter.waitForTransaction("0x1")).toBe(receipt1); + expect(await adapter.waitForTransaction("0x2")).toBe(receipt2); + }); }); - it("Creates mock read contracts", async () => { - const mockAdapter = new MockAdapter(); - const contract = mockAdapter.readContract(IERC20.abi); - contract.stubRead({ - functionName: "balanceOf", - value: 100n, + describe("encodeFunctionData", () => { + it("Returns a default value", async () => { + const adapter = new MockAdapter(); + expect( + adapter.encodeFunctionData({ + abi: IERC20.abi, + fn: "balanceOf", + args: { owner: "0x" }, + }), + ).toBeTypeOf("string"); + }); + + it("Can be stubbed", async () => { + const adapter = new MockAdapter(); + adapter + .onEncodeFunctionData({ + abi: IERC20.abi, + fn: "balanceOf", + args: { owner: "0x" }, + }) + .returns("0x123"); + expect( + adapter.encodeFunctionData({ + abi: IERC20.abi, + fn: "balanceOf", + args: { owner: "0x" }, + }), + ).toBe("0x123"); + }); + + it("Can be stubbed with specific args", async () => { + const adapter = new MockAdapter(); + const params1: EncodeFunctionDataParams = + { + abi: IERC20.abi, + fn: "balanceOf", + args: { owner: "0x1" }, + }; + const params2: EncodeFunctionDataParams = + { + ...params1, + args: { owner: "0x2" }, + }; + adapter.onEncodeFunctionData(params1).returns("0x1"); + adapter.onEncodeFunctionData(params2).returns("0x2"); + expect(adapter.encodeFunctionData(params1)).toBe("0x1"); + expect(adapter.encodeFunctionData(params2)).toBe("0x2"); }); - const balance = await contract.read("balanceOf", { owner: "0xMe" }); - expect(balance).toBe(100n); }); - it("Creates mock read-write contracts", async () => { - const mockAdapter = new MockAdapter(); - const contract = mockAdapter.readWriteContract(IERC20.abi); - contract.stubWrite("approve", "0xDone"); - const txHash = await contract.write("approve", { - spender: "0xYou", - value: 100n, + describe("decodeFunctionData", () => { + it("Throws an error by default", async () => { + const adapter = new MockAdapter(); + expect( + (async () => + adapter.decodeFunctionData({ + abi: IERC20.abi, + fn: "balanceOf", + data: "0x", + }))(), + ).rejects.toThrowError(); + }); + + it("Can be stubbed", async () => { + const adapter = new MockAdapter(); + adapter + .onDecodeFunctionData({ + abi: IERC20.abi, + fn: "balanceOf", + data: "0x", + }) + .returns(123n); + expect( + adapter.decodeFunctionData({ + abi: IERC20.abi, + fn: "balanceOf", + data: "0x", + }), + ).toBe(123n); + }); + + it("Can be stubbed with specific args", async () => { + const adapter = new MockAdapter(); + const params1: DecodeFunctionDataParams = + { + abi: IERC20.abi, + fn: "balanceOf", + data: "0x1", + }; + const params2: DecodeFunctionDataParams = + { + ...params1, + data: "0x2", + }; + adapter.onDecodeFunctionData(params1).returns(1n); + adapter.onDecodeFunctionData(params2).returns(2n); + expect(adapter.decodeFunctionData(params1)).toBe(1n); + expect(adapter.decodeFunctionData(params2)).toBe(2n); + }); + }); + + describe("getEvents", () => { + it("Rejects with an error by default", async () => { + const adapter = new MockAdapter(); + expect( + adapter.getEvents({ + abi: IERC20.abi, + address: "0x", + event: "Transfer", + }), + ).rejects.toThrowError(); + }); + + it("Can be stubbed", async () => { + const adapter = new MockAdapter(); + const events: Event[] = [ + { + eventName: "Transfer", + args: { + from: "0x", + to: "0x", + value: 123n, + }, + }, + ]; + adapter + .onGetEvents({ + abi: IERC20.abi, + address: "0x", + event: "Transfer", + }) + .resolves(events); + expect( + await adapter.getEvents({ + abi: IERC20.abi, + address: "0x", + event: "Transfer", + }), + ).toBe(events); + }); + + it("Can be stubbed with specific args", async () => { + const adapter = new MockAdapter(); + const params1: GetEventsParams = { + abi: IERC20.abi, + address: "0x1", + event: "Transfer", + filter: { from: "0x1" }, + }; + const params2: GetEventsParams = { + ...params1, + filter: { from: "0x2" }, + }; + const events1: Event[] = [ + { + eventName: "Transfer", + args: { + from: "0x1", + to: "0x1", + value: 123n, + }, + }, + ]; + const events2: Event[] = [ + { + eventName: "Transfer", + args: { + from: "0x2", + to: "0x2", + value: 123n, + }, + }, + ]; + adapter.onGetEvents(params1).resolves(events1); + adapter.onGetEvents(params2).resolves(events2); + expect(await adapter.getEvents(params1)).toBe(events1); + expect(await adapter.getEvents(params2)).toBe(events2); + }); + }); + + describe("read", () => { + it("Rejects with an error by default", async () => { + const adapter = new MockAdapter(); + expect( + adapter.read({ + abi: IERC20.abi, + address: "0x", + fn: "symbol", + }), + ).rejects.toThrowError(); + }); + + it("Can be stubbed", async () => { + const adapter = new MockAdapter(); + adapter + .onRead({ + abi: IERC20.abi, + address: "0x", + fn: "symbol", + }) + .resolves("ABC"); + expect( + await adapter.read({ + abi: IERC20.abi, + address: "0x", + fn: "symbol", + }), + ).toBe("ABC"); + }); + + it("Can be stubbed with specific args", async () => { + const adapter = new MockAdapter(); + const params1: ReadParams = { + abi: IERC20.abi, + address: "0x1", + fn: "allowance", + args: { owner: "0x1", spender: "0x1" }, + }; + const params2: ReadParams = { + ...params1, + args: { owner: "0x2", spender: "0x2" }, + }; + adapter.onRead(params1).resolves(1n); + adapter.onRead(params2).resolves(2n); + expect(await adapter.read(params1)).toBe(1n); + expect(await adapter.read(params2)).toBe(2n); + }); + + it.todo("Can be stubbed with partial args", async () => { + const adapter = new MockAdapter(); + adapter + .onRead({ + abi: IERC20.abi, + address: "0x", + fn: "balanceOf", + }) + .resolves(123n); + expect( + await adapter.read({ + abi: IERC20.abi, + address: "0x", + fn: "balanceOf", + args: { owner: "0x" }, + }), + ).toBe(123n); + }); + }); + + describe("simulateWrite", () => { + it("Rejects with an error by default", async () => { + const adapter = new MockAdapter(); + expect( + adapter.simulateWrite({ + abi: IERC20.abi, + address: "0x", + fn: "transfer", + args: { to: "0x", value: 123n }, + }), + ).rejects.toThrowError(); + }); + + it("Can be stubbed", async () => { + const adapter = new MockAdapter(); + adapter + .onSimulateWrite({ + abi: IERC20.abi, + address: "0x", + fn: "transfer", + args: { to: "0x", value: 123n }, + }) + .resolves(true); + expect( + await adapter.simulateWrite({ + abi: IERC20.abi, + address: "0x", + fn: "transfer", + args: { to: "0x", value: 123n }, + }), + ).toBe(true); + }); + + it("Can be stubbed with specific args", async () => { + const adapter = new MockAdapter(); + const params1: WriteParams = { + abi: IERC20.abi, + address: "0x1", + fn: "transfer", + args: { to: "0x1", value: 123n }, + }; + const params2: WriteParams = { + ...params1, + args: { to: "0x2", value: 123n }, + }; + adapter.onSimulateWrite(params1).resolves(true); + adapter.onSimulateWrite(params2).resolves(false); + expect(await adapter.simulateWrite(params1)).toBe(true); + expect(await adapter.simulateWrite(params2)).toBe(false); + }); + }); + + describe("write", () => { + it("Returns a default value", async () => { + const adapter = new MockAdapter(); + expect( + await adapter.write({ + abi: IERC20.abi, + address: "0x", + fn: "transfer", + args: { to: "0x", value: 123n }, + }), + ).toBeTypeOf("string"); + }); + + it("Can be stubbed", async () => { + const adapter = new MockAdapter(); + adapter + .onWrite({ + abi: IERC20.abi, + address: "0x", + fn: "transfer", + args: { to: "0x", value: 123n }, + }) + .returns("0x123"); + expect( + await adapter.write({ + abi: IERC20.abi, + address: "0x", + fn: "transfer", + args: { to: "0x", value: 123n }, + }), + ).toBe("0x123"); + }); + + it("Can be stubbed with specific args", async () => { + const adapter = new MockAdapter(); + const params1: WriteParams = { + abi: IERC20.abi, + address: "0x", + fn: "transfer", + args: { to: "0x1", value: 123n }, + }; + const params2: WriteParams = { + ...params1, + args: { to: "0x2", value: 123n }, + }; + adapter.onWrite(params1).returns("0x1"); + adapter.onWrite(params2).returns("0x2"); + expect(await adapter.write(params1)).toBe("0x1"); + expect(await adapter.write(params2)).toBe("0x2"); + }); + }); + + describe("getSignerAddress", () => { + it("Returns a default value", async () => { + const adapter = new MockAdapter(); + expect(await adapter.getSignerAddress()).toBeTypeOf("string"); + }); + + it("Can be stubbed", async () => { + const adapter = new MockAdapter(); + adapter.onGetSignerAddress().resolves("0x123"); + expect(await adapter.getSignerAddress()).toBe("0x123"); }); - expect(txHash).toBe("0xDone"); }); }); diff --git a/packages/drift/src/adapter/MockAdapter.ts b/packages/drift/src/adapter/MockAdapter.ts index f808cb2b..2dd8c3fa 100644 --- a/packages/drift/src/adapter/MockAdapter.ts +++ b/packages/drift/src/adapter/MockAdapter.ts @@ -1,13 +1,390 @@ import type { Abi } from "abitype"; -import { ReadContractStub } from "src/adapter/contract/stubs/ReadContractStub"; -import { ReadWriteContractStub } from "src/adapter/contract/stubs/ReadWriteContractStub"; -import { MockNetwork } from "src/adapter/network/MockNetwork"; -import type { ReadWriteAdapter } from "src/adapter/types"; +import { type SinonStub, stub as createStub } from "sinon"; +import type { Event, EventName } from "src/adapter/contract/types/event"; +import type { + FunctionName, + FunctionReturn, +} from "src/adapter/contract/types/function"; +import type { + NetworkGetBalanceArgs, + NetworkGetBlockArgs, + NetworkGetTransactionArgs, + NetworkWaitForTransactionArgs, +} from "src/adapter/network/types/NetworkAdapter"; +import type { + DecodeFunctionDataParams, + EncodeFunctionDataParams, + GetEventsParams, + ReadParams, + ReadWriteAdapter, + WriteParams, +} from "src/adapter/types"; +import type { Address, Bytes, TransactionHash } from "src/types"; +import type { OptionalKeys } from "src/utils/types"; +// TODO: Allow configuration of error throwing/default return value behavior export class MockAdapter implements ReadWriteAdapter { - network = new MockNetwork(); - getSignerAddress = async () => "0xMockSigner"; - readContract = (abi: TAbi) => new ReadContractStub(abi); - readWriteContract = (abi: TAbi) => - new ReadWriteContractStub(abi); + // stubs // + + protected stubs = new Map(); + + protected getStub SinonStub)>({ + key, + create, + }: { + key: string; + create?: TInsert; + }): TInsert extends false | undefined ? SinonStub | undefined : SinonStub { + let stub = this.stubs.get(key); + if (!stub && create) { + stub = typeof create === "function" ? create() : createStub(); + this.stubs.set(key, stub); + } + return stub as any; + } + + reset() { + this.stubs.clear(); + } + + // getBalance // + + protected get getBalanceStub() { + return this.getStub({ + key: "getBalance", + create: () => createStub().resolves(0n), + }); + } + + getBalance(...args: NetworkGetBalanceArgs) { + return this.getBalanceStub(...args); + } + + onGetBalance(...args: Partial) { + return this.getBalanceStub.withArgs(...args); + } + + // getBlock // + + protected get getBlockStub() { + return this.getStub({ + key: "getBlock", + create: () => + createStub().resolves({ + blockNumber: 0n, + timestamp: 0n, + }), + }); + } + + getBlock(...args: NetworkGetBlockArgs) { + return this.getBlockStub(...args); + } + + onGetBlock(...args: Partial) { + return this.getBlockStub.withArgs(...args); + } + + // getChainId // + + protected get getChainIdStub() { + return this.getStub({ + key: "getChainId", + create: () => createStub().resolves(0), + }); + } + + getChainId() { + return this.getChainIdStub(); + } + + onGetChainId() { + return this.getChainIdStub; + } + + // getTransaction // + + protected get getTransactionStub() { + return this.getStub({ + key: "getTransaction", + create: () => createStub().resolves(undefined), + }); + } + + getTransaction(...args: NetworkGetTransactionArgs) { + return this.getTransactionStub(...args); + } + + onGetTransaction(...args: Partial) { + return this.getTransactionStub.withArgs(...args); + } + + // waitForTransaction // + + protected get waitForTransactionStub() { + return this.getStub({ + key: "waitForTransaction", + create: () => createStub().resolves(undefined), + }); + } + + waitForTransaction(...args: NetworkWaitForTransactionArgs) { + return this.waitForTransactionStub(...args); + } + + onWaitForTransaction(...args: Partial) { + return this.waitForTransactionStub.withArgs(...args); + } + + // encodeFunction // + + protected get encodeFunctionDataStub() { + return this.getStub({ + key: "encodeFunctionData", + create: () => createStub().returns("0x0"), + }); + } + + encodeFunctionData< + TAbi extends Abi, + TFunctionName extends FunctionName, + >(params: EncodeFunctionDataParams): Bytes { + return this.encodeFunctionDataStub(params); + } + + onEncodeFunctionData< + TAbi extends Abi, + TFunctionName extends FunctionName, + >(params: EncodeFunctionDataStubParams) { + return this.encodeFunctionDataStub.withArgs(params); + } + + // decodeFunction // + + // TODO: This should be specific to the abi to ensure the correct return type. + protected decodeFunctionDataStubKey({ + fn, + }: Partial) { + return `decodeFunctionData:${fn}`; + } + + decodeFunctionData< + TAbi extends Abi, + TFunctionName extends FunctionName, + >( + params: DecodeFunctionDataParams, + ): FunctionReturn { + const stub = this.getStub({ + key: this.decodeFunctionDataStubKey(params), + }); + if (!stub) { + throw new NotImplementedError({ + name: params.fn || params.data, + method: "decodeFunctionData", + stubMethod: "onDecodeFunctionData", + }); + } + return stub(params); + } + + // TODO: Does calling `onDecodeFunctionData` without calling any methods on + // it, e.g. `returns`, break the error behavior? + onDecodeFunctionData< + TAbi extends Abi, + TFunctionName extends FunctionName, + >(params: DecodeFunctionDataStubParams) { + return this.getStub({ + key: this.decodeFunctionDataStubKey(params), + create: true, + }).withArgs(params); + } + + // getEvents // + + protected getEventsStubKey({ + address, + event, + }: Partial>): string { + return `getEvents:${address}:${event}`; + } + + getEvents>( + params: GetEventsParams, + ): Promise[]> { + const stub = this.stubs.get(this.getEventsStubKey(params)); + if (!stub) { + return Promise.reject( + new NotImplementedError({ + name: params.event, + method: "getEvents", + stubMethod: "onGetEvents", + }), + ); + } + return Promise.resolve(stub(params)); + } + + onGetEvents>( + params: GetEventsParams, + ) { + return this.getStub< + [GetEventsParams], + Promise[]> + >({ + key: this.getEventsStubKey(params), + args: [params], + }); + } + + // read // + + protected readStubKey({ address, fn }: ReadStubParams) { + return `read:${address}:${fn}`; + } + + read< + TAbi extends Abi, + TFunctionName extends FunctionName, + >( + params: ReadParams, + ): Promise> { + const stub = this.stubs.get(this.readStubKey(params)); + if (!stub) { + return Promise.reject( + new NotImplementedError({ + name: params.fn, + method: "read", + stubMethod: "onRead", + }), + ); + } + return Promise.resolve(stub(params)); + } + + onRead< + TAbi extends Abi, + TFunctionName extends FunctionName, + >(params: ReadStubParams) { + return this.getStub< + [ReadStubParams], + Promise> + >({ + key: this.readStubKey(params), + args: [params], + }); + } + + // simulateWrite // + + protected simulateWriteStubKey({ address, fn }: WriteStubParams) { + return `simulateWrite:${address}:${fn}`; + } + + simulateWrite< + TAbi extends Abi, + TFunctionName extends FunctionName, + >( + params: WriteParams, + ): Promise> { + const stub = this.stubs.get(this.simulateWriteStubKey(params)); + if (!stub) { + return Promise.reject( + new NotImplementedError({ + name: params.fn, + method: "simulateWrite", + stubMethod: "onSimulateWrite", + }), + ); + } + return Promise.resolve(stub(params)); + } + + onSimulateWrite< + TAbi extends Abi, + TFunctionName extends FunctionName, + >(params: WriteStubParams) { + return this.getStub< + [WriteStubParams], + Promise> + >({ + key: this.simulateWriteStubKey(params), + args: [params], + }); + } + + // write // + + protected get writeStub() { + return this.getStub<[WriteStubParams], Bytes>({ + key: "write", + }).returns("0x0"); + } + + write< + TAbi extends Abi, + TFunctionName extends FunctionName, + >(params: WriteParams): Promise { + return Promise.resolve(this.writeStub(params)); + } + + onWrite< + TAbi extends Abi, + TFunctionName extends FunctionName, + >(params: WriteStubParams) { + return this.writeStub.withArgs(params); + } + + // getSignerAddress // + + protected get getSignerAddressStub() { + const key = "getSignerAddress"; + let stub = this.stubs.get(key); + if (!stub) { + stub = createStub().resolves("0xMockSigner"); + this.stubs.set(key, stub); + } + } + + onGetSignerAddress() { + return this.getSignerAddressStub; + } + + getSignerAddress(): Promise
{ + return Promise.resolve(this.getSignerAddressStub()); + } +} + +// TODO: Make address optional and create a key from the abi entry and fn name. +export type ReadStubParams< + TAbi extends Abi = Abi, + TFunctionName extends FunctionName = FunctionName, +> = OptionalKeys, "args">; + +export type WriteStubParams< + TAbi extends Abi = Abi, + TFunctionName extends FunctionName = FunctionName, +> = OptionalKeys, "args" | "abi">; + +export type EncodeFunctionDataStubParams< + TAbi extends Abi = Abi, + TFunctionName extends FunctionName = FunctionName, +> = OptionalKeys, "args">; + +export type DecodeFunctionDataStubParams< + TAbi extends Abi = Abi, + TFunctionName extends FunctionName = FunctionName, +> = OptionalKeys, "data" | "abi">; + +class NotImplementedError extends Error { + constructor({ + method, + stubMethod, + name, + }: { method: string; stubMethod: string; name?: string }) { + // TODO: This error message is not accurate. + super( + `Called ${method}${name ? ` for "${name}"` : ""} on a MockAdapter without a return value. The function must be stubbed first:\n\tadapter.${stubMethod}("${name}").resolves(value)`, + ); + this.name = "NotImplementedError"; + } } diff --git a/packages/drift/src/adapter/contract/stubs/ReadContractStub.test.ts b/packages/drift/src/adapter/contract/mocks/ReadContractStub.test.ts similarity index 97% rename from packages/drift/src/adapter/contract/stubs/ReadContractStub.test.ts rename to packages/drift/src/adapter/contract/mocks/ReadContractStub.test.ts index ce5d2253..ad721ef7 100644 --- a/packages/drift/src/adapter/contract/stubs/ReadContractStub.test.ts +++ b/packages/drift/src/adapter/contract/mocks/ReadContractStub.test.ts @@ -1,5 +1,5 @@ -import { ReadContractStub } from "src/adapter/contract/stubs/ReadContractStub"; -import type { Event } from "src/adapter/contract/types/Event"; +import { ReadContractStub } from "src/adapter/contract/mocks/ReadContractStub"; +import type { Event } from "src/adapter/contract/types/event"; import { IERC20 } from "src/utils/testing/IERC20"; import { ALICE, BOB, NANCY } from "src/utils/testing/accounts"; import { describe, expect, it } from "vitest"; diff --git a/packages/drift/src/adapter/contract/stubs/ReadContractStub.ts b/packages/drift/src/adapter/contract/mocks/ReadContractStub.ts similarity index 98% rename from packages/drift/src/adapter/contract/stubs/ReadContractStub.ts rename to packages/drift/src/adapter/contract/mocks/ReadContractStub.ts index 1fa981f6..67c20846 100644 --- a/packages/drift/src/adapter/contract/stubs/ReadContractStub.ts +++ b/packages/drift/src/adapter/contract/mocks/ReadContractStub.ts @@ -11,14 +11,14 @@ import type { ContractReadOptions, ContractWriteArgs, ContractWriteOptions, -} from "src/adapter/contract/types/Contract"; -import type { Event, EventName } from "src/adapter/contract/types/Event"; +} from "src/adapter/contract/types/contract"; +import type { Event, EventName } from "src/adapter/contract/types/event"; import type { DecodedFunctionData, FunctionArgs, FunctionName, FunctionReturn, -} from "src/adapter/contract/types/Function"; +} from "src/adapter/contract/types/function"; /** * A mock implementation of a `ReadContract` designed to facilitate unit diff --git a/packages/drift/src/adapter/contract/stubs/ReadWriteContractStub.test.ts b/packages/drift/src/adapter/contract/mocks/ReadWriteContractStub.test.ts similarity index 90% rename from packages/drift/src/adapter/contract/stubs/ReadWriteContractStub.test.ts rename to packages/drift/src/adapter/contract/mocks/ReadWriteContractStub.test.ts index b33b28df..cfa6b0fa 100644 --- a/packages/drift/src/adapter/contract/stubs/ReadWriteContractStub.test.ts +++ b/packages/drift/src/adapter/contract/mocks/ReadWriteContractStub.test.ts @@ -1,4 +1,4 @@ -import { ReadWriteContractStub } from "src/adapter/contract/stubs/ReadWriteContractStub"; +import { ReadWriteContractStub } from "src/adapter/contract/mocks/ReadWriteContractStub"; import { IERC20 } from "src/utils/testing/IERC20"; import { describe, expect, it } from "vitest"; diff --git a/packages/drift/src/adapter/contract/mocks/ReadWriteContractStub.ts b/packages/drift/src/adapter/contract/mocks/ReadWriteContractStub.ts new file mode 100644 index 00000000..cdd469fc --- /dev/null +++ b/packages/drift/src/adapter/contract/mocks/ReadWriteContractStub.ts @@ -0,0 +1,105 @@ +import type { Abi } from "abitype"; +import { type SinonStub, stub } from "sinon"; +import { ReadContractStub } from "src/adapter/contract/mocks/ReadContractStub"; +import type { + AdapterReadWriteContract, + ContractWriteArgs, + ContractWriteOptions, +} from "src/adapter/contract/types/contract"; +import type { + FunctionArgs, + FunctionName, +} from "src/adapter/contract/types/function"; +import { BOB } from "src/utils/testing/accounts"; + +/** + * A mock implementation of a writable Ethereum contract designed for unit + * testing purposes. The `ReadWriteContractStub` extends the functionalities of + * `ReadContractStub` and provides capabilities to stub out specific + * contract write behaviors. This makes it a valuable tool when testing + * scenarios that involve contract writing operations, without actually + * interacting with a real Ethereum contract. + * + * @example + * const contract = new ReadWriteContractStub(ERC20ABI); + * contract.stubWrite("addLiquidity", 100n); + * + * const result = await contract.write("addLiquidity", []); // 100n + * @extends {ReadContractStub} + * @implements {ReadWriteContract} + */ +export class ReadWriteContractStub + extends ReadContractStub + implements AdapterReadWriteContract +{ + protected writeStubMap = new Map< + FunctionName, + WriteStub> + >(); + + getSignerAddress = stub().resolves(BOB); + + /** + * Simulates a contract write operation for a given function. If the function + * is not previously stubbed using `stubWrite` from the parent class, an error + * will be thrown. + */ + async write< + TFunctionName extends FunctionName, + >( + ...[functionName, args, options]: ContractWriteArgs + ): Promise<`0x${string}`> { + const stub = this.getWriteStub(functionName); + if (!stub) { + throw new Error( + `Called write for ${functionName} on a stubbed contract without a return value. The function must be stubbed first:\n\tcontract.stubWrite("${functionName}", value)`, + ); + } + return stub(args, options); + } + + /** + * Stubs the return value for a given function when `simulateWrite` is called + * with that function name. This method overrides any previously stubbed + * values for the same function. + * + * *Note: The stub doesn't account for dynamic values based on provided + * arguments/options.* + */ + stubWrite>( + functionName: TFunctionName, + value: `0x${string}`, + ): void { + let writeStub = this.writeStubMap.get(functionName); + if (!writeStub) { + writeStub = stub(); + this.writeStubMap.set(functionName, writeStub); + } + writeStub.resolves(value); + } + + /** + * Retrieves the stub associated with a write function name. + * Useful for assertions in testing, such as checking call counts. + */ + getWriteStub< + TFunctionName extends FunctionName, + >(functionName: TFunctionName): WriteStub | undefined { + return this.writeStubMap.get(functionName) as WriteStub< + TAbi, + TFunctionName + >; + } +} + +/** + * Type representing a stub for the "write" and "simulateWrite" functions of a + * contract. + */ +type WriteStub< + TAbi extends Abi, + TFunctionName extends FunctionName, +> = SinonStub< + [args?: FunctionArgs, options?: ContractWriteOptions], + `0x${string}` +>; diff --git a/packages/drift/src/adapter/contract/stubs/ReadWriteContractStub.ts b/packages/drift/src/adapter/contract/stubs/ReadWriteContractStub.ts index e0fb0b28..cdd469fc 100644 --- a/packages/drift/src/adapter/contract/stubs/ReadWriteContractStub.ts +++ b/packages/drift/src/adapter/contract/stubs/ReadWriteContractStub.ts @@ -1,15 +1,15 @@ import type { Abi } from "abitype"; import { type SinonStub, stub } from "sinon"; -import { ReadContractStub } from "src/adapter/contract/stubs/ReadContractStub"; +import { ReadContractStub } from "src/adapter/contract/mocks/ReadContractStub"; import type { AdapterReadWriteContract, ContractWriteArgs, ContractWriteOptions, -} from "src/adapter/contract/types/Contract"; +} from "src/adapter/contract/types/contract"; import type { FunctionArgs, FunctionName, -} from "src/adapter/contract/types/Function"; +} from "src/adapter/contract/types/function"; import { BOB } from "src/utils/testing/accounts"; /** diff --git a/packages/drift/src/adapter/contract/types/Contract.ts b/packages/drift/src/adapter/contract/types/Contract.ts index eb9497f4..5833114c 100644 --- a/packages/drift/src/adapter/contract/types/Contract.ts +++ b/packages/drift/src/adapter/contract/types/Contract.ts @@ -3,14 +3,14 @@ import type { Event, EventFilter, EventName, -} from "src/adapter/contract/types/Event"; +} from "src/adapter/contract/types/event"; import type { DecodedFunctionData, FunctionArgs, FunctionName, FunctionReturn, -} from "src/adapter/contract/types/Function"; -import type { BlockTag } from "src/adapter/network/Block"; +} from "src/adapter/contract/types/function"; +import type { BlockTag } from "src/adapter/network/types/Block"; import type { EmptyObject } from "src/utils/types"; // https://ethereum.github.io/execution-apis/api-documentation/ @@ -21,7 +21,7 @@ import type { EmptyObject } from "src/utils/types"; */ export interface AdapterReadContract { abi: TAbi; - address: `0x${string}`; + address: string; /** * Reads a specified function from the contract. @@ -51,7 +51,7 @@ export interface AdapterReadContract { */ encodeFunctionData>( ...args: ContractEncodeFunctionDataArgs - ): `0x${string}`; + ): string; /** * Decodes a string of function calldata into it's arguments and function @@ -73,7 +73,7 @@ export interface AdapterReadWriteContract /** * Get the address of the signer for this contract. */ - getSignerAddress(): Promise<`0x${string}`>; + getSignerAddress(): Promise; /** * Writes to a specified function on the contract. @@ -81,7 +81,7 @@ export interface AdapterReadWriteContract */ write>( ...args: ContractWriteArgs - ): Promise<`0x${string}`>; + ): Promise; } // https://github.com/ethereum/execution-apis/blob/main/src/eth/execute.yaml#L1 @@ -99,8 +99,8 @@ export type ContractReadOptions = }; export type ContractReadArgs< - TAbi extends Abi, - TFunctionName extends FunctionName, + TAbi extends Abi = Abi, + TFunctionName extends FunctionName = FunctionName, > = FunctionArgs extends EmptyObject ? [ functionName: TFunctionName, @@ -114,7 +114,7 @@ export type ContractReadArgs< ]; export interface ContractGetEventsOptions< - TAbi extends Abi, + TAbi extends Abi = Abi, TEventName extends EventName = EventName, > { filter?: EventFilter; @@ -123,8 +123,8 @@ export interface ContractGetEventsOptions< } export type ContractGetEventsArgs< - TAbi extends Abi, - TEventName extends EventName, + TAbi extends Abi = Abi, + TEventName extends EventName = EventName, > = [ eventName: TEventName, options?: ContractGetEventsOptions, @@ -132,16 +132,16 @@ export type ContractGetEventsArgs< // https://github.com/ethereum/execution-apis/blob/main/src/schemas/transaction.yaml#L274 export interface ContractWriteOptions { - type?: `0x${string}`; + type?: string; nonce?: bigint; - to?: `0x${string}`; - from?: `0x${string}`; + to?: string; + from?: string; /** * Gas limit */ gas?: bigint; value?: bigint; - input?: `0x${string}`; + input?: string; /** * The gas price willing to be paid by the sender in wei */ @@ -159,8 +159,8 @@ export interface ContractWriteOptions { * EIP-2930 access list */ accessList?: { - address: `0x${string}`; - storageKeys: `0x${string}`[]; + address: string; + storageKeys: string[]; }[]; /** * Chain ID that this transaction is valid on. @@ -169,8 +169,11 @@ export interface ContractWriteOptions { } export type ContractWriteArgs< - TAbi extends Abi, - TFunctionName extends FunctionName, + TAbi extends Abi = Abi, + TFunctionName extends FunctionName< + TAbi, + "nonpayable" | "payable" + > = FunctionName, > = FunctionArgs extends EmptyObject ? [ functionName: TFunctionName, @@ -184,10 +187,10 @@ export type ContractWriteArgs< ]; export type ContractEncodeFunctionDataArgs< - TAbi extends Abi, - TFunctionName extends FunctionName, + TAbi extends Abi = Abi, + TFunctionName extends FunctionName = FunctionName, > = FunctionArgs extends EmptyObject ? [functionName: TFunctionName, args?: FunctionArgs] : [functionName: TFunctionName, args: FunctionArgs]; -export type ContractDecodeFunctionDataArgs = [data: `0x${string}`]; +export type ContractDecodeFunctionDataArgs = [data: string]; diff --git a/packages/drift/src/adapter/contract/types/Event.ts b/packages/drift/src/adapter/contract/types/Event.ts index 0d73256e..d524721f 100644 --- a/packages/drift/src/adapter/contract/types/Event.ts +++ b/packages/drift/src/adapter/contract/types/Event.ts @@ -5,7 +5,7 @@ import type { AbiParameters, AbiParametersToObject, NamedAbiParameter, -} from "src/adapter/contract/types/AbiEntry"; +} from "src/adapter/contract/types/abi"; /** * Get a union of event names from an abi diff --git a/packages/drift/src/adapter/contract/types/Function.ts b/packages/drift/src/adapter/contract/types/Function.ts index f2eaadf1..a9b16926 100644 --- a/packages/drift/src/adapter/contract/types/Function.ts +++ b/packages/drift/src/adapter/contract/types/Function.ts @@ -2,7 +2,7 @@ import type { Abi, AbiStateMutability } from "abitype"; import type { AbiFriendlyType, AbiObjectType, -} from "src/adapter/contract/types/AbiEntry"; +} from "src/adapter/contract/types/abi"; /** * Get a union of function names from an abi diff --git a/packages/drift/src/adapter/contract/types/AbiEntry.ts b/packages/drift/src/adapter/contract/types/abi.ts similarity index 99% rename from packages/drift/src/adapter/contract/types/AbiEntry.ts rename to packages/drift/src/adapter/contract/types/abi.ts index 11010a48..2b1c6aa5 100644 --- a/packages/drift/src/adapter/contract/types/AbiEntry.ts +++ b/packages/drift/src/adapter/contract/types/abi.ts @@ -137,7 +137,7 @@ type NamedParametersToObject< // key, so we have to use `number` as the key for any parameters that have // empty names ("") in arrays Extract extends never - ? unknown // <- No parameters with empty names + ? any // <- No parameters with empty names : { [index: number]: AbiParameterToPrimitiveType< Extract, diff --git a/packages/drift/src/adapter/contract/utils/arrayToFriendly.ts b/packages/drift/src/adapter/contract/utils/arrayToFriendly.ts index b5fb887b..20172eb1 100644 --- a/packages/drift/src/adapter/contract/utils/arrayToFriendly.ts +++ b/packages/drift/src/adapter/contract/utils/arrayToFriendly.ts @@ -3,7 +3,7 @@ import type { AbiArrayType, AbiEntryName, AbiFriendlyType, -} from "src/adapter/contract/types/AbiEntry"; +} from "src/adapter/contract/types/abi"; import { getAbiEntry } from "src/adapter/contract/utils/getAbiEntry"; /** diff --git a/packages/drift/src/adapter/contract/utils/arrayToObject.ts b/packages/drift/src/adapter/contract/utils/arrayToObject.ts index 66d9eaba..37c9f56b 100644 --- a/packages/drift/src/adapter/contract/utils/arrayToObject.ts +++ b/packages/drift/src/adapter/contract/utils/arrayToObject.ts @@ -3,7 +3,7 @@ import type { AbiArrayType, AbiEntryName, AbiObjectType, -} from "src/adapter/contract/types/AbiEntry"; +} from "src/adapter/contract/types/abi"; import { getAbiEntry } from "src/adapter/contract/utils/getAbiEntry"; /** diff --git a/packages/drift/src/adapter/contract/utils/getAbiEntry.ts b/packages/drift/src/adapter/contract/utils/getAbiEntry.ts index 3936d014..4775629b 100644 --- a/packages/drift/src/adapter/contract/utils/getAbiEntry.ts +++ b/packages/drift/src/adapter/contract/utils/getAbiEntry.ts @@ -2,7 +2,7 @@ import type { Abi, AbiItemType } from "abitype"; import type { AbiEntry, AbiEntryName, -} from "src/adapter/contract/types/AbiEntry"; +} from "src/adapter/contract/types/abi"; import { AbiEntryNotFoundError } from "src/errors/AbiEntryNotFound"; /** diff --git a/packages/drift/src/adapter/contract/utils/objectToArray.ts b/packages/drift/src/adapter/contract/utils/objectToArray.ts index 5eb6e1ec..97f938e1 100644 --- a/packages/drift/src/adapter/contract/utils/objectToArray.ts +++ b/packages/drift/src/adapter/contract/utils/objectToArray.ts @@ -3,7 +3,7 @@ import type { AbiArrayType, AbiEntryName, AbiObjectType, -} from "src/adapter/contract/types/AbiEntry"; +} from "src/adapter/contract/types/abi"; import { getAbiEntry } from "src/adapter/contract/utils/getAbiEntry"; /** diff --git a/packages/drift/src/adapter/network/MockNetwork.test.ts b/packages/drift/src/adapter/network/MockNetwork.test.ts index 4f4b697e..edf49b30 100644 --- a/packages/drift/src/adapter/network/MockNetwork.test.ts +++ b/packages/drift/src/adapter/network/MockNetwork.test.ts @@ -4,7 +4,7 @@ import { } from "src/adapter/network/MockNetwork"; import { ALICE } from "src/utils/testing/accounts"; import { describe, expect, it } from "vitest"; -import type { Transaction } from "./Transaction"; +import type { Transaction } from "./types/Transaction"; describe("MockNetwork", () => { it("stubs getBalance", async () => { diff --git a/packages/drift/src/adapter/network/MockNetwork.ts b/packages/drift/src/adapter/network/MockNetwork.ts index b234be6f..c6af2e29 100644 --- a/packages/drift/src/adapter/network/MockNetwork.ts +++ b/packages/drift/src/adapter/network/MockNetwork.ts @@ -1,22 +1,22 @@ import { type SinonStub, stub } from "sinon"; +import type { Block } from "src/adapter/network/types/Block"; import type { - AdapterNetwork, + NetworkAdapter, NetworkGetBalanceArgs, NetworkGetBlockArgs, NetworkGetTransactionArgs, NetworkWaitForTransactionArgs, -} from "src/adapter/network/AdapterNetwork"; -import type { Block } from "src/adapter/network/Block"; +} from "src/adapter/network/types/NetworkAdapter"; import type { Transaction, TransactionReceipt, -} from "src/adapter/network/Transaction"; +} from "src/adapter/network/types/Transaction"; /** * A mock implementation of a `Network` designed to facilitate unit * testing. */ -export class MockNetwork implements AdapterNetwork { +export class MockNetwork implements NetworkAdapter { protected getBalanceStub: | SinonStub<[NetworkGetBalanceArgs?], Promise> | undefined; diff --git a/packages/drift/src/adapter/network/Block.ts b/packages/drift/src/adapter/network/types/Block.ts similarity index 100% rename from packages/drift/src/adapter/network/Block.ts rename to packages/drift/src/adapter/network/types/Block.ts diff --git a/packages/drift/src/adapter/network/AdapterNetwork.ts b/packages/drift/src/adapter/network/types/NetworkAdapter.ts similarity index 81% rename from packages/drift/src/adapter/network/AdapterNetwork.ts rename to packages/drift/src/adapter/network/types/NetworkAdapter.ts index 83b3af4a..ea6e1360 100644 --- a/packages/drift/src/adapter/network/AdapterNetwork.ts +++ b/packages/drift/src/adapter/network/types/NetworkAdapter.ts @@ -1,15 +1,16 @@ -import type { Block, BlockTag } from "src/adapter/network/Block"; +import type { Block, BlockTag } from "src/adapter/network/types/Block"; import type { Transaction, TransactionReceipt, -} from "src/adapter/network/Transaction"; +} from "src/adapter/network/types/Transaction"; +import type { Address, HexString, TransactionHash } from "src/types"; // https://ethereum.github.io/execution-apis/api-documentation/ /** * An interface representing data the SDK needs to get from the network. */ -export interface AdapterNetwork { +export interface NetworkAdapter { /** * Get the balance of native currency for an account. */ @@ -43,7 +44,7 @@ export interface AdapterNetwork { export type NetworkGetBlockOptions = | { - blockHash?: `0x${string}`; + blockHash?: HexString; blockNumber?: never; blockTag?: never; } @@ -59,16 +60,16 @@ export type NetworkGetBlockOptions = }; export type NetworkGetBalanceArgs = [ - address: `0x${string}`, + address: Address, block?: NetworkGetBlockOptions, ]; export type NetworkGetBlockArgs = [options?: NetworkGetBlockOptions]; -export type NetworkGetTransactionArgs = [hash: `0x${string}`]; +export type NetworkGetTransactionArgs = [hash: TransactionHash]; export type NetworkWaitForTransactionArgs = [ - hash: `0x${string}`, + hash: TransactionHash, options?: { /** * The number of milliseconds to wait for the transaction until rejecting diff --git a/packages/drift/src/adapter/network/Transaction.ts b/packages/drift/src/adapter/network/types/Transaction.ts similarity index 100% rename from packages/drift/src/adapter/network/Transaction.ts rename to packages/drift/src/adapter/network/types/Transaction.ts diff --git a/packages/drift/src/adapter/types.ts b/packages/drift/src/adapter/types.ts index 986e6ab9..9a1018e4 100644 --- a/packages/drift/src/adapter/types.ts +++ b/packages/drift/src/adapter/types.ts @@ -1,25 +1,125 @@ import type { Abi } from "abitype"; import type { - AdapterReadContract, - AdapterReadWriteContract, -} from "src/adapter/contract/types/Contract"; -import type { AdapterNetwork } from "src/adapter/network/AdapterNetwork"; - -export interface ReadAdapter { - network: AdapterNetwork; - readContract: ( - abi: TAbi, - address: string, - ) => AdapterReadContract; + ContractGetEventsOptions, + ContractReadOptions, + ContractWriteOptions, +} from "src/adapter/contract/types/contract"; +import type { Event, EventName } from "src/adapter/contract/types/event"; +import type { + FunctionArgs, + FunctionName, + FunctionReturn, +} from "src/adapter/contract/types/function"; +import type { NetworkAdapter } from "src/adapter/network/types/NetworkAdapter"; +import type { TransactionReceipt } from "src/adapter/network/types/Transaction"; +import type { Address, Bytes, TransactionHash } from "src/types"; +import type { EmptyObject } from "src/utils/types"; + +export interface ReadAdapter extends NetworkAdapter { + read< + TAbi extends Abi, + TFunctionName extends FunctionName, + >( + params: ReadParams, + ): Promise>; + + getEvents>( + params: GetEventsParams, + ): Promise[]>; + + simulateWrite< + TAbi extends Abi, + TFunctionName extends FunctionName, + >( + params: WriteParams, + ): Promise>; + + encodeFunctionData< + TAbi extends Abi, + TFunctionName extends FunctionName, + >(params: EncodeFunctionDataParams): Bytes; + + decodeFunctionData< + TAbi extends Abi, + TFunctionName extends FunctionName, + >( + params: DecodeFunctionDataParams, + ): FunctionReturn; } export interface ReadWriteAdapter extends ReadAdapter { - // Write-only properties - getSignerAddress: () => Promise; - readWriteContract: ( - abi: TAbi, - address: string, - ) => AdapterReadWriteContract; + getSignerAddress(): Promise
; + + write< + TAbi extends Abi, + TFunctionName extends FunctionName, + >(params: WriteParams): Promise; } export type Adapter = ReadAdapter | ReadWriteAdapter; + +export type ArgsParam< + TAbi extends Abi = Abi, + TFunctionName extends FunctionName = FunctionName, +> = FunctionArgs extends EmptyObject + ? { + args?: FunctionArgs; + } + : Abi extends TAbi + ? { + args?: FunctionArgs; + } + : { + args: FunctionArgs; + }; + +export type ReadParams< + TAbi extends Abi = Abi, + TFunctionName extends FunctionName = FunctionName, +> = ContractReadOptions & { + abi: TAbi; + address: Address; + fn: TFunctionName; +} & ArgsParam; + +export interface GetEventsParams< + TAbi extends Abi = Abi, + TEventName extends EventName = EventName, +> extends ContractGetEventsOptions { + abi: TAbi; + address: Address; + event: TEventName; +} + +export type WriteParams< + TAbi extends Abi = Abi, + TFunctionName extends FunctionName< + TAbi, + "nonpayable" | "payable" + > = FunctionName, +> = ContractWriteOptions & { + abi: TAbi; + address: Address; + fn: TFunctionName; + onMined?: (receipt?: TransactionReceipt) => void; +} & ArgsParam; + +export type EncodeFunctionDataParams< + TAbi extends Abi = Abi, + TFunctionName extends FunctionName = FunctionName, +> = { + abi: TAbi; + fn: TFunctionName; +} & ArgsParam; + +export interface DecodeFunctionDataParams< + TAbi extends Abi = Abi, + TFunctionName extends FunctionName = FunctionName, +> { + abi: TAbi; + data: Bytes; + // TODO: This is optional and only used to determine the return type, but is + // there another way to get the return type based on the function selector in + // the data? + fn?: TFunctionName; +} diff --git a/packages/drift/src/cache/DriftCache/createDriftCache.ts b/packages/drift/src/cache/DriftCache/createDriftCache.ts index 9e67142e..f6079d2c 100644 --- a/packages/drift/src/cache/DriftCache/createDriftCache.ts +++ b/packages/drift/src/cache/DriftCache/createDriftCache.ts @@ -30,6 +30,9 @@ export function createDriftCache( preloadRead: ({ value, ...params }) => cache.set(driftCache.readKey(params as DriftReadKeyParams), value), + preloadEvents: ({ value, ...params }) => + cache.set(driftCache.eventsKey(params), value), + invalidateRead: (params) => cache.delete(driftCache.readKey(params)), invalidateReadsMatching(params) { @@ -44,9 +47,6 @@ export function createDriftCache( } } }, - - preloadEvents: ({ value, ...params }) => - cache.set(driftCache.eventsKey(params), value), }); return driftCache; diff --git a/packages/drift/src/cache/DriftCache/types.ts b/packages/drift/src/cache/DriftCache/types.ts index 2a87c35d..f7133eae 100644 --- a/packages/drift/src/cache/DriftCache/types.ts +++ b/packages/drift/src/cache/DriftCache/types.ts @@ -1,18 +1,15 @@ import type { Abi } from "abitype"; -import type { Event, EventName } from "src/adapter/contract/types/Event"; +import type { Event, EventName } from "src/adapter/contract/types/event"; import type { FunctionName, FunctionReturn, -} from "src/adapter/contract/types/Function"; +} from "src/adapter/contract/types/function"; +import type { GetEventsParams, ReadParams } from "src/adapter/types"; import type { SimpleCache, SimpleCacheKey } from "src/cache/SimpleCache/types"; -import type { - DriftGetEventsParams, - DriftReadParams, -} from "src/drift/types/DriftContract"; -import type { DeepPartial } from "src/utils/types"; +import type { DeepPartial, MaybePromise } from "src/utils/types"; export type DriftCache = T & { - // Key Generators // + // Key generation // partialReadKey>( params: DeepPartial>, @@ -26,38 +23,38 @@ export type DriftCache = T & { params: DriftEventsKeyParams, ): SimpleCacheKey; - // Cache Management // + // Cache management // preloadRead>( params: DriftReadKeyParams & { value: FunctionReturn; }, - ): void | Promise; + ): MaybePromise; + + preloadEvents>( + params: DriftEventsKeyParams & { + value: readonly Event[]; + }, + ): MaybePromise; invalidateRead>( params: DriftReadKeyParams, - ): void | Promise; + ): MaybePromise; invalidateReadsMatching< TAbi extends Abi, TFunctionName extends FunctionName, >( params: DeepPartial>, - ): void | Promise; - - preloadEvents>( - params: DriftEventsKeyParams & { - value: readonly Event[]; - }, - ): void | Promise; + ): MaybePromise; }; export type DriftReadKeyParams< TAbi extends Abi = Abi, TFunctionName extends FunctionName = FunctionName, -> = Omit, "cache">; +> = Omit, "cache">; export type DriftEventsKeyParams< TAbi extends Abi = Abi, TEventName extends EventName = EventName, -> = Omit, "cache">; +> = Omit, "cache">; diff --git a/packages/drift/src/cache/utils/DriftCache.ts b/packages/drift/src/cache/utils/DriftCache.ts index f2c3da12..3e6bba2c 100644 --- a/packages/drift/src/cache/utils/DriftCache.ts +++ b/packages/drift/src/cache/utils/DriftCache.ts @@ -1,5 +1,5 @@ import type { Abi } from "abitype"; -import type { Event, EventName } from "src/adapter/contract/types/Event"; +import type { Event, EventName } from "src/adapter/contract/types/event"; import type { DriftEventsKeyParams } from "src/cache/DriftCache/types"; diff --git a/packages/drift/src/cache/utils/createCachedReadContract.test.ts b/packages/drift/src/cache/utils/createCachedReadContract.test.ts index 9a296d2a..2d4c20b1 100644 --- a/packages/drift/src/cache/utils/createCachedReadContract.test.ts +++ b/packages/drift/src/cache/utils/createCachedReadContract.test.ts @@ -1,5 +1,5 @@ -import { ReadContractStub } from "src/adapter/contract/stubs/ReadContractStub"; -import type { Event } from "src/adapter/contract/types/Event"; +import { ReadContractStub } from "src/adapter/contract/mocks/ReadContractStub"; +import type { Event } from "src/adapter/contract/types/event"; import { createCachedReadContract } from "src/cache/utils/createCachedReadContract"; import { IERC20 } from "src/utils/testing/IERC20"; import { ALICE, BOB } from "src/utils/testing/accounts"; diff --git a/packages/drift/src/cache/utils/createCachedReadContract.ts b/packages/drift/src/cache/utils/createCachedReadContract.ts index 6c219299..a60a3d87 100644 --- a/packages/drift/src/cache/utils/createCachedReadContract.ts +++ b/packages/drift/src/cache/utils/createCachedReadContract.ts @@ -1,10 +1,10 @@ import type { Abi } from "abitype"; import isMatch from "lodash.ismatch"; -import type { AdapterReadContract } from "src/adapter/contract/types/Contract"; +import type { AdapterReadContract } from "src/adapter/contract/types/contract"; import { createLruSimpleCache } from "src/cache/SimpleCache/createLruSimpleCache"; import { createSimpleCacheKey } from "src/cache/SimpleCache/createSimpleCacheKey"; import type { SimpleCache, SimpleCacheKey } from "src/cache/SimpleCache/types"; -import type { CachedReadContract } from "src/contract/CachedContract"; +import type { ReadContract } from "src/contract/types"; // TODO: Figure out a good default cache size const DEFAULT_CACHE_SIZE = 100; @@ -33,7 +33,7 @@ export function createCachedReadContract({ contract, cache = createLruSimpleCache({ max: DEFAULT_CACHE_SIZE }), namespace, -}: CreateCachedReadContractOptions): CachedReadContract { +}: CreateCachedReadContractOptions): ReadContract { // Because this is part of the public API, we won't know if the original // contract is a plain object or a class instance, so we use Object.create to // preserve the original contract's prototype chain when extending, ensuring @@ -42,7 +42,7 @@ export function createCachedReadContract({ const contractPrototype = Object.getPrototypeOf(contract); const newContract = Object.create(contractPrototype); - const overrides: Partial> = { + const overrides: Partial> = { cache, /** diff --git a/packages/drift/src/cache/utils/createCachedReadWriteContract.ts b/packages/drift/src/cache/utils/createCachedReadWriteContract.ts index 099f4dac..10733359 100644 --- a/packages/drift/src/cache/utils/createCachedReadWriteContract.ts +++ b/packages/drift/src/cache/utils/createCachedReadWriteContract.ts @@ -1,5 +1,5 @@ import type { Abi } from "abitype"; -import type { AdapterReadWriteContract } from "src/adapter/contract/types/Contract"; +import type { AdapterReadWriteContract } from "src/adapter/contract/types/contract"; import type { CachedReadWriteContract } from "src/cache/types/CachedContract"; import { type CreateCachedReadContractOptions, diff --git a/packages/drift/src/cache/utils/eventsKey.ts b/packages/drift/src/cache/utils/eventsKey.ts index 1202d252..37aefec5 100644 --- a/packages/drift/src/cache/utils/eventsKey.ts +++ b/packages/drift/src/cache/utils/eventsKey.ts @@ -1,5 +1,5 @@ import type { Abi } from "abitype"; -import type { EventName } from "src/adapter/contract/types/Event"; +import type { EventName } from "src/adapter/contract/types/event"; import type { DriftEventsKeyParams } from "src/cache/DriftCache/types"; import { createSimpleCacheKey } from "src/cache/SimpleCache/createSimpleCacheKey"; import type { SimpleCacheKey } from "src/cache/SimpleCache/types"; diff --git a/packages/drift/src/cache/utils/invalidateRead.ts b/packages/drift/src/cache/utils/invalidateRead.ts index 13b1a12f..a6dc5883 100644 --- a/packages/drift/src/cache/utils/invalidateRead.ts +++ b/packages/drift/src/cache/utils/invalidateRead.ts @@ -1,5 +1,5 @@ import type { Abi } from "abitype"; -import type { FunctionName } from "src/adapter/contract/types/Function"; +import type { FunctionName } from "src/adapter/contract/types/function"; import type { DriftReadKeyParams } from "src/cache/DriftCache/types"; import type { SimpleCache } from "src/cache/SimpleCache/types"; import { readKey } from "src/cache/utils/readKey"; diff --git a/packages/drift/src/cache/utils/invalidateReadsMatching.ts b/packages/drift/src/cache/utils/invalidateReadsMatching.ts index 073df1b7..b2498532 100644 --- a/packages/drift/src/cache/utils/invalidateReadsMatching.ts +++ b/packages/drift/src/cache/utils/invalidateReadsMatching.ts @@ -1,6 +1,6 @@ import type { Abi } from "abitype"; import isMatch from "lodash.ismatch"; -import type { FunctionName } from "src/adapter/contract/types/Function"; +import type { FunctionName } from "src/adapter/contract/types/function"; import type { DriftReadKeyParams } from "src/cache/DriftCache/types"; import type { SimpleCache, SimpleCacheKey } from "src/cache/SimpleCache/types"; import { partialReadKey } from "src/cache/utils/partialReadKey"; diff --git a/packages/drift/src/cache/utils/partialReadKey.ts b/packages/drift/src/cache/utils/partialReadKey.ts index 933dd59d..1e9bd920 100644 --- a/packages/drift/src/cache/utils/partialReadKey.ts +++ b/packages/drift/src/cache/utils/partialReadKey.ts @@ -1,5 +1,5 @@ import type { Abi } from "abitype"; -import type { FunctionName } from "src/adapter/contract/types/Function"; +import type { FunctionName } from "src/adapter/contract/types/function"; import type { DriftReadKeyParams } from "src/cache/DriftCache/types"; import { createSimpleCacheKey } from "src/cache/SimpleCache/createSimpleCacheKey"; import type { SimpleCacheKey } from "src/cache/SimpleCache/types"; diff --git a/packages/drift/src/cache/utils/preloadRead.ts b/packages/drift/src/cache/utils/preloadRead.ts index 012d1edc..d058a2af 100644 --- a/packages/drift/src/cache/utils/preloadRead.ts +++ b/packages/drift/src/cache/utils/preloadRead.ts @@ -1,5 +1,5 @@ import type { Abi } from "abitype"; -import type { FunctionName, FunctionReturn } from "src/adapter/contract/types/Function"; +import type { FunctionName, FunctionReturn } from "src/adapter/contract/types/function"; import type { DriftReadKeyParams } from "src/cache/DriftCache/types"; import type { SimpleCache } from "src/cache/SimpleCache/types"; import { readKey } from "src/cache/utils/readKey"; diff --git a/packages/drift/src/cache/utils/readKey.ts b/packages/drift/src/cache/utils/readKey.ts index 7d4f41c5..958f65d7 100644 --- a/packages/drift/src/cache/utils/readKey.ts +++ b/packages/drift/src/cache/utils/readKey.ts @@ -1,5 +1,5 @@ import type { Abi } from "abitype"; -import type { FunctionName } from "src/adapter/contract/types/Function"; +import type { FunctionName } from "src/adapter/contract/types/function"; import type { DriftReadKeyParams } from "src/cache/DriftCache/types"; import type { SimpleCacheKey } from "src/cache/SimpleCache/types"; import { partialReadKey } from "src/cache/utils/partialReadKey"; diff --git a/packages/drift/src/contract/CachedContract.ts b/packages/drift/src/contract/CachedContract.ts deleted file mode 100644 index 528fe5f4..00000000 --- a/packages/drift/src/contract/CachedContract.ts +++ /dev/null @@ -1,28 +0,0 @@ -import type { Abi } from "abitype"; -import type { - AdapterReadContract, - AdapterReadWriteContract, - ContractReadArgs, -} from "src/adapter/contract/types/Contract"; -import type { FunctionName } from "src/adapter/contract/types/Function"; -import type { SimpleCache } from "src/exports"; -import type { DeepPartial } from "src/utils/types"; - -export interface CachedReadContract - extends AdapterReadContract { - cache: SimpleCache; - namespace?: string; - clearCache(): void; - deleteRead>( - ...[functionName, args, options]: ContractReadArgs - ): void; - deleteReadsMatching>( - ...[functionName, args, options]: DeepPartial< - ContractReadArgs - > - ): void; -} - -export interface CachedReadWriteContract - extends CachedReadContract, - AdapterReadWriteContract {} diff --git a/packages/drift/src/contract/createReadContract.ts b/packages/drift/src/contract/createReadContract.ts new file mode 100644 index 00000000..15f0ff86 --- /dev/null +++ b/packages/drift/src/contract/createReadContract.ts @@ -0,0 +1,63 @@ +import type { Abi } from "abitype"; +import type { AdapterReadContract } from "src/adapter/contract/types/contract"; +import type { DriftCache } from "src/cache/DriftCache/types"; +import type { ReadContract } from "src/contract/types"; +import { extendInstance } from "src/utils/extendInstance"; + +interface CreateReadContractParams< +TAbi extends Abi, +TContract extends AdapterReadContract, +TCache extends DriftCache, +> { + contract: TContract; + cache: TCache; + namespace: string; +} + +/** + * Extends an {@linkcode AdapterReadContract} with additional API methods for + * use with Drift clients. + */ +export function createReadContract< + TAbi extends Abi, + TContract extends AdapterReadContract, + TCache extends DriftCache, +>(contract: TContract, cache: TCache): ReadContract { + const readContract: ReadContract = extendInstance< + TContract, + Omit + >(contract, { + cache, + + partialReadKey: (fn, args, options) => + cache.partialReadKey({ abi: contract.abi, fn, args, address: contract.address, namespace, }), + + readKey: (params) => driftCache.partialReadKey(params), + + eventsKey: ({ abi, namespace, ...params }) => + createSimpleCacheKey([namespace, "events", params]), + + preloadRead: ({ value, ...params }) => + cache.set(driftCache.readKey(params as DriftReadKeyParams), value), + + preloadEvents: ({ value, ...params }) => + cache.set(driftCache.eventsKey(params), value), + + invalidateRead: (params) => cache.delete(driftCache.readKey(params)), + + invalidateReadsMatching(params) { + const sourceKey = driftCache.partialReadKey(params); + + for (const [key] of cache.entries) { + if ( + typeof key === "object" && + isMatch(key, sourceKey as SimpleCacheKey[]) + ) { + cache.delete(key); + } + } + }, + }); + + return readContract; +} diff --git a/packages/drift/src/contract/types.ts b/packages/drift/src/contract/types.ts new file mode 100644 index 00000000..d4500b61 --- /dev/null +++ b/packages/drift/src/contract/types.ts @@ -0,0 +1,90 @@ +import type { Abi } from "abitype"; +import type { + AdapterReadContract, + AdapterReadWriteContract, + ContractGetEventsArgs, + ContractReadArgs, +} from "src/adapter/contract/types/contract"; +import type { EventName } from "src/adapter/contract/types/event"; +import type { Event } from "src/adapter/contract/types/event"; +import type { + FunctionName, + FunctionReturn, +} from "src/adapter/contract/types/function"; +import type { + DriftCache, + DriftReadKeyParams, +} from "src/cache/DriftCache/types"; +import type { SimpleCache } from "src/cache/SimpleCache/types"; +import type { Address } from "src/types"; +import type { MaybePromise } from "src/utils/types"; + +export interface ContractParams { + abi: TAbi; + address: Address; + cache?: SimpleCache; + /** + * A namespace to distinguish this instance from others in the cache by + * prefixing all cache keys. + */ + namespace?: string; +} + +export type ReadContract< + TAbi extends Abi = Abi, + TAdapterContract extends + AdapterReadContract = AdapterReadContract, + TCache extends DriftCache = DriftCache, +> = TAdapterContract & { + cache: TCache; + + // Key generation // + + partialReadKey>( + ...args: Partial> + ): string; + + readKey>( + ...args: ContractReadArgs + ): string; + + eventsKey>( + ...args: ContractGetEventsArgs + ): string; + + // Cache management // + + preloadRead>( + params: ContractReadKeyParams & { + value: FunctionReturn; + }, + ): MaybePromise; + + preloadEvents>( + ...args: ContractGetEventsArgs & { + value: readonly Event[]; + } + ): MaybePromise; + + invalidateRead>( + ...args: ContractReadArgs + ): MaybePromise; + + invalidateReadsMatching>( + ...args: Partial> + ): MaybePromise; + + invalidateAllReads(): void; +}; + +export interface ReadWriteContract + extends ReadContract, + AdapterReadWriteContract {} + +export type ContractReadKeyParams< + TAbi extends Abi = Abi, + TFunctionName extends FunctionName = FunctionName< + TAbi, + "pure" | "view" + >, +> = Omit, keyof ContractParams>; diff --git a/packages/drift/src/drift/Drift.ts b/packages/drift/src/drift/Drift.ts index 6e87cdbe..b6074d62 100644 --- a/packages/drift/src/drift/Drift.ts +++ b/packages/drift/src/drift/Drift.ts @@ -1,10 +1,10 @@ import type { Abi } from "abitype"; -import type { Event, EventName } from "src/adapter/contract/types/Event"; +import type { Event, EventName } from "src/adapter/contract/types/event"; import type { DecodedFunctionData, FunctionName, FunctionReturn, -} from "src/adapter/contract/types/Function"; +} from "src/adapter/contract/types/function"; import type { Adapter, ReadWriteAdapter } from "src/adapter/types"; import { createDriftCache } from "src/cache/DriftCache/createDriftCache"; import type { DriftCache } from "src/cache/DriftCache/types"; @@ -12,9 +12,9 @@ import type { SimpleCache } from "src/cache/SimpleCache/types"; import { createCachedReadContract } from "src/cache/utils/createCachedReadContract"; import { createCachedReadWriteContract } from "src/cache/utils/createCachedReadWriteContract"; import type { - CachedReadContract, - CachedReadWriteContract, -} from "src/contract/CachedContract"; + ReadContract, + ReadWriteContract, +} from "src/contract/types"; import type { ContractParams, DecodeFunctionDataParams, @@ -28,8 +28,8 @@ export type DriftContract< TAbi extends Abi, TAdapter extends Adapter = Adapter, > = TAdapter extends ReadWriteAdapter - ? CachedReadWriteContract - : CachedReadContract; + ? ReadWriteContract + : ReadContract; export interface DriftOptions { cache?: TCache; @@ -109,18 +109,9 @@ export class Drift< address, cache = this.cache, namespace = this.namespace, - }: ContractParams): DriftContract => - this.isReadWrite() - ? createCachedReadWriteContract({ - contract: this.adapter.readWriteContract(abi, address), - cache, - namespace, - }) - : (createCachedReadContract({ - contract: this.adapter.readContract(abi, address), - cache, - namespace, - }) as DriftContract); + }: ContractParams): DriftContract => { + con + } /** * Reads a specified function from a contract. diff --git a/packages/drift/src/drift/MockDrift.ts b/packages/drift/src/drift/MockDrift.ts index 024e59dd..16e2973e 100644 --- a/packages/drift/src/drift/MockDrift.ts +++ b/packages/drift/src/drift/MockDrift.ts @@ -1,7 +1,7 @@ import type { Abi } from "abitype"; import { MockAdapter } from "src/adapter/MockAdapter"; -import type { ReadWriteContractStub } from "src/adapter/contract/stubs/ReadWriteContractStub"; -import type { CachedReadWriteContract } from "src/contract/CachedContract"; +import type { ReadWriteContractStub } from "src/adapter/contract/mocks/ReadWriteContractStub"; +import type { ReadWriteContract } from "src/contract/types"; import { Drift, type DriftOptions } from "src/drift/Drift"; import type { SimpleCache } from "src/exports"; import type { ContractParams } from "src/types"; @@ -16,5 +16,5 @@ export class MockDrift extends Drift< declare contract: ( params: ContractParams, - ) => CachedReadWriteContract & ReadWriteContractStub; + ) => ReadWriteContract & ReadWriteContractStub; } diff --git a/packages/drift/src/types.ts b/packages/drift/src/types.ts index 2b66ee1c..3435f0dc 100644 --- a/packages/drift/src/types.ts +++ b/packages/drift/src/types.ts @@ -1,87 +1,4 @@ -import type { Abi } from "abitype"; -import type { - ContractGetEventsOptions, - ContractReadOptions, - ContractWriteOptions, -} from "src/adapter/contract/types/Contract"; -import type { EventName } from "src/adapter/contract/types/Event"; -import type { - FunctionArgs, - FunctionName, -} from "src/adapter/contract/types/Function"; -import type { TransactionReceipt } from "src/adapter/network/Transaction"; -import type { SimpleCache } from "src/cache/SimpleCache/types"; -import type { EmptyObject } from "src/utils/types"; - -export interface ContractParams { - abi: TAbi; - address: string; - cache?: SimpleCache; - /** - * A namespace to distinguish this instance from others in the cache by - * prefixing all cache keys. - */ - namespace?: string; -} - -export type ReadParams< - TAbi extends Abi, - TFunctionName extends FunctionName, -> = { - fn: TFunctionName; -} & (FunctionArgs extends EmptyObject - ? { - args?: FunctionArgs; - } - : { - args: FunctionArgs; - }) & - ContractReadOptions & - ContractParams; - -export interface GetEventsParams< - TAbi extends Abi, - TEventName extends EventName, -> extends ContractGetEventsOptions, - ContractParams { - event: TEventName; -} - -export type WriteParams< - TAbi extends Abi, - TFunctionName extends FunctionName, -> = ContractWriteOptions & { - abi: TAbi; - address: string; - fn: TFunctionName; - onMined?: (receipt?: TransactionReceipt) => void; -} & (FunctionArgs extends EmptyObject - ? { - args?: FunctionArgs; - } - : { - args: FunctionArgs; - }); - -export type EncodeFunctionDataParams< - TAbi extends Abi, - TFunctionName extends FunctionName, -> = { - abi: TAbi; - fn: TFunctionName; -} & (FunctionArgs extends EmptyObject - ? { - args?: FunctionArgs; - } - : { - args: FunctionArgs; - }); - -export interface DecodeFunctionDataParams< - TAbi extends Abi, - TFunctionName extends FunctionName, -> { - abi: TAbi; - data: string; - fn?: TFunctionName; -} +export type HexString = string; +export type Address = HexString; +export type Bytes = HexString; +export type TransactionHash = HexString; diff --git a/packages/drift/src/utils/types.ts b/packages/drift/src/utils/types.ts index d5b8c4bd..5b254c6d 100644 --- a/packages/drift/src/utils/types.ts +++ b/packages/drift/src/utils/types.ts @@ -1,4 +1,6 @@ -export type EmptyObject = Record; +export type EmptyObject = Record; + +export type MaybePromise = T | Promise; /** * Combines members of an intersection into a readable type. @@ -20,3 +22,12 @@ export type RequiredKeys = Prettify< [P in K]-?: NonNullable; } >; + +/** + * Make all properties in `T` whose keys are in the union `K` optional. + */ +export type OptionalKeys = Prettify< + Omit & { + [P in K]?: T[P]; + } +>;