Skip to content

Commit

Permalink
evm: Make getSendAdaptersByChain also return index
Browse files Browse the repository at this point in the history
  • Loading branch information
bruce-riley committed Dec 9, 2024
1 parent 4607891 commit f442f4f
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 23 deletions.
33 changes: 22 additions & 11 deletions evm/src/AdapterRegistry.sol
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ abstract contract AdapterRegistry {
uint8 index; // the index into the integrator's adapters array
}

/// @dev Data maintained for each send adapter enabled for an integrator and chain.
struct PerSendAdapterInfo {
address addr;
uint8 index;
}

/// @dev Bitmap encoding the enabled adapters.
struct _EnabledAdapterBitmap {
uint128 bitmap; // MAX_ADAPTERS = 128
Expand Down Expand Up @@ -76,7 +82,7 @@ abstract contract AdapterRegistry {
error AdapterAlreadyDisabled(address adapter);

/// @notice Error when the number of registered adapters
/// exceeeds (MAX_ADAPTERS = 128).
/// exceeds (MAX_ADAPTERS = 128).
/// @dev Selector: 0x5bde12c0.
error TooManyAdapters();

Expand Down Expand Up @@ -143,7 +149,7 @@ abstract contract AdapterRegistry {
function _getPerChainSendAdapterArrayStorage()
private
pure
returns (mapping(address => mapping(uint16 => address[])) storage $)
returns (mapping(address => mapping(uint16 => PerSendAdapterInfo[])) storage $)
{
uint256 slot = uint256(ENABLED_SEND_ADAPTER_ARRAY_SLOT);
assembly ("memory-safe") {
Expand Down Expand Up @@ -235,9 +241,10 @@ abstract contract AdapterRegistry {
if (_isSendAdapterEnabledForChain(integrator, chain, adapter)) {
revert AdapterAlreadyEnabled(adapter);
}
mapping(address => mapping(uint16 => address[])) storage sendAdapterArray =
uint8 index = _getAdapterInfosStorage()[integrator][adapter].index;
mapping(address => mapping(uint16 => PerSendAdapterInfo[])) storage sendAdapterArray =
_getPerChainSendAdapterArrayStorage();
sendAdapterArray[integrator][chain].push(adapter);
sendAdapterArray[integrator][chain].push(PerSendAdapterInfo({addr: adapter, index: index}));
emit SendAdapterEnabledForChain(integrator, chain, adapter);
}

Expand Down Expand Up @@ -268,16 +275,16 @@ abstract contract AdapterRegistry {
internal
onlyRegisteredAdapter(integrator, chain, adapter)
{
mapping(address => mapping(uint16 => address[])) storage enabledSendAdapters =
mapping(address => mapping(uint16 => PerSendAdapterInfo[])) storage enabledSendAdapters =
_getPerChainSendAdapterArrayStorage();
address[] storage adapters = enabledSendAdapters[integrator][chain];
PerSendAdapterInfo[] storage adapters = enabledSendAdapters[integrator][chain];

// Get the index of the disabled adapter in the enabled adapters array
// and replace it with the last element in the array.
uint256 len = adapters.length;
bool found = false;
for (uint256 i = 0; i < len;) {
if (adapters[i] == adapter) {
if (adapters[i].addr == adapter) {
// Swap the last element with the element to be removed
adapters[i] = adapters[len - 1];
// Remove the last element
Expand Down Expand Up @@ -342,10 +349,10 @@ abstract contract AdapterRegistry {
view
returns (bool)
{
address[] storage adapters = _getPerChainSendAdapterArrayStorage()[integrator][chain];
PerSendAdapterInfo[] storage adapters = _getPerChainSendAdapterArrayStorage()[integrator][chain];
uint256 length = adapters.length;
for (uint256 i = 0; i < length;) {
if (adapters[i] == adapter) {
if (adapters[i].addr == adapter) {
return true;
}
unchecked {
Expand Down Expand Up @@ -388,7 +395,7 @@ abstract contract AdapterRegistry {
internal
view
virtual
returns (address[] storage array)
returns (PerSendAdapterInfo[] storage array)
{
if (chain == 0) {
revert InvalidChain(chain);
Expand Down Expand Up @@ -453,7 +460,11 @@ abstract contract AdapterRegistry {
/// @param integrator The integrator address.
/// @param chain The Wormhole chain ID for the desired adapters.
/// @return result The enabled send side adapters for the given integrator and chain.
function getSendAdaptersByChain(address integrator, uint16 chain) public view returns (address[] memory result) {
function getSendAdaptersByChain(address integrator, uint16 chain)
public
view
returns (PerSendAdapterInfo[] memory result)
{
if (chain == 0) {
revert InvalidChain(chain);
}
Expand Down
10 changes: 5 additions & 5 deletions evm/src/Endpoint.sol
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ contract Endpoint is IEndpointAdmin, IEndpointIntegrator, IEndpointAdapter, Mess
returns (uint64 sequence)
{
// get the enabled send adapters for [msg.sender][dstChain]
address[] memory sendAdapters = getSendAdaptersByChain(msg.sender, dstChain);
PerSendAdapterInfo[] memory sendAdapters = getSendAdaptersByChain(msg.sender, dstChain);
uint256 len = sendAdapters.length;
if (len == 0) {
revert AdapterNotEnabled();
Expand All @@ -395,9 +395,9 @@ contract Endpoint is IEndpointAdmin, IEndpointIntegrator, IEndpointAdapter, Mess
for (uint256 i = 0; i < len;) {
bytes memory adapterInstructions; // TODO: Pass this in.
// quote the delivery price
uint256 deliveryPrice = IAdapter(sendAdapters[i]).quoteDeliveryPrice(dstChain, adapterInstructions);
uint256 deliveryPrice = IAdapter(sendAdapters[i].addr).quoteDeliveryPrice(dstChain, adapterInstructions);
// call sendMessage
IAdapter(sendAdapters[i]).sendMessage{value: deliveryPrice}(
IAdapter(sendAdapters[i].addr).sendMessage{value: deliveryPrice}(
sender, sequence, dstChain, dstAddr, payloadHash, refundAddress, adapterInstructions
);
unchecked {
Expand Down Expand Up @@ -574,12 +574,12 @@ contract Endpoint is IEndpointAdmin, IEndpointIntegrator, IEndpointAdapter, Mess
/// @param dstChain The Wormhole chain ID of the recipient.
/// @return totalCost The total cost of delivering a message to the recipient chain in this chain's native token.
function _quoteDeliveryPrice(address integrator, uint16 dstChain) internal view returns (uint256 totalCost) {
address[] memory sendAdapters = getSendAdaptersByChain(integrator, dstChain);
PerSendAdapterInfo[] memory sendAdapters = getSendAdaptersByChain(integrator, dstChain);
uint256 len = sendAdapters.length;
totalCost = 0;
for (uint256 i = 0; i < len;) {
bytes memory adapterInstructions; // TODO: Pass this in.
totalCost += IAdapter(sendAdapters[i]).quoteDeliveryPrice(dstChain, adapterInstructions);
totalCost += IAdapter(sendAdapters[i].addr).quoteDeliveryPrice(dstChain, adapterInstructions);
unchecked {
++i;
}
Expand Down
6 changes: 3 additions & 3 deletions evm/test/AdapterRegistry.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ contract ConcreteAdapterRegistry is AdapterRegistry {
function getEnabledSendAdaptersBitmapForChain(address integrator, uint16 chain)
public
view
returns (address[] memory adapters)
returns (PerSendAdapterInfo[] memory adapters)
{
return _getEnabledSendAdaptersArrayForChain(integrator, chain);
}
Expand Down Expand Up @@ -298,9 +298,9 @@ contract AdapterRegistryTest is Test {
adapterRegistry.addAdapter(me, adapter3);
adapterRegistry.enableSendAdapter(me, chain2, adapter3);
adapterRegistry.addAdapter(me, adapter4);
address[] memory chain1Addrs = adapterRegistry.getSendAdaptersByChain(me, chain1);
AdapterRegistry.PerSendAdapterInfo[] memory chain1Addrs = adapterRegistry.getSendAdaptersByChain(me, chain1);
require(chain1Addrs.length == 2, "Wrong number of adapters enabled on chain one");
address[] memory chain2Addrs = adapterRegistry.getSendAdaptersByChain(me, chain2);
AdapterRegistry.PerSendAdapterInfo[] memory chain2Addrs = adapterRegistry.getSendAdaptersByChain(me, chain2);
require(chain2Addrs.length == 1, "Wrong number of adapters enabled on chain two");
adapterRegistry.enableSendAdapter(me, chain2, adapter4);
adapterRegistry.disableSendAdapter(me, chain2, adapter3);
Expand Down
11 changes: 7 additions & 4 deletions evm/test/Endpoint.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ contract EndpointTest is Test {
vm.expectEmit(true, true, false, true);
emit AdapterRegistry.AdapterAdded(integrator, taddr2, 2);
endpoint.addAdapter(integrator, taddr2);
address[] memory adapters = endpoint.getSendAdaptersByChain(integrator, 1);
AdapterRegistry.PerSendAdapterInfo[] memory adapters = endpoint.getSendAdaptersByChain(integrator, 1);
require(adapters.length == 1, "Wrong number of adapters enabled on chain one, should be 1");
// Enable another adapter on chain one and one on chain two.
vm.expectEmit(true, true, false, true);
Expand All @@ -255,11 +255,14 @@ contract EndpointTest is Test {
// And verify they got set properly.
adapters = endpoint.getSendAdaptersByChain(integrator, 1);
require(adapters.length == 2, "Wrong number of adapters enabled on chain one");
require(adapters[0] == taddr1, "Wrong adapter one on chain one");
require(adapters[1] == taddr2, "Wrong adapter two on chain one");
require(adapters[0].addr == taddr1, "Wrong adapter one on chain one");
require(adapters[0].index == 0, "Wrong adapter index one on chain one");
require(adapters[1].addr == taddr2, "Wrong adapter two on chain one");
require(adapters[1].index == 1, "Wrong adapter index two on chain one");
adapters = endpoint.getSendAdaptersByChain(integrator, 2);
require(adapters.length == 1, "Wrong number of adapters enabled on chain two");
require(adapters[0] == taddr3, "Wrong adapter one on chain two");
require(adapters[0].addr == taddr3, "Wrong adapter one on chain two");
require(adapters[0].index == 2, "Wrong adapter index one on chain two");
vm.expectEmit(true, true, false, true);
emit AdapterRegistry.SendAdapterDisabledForChain(integrator, 2, taddr3);
endpoint.disableSendAdapter(integrator, 2, taddr3);
Expand Down

0 comments on commit f442f4f

Please sign in to comment.