Skip to content

Commit

Permalink
evm: require adapter instructions
Browse files Browse the repository at this point in the history
  • Loading branch information
bruce-riley committed Dec 10, 2024
1 parent 82533d5 commit 8f9b270
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 43 deletions.
2 changes: 1 addition & 1 deletion evm/script/TestIntegrator.s.sol
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ contract TestIntegrator {
address refundAddress = address(this);
bytes32 payloadHash = keccak256("hello, world");
IEndpointIntegrator(endpoint).sendMessage(
chain, UniversalAddressLibrary.fromAddress(dstAddr), payloadHash, refundAddress
chain, UniversalAddressLibrary.fromAddress(dstAddr), payloadHash, refundAddress, new bytes(0)
);
}

Expand Down
90 changes: 70 additions & 20 deletions evm/src/Endpoint.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pragma solidity ^0.8.19;
import "./interfaces/IEndpointAdmin.sol";
import "./interfaces/IEndpointIntegrator.sol";
import "./interfaces/IEndpointAdapter.sol";
import "./libraries/AdapterInstructions.sol";
import "./MessageSequence.sol";
import "./AdapterRegistry.sol";
import "./interfaces/IAdapter.sol";
Expand Down Expand Up @@ -378,13 +379,41 @@ contract Endpoint is IEndpointAdmin, IEndpointIntegrator, IEndpointAdapter, Mess
// =============== Message functions =======================================================

/// @inheritdoc IEndpointIntegrator
function sendMessage(uint16 dstChain, UniversalAddress dstAddr, bytes32 payloadHash, address refundAddress)
external
payable
returns (uint64 sequence)
{
function sendMessage(
uint16 dstChain,
UniversalAddress dstAddr,
bytes32 payloadHash,
address refundAddress,
bytes calldata adapterInstructions
) external payable returns (uint64 sequence) {
return _sendMessage(
SendMessageArgs({
dstChain: dstChain,
dstAddr: dstAddr,
payloadHash: payloadHash,
refundAddress: refundAddress,
adapterInstructions: adapterInstructions
})
);
}

/// @dev Used to get around "stack too deep.
struct SendMessageArgs {
uint16 dstChain;
UniversalAddress dstAddr;
bytes32 payloadHash;
address refundAddress;
bytes adapterInstructions;
}

function _sendMessage(SendMessageArgs memory args) internal returns (uint64 sequence) {
// Parse the adapter instructions so we can pass the appropriate one to each adapter.
AdapterInstructions.Instruction[] memory adapterInst = AdapterInstructions.parseInstructions(
args.adapterInstructions, _getRegisteredAdaptersStorage()[msg.sender].length
);

// get the enabled send adapters for [msg.sender][dstChain]
PerSendAdapterInfo[] memory sendAdapters = getSendAdaptersByChain(msg.sender, dstChain);
PerSendAdapterInfo[] memory sendAdapters = getSendAdaptersByChain(msg.sender, args.dstChain);
uint256 len = sendAdapters.length;
if (len == 0) {
revert AdapterNotEnabled();
Expand All @@ -393,25 +422,32 @@ contract Endpoint is IEndpointAdmin, IEndpointIntegrator, IEndpointAdapter, Mess
// get the next sequence number for msg.sender
sequence = _useMessageSequence(msg.sender);
for (uint256 i = 0; i < len;) {
bytes memory adapterInstructions; // TODO: Pass this in.
// quote the delivery price
uint256 deliveryPrice = IAdapter(sendAdapters[i].addr).quoteDeliveryPrice(dstChain, adapterInstructions);
uint256 deliveryPrice = IAdapter(sendAdapters[i].addr).quoteDeliveryPrice(
args.dstChain, adapterInst[sendAdapters[i].index].payload
);
// call sendMessage
IAdapter(sendAdapters[i].addr).sendMessage{value: deliveryPrice}(
sender, sequence, dstChain, dstAddr, payloadHash, refundAddress, adapterInstructions
sender,
sequence,
args.dstChain,
args.dstAddr,
args.payloadHash,
args.refundAddress,
adapterInst[sendAdapters[i].index].payload
);
unchecked {
++i;
}
}

emit MessageSent(
computeMessageDigest(ourChain, sender, sequence, dstChain, dstAddr, payloadHash),
computeMessageDigest(ourChain, sender, sequence, args.dstChain, args.dstAddr, args.payloadHash),
sender,
sequence,
dstAddr,
dstChain,
payloadHash
args.dstAddr,
args.dstChain,
args.payloadHash
);
}

Expand Down Expand Up @@ -544,13 +580,17 @@ contract Endpoint is IEndpointAdmin, IEndpointIntegrator, IEndpointAdapter, Mess
}

/// @inheritdoc IEndpointIntegrator
function quoteDeliveryPrice(address integrator, uint16 dstChain) external view returns (uint256) {
return _quoteDeliveryPrice(integrator, dstChain);
function quoteDeliveryPrice(address integrator, uint16 dstChain, bytes calldata adapterInstructions)
external
view
returns (uint256)
{
return _quoteDeliveryPrice(integrator, dstChain, adapterInstructions);
}

/// @inheritdoc IEndpointIntegrator
function quoteDeliveryPrice(uint16 dstChain) external view returns (uint256) {
return _quoteDeliveryPrice(msg.sender, dstChain);
function quoteDeliveryPrice(uint16 dstChain, bytes calldata adapterInstructions) external view returns (uint256) {
return _quoteDeliveryPrice(msg.sender, dstChain, adapterInstructions);
}

// =============== Internal ==============================================================
Expand All @@ -572,14 +612,24 @@ contract Endpoint is IEndpointAdmin, IEndpointIntegrator, IEndpointAdapter, Mess
/// @dev This sums up all the individual sendAdapter's quoteDeliveryPrice calls.
/// @param integrator The address of the integrator.
/// @param dstChain The Wormhole chain ID of the recipient.
/// @param adapterInstructions An array of adapter instructions to be passed into the adapters.
/// @return totalCost The total cost of delivering a message to the recipient chain in this chain's native token.
function _quoteDeliveryPrice(address integrator, uint16 dstChain) internal view returns (uint256 totalCost) {
function _quoteDeliveryPrice(address integrator, uint16 dstChain, bytes calldata adapterInstructions)
internal
view
returns (uint256 totalCost)
{
// Parse the adapter instructions so we can pass the appropriate one to each adapter.
AdapterInstructions.Instruction[] memory adapterInst = AdapterInstructions.parseInstructions(
adapterInstructions, _getRegisteredAdaptersStorage()[integrator].length
);

PerSendAdapterInfo[] memory sendAdapters = getSendAdaptersByChain(integrator, dstChain);
uint256 len = sendAdapters.length;
totalCost = 0;
for (uint256 i = 0; i < len;) {
bytes memory adapterInstructions; // TODO: Pass this in.
totalCost += IAdapter(sendAdapters[i].addr).quoteDeliveryPrice(dstChain, adapterInstructions);
totalCost +=
IAdapter(sendAdapters[i].addr).quoteDeliveryPrice(dstChain, adapterInst[sendAdapters[i].index].payload);
unchecked {
++i;
}
Expand Down
22 changes: 15 additions & 7 deletions evm/src/interfaces/IEndpointIntegrator.sol
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@ interface IEndpointIntegrator is IMessageSequence {
/// @param dstChain The Wormhole chain ID of the recipient.
/// @param dstAddr The universal address of the peer on the recipient chain.
/// @param payloadHash keccak256 of a message to be sent to the recipient chain.
/// @return uint64 The sequence number of the message.
/// @param refundAddress The source chain refund address passed to the Adapter.
function sendMessage(uint16 dstChain, UniversalAddress dstAddr, bytes32 payloadHash, address refundAddress)
external
payable
returns (uint64);
/// @param adapterInstructions An array of adapter instructions to be passed into the adapters.
/// @return uint64 The sequence number of the message.
function sendMessage(
uint16 dstChain,
UniversalAddress dstAddr,
bytes32 payloadHash,
address refundAddress,
bytes calldata adapterInstructions
) external payable returns (uint64);

/// @notice Receives a message and marks it as executed.
/// @param srcChain The Wormhole chain ID of the sender.
Expand Down Expand Up @@ -74,13 +78,17 @@ interface IEndpointIntegrator is IMessageSequence {
/// @dev This sums up all the individual sendAdapter's quoteDeliveryPrice calls.
/// @param integrator The address of the integrator.
/// @param dstChain The Wormhole chain ID of the recipient.
/// @param adapterInstructions An array of adapter instructions to be passed into the adapters.
/// @return uint256 The total cost of delivering a message to the recipient chain in this chain's native token.
function quoteDeliveryPrice(address integrator, uint16 dstChain) external returns (uint256);
function quoteDeliveryPrice(address integrator, uint16 dstChain, bytes calldata adapterInstructions)
external
returns (uint256);

/// @notice Retrieves the quote for message delivery.
/// @dev This version must be called by the integrator.
/// @dev This sums up all the individual sendAdapter's quoteDeliveryPrice calls.
/// @param dstChain The Wormhole chain ID of the recipient.
/// @param adapterInstructions An array of adapter instructions to be passed into the adapters.
/// @return uint256 The total cost of delivering a message to the recipient chain in this chain's native token.
function quoteDeliveryPrice(uint16 dstChain) external view returns (uint256);
function quoteDeliveryPrice(uint16 dstChain, bytes calldata adapterInstructions) external view returns (uint256);
}
10 changes: 5 additions & 5 deletions evm/src/libraries/AdapterInstructions.sol
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ library AdapterInstructions {
/// @notice Encodes an adapter instruction.
/// @param instruction The instruction to be encoded.
/// @return encoded The encoded bytes, where the first byte is the index and the next two bytes are the instruction length.
function encodeInstruction(Instruction memory instruction) public pure returns (bytes memory encoded) {
function encodeInstruction(Instruction calldata instruction) public pure returns (bytes memory encoded) {
if (instruction.payload.length > type(uint16).max) {
revert PayloadTooLong(instruction.payload.length);
}
Expand All @@ -49,7 +49,7 @@ library AdapterInstructions {
/// @notice Encodes an array of adapter instructions.
/// @param instructions The array of instructions to be encoded.
/// @return address The encoded bytes, where the first byte is the number of entries.
function encodeInstructions(Instruction[] memory instructions) public pure returns (bytes memory) {
function encodeInstructions(Instruction[] calldata instructions) public pure returns (bytes memory) {
if (instructions.length > type(uint8).max) {
revert TooManyInstructions();
}
Expand All @@ -69,7 +69,7 @@ library AdapterInstructions {
/// @notice Parses a byte array into an adapter instruction.
/// @param encoded The encoded instruction.
/// @return instruction The parsed instruction.
function parseInstruction(bytes memory encoded) public pure returns (Instruction memory instruction) {
function parseInstruction(bytes calldata encoded) public pure returns (Instruction memory instruction) {
uint256 offset = 0;
(instruction, offset) = parseInstructionUnchecked(encoded, offset);
encoded.checkLength(offset);
Expand All @@ -80,7 +80,7 @@ library AdapterInstructions {
/// @param offset The current offset into the encoded buffer.
/// @return instruction The parsed instruction.
/// @return nextOffset The next index into the array (used for further parsing).
function parseInstructionUnchecked(bytes memory encoded, uint256 offset)
function parseInstructionUnchecked(bytes calldata encoded, uint256 offset)
public
pure
returns (Instruction memory instruction, uint256 nextOffset)
Expand All @@ -95,7 +95,7 @@ library AdapterInstructions {
/// @param encoded The encoded instructions.
/// @param numRegisteredAdapters The total number of registered adapters.
/// @return instructions A sparse array of adapter instructions, where the index into the array is the adapter index.
function parseInstructions(bytes memory encoded, uint256 numRegisteredAdapters)
function parseInstructions(bytes calldata encoded, uint256 numRegisteredAdapters)
public
pure
returns (Instruction[] memory instructions)
Expand Down
41 changes: 31 additions & 10 deletions evm/test/Endpoint.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import "../src/libraries/UniversalAddress.sol";
import {Endpoint} from "../src/Endpoint.sol";
import {AdapterRegistry} from "../src/AdapterRegistry.sol";
import {IAdapter} from "../src/interfaces/IAdapter.sol";
import "../src/libraries/AdapterInstructions.sol";

contract EndpointImpl is Endpoint {
uint16 public constant OurChainId = 0x2714;
Expand Down Expand Up @@ -350,11 +351,13 @@ contract EndpointTest is Test {
vm.expectEmit(true, true, false, true);
emit Endpoint.IntegratorRegistered(address(integrator), address(admin));
endpoint.register(admin);
bytes memory insts = defaultInsts();

// Sending with no adapters should revert.
vm.startPrank(integrator);
vm.expectRevert(abi.encodeWithSelector(Endpoint.AdapterNotEnabled.selector));
uint64 sequence = endpoint.sendMessage(2, UniversalAddressLibrary.fromAddress(userA), payloadHash, refundAddr);
uint64 sequence =
endpoint.sendMessage(2, UniversalAddressLibrary.fromAddress(userA), payloadHash, refundAddr, insts);

// Now enable some adapters.
vm.startPrank(admin);
Expand All @@ -380,7 +383,8 @@ contract EndpointTest is Test {
// Only an integrator can call send.
vm.startPrank(userA);
vm.expectRevert(abi.encodeWithSelector(Endpoint.AdapterNotEnabled.selector));
sequence = endpoint.sendMessage(chain, UniversalAddressLibrary.fromAddress(userA), payloadHash, refundAddr);
sequence =
endpoint.sendMessage(chain, UniversalAddressLibrary.fromAddress(userA), payloadHash, refundAddr, insts);

// Send a message on chain two. It should go out on the first two adapters, but not the third one.
vm.startPrank(integrator);
Expand All @@ -400,24 +404,34 @@ contract EndpointTest is Test {
chain,
payloadHash
);
sequence = endpoint.sendMessage(chain, UniversalAddressLibrary.fromAddress(userA), payloadHash, refundAddr);
sequence =
endpoint.sendMessage(chain, UniversalAddressLibrary.fromAddress(userA), payloadHash, refundAddr, insts);
require(sequence == 0, "Sequence number is wrong");
require(adapter1.getMessagesSent() == 1, "Failed to send a message on adapter 1");
require(adapter2.getMessagesSent() == 1, "Failed to send a message on adapter 2");
require(adapter3.getMessagesSent() == 0, "Should not have sent a message on adapter 3");

sequence = endpoint.sendMessage(chain, UniversalAddressLibrary.fromAddress(userA), payloadHash, refundAddr);
sequence =
endpoint.sendMessage(chain, UniversalAddressLibrary.fromAddress(userA), payloadHash, refundAddr, insts);
require(sequence == 1, "Second sequence number is wrong");
require(adapter1.getMessagesSent() == 2, "Failed to send second message on adapter 1");
require(adapter2.getMessagesSent() == 2, "Failed to send second message on adapter 2");
require(adapter3.getMessagesSent() == 0, "Should not have sent second message on adapter 3");

vm.expectRevert(abi.encodeWithSelector(AdapterRegistry.InvalidChain.selector, zeroChain));
sequence = endpoint.sendMessage(zeroChain, UniversalAddressLibrary.fromAddress(userA), payloadHash, refundAddr);
sequence =
endpoint.sendMessage(zeroChain, UniversalAddressLibrary.fromAddress(userA), payloadHash, refundAddr, insts);
require(sequence == 0, "Failed sequence number is wrong"); // 0 because of the revert

sequence = endpoint.sendMessage(chain, UniversalAddressLibrary.fromAddress(userA), payloadHash, refundAddr);
sequence =
endpoint.sendMessage(chain, UniversalAddressLibrary.fromAddress(userA), payloadHash, refundAddr, insts);
require(sequence == 2, "Third sequence number is wrong");

// Should be able to send with no adapter instructions.
sequence = endpoint.sendMessage(
chain, UniversalAddressLibrary.fromAddress(userA), payloadHash, refundAddr, new bytes(0)
);
require(sequence == 3, "Fourth sequence number is wrong");
}

function test_attestMessage() public {
Expand Down Expand Up @@ -729,28 +743,35 @@ contract EndpointTest is Test {
adapter1.setDeliveryPrice(100);
adapter2.setDeliveryPrice(200);
adapter3.setDeliveryPrice(300);
bytes memory insts = defaultInsts();

// Now enable some adapters.
vm.startPrank(admin);
endpoint.addAdapter(integrator, address(adapter1));
endpoint.enableSendAdapter(integrator, chain, address(adapter1));
uint256 price = endpoint.quoteDeliveryPrice(integrator, chain);
uint256 price = endpoint.quoteDeliveryPrice(integrator, chain, insts);
require(price == 100, "Single price is wrong");
endpoint.addAdapter(integrator, address(adapter2));
endpoint.enableSendAdapter(integrator, chain, address(adapter2));
price = endpoint.quoteDeliveryPrice(integrator, chain);
price = endpoint.quoteDeliveryPrice(integrator, chain, insts);
require(price == 300, "Double price is wrong");
endpoint.addAdapter(integrator, address(adapter3));
endpoint.enableSendAdapter(integrator, 3, address(adapter3));
price = endpoint.quoteDeliveryPrice(integrator, chain);
price = endpoint.quoteDeliveryPrice(integrator, chain, insts);
require(price == 300, "Triple price is wrong");
vm.startPrank(integrator);
price = endpoint.quoteDeliveryPrice(chain);
price = endpoint.quoteDeliveryPrice(chain, insts);
require(price == 300, "Triple price is wrong");
}

function test_getNumEnabledRecvAdaptersForChain() public view {
// This function is actually tested in the AdapterRegistry tests. Just call it here for code coverage.
require(endpoint.getNumEnabledRecvAdaptersForChain(address(0x01), 2) == 0, "Count should be zero");
}

/// @dev Builds empty adapter instructions.
function defaultInsts() public pure returns (bytes memory encoded) {
AdapterInstructions.Instruction[] memory insts = new AdapterInstructions.Instruction[](0);
encoded = AdapterInstructions.encodeInstructions(insts);
}
}

0 comments on commit 8f9b270

Please sign in to comment.