Skip to content

Instantly share code, notes, and snippets.

@0xJCN
Created April 9, 2023 21:28
Show Gist options
  • Save 0xJCN/b03b5f1f8cabc937c8bbe4d4a46b8d47 to your computer and use it in GitHub Desktop.
Save 0xJCN/b03b5f1f8cabc937c8bbe4d4a46b8d47 to your computer and use it in GitHub Desktop.

Diff for Factory.sol, with all optimizations applied

diff --git a/src/Factory.sol b/src/Factory.sol
index 09cbb4e..c3e4953 100644
--- a/src/Factory.sol
+++ b/src/Factory.sol
@@ -116,8 +116,39 @@ contract Factory is ERC721, Owned {
         }

         // deposit the nfts from the caller into the pool
-        for (uint256 i = 0; i < tokenIds.length; i++) {
-            ERC721(_nft).safeTransferFrom(msg.sender, address(privatePool), tokenIds[i]);
+        assembly {
+            // check tokenIds length > 0
+            if mload(tokenIds) {
+                // cache free memory pointer
+                let memptr := mload(0x40)
+                // cache end of tokenIds array
+                let end := add(add(tokenIds, 0x20), mul(0x20, mload(tokenIds)))
+                // cache index where items in tokenIds array starts
+                let i := add(tokenIds, 0x20)
+                // pre-load `safeTransferFrom` function signature into memory
+                mstore(0x00, 0x42842e0e)
+                // pre-load 1st param into memory
+                mstore(0x20, caller())
+                // pre-load 2nd param into memory
+                mstore(0x40, privatePool)
+                // infinite loop
+                for {} 1 {} {
+                    // load 3rd param into memory (tokenIds[i])
+                    mstore(0x60, mload(i))
+                    // call _nft
+                    let success := call(gas(), calldataload(0x24), 0x00, 0x1c, 0x64, 0x00, 0x00)
+                    if iszero(success) {
+                        revert(0, 0)
+                    }
+                    // increment i
+                    i := add(i, 0x20)
+                    if iszero(lt(i, end)) { break }
+                }
+                // restore free memory pointer
+                mstore(0x40, memptr)
+                // restore zero slot
+                mstore(0x60, 0x00)
+            }
         }

Diff for EthRouter.sol, with all optimizations applied

diff --git a/src/EthRouter.sol b/src/EthRouter.sol
index 125001d..97fb7a8 100644
--- a/src/EthRouter.sol
+++ b/src/EthRouter.sol
@@ -98,25 +98,29 @@ contract EthRouter is ERC721TokenReceiver {
     /// @param payRoyalties Whether to pay royalties or not.
     function buy(Buy[] calldata buys, uint256 deadline, bool payRoyalties) public payable {
         // check that the deadline has not passed (if any)
-        if (block.timestamp > deadline && deadline != 0) {
-            revert DeadlinePassed();
+        if (block.timestamp > deadline) {
+            if (deadline != 0) {
+                revert DeadlinePassed();
+            }
         }

         // loop through and execute the the buys
         for (uint256 i = 0; i < buys.length; i++) {
-            if (buys[i].isPublicPool) {
+            Buy calldata _buy = buys[i];
+            uint256[] calldata _tokenIds = _buy.tokenIds;
+            if (_buy.isPublicPool) {
                 // execute the buy against a public pool
-                uint256 inputAmount = Pair(buys[i].pool).nftBuy{value: buys[i].baseTokenAmount}(
-                    buys[i].tokenIds, buys[i].baseTokenAmount, 0
+                uint256 inputAmount = Pair(_buy.pool).nftBuy{value: _buy.baseTokenAmount}(
+                    _tokenIds, _buy.baseTokenAmount, 0
                 );

                 // pay the royalties if buyer has opted-in
                 if (payRoyalties) {
-                    uint256 salePrice = inputAmount / buys[i].tokenIds.length;
-                    for (uint256 j = 0; j < buys[i].tokenIds.length; j++) {
+                    uint256 salePrice = inputAmount / _tokenIds.length;
+                    for (uint256 j = 0; j < _tokenIds.length; j++) {
                         // get the royalty fee and recipient
                         (uint256 royaltyFee, address royaltyRecipient) =
-                            getRoyalty(buys[i].nft, buys[i].tokenIds[j], salePrice);
+                            getRoyalty(_buy.nft, _tokenIds[j], salePrice);

                         if (royaltyFee > 0) {
                             // transfer the royalty fee to the royalty recipient
@@ -126,14 +130,44 @@ contract EthRouter is ERC721TokenReceiver {
                 }
             } else {
                 // execute the buy against a private pool
-                PrivatePool(buys[i].pool).buy{value: buys[i].baseTokenAmount}(
-                    buys[i].tokenIds, buys[i].tokenWeights, buys[i].proof
+                PrivatePool(_buy.pool).buy{value: _buy.baseTokenAmount}(
+                    _tokenIds, _buy.tokenWeights, _buy.proof
                 );
             }

-            for (uint256 j = 0; j < buys[i].tokenIds.length; j++) {
-                // transfer the NFT to the caller
-                ERC721(buys[i].nft).safeTransferFrom(address(this), msg.sender, buys[i].tokenIds[j]);
+            assembly {
+                // check _tokenIds.length > 0
+                if _tokenIds.length {
+                    // cache free memory pointer
+                    let memptr := mload(0x40)
+                    // cache end of _tokenIds array
+                    let end := add(_tokenIds.offset, mul(0x20, _tokenIds.length))
+                    // cache index where items in _tokenIds array start
+                    let j := _tokenIds.offset
+                    // pre-load `safeTransferFrom` function signature into memory
+                    mstore(0x00, 0x42842e0e)
+                    // pre-load 1st param into memory
+                    mstore(0x20, address())
+                    // pre-load 2nd param into memory
+                    mstore(0x40, caller())
+                    // infinite loop
+                    for {} 1 {} {
+                        // load 3rd param into memory (_tokenIds[j])
+                        mstore(0x60, calldataload(j))
+                        // call `buys[i].nft` (2nd struct item)
+                        let success := call(gas(), calldataload(add(_buy, 0x20)), 0x00, 0x1c, 0x64, 0x00, 0x00)
+                        if iszero(success) {
+                            revert(0, 0)
+                        }
+                        // increment j
+                        j := add(j, 0x20)
+                        if iszero(lt(j, end)) { break }
+                    }
+                    // restore free memory pointer
+                    mstore(0x40, memptr)
+                    // restore zero slot
+                    mstore(0x60, 0x00)
+                }
             }
         }

@@ -151,39 +185,74 @@ contract EthRouter is ERC721TokenReceiver {
     /// @param payRoyalties Whether to pay royalties or not.
     function sell(Sell[] calldata sells, uint256 minOutputAmount, uint256 deadline, bool payRoyalties) public {
         // check that the deadline has not passed (if any)
-        if (block.timestamp > deadline && deadline != 0) {
-            revert DeadlinePassed();
+        if (block.timestamp > deadline) {
+            if (deadline != 0) {
+                revert DeadlinePassed();
+            }
         }

         // loop through and execute the sells
         for (uint256 i = 0; i < sells.length; i++) {
             // transfer the NFTs into the router from the caller
-            for (uint256 j = 0; j < sells[i].tokenIds.length; j++) {
-                ERC721(sells[i].nft).safeTransferFrom(msg.sender, address(this), sells[i].tokenIds[j]);
+            Sell calldata _sell = sells[i];
+            uint256[] calldata _tokenIds = _sell.tokenIds;
+            assembly {
+                // check _tokenIds.length > 0
+                if _tokenIds.length {
+                    // cache free memory pointer
+                    let memptr := mload(0x40)
+                    // cache end of _tokenIds array
+                    let end := add(_tokenIds.offset, mul(0x20, _tokenIds.length))
+                    // cache index where items in _tokensIds array start
+                    let j := _tokenIds.offset
+                    // pre-load `safeTransferFrom` function signature into memory
+                    mstore(0x00, 0x42842e0e)
+                    // pre-load 1st param into memory
+                    mstore(0x20, caller())
+                    // pre-load 2nd param into memory
+                    mstore(0x40, address())
+                    // infinite loop
+                    for {} 1 {} {
+                        // load 3rd param into memory (_tokenIds[j])
+                        mstore(0x60, calldataload(j))
+                        // call `sells[i].nft` (2nd item in struct)
+                        let success := call(gas(), calldataload(add(_sell, 0x20)), 0x00, 0x1c, 0x64, 0x00, 0x00)
+                        if iszero(success) {
+                            revert(0, 0)
+                        }
+                        // increment j
+                        j := add(j, 0x20)
+                        if iszero(lt(j, end)) { break }
+                    }
+                    // restore free memory pointer
+                    mstore(0x40, memptr)
+                    // restore zero slot
+                    mstore(0x60, 0x00)
+                }
             }

             // approve the pair to transfer NFTs from the router
-            ERC721(sells[i].nft).setApprovalForAll(sells[i].pool, true);
+            ERC721(_sell.nft).setApprovalForAll(_sell.pool, true);

-            if (sells[i].isPublicPool) {
+            if (_sell.isPublicPool) {
                 // exceute the sell against a public pool
-                uint256 outputAmount = Pair(sells[i].pool).nftSell(
-                    sells[i].tokenIds,
+                uint256 outputAmount = Pair(_sell.pool).nftSell(
+                    _tokenIds,
                     0,
                     0,
-                    sells[i].publicPoolProofs,
+                    _sell.publicPoolProofs,
                     // ReservoirOracle.Message[] is the exact same as IStolenNftOracle.Message[] and can be
                     // decoded/encoded 1-to-1.
-                    abi.decode(abi.encode(sells[i].stolenNftProofs), (ReservoirOracle.Message[]))
+                    abi.decode(abi.encode(_sell.stolenNftProofs), (ReservoirOracle.Message[]))
                 );

                 // pay the royalties if seller has opted-in
                 if (payRoyalties) {
-                    uint256 salePrice = outputAmount / sells[i].tokenIds.length;
-                    for (uint256 j = 0; j < sells[i].tokenIds.length; j++) {
+                    uint256 salePrice = outputAmount / _tokenIds.length;
+                    for (uint256 j = 0; j < _tokenIds.length; j++) {
                         // get the royalty fee and recipient
                         (uint256 royaltyFee, address royaltyRecipient) =
-                            getRoyalty(sells[i].nft, sells[i].tokenIds[j], salePrice);
+                            getRoyalty(_sell.nft, _tokenIds[j], salePrice);

                         if (royaltyFee > 0) {
                             // transfer the royalty fee to the royalty recipient
@@ -193,8 +262,8 @@ contract EthRouter is ERC721TokenReceiver {
                 }
             } else {
                 // execute the sell against a private pool
-                PrivatePool(sells[i].pool).sell(
-                    sells[i].tokenIds, sells[i].tokenWeights, sells[i].proof, sells[i].stolenNftProofs
+                PrivatePool(_sell.pool).sell(
+                    _tokenIds, _sell.tokenWeights, _sell.proof, _sell.stolenNftProofs
                 );
             }
         }
@@ -225,8 +294,10 @@ contract EthRouter is ERC721TokenReceiver {
         uint256 deadline
     ) public payable {
         // check deadline has not passed (if any)
-        if (block.timestamp > deadline && deadline != 0) {
-            revert DeadlinePassed();
+        if (block.timestamp > deadline) {
+            if (deadline != 0) {
+                revert DeadlinePassed();
+            }
         }

         // check pool price is in between min and max
@@ -236,8 +307,39 @@ contract EthRouter is ERC721TokenReceiver {
         }

         // transfer NFTs from caller
-        for (uint256 i = 0; i < tokenIds.length; i++) {
-            ERC721(nft).safeTransferFrom(msg.sender, address(this), tokenIds[i]);
+        assembly {
+            // check tokenIds.length > 0
+            if tokenIds.length {
+                // cache free memory pointer
+                let memptr := mload(0x40)
+                // cache end of tokenIds array
+                let end := add(tokenIds.offset, mul(0x20, tokenIds.length))
+                // cache index where items in tokenIds array start
+                let i := tokenIds.offset
+                // pre-load `safeTransferFrom` function signature into memory
+                mstore(0x00, 0x42842e0e)
+                // pre-load 1st param into memory
+                mstore(0x20, caller())
+                // pre-load 2nd param into memory
+                mstore(0x40, address())
+                // infinite loop
+                for {} 1 {} {
+                    // load 3rd param into memory (tokenIds[i])
+                    mstore(0x60, calldataload(i))
+                    // call nft (2nd param from calldata)
+                    let success := call(gas(), calldataload(0x24), 0x00, 0x1c, 0x64, 0x00, 0x00)
+                    if iszero(success) {
+                        revert(0, 0)
+                    }
+                    // increment i
+                    i := add(i, 0x20)
+                    if iszero(lt(i, end)) { break }
+                }
+                // restore free memory pointer
+                mstore(0x40, memptr)
+                // restore zero slot
+                mstore(0x60, 0x00)
+            }
         }

         // approve pair to transfer NFTs from router
@@ -253,17 +355,52 @@ contract EthRouter is ERC721TokenReceiver {
     /// Set to 0 for deadline to be ignored.
     function change(Change[] calldata changes, uint256 deadline) public payable {
         // check deadline has not passed (if any)
-        if (block.timestamp > deadline && deadline != 0) {
-            revert DeadlinePassed();
+        if (block.timestamp > deadline) {
+            if (deadline != 0) {
+                revert DeadlinePassed();
+            }
         }

         // loop through and execute the changes
         for (uint256 i = 0; i < changes.length; i++) {
-            Change memory _change = changes[i];
+            Change calldata _change = changes[i];
+            uint256[] memory _inputTokenIds = _change.inputTokenIds;
+            uint256[] memory _outputTokenIds = _change.outputTokenIds;

             // transfer NFTs from caller
-            for (uint256 j = 0; j < changes[i].inputTokenIds.length; j++) {
-                ERC721(_change.nft).safeTransferFrom(msg.sender, address(this), _change.inputTokenIds[j]);
+            assembly {
+                // check _inputTokensIds.length > 0
+                if mload(_inputTokenIds) {
+                    // cache free memory pointer
+                    let memptr := mload(0x40)
+                    // cache end of _inputTokenIds array
+                    let end := add(add(_inputTokenIds, 0x20), mul(0x20, mload(_inputTokenIds)))
+                    // cache index where items in _inputTokenIds start
+                    let j := add(_inputTokenIds, 0x20)
+                    // pre-load `safeTransferFrom` function sig into memory
+                    mstore(0x00, 0x42842e0e)
+                    // pre-load 1st param into memory
+                    mstore(0x20, caller())
+                    // pre-load 2nd param into memory
+                    mstore(0x40, address())
+                    // infinite loop
+                    for {} 1 {} {
+                        // load 3rd param into memory (changes[i].inputTokenIds[j])
+                        mstore(0x60, mload(j))
+                        // call `changes[i].nft` (2nd item in struct)
+                        let success := call(gas(), calldataload(add(_change, 0x20)), 0x00, 0x1c, 0x64, 0x00, 0x00)
+                        if iszero(success) {
+                            revert(0, 0)
+                        }
+                        // increment j
+                        j := add(j, 0x20)
+                        if iszero(lt(j, end)) { break }
+                    }
+                    // restore free memory pointer
+                    mstore(0x40, memptr)
+                    // restore zero slot
+                    mstore(0x60, 0x00)
+                }
             }

             // approve pair to transfer NFTs from router
@@ -271,18 +408,49 @@ contract EthRouter is ERC721TokenReceiver {

             // execute change
             PrivatePool(_change.pool).change{value: msg.value}(
-                _change.inputTokenIds,
+                _inputTokenIds,
                 _change.inputTokenWeights,
                 _change.inputProof,
                 _change.stolenNftProofs,
-                _change.outputTokenIds,
+                _outputTokenIds,
                 _change.outputTokenWeights,
                 _change.outputProof
             );

             // transfer NFTs to caller
-            for (uint256 j = 0; j < changes[i].outputTokenIds.length; j++) {
-                ERC721(_change.nft).safeTransferFrom(address(this), msg.sender, _change.outputTokenIds[j]);
+            assembly {
+                // check _outputTokenIds.length > 0
+                if mload(_outputTokenIds) {
+                    // cache free memory pointer
+                    let memptr := mload(0x40)
+                    // cache end of _outputTokenIds array
+                    let end := add(add(_outputTokenIds, 0x20), mul(0x20, mload(_outputTokenIds)))
+                    // cache index where items in _outputTokenIds start
+                    let j := add(_outputTokenIds, 0x20)
+                    // pre-load `safeTransferFrom` function sig into memory
+                    mstore(0x00, 0x42842e0e)
+                    // pre-load 1st param into memory
+                    mstore(0x20, address())
+                    // pre-load 2nd param into memory
+                    mstore(0x40, caller())
+                    // infinite loop
+                    for {} 1 {} {
+                        // load 3rd param into memory (changes[i].outputTokenIds[j])
+                        mstore(0x60, mload(j))
+                        // call `changes[i].nft` (2nd item in struct)
+                        let success := call(gas(), calldataload(add(_change, 0x20)), 0x00, 0x1c, 0x64, 0x00, 0x00)
+                        if iszero(success) {
+                            revert(0, 0)
+                        }
+                        // increment j
+                        j := add(j, 0x20)
+                        if iszero(lt(j, end)) { break }
+                    }
+                    // restore free memory pointer
+                    mstore(0x40, memptr)
+                    // restore zero slot
+                    mstore(0x60, 0x00)
+                }
             }
         }

Diff for PrivatePool.sol, with all optimizations applied

diff --git a/src/PrivatePool.sol b/src/PrivatePool.sol
index 75991e1..2205880 100644
--- a/src/PrivatePool.sol
+++ b/src/PrivatePool.sol
@@ -215,31 +215,41 @@ contract PrivatePool is ERC721TokenReceiver {
     {
         // ~~~ Checks ~~~ //

+        // check that the caller sent 0 ETH if the base token is not ETH
+        // @audit: can not cache `baseToken` here due to `stack too deep` error
+        if (baseToken != address(0)) {
+            if (msg.value > 0) {
+                revert InvalidEthAmount();
+            }
+        }
+
         // calculate the sum of weights of the NFTs to buy
-        uint256 weightSum = sumWeightsAndValidateProof(tokenIds, tokenWeights, proof);
+        { // @audit: to fix `stack too deep error`
+            uint256 weightSum = sumWeightsAndValidateProof(tokenIds, tokenWeights, proof);

-        // calculate the required net input amount and fee amount
-        (netInputAmount, feeAmount, protocolFeeAmount) = buyQuote(weightSum);
+            // calculate the required net input amount and fee amount
+            (netInputAmount, feeAmount, protocolFeeAmount) = buyQuote(weightSum);

-        // check that the caller sent 0 ETH if the base token is not ETH
-        if (baseToken != address(0) && msg.value > 0) revert InvalidEthAmount();
+            // ~~~ Effects ~~~ //

-        // ~~~ Effects ~~~ //
-
-        // update the virtual reserves
-        virtualBaseTokenReserves += uint128(netInputAmount - feeAmount - protocolFeeAmount);
-        virtualNftReserves -= uint128(weightSum);
+            // update the virtual reserves
+            virtualBaseTokenReserves = virtualBaseTokenReserves + uint128(netInputAmount - feeAmount - protocolFeeAmount);
+            virtualNftReserves = virtualNftReserves - uint128(weightSum);
+        }

         // ~~~ Interactions ~~~ //

         // calculate the sale price (assume it's the same for each NFT even if weights differ)
         uint256 salePrice = (netInputAmount - feeAmount - protocolFeeAmount) / tokenIds.length;
         uint256 royaltyFeeAmount = 0;
+        // @audit: due to `stack too deep` error we can only have one extra stack variable here
+        // @audit: you can choose to cache `nft` or `payRoyalties`. Caching `payRoyalties` seems to save more gas
+        bool _payRoyalties = payRoyalties;
         for (uint256 i = 0; i < tokenIds.length; i++) {
             // transfer the NFT to the caller
             ERC721(nft).safeTransferFrom(address(this), msg.sender, tokenIds[i]);

-            if (payRoyalties) {
+            if (_payRoyalties) {
                 // get the royalty fee for the NFT
                 (uint256 royaltyFee,) = _getRoyalty(tokenIds[i], salePrice);

@@ -250,13 +260,14 @@ contract PrivatePool is ERC721TokenReceiver {

         // add the royalty fee amount to the net input aount
         netInputAmount += royaltyFeeAmount;
-
-        if (baseToken != address(0)) {
+
+        address _baseToken = baseToken;
+        if (_baseToken != address(0)) {
             // transfer the base token from the caller to the contract
-            ERC20(baseToken).safeTransferFrom(msg.sender, address(this), netInputAmount);
+            ERC20(_baseToken).safeTransferFrom(msg.sender, address(this), netInputAmount);

             // if the protocol fee is set then pay the protocol fee
-            if (protocolFeeAmount > 0) ERC20(baseToken).safeTransfer(address(factory), protocolFeeAmount);
+            if (protocolFeeAmount > 0) ERC20(_baseToken).safeTransfer(address(factory), protocolFeeAmount);
         } else {
             // check that the caller sent enough ETH to cover the net required input
             if (msg.value < netInputAmount) revert InvalidEthAmount();
@@ -268,17 +279,19 @@ contract PrivatePool is ERC721TokenReceiver {
             if (msg.value > netInputAmount) msg.sender.safeTransferETH(msg.value - netInputAmount);
         }

-        if (payRoyalties) {
+        if (_payRoyalties) {
             for (uint256 i = 0; i < tokenIds.length; i++) {
                 // get the royalty fee for the NFT
                 (uint256 royaltyFee, address recipient) = _getRoyalty(tokenIds[i], salePrice);

                 // transfer the royalty fee to the recipient if it's greater than 0
-                if (royaltyFee > 0 && recipient != address(0)) {
-                    if (baseToken != address(0)) {
-                        ERC20(baseToken).safeTransfer(recipient, royaltyFee);
-                    } else {
-                        recipient.safeTransferETH(royaltyFee);
+                if (royaltyFee > 0) {
+                    if (recipient != address(0)) {
+                        if (_baseToken != address(0)) {
+                            ERC20(_baseToken).safeTransfer(recipient, royaltyFee);
+                        } else {
+                            recipient.safeTransferETH(royaltyFee);
+                        }
                     }
                 }
             }
@@ -307,30 +320,35 @@ contract PrivatePool is ERC721TokenReceiver {
         // ~~~ Checks ~~~ //

         // calculate the sum of weights of the NFTs to sell
-        uint256 weightSum = sumWeightsAndValidateProof(tokenIds, tokenWeights, proof);
+        { // @audit: to avoid `stack too deep` error
+            uint256 weightSum = sumWeightsAndValidateProof(tokenIds, tokenWeights, proof);

-        // calculate the net output amount and fee amount
-        (netOutputAmount, feeAmount, protocolFeeAmount) = sellQuote(weightSum);
+            // calculate the net output amount and fee amount
+            (netOutputAmount, feeAmount, protocolFeeAmount) = sellQuote(weightSum);

-        //  check the nfts are not stolen
-        if (useStolenNftOracle) {
-            IStolenNftOracle(stolenNftOracle).validateTokensAreNotStolen(nft, tokenIds, stolenNftProofs);
-        }
+            //  check the nfts are not stolen
+            if (useStolenNftOracle) {
+                IStolenNftOracle(stolenNftOracle).validateTokensAreNotStolen(nft, tokenIds, stolenNftProofs);
+            }

-        // ~~~ Effects ~~~ //
+            // ~~~ Effects ~~~ //

-        // update the virtual reserves
-        virtualBaseTokenReserves -= uint128(netOutputAmount + protocolFeeAmount + feeAmount);
-        virtualNftReserves += uint128(weightSum);
+            // update the virtual reserves
+            virtualBaseTokenReserves = virtualBaseTokenReserves - uint128(netOutputAmount + protocolFeeAmount + feeAmount);
+            virtualNftReserves = virtualNftReserves + uint128(weightSum);
+        }

         // ~~~ Interactions ~~~ //

         uint256 royaltyFeeAmount = 0;
+        // @audit: due to `stack too deep` error we can only have one extra stack variable here
+        // @audit: you can choose to cache `nft` or `payRoyalties`. Caching `payRoyalties` seems to save more gas
+        bool _payRoyalties = payRoyalties;
         for (uint256 i = 0; i < tokenIds.length; i++) {
             // transfer each nft from the caller
             ERC721(nft).safeTransferFrom(msg.sender, address(this), tokenIds[i]);

-            if (payRoyalties) {
+            if (_payRoyalties) {
                 // calculate the sale price (assume it's the same for each NFT even if weights differ)
                 uint256 salePrice = (netOutputAmount + feeAmount + protocolFeeAmount) / tokenIds.length;

@@ -341,11 +359,13 @@ contract PrivatePool is ERC721TokenReceiver {
                 royaltyFeeAmount += royaltyFee;

                 // transfer the royalty fee to the recipient if it's greater than 0
-                if (royaltyFee > 0 && recipient != address(0)) {
-                    if (baseToken != address(0)) {
-                        ERC20(baseToken).safeTransfer(recipient, royaltyFee);
-                    } else {
-                        recipient.safeTransferETH(royaltyFee);
+                if (royaltyFee > 0) {
+                    if (recipient != address(0)) {
+                        if (baseToken != address(0)) {
+                            ERC20(baseToken).safeTransfer(recipient, royaltyFee);
+                        } else {
+                            recipient.safeTransferETH(royaltyFee);
+                        }
                     }
                 }
             }
@@ -353,8 +373,9 @@ contract PrivatePool is ERC721TokenReceiver {

         // subtract the royalty fee amount from the net output amount
         netOutputAmount -= royaltyFeeAmount;
-
-        if (baseToken == address(0)) {
+
+        address _baseToken = baseToken;
+        if (_baseToken == address(0)) {
             // transfer ETH to the caller
             msg.sender.safeTransferETH(netOutputAmount);

@@ -362,10 +383,10 @@ contract PrivatePool is ERC721TokenReceiver {
             if (protocolFeeAmount > 0) factory.safeTransferETH(protocolFeeAmount);
         } else {
             // transfer base tokens to the caller
-            ERC20(baseToken).transfer(msg.sender, netOutputAmount);
+            ERC20(_baseToken).transfer(msg.sender, netOutputAmount);

             // if the protocol fee is set then pay the protocol fee
-            if (protocolFeeAmount > 0) ERC20(baseToken).safeTransfer(address(factory), protocolFeeAmount);
+            if (protocolFeeAmount > 0) ERC20(_baseToken).safeTransfer(address(factory), protocolFeeAmount);
         }

         // emit the sell event
@@ -383,18 +404,23 @@ contract PrivatePool is ERC721TokenReceiver {
     /// @param outputTokenWeights The weights of the NFTs to receive.
     /// @param outputProof The merkle proof for the weights of each NFT to receive.
     function change(
-        uint256[] memory inputTokenIds,
+        uint256[] calldata inputTokenIds,
         uint256[] memory inputTokenWeights,
         MerkleMultiProof memory inputProof,
         IStolenNftOracle.Message[] memory stolenNftProofs,
-        uint256[] memory outputTokenIds,
+        uint256[] calldata outputTokenIds,
         uint256[] memory outputTokenWeights,
         MerkleMultiProof memory outputProof
     ) public payable returns (uint256 feeAmount, uint256 protocolFeeAmount) {
         // ~~~ Checks ~~~ //

         // check that the caller sent 0 ETH if base token is not ETH
-        if (baseToken != address(0) && msg.value > 0) revert InvalidEthAmount();
+        address _baseToken = baseToken;
+        if (_baseToken != address(0)) {
+            if (msg.value > 0) {
+                revert InvalidEthAmount();
+            }
+        }

         // check that NFTs are not stolen
         if (useStolenNftOracle) {
@@ -418,12 +444,12 @@ contract PrivatePool is ERC721TokenReceiver {

         // ~~~ Interactions ~~~ //

-        if (baseToken != address(0)) {
+        if (_baseToken != address(0)) {
             // transfer the fee amount of base tokens from the caller
-            ERC20(baseToken).safeTransferFrom(msg.sender, address(this), feeAmount);
+            ERC20(_baseToken).safeTransferFrom(msg.sender, address(this), feeAmount);

             // if the protocol fee is non-zero then transfer the protocol fee to the factory
-            if (protocolFeeAmount > 0) ERC20(baseToken).safeTransferFrom(msg.sender, factory, protocolFeeAmount);
+            if (protocolFeeAmount > 0) ERC20(_baseToken).safeTransferFrom(msg.sender, factory, protocolFeeAmount);
         } else {
             // check that the caller sent enough ETH to cover the fee amount and protocol fee
             if (msg.value < feeAmount + protocolFeeAmount) revert InvalidEthAmount();
@@ -438,13 +464,76 @@ contract PrivatePool is ERC721TokenReceiver {
         }

         // transfer the input nfts from the caller
-        for (uint256 i = 0; i < inputTokenIds.length; i++) {
-            ERC721(nft).safeTransferFrom(msg.sender, address(this), inputTokenIds[i]);
+        address _nft = nft;
+        assembly {
+            // check inputTokenIds.length > 0
+            if inputTokenIds.length {
+                // cache free memory pointer
+                let memptr := mload(0x40)
+                // cache end of inputTokenIds array
+                let end := add(inputTokenIds.offset, mul(0x20, inputTokenIds.length))
+                // cache index where items in inputTokenIds start
+                let i := inputTokenIds.offset
+                // pre-load `safeTransferFrom` function signature into memory
+                mstore(0x00, 0x42842e0e)
+                // pre-load 1st param into memory
+                mstore(0x20, caller())
+                // pre-load 2nd param into memory
+                mstore(0x40, address())
+                // infinite loop
+                for {} 1 {} {
+                    // load 3rd param into memory (inputTokenIds[i])
+                    mstore(0x60, calldataload(i))
+                    // call nft
+                    let success := call(gas(), _nft, 0x00, 0x1c, 0x64, 0x00, 0x00)
+                    if iszero(success) {
+                        revert(0, 0)
+                    }
+                    // increment i
+                    i := add(i, 0x20)
+                    if iszero(lt(i, end)) { break }
+                }
+                // restore free memory pointer
+                mstore(0x40, memptr)
+                // restore zero slot
+                mstore(0x60, 0x00)
+            }
         }

         // transfer the output nfts to the caller
-        for (uint256 i = 0; i < outputTokenIds.length; i++) {
-            ERC721(nft).safeTransferFrom(address(this), msg.sender, outputTokenIds[i]);
+        assembly {
+            // check outputTokenIds.length > 0
+            if outputTokenIds.length {
+                // cache free memory pointer
+                let memptr := mload(0x40)
+                // cache end of outputTokenIds array
+                let end := add(outputTokenIds.offset, mul(0x20, outputTokenIds.length))
+                // cache index where items in outputTokenIds start
+                let i := outputTokenIds.offset
+                // pre-load `safeTransferFrom` function signature into memory
+                mstore(0x00, 0x42842e0e)
+                // pre-load 1st param into memory
+                mstore(0x20, address())
+                // pre-load 2nd param into memory
+                mstore(0x40, caller())
+                // infinite loop
+                for {} 1 {} {
+                    // load 3rd param into memory (outputTokenIds[i])
+                    mstore(0x60, calldataload(i))
+                    // call nft
+                    let success := call(gas(), _nft, 0x00, 0x1c, 0x64, 0x00, 0x00)
+                    if iszero(success) {
+                        revert(0, 0)
+                    }
+                    // increment i
+                    i := add(i, 0x20)
+                    if iszero(lt(i, end)) { break }
+                }
+                // restore free memory pointer
+                mstore(0x40, memptr)
+                // restore zero slot
+                mstore(0x60, 0x00)
+            }
         }

         // emit the change event
@@ -456,7 +545,7 @@ contract PrivatePool is ERC721TokenReceiver {
     /// @param target The address of the target contract.
     /// @param data The data to send to the target contract.
     /// @return returnData The return data of the transaction.
-    function execute(address target, bytes memory data) public payable onlyOwner returns (bytes memory) {
+    function execute(address target, bytes calldata data) public payable onlyOwner returns (bytes memory) {
         // call the target with the value and data
         (bool success, bytes memory returnData) = target.call{value: msg.value}(data);

@@ -486,20 +575,54 @@ contract PrivatePool is ERC721TokenReceiver {

         // ensure the caller sent a valid amount of ETH if base token is ETH or that the caller sent 0 ETH if base token
         // is not ETH
-        if ((baseToken == address(0) && msg.value != baseTokenAmount) || (msg.value > 0 && baseToken != address(0))) {
+        address _baseToken = baseToken;
+        if ((_baseToken == address(0) && msg.value != baseTokenAmount) || (msg.value > 0 && _baseToken != address(0))) {
             revert InvalidEthAmount();
         }

         // ~~~ Interactions ~~~ //

         // transfer the nfts from the caller
-        for (uint256 i = 0; i < tokenIds.length; i++) {
-            ERC721(nft).safeTransferFrom(msg.sender, address(this), tokenIds[i]);
+        assembly {
+            // check tokenIds.length > 0
+            if tokenIds.length {
+                // cache nft
+                let _nft := sload(nft.slot)
+                // cache free memory pointer
+                let memptr := mload(0x40)
+                // cache end of tokenIds array
+                let end := add(tokenIds.offset, mul(0x20, tokenIds.length))
+                // cache index where items in tokenIds array start
+                let i := tokenIds.offset
+                // pre-load `safeTransferFrom` function signature into memory
+                mstore(0x00, 0x42842e0e)
+                // pre-load 1st param into memory
+                mstore(0x20, caller())
+                // pre-load 2nd param into memory
+                mstore(0x40, address())
+                // infinite loop
+                for {} 1 {} {
+                    // load 3rd param into memory (tokenIds[i])
+                    mstore(0x60, calldataload(i))
+                    // call nft
+                    let success := call(gas(), _nft, 0x00, 0x1c, 0x64, 0x00, 0x00)
+                    if iszero(success) {
+                        revert(0, 0)
+                    }
+                    // increment i
+                    i := add(i, 0x20)
+                    if iszero(lt(i, end)) { break }
+                }
+                // restore free memory pointer
+                mstore(0x40, memptr)
+                // restore zero slot
+                mstore(0x60, 0x00)
+            }
         }

-        if (baseToken != address(0)) {
+        if (_baseToken != address(0)) {
             // transfer the base tokens from the caller
-            ERC20(baseToken).safeTransferFrom(msg.sender, address(this), baseTokenAmount);
+            ERC20(_baseToken).safeTransferFrom(msg.sender, address(this), baseTokenAmount);
         }

         // emit the deposit event
@@ -515,8 +638,39 @@ contract PrivatePool is ERC721TokenReceiver {
         // ~~~ Interactions ~~~ //

         // transfer the nfts to the caller
-        for (uint256 i = 0; i < tokenIds.length; i++) {
-            ERC721(_nft).safeTransferFrom(address(this), msg.sender, tokenIds[i]);
+        assembly {
+            // check tokenIds.length > 0
+            if tokenIds.length {
+                // cache free memory pointer
+                let memptr := mload(0x40)
+                // cache end of tokenIds array
+                let end := add(tokenIds.offset, mul(0x20, tokenIds.length))
+                // cache index where items in tokenIds array start
+                let i := tokenIds.offset
+                // pre-load `safeTransferFrom` function signature into memory
+                mstore(0x00, 0x42842e0e)
+                // pre-load 1st param into memory
+                mstore(0x20, address())
+                // pre-load 2nd param into memory
+                mstore(0x40, caller())
+                // infinite loop
+                for {} 1 {} {
+                    // load 3rd param into memory
+                    mstore(0x60, calldataload(i))
+                    // call _nft (1st param in calldata)
+                    let success := call(gas(), calldataload(0x04), 0x00, 0x1c, 0x64, 0x00, 0x00)
+                    if iszero(success) {
+                        revert(0, 0)
+                    }
+                    // increment i
+                    i := add(i, 0x20)
+                    if iszero(lt(i, end)) { break }
+                }
+                // restore free memory pointer
+                mstore(0x40, memptr)
+                // restore zero slot
+                mstore(0x60, 0x00)
+            }
         }

         if (token == address(0)) {
@@ -632,7 +786,12 @@ contract PrivatePool is ERC721TokenReceiver {
         uint256 fee = flashFee(token, tokenId);

         // if base token is ETH then check that caller sent enough for the fee
-        if (baseToken == address(0) && msg.value < fee) revert InvalidEthAmount();
+        address _baseToken = baseToken;
+        if (_baseToken == address(0)) {
+            if (msg.value < fee) {
+                revert InvalidEthAmount();
+            }
+        }

         // transfer the NFT to the borrower
         ERC721(token).safeTransferFrom(address(this), address(receiver), tokenId);
@@ -648,7 +807,7 @@ contract PrivatePool is ERC721TokenReceiver {
         ERC721(token).safeTransferFrom(address(receiver), address(this), tokenId);

         // transfer the fee from the borrower
-        if (baseToken != address(0)) ERC20(baseToken).transferFrom(msg.sender, address(this), fee);
+        if (_baseToken != address(0)) ERC20(_baseToken).transferFrom(msg.sender, address(this), fee);

         return success;
     }
@@ -659,27 +818,56 @@ contract PrivatePool is ERC721TokenReceiver {
     /// @param proof The merkle proof for the weights of each NFT.
     /// @return sum The sum of the weights of each NFT.
     function sumWeightsAndValidateProof(
-        uint256[] memory tokenIds,
+        uint256[] calldata tokenIds,
         uint256[] memory tokenWeights,
         MerkleMultiProof memory proof
     ) public view returns (uint256) {
         // if the merkle root is not set then set the weight of each nft to be 1e18
-        if (merkleRoot == bytes32(0)) {
+        bytes32 _merkleRoot = merkleRoot;
+        if (_merkleRoot == bytes32(0)) {
             return tokenIds.length * 1e18;
         }

         uint256 sum;
         bytes32[] memory leafs = new bytes32[](tokenIds.length);
-        for (uint256 i = 0; i < tokenIds.length; i++) {
-            // create the leaf for the merkle proof
-            leafs[i] = keccak256(bytes.concat(keccak256(abi.encode(tokenIds[i], tokenWeights[i]))));
-
-            // sum each token weight
-            sum += tokenWeights[i];
+        assembly {
+            // check tokenIds.length > 0
+            if tokenIds.length {
+                // cache end of tokenIds array
+                let end := add(tokenIds.offset, mul(0x20, tokenIds.length))
+                // cache index where items in tokenIds array start
+                let i := tokenIds.offset
+                // cache index where items in tokenWeights array start
+                let j := add(tokenWeights, 0x20)
+                // cache index where items in leafs start
+                let k := add(leafs, 0x20)
+                // infinite loop
+                for {} 1 {} {
+                    // load `tokensId[i]` into memory
+                    mstore(0x00, calldataload(i))
+                    // load `tokenWeights[i]` into memory
+                    mstore(0x20, mload(j))
+                    // load `keccak(abi.encode(tokenIds[i], tokenWeights[i]))` into memory
+                    mstore(0x00, keccak256(0x00, 0x40))
+                    // hash the previous hash
+                    let hash := keccak256(0x00, 0x20)
+                    // store the hash in `leafs[i]`
+                    mstore(k, hash)
+                    // sum each token weight
+                    sum := add(sum, mload(j))
+                    // increment i
+                    i := add(i, 0x20)
+                    // increment j
+                    j := add(j, 0x20)
+                    // increment k
+                    k := add(k, 0x20)
+                    if iszero(lt(i, end)) { break }
+                }
+            }
         }

         // validate that the weights are valid against the merkle proof
-        if (!MerkleProofLib.verifyMultiProof(proof.proof, merkleRoot, leafs, proof.flags)) {
+        if (!MerkleProofLib.verifyMultiProof(proof.proof, _merkleRoot, leafs, proof.flags)) {
             revert InvalidMerkleProof();
         }

@@ -730,7 +918,8 @@ contract PrivatePool is ERC721TokenReceiver {
     /// @return protocolFeeAmount The protocol fee amount.
     function changeFeeQuote(uint256 inputAmount) public view returns (uint256 feeAmount, uint256 protocolFeeAmount) {
         // multiply the changeFee to get the fee per NFT (4 decimals of accuracy)
-        uint256 exponent = baseToken == address(0) ? 18 - 4 : ERC20(baseToken).decimals() - 4;
+        address _baseToken = baseToken;
+        uint256 exponent = _baseToken == address(0) ? 18 - 4 : ERC20(_baseToken).decimals() - 4;
         uint256 feePerNft = changeFee * 10 ** exponent;

         feeAmount = inputAmount * feePerNft / 1e18;
@@ -741,7 +930,8 @@ contract PrivatePool is ERC721TokenReceiver {
     /// @return price The price of the pool.
     function price() public view returns (uint256) {
         // ensure that the exponent is always to 18 decimals of accuracy
-        uint256 exponent = baseToken == address(0) ? 18 : (36 - ERC20(baseToken).decimals());
+        address _baseToken = baseToken;
+        uint256 exponent = _baseToken == address(0) ? 18 : (36 - ERC20(_baseToken).decimals());
         return (virtualBaseTokenReserves * 10 ** exponent) / virtualNftReserves;
     }
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment