diff --git a/src/LZEndpointMock.sol b/src/LZEndpointMock.sol index c7f479f..3cf07a7 100644 --- a/src/LZEndpointMock.sol +++ b/src/LZEndpointMock.sol @@ -1,10 +1,10 @@ -// SPDX-License-Identifier: BUSL-1.1 +// SPDX-License-Identifier: MIT pragma solidity ^0.8.4; pragma abicoder v2; -import "./interfaces/ILayerZeroReceiver.sol"; -import "./interfaces/ILayerZeroEndpoint.sol"; +import "../interfaces/ILayerZeroReceiver.sol"; +import "../interfaces/ILayerZeroEndpoint.sol"; /* mocking multi endpoint connection. @@ -19,17 +19,38 @@ contract LZEndpointMock is ILayerZeroEndpoint { uint16 public mockChainId; address payable public mockOracle; address payable public mockRelayer; - uint256 public mockBlockConfirmations; + uint public mockBlockConfirmations; uint16 public mockLibraryVersion; - uint256 public mockStaticNativeFee; + uint public mockStaticNativeFee; uint16 public mockLayerZeroVersion; uint public nativeFee; uint public zroFee; + bool nextMsgBLocked; + + struct StoredPayload { + uint64 payloadLength; + address dstAddress; + bytes32 payloadHash; + } + + struct QueuedPayload { + address dstAddress; + uint64 nonce; + bytes payload; + } // inboundNonce = [srcChainId][srcAddress]. mapping(uint16 => mapping(bytes => uint64)) public inboundNonce; // outboundNonce = [dstChainId][srcAddress]. mapping(uint16 => mapping(address => uint64)) public outboundNonce; + // storedPayload = [srcChainId][srcAddress] + mapping(uint16 => mapping(bytes => StoredPayload)) public storedPayload; + // msgToDeliver = [srcChainId][srcAddress] + mapping(uint16 => mapping(bytes => QueuedPayload[])) public msgsToDeliver; + + event UaForceResumeReceive(uint16 chainId, bytes srcAddress); + event PayloadCleared(uint16 srcChainId, bytes srcAddress, uint64 nonce, address dstAddress); + event PayloadStored(uint16 srcChainId, bytes srcAddress, address dstAddress, uint64 nonce, bytes payload, bytes reason); constructor(uint16 _chainId) { mockStaticNativeFee = 42; @@ -55,15 +76,17 @@ contract LZEndpointMock is ILayerZeroEndpoint { uint16 _chainId, bytes calldata _destination, bytes calldata _payload, - address payable, /*_refundAddress*/ - address, /*_zroPaymentAddress*/ - bytes memory dstGas + address payable, // _refundAddress + address, // _zroPaymentAddress + bytes memory _adapterParams ) external payable override { address destAddr = packedBytesToAddr(_destination); address lzEndpoint = lzEndpointLookup[destAddr]; require(lzEndpoint != address(0), "LayerZeroMock: destination LayerZero Endpoint not found"); + require(msg.value >= nativeFee * _payload.length, "LayerZeroMock: not enough native for fees"); + uint64 nonce; { nonce = ++outboundNonce[_chainId][msg.sender]; @@ -71,33 +94,78 @@ contract LZEndpointMock is ILayerZeroEndpoint { // Mock the relayer paying the dstNativeAddr the amount of extra native token { - uint256 dstNative; + uint extraGas; + uint dstNative; address dstNativeAddr; assembly { - dstNative := mload(add(dstGas, 66)) - dstNativeAddr := mload(add(dstGas, 86)) + extraGas := mload(add(_adapterParams, 34)) + dstNative := mload(add(_adapterParams, 66)) + dstNativeAddr := mload(add(_adapterParams, 86)) } - if (dstNativeAddr == 0x90F79bf6EB2c4f870365E785982E1f101E93b906) { - require(dstNative == 453, "Gas incorrect"); - require(1 != 1, "NativeGasParams check"); - } + // to simulate actually sending the ether, add a transfer call and ensure the LZEndpointMock contract has an ether balance } bytes memory bytesSourceUserApplicationAddr = addrToPackedBytes(address(msg.sender)); // cast this address to bytes - inboundNonce[_chainId][abi.encodePacked(msg.sender)] = nonce; - LZEndpointMock(lzEndpoint).receiveAndForward(destAddr, mockChainId, bytesSourceUserApplicationAddr, nonce, _payload); + // not using the extra gas parameter because this is a single tx call, not split between different chains + // LZEndpointMock(lzEndpoint).receivePayload(mockChainId, bytesSourceUserApplicationAddr, destAddr, nonce, extraGas, _payload); + LZEndpointMock(lzEndpoint).receivePayload(mockChainId, bytesSourceUserApplicationAddr, destAddr, nonce, 0, _payload); } - function receiveAndForward( - address _destAddr, + function receivePayload( uint16 _srcChainId, - bytes memory _srcAddress, + bytes calldata _srcAddress, + address _dstAddress, uint64 _nonce, - bytes memory _payload - ) external { - ILayerZeroReceiver(_destAddr).lzReceive(_srcChainId, _srcAddress, _nonce, _payload); // invoke lzReceive + uint, /*_gasLimit*/ + bytes calldata _payload + ) external override { + StoredPayload storage sp = storedPayload[_srcChainId][_srcAddress]; + + // assert and increment the nonce. no message shuffling + require(_nonce == ++inboundNonce[_srcChainId][_srcAddress], "LayerZero: wrong nonce"); + + // queue the following msgs inside of a stack to simulate a successful send on src, but not fully delivered on dst + if (sp.payloadHash != bytes32(0)) { + QueuedPayload[] storage msgs = msgsToDeliver[_srcChainId][_srcAddress]; + QueuedPayload memory newMsg = QueuedPayload(_dstAddress, _nonce, _payload); + + // warning, might run into gas issues trying to forward through a bunch of queued msgs + // shift all the msgs over so we can treat this like a fifo via array.pop() + if (msgs.length > 0) { + // extend the array + msgs.push(newMsg); + + // shift all the indexes up for pop() + for (uint i = 0; i < msgs.length - 1; i++) { + msgs[i + 1] = msgs[i]; + } + + // put the newMsg at the bottom of the stack + msgs[0] = newMsg; + } else { + msgs.push(newMsg); + } + } else if (nextMsgBLocked) { + storedPayload[_srcChainId][_srcAddress] = StoredPayload(uint64(_payload.length), _dstAddress, keccak256(_payload)); + emit PayloadStored(_srcChainId, _srcAddress, _dstAddress, _nonce, _payload, bytes("")); + // ensure the next msgs that go through are no longer blocked + nextMsgBLocked = false; + } else { + // we ignore the gas limit because this call is made in one tx due to being "same chain" + // ILayerZeroReceiver(_dstAddress).lzReceive{gas: _gasLimit}(_srcChainId, _srcAddress, _nonce, _payload); // invoke lzReceive + ILayerZeroReceiver(_dstAddress).lzReceive(_srcChainId, _srcAddress, _nonce, _payload); // invoke lzReceive + } + } + + // used to simulate messages received get stored as a payload + function blockNextMsg() external { + nextMsgBLocked = true; + } + + function getLengthOfQueue(uint16 _srcChainId, bytes calldata _srcAddress) external view returns (uint) { + return msgsToDeliver[_srcChainId][_srcAddress].length; } // @notice gets a quote in source native gas, for the amount that send() requires to pay for message delivery @@ -106,14 +174,8 @@ contract LZEndpointMock is ILayerZeroEndpoint { // @param _payload - the custom message to send over LayerZero // @param _payInZRO - if false, user app pays the protocol fee in native token // @param _adapterParam - parameters for the adapter service, e.g. send some dust native token to dstChain - function estimateFees( - uint16, - address, - bytes memory, - bool, - bytes memory - ) external override view returns (uint _nativeFee, uint _zroFee){ - _nativeFee = nativeFee; + function estimateFees(uint16, address, bytes memory _payload, bool, bytes memory) external view override returns (uint _nativeFee, uint _zroFee) { + _nativeFee = nativeFee * _payload.length; _zroFee = zroFee; } @@ -134,56 +196,114 @@ contract LZEndpointMock is ILayerZeroEndpoint { return data; } - function setConfig(uint16 /*_version*/, uint16 /*_chainId*/, uint /*_configType*/, bytes memory /*_config*/) override external { - } - function getConfig(uint16 /*_version*/, uint16 /*_chainId*/, address /*_ua*/, uint /*_configType*/) override pure external returns(bytes memory) { + function setConfig( + uint16, /*_version*/ + uint16, /*_chainId*/ + uint, /*_configType*/ + bytes memory /*_config*/ + ) external override {} + + function getConfig( + uint16, /*_version*/ + uint16, /*_chainId*/ + address, /*_ua*/ + uint /*_configType*/ + ) external pure override returns (bytes memory) { return ""; } - function receivePayload(uint16 _srcChainId, bytes calldata _srcAddress, address _dstAddress, uint64 _nonce, uint _gasLimit, bytes calldata _payload) external override {} + function setSendVersion( + uint16 /*version*/ + ) external override {} - function setSendVersion(uint16 /*version*/) override external { - } - function setReceiveVersion(uint16 /*version*/) override external { - } - function getSendVersion(address /*_userApplication*/) override external pure returns (uint16) { + function setReceiveVersion( + uint16 /*version*/ + ) external override {} + + function getSendVersion( + address /*_userApplication*/ + ) external pure override returns (uint16) { return 1; } - function getReceiveVersion(address /*_userApplication*/) override external pure returns (uint16){ + + function getReceiveVersion( + address /*_userApplication*/ + ) external pure override returns (uint16) { return 1; } - function getInboundNonce(uint16 _chainID, bytes calldata _srcAddress) override external view returns (uint64) { + function getInboundNonce(uint16 _chainID, bytes calldata _srcAddress) external view override returns (uint64) { return inboundNonce[_chainID][_srcAddress]; } - function getOutboundNonce(uint16 _chainID, address _srcAddress) override external view returns (uint64) { + function getOutboundNonce(uint16 _chainID, address _srcAddress) external view override returns (uint64) { return outboundNonce[_chainID][_srcAddress]; } - function forceResumeReceive(uint16 _srcChainId, bytes calldata _srcAddress) override external { - // This mock does not implement the forceResumeReceive + // simulates the relayer pushing through the rest of the msgs that got delayed due to the stored payload + function _clearMsgQue(uint16 _srcChainId, bytes calldata _srcAddress) internal { + QueuedPayload[] storage msgs = msgsToDeliver[_srcChainId][_srcAddress]; + + // warning, might run into gas issues trying to forward through a bunch of queued msgs + while (msgs.length > 0) { + QueuedPayload memory payload = msgs[msgs.length - 1]; + ILayerZeroReceiver(payload.dstAddress).lzReceive(_srcChainId, _srcAddress, payload.nonce, payload.payload); + msgs.pop(); + } } - function retryPayload(uint16 _srcChainId, bytes calldata _srcAddress, bytes calldata _payload) override pure external {} + function forceResumeReceive(uint16 _srcChainId, bytes calldata _srcAddress) external override { + StoredPayload storage sp = storedPayload[_srcChainId][_srcAddress]; + // revert if no messages are cached. safeguard malicious UA behaviour + require(sp.payloadHash != bytes32(0), "LayerZero: no stored payload"); + require(sp.dstAddress == msg.sender, "LayerZero: invalid caller"); + + // empty the storedPayload + sp.payloadLength = 0; + sp.dstAddress = address(0); + sp.payloadHash = bytes32(0); + + emit UaForceResumeReceive(_srcChainId, _srcAddress); + + // resume the receiving of msgs after we force clear the "stuck" msg + _clearMsgQue(_srcChainId, _srcAddress); + } + + function retryPayload(uint16 _srcChainId, bytes calldata _srcAddress, bytes calldata _payload) external override { + StoredPayload storage sp = storedPayload[_srcChainId][_srcAddress]; + require(sp.payloadHash != bytes32(0), "LayerZero: no stored payload"); + require(_payload.length == sp.payloadLength && keccak256(_payload) == sp.payloadHash, "LayerZero: invalid payload"); + + address dstAddress = sp.dstAddress; + // empty the storedPayload + sp.payloadLength = 0; + sp.dstAddress = address(0); + sp.payloadHash = bytes32(0); + + uint64 nonce = inboundNonce[_srcChainId][_srcAddress]; + + ILayerZeroReceiver(dstAddress).lzReceive(_srcChainId, _srcAddress, nonce, _payload); + emit PayloadCleared(_srcChainId, _srcAddress, nonce, dstAddress); + } - function hasStoredPayload(uint16, bytes memory) external pure override returns(bool) { - return true; + function hasStoredPayload(uint16 _srcChainId, bytes calldata _srcAddress) external view override returns (bool) { + StoredPayload storage sp = storedPayload[_srcChainId][_srcAddress]; + return sp.payloadHash != bytes32(0); } - function isSendingPayload() external override pure returns (bool) { + function isSendingPayload() external pure override returns (bool) { return false; } - function isReceivingPayload() external override pure returns (bool) { + function isReceivingPayload() external pure override returns (bool) { return false; } - function getSendLibraryAddress(address) external override view returns (address) { + function getSendLibraryAddress(address) external view override returns (address) { return address(this); } - function getReceiveLibraryAddress(address) external override view returns (address) { + function getReceiveLibraryAddress(address) external view override returns (address) { return address(this); } }