Skip to content

Instantly share code, notes, and snippets.

@MicahZoltu
Last active April 1, 2020 06:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save MicahZoltu/fb50ade993e591878bb5733f49dca4a2 to your computer and use it in GitHub Desktop.
Save MicahZoltu/fb50ade993e591878bb5733f49dca4a2 to your computer and use it in GitHub Desktop.
Merkle Patritia Proof Validator in Solidity
pragma solidity 0.6.4;
library RLP {
uint constant DATA_SHORT_START = 0x80;
uint constant DATA_LONG_START = 0xB8;
uint constant LIST_SHORT_START = 0xC0;
uint constant LIST_LONG_START = 0xF8;
uint constant DATA_LONG_OFFSET = 0xB7;
uint constant LIST_LONG_OFFSET = 0xF7;
struct RLPItem {
uint _unsafe_memPtr; // Pointer to the RLP-encoded bytes.
uint _unsafe_length; // Number of bytes. This is the full length of the string.
}
struct Iterator {
RLPItem _unsafe_item; // Item that's being iterated over.
uint _unsafe_nextPtr; // Position of the next item in the list.
}
/* Iterator */
function next(Iterator memory self) internal pure returns (RLPItem memory subItem) {
require(hasNext(self));
uint256 ptr = self._unsafe_nextPtr;
uint256 itemLength = _itemLength(ptr);
subItem._unsafe_memPtr = ptr;
subItem._unsafe_length = itemLength;
self._unsafe_nextPtr = ptr + itemLength;
}
function next(Iterator memory self, bool strict) internal pure returns (RLPItem memory subItem) {
subItem = next(self);
require(!strict || _validate(subItem));
}
function hasNext(Iterator memory self) internal pure returns (bool) {
RLP.RLPItem memory item = self._unsafe_item;
return self._unsafe_nextPtr < item._unsafe_memPtr + item._unsafe_length;
}
/* RLPItem */
/// @dev Creates an RLPItem from an array of RLP encoded bytes.
/// @param self The RLP encoded bytes.
/// @return An RLPItem
function toRLPItem(bytes memory self) internal pure returns (RLPItem memory) {
uint len = self.length;
if (len == 0) {
return RLPItem(0, 0);
}
uint memPtr;
assembly {
memPtr := add(self, 0x20)
}
return RLPItem(memPtr, len);
}
/// @dev Creates an RLPItem from an array of RLP encoded bytes.
/// @param self The RLP encoded bytes.
/// @param strict Will throw if the data is not RLP encoded.
/// @return An RLPItem
function toRLPItem(bytes memory self, bool strict) internal pure returns (RLPItem memory) {
RLP.RLPItem memory item = toRLPItem(self);
if(strict) {
uint len = self.length;
require(_payloadOffset(item) <= len);
require(_itemLength(item._unsafe_memPtr) == len);
require(_validate(item));
}
return item;
}
/// @dev Check if the RLP item is null.
/// @param self The RLP item.
/// @return 'true' if the item is null.
function isNull(RLPItem memory self) internal pure returns (bool) {
return self._unsafe_length == 0;
}
/// @dev Check if the RLP item is a list.
/// @param self The RLP item.
/// @return 'true' if the item is a list.
function isList(RLPItem memory self) internal pure returns (bool) {
if (self._unsafe_length == 0)
return false;
uint memPtr = self._unsafe_memPtr;
bool result;
assembly {
result := iszero(lt(byte(0, mload(memPtr)), 0xC0))
}
return result;
}
/// @dev Check if the RLP item is data.
/// @param self The RLP item.
/// @return 'true' if the item is data.
function isData(RLPItem memory self) internal pure returns (bool) {
if (self._unsafe_length == 0)
return false;
uint memPtr = self._unsafe_memPtr;
bool result;
assembly {
result := lt(byte(0, mload(memPtr)), 0xC0)
}
return result;
}
/// @dev Check if the RLP item is empty (string or list).
/// @param self The RLP item.
/// @return result 'true' if the item is null.
function isEmpty(RLPItem memory self) internal pure returns (bool) {
if(isNull(self))
return false;
uint b0;
uint memPtr = self._unsafe_memPtr;
assembly {
b0 := byte(0, mload(memPtr))
}
return (b0 == DATA_SHORT_START || b0 == LIST_SHORT_START);
}
/// @dev Get the number of items in an RLP encoded list.
/// @param self The RLP item.
/// @return The number of items.
function items(RLPItem memory self) internal pure returns (uint) {
if (!isList(self))
return 0;
uint b0;
uint memPtr = self._unsafe_memPtr;
assembly {
b0 := byte(0, mload(memPtr))
}
uint pos = memPtr + _payloadOffset(self);
uint last = memPtr + self._unsafe_length - 1;
uint itms;
while(pos <= last) {
pos += _itemLength(pos);
itms++;
}
return itms;
}
/// @dev Create an iterator.
/// @param self The RLP item.
/// @return An 'Iterator' over the item.
function iterator(RLPItem memory self) internal pure returns (Iterator memory) {
require(isList(self));
uint ptr = self._unsafe_memPtr + _payloadOffset(self);
Iterator memory it;
it._unsafe_item = self;
it._unsafe_nextPtr = ptr;
return it;
}
/// @dev Return the RLP encoded bytes.
/// @param self The RLPItem.
/// @return The bytes.
function toBytes(RLPItem memory self) internal pure returns (bytes memory) {
uint256 len = self._unsafe_length;
require(len != 0);
bytes memory bts;
bts = new bytes(len);
_copyToBytes(self._unsafe_memPtr, bts, len);
return bts;
}
/// @dev Decode an RLPItem into bytes. This will not work if the
/// RLPItem is a list.
/// @param self The RLPItem.
/// @return The decoded string.
function toData(RLPItem memory self) internal pure returns (bytes memory) {
require(isData(self));
(uint256 rStartPos, uint256 len) = _decode(self);
bytes memory bts;
bts = new bytes(len);
_copyToBytes(rStartPos, bts, len);
return bts;
}
/// @dev Get the list of sub-items from an RLP encoded list.
/// Warning: This is inefficient, as it requires that the list is read twice.
/// @param self The RLP item.
/// @return Array of RLPItems.
function toList(RLPItem memory self) internal pure returns (RLPItem[] memory) {
require(isList(self));
uint256 numItems = items(self);
RLPItem[] memory list = new RLPItem[](numItems);
RLP.Iterator memory it = iterator(self);
uint idx;
while(hasNext(it)) {
list[idx] = next(it);
idx++;
}
return list;
}
/// @dev Decode an RLPItem into an ascii string. This will not work if the
/// RLPItem is a list.
/// @param self The RLPItem.
/// @return The decoded string.
function toAscii(RLPItem memory self) internal pure returns (string memory) {
require(isData(self));
(uint256 rStartPos, uint256 len) = _decode(self);
bytes memory bts = new bytes(len);
_copyToBytes(rStartPos, bts, len);
string memory str = string(bts);
return str;
}
/// @dev Decode an RLPItem into a uint. This will not work if the
/// RLPItem is a list.
/// @param self The RLPItem.
/// @return The decoded string.
function toUint(RLPItem memory self) internal pure returns (uint) {
require(isData(self));
(uint256 rStartPos, uint256 len) = _decode(self);
require(len <= 32);
require(len != 0);
uint data;
assembly {
data := div(mload(rStartPos), exp(256, sub(32, len)))
}
return data;
}
/// @dev Decode an RLPItem into a boolean. This will not work if the
/// RLPItem is a list.
/// @param self The RLPItem.
/// @return The decoded string.
function toBool(RLPItem memory self) internal pure returns (bool) {
require(isData(self));
(uint256 rStartPos, uint256 len) = _decode(self);
require(len == 1);
uint temp;
assembly {
temp := byte(0, mload(rStartPos))
}
require(temp <= 1);
return temp == 1 ? true : false;
}
/// @dev Decode an RLPItem into a byte. This will not work if the
/// RLPItem is a list.
/// @param self The RLPItem.
/// @return The decoded string.
function toByte(RLPItem memory self) internal pure returns (byte) {
require(isData(self));
(uint256 rStartPos, uint256 len) = _decode(self);
require(len == 1);
byte temp;
assembly {
temp := byte(0, mload(rStartPos))
}
return byte(temp);
}
/// @dev Decode an RLPItem into an int. This will not work if the
/// RLPItem is a list.
/// @param self The RLPItem.
/// @return The decoded string.
function toInt(RLPItem memory self) internal pure returns (int) {
return int(toUint(self));
}
/// @dev Decode an RLPItem into a bytes32. This will not work if the
/// RLPItem is a list.
/// @param self The RLPItem.
/// @return The decoded string.
function toBytes32(RLPItem memory self) internal pure returns (bytes32) {
return bytes32(toUint(self));
}
/// @dev Decode an RLPItem into an address. This will not work if the
/// RLPItem is a list.
/// @param self The RLPItem.
/// @return The decoded string.
function toAddress(RLPItem memory self) internal pure returns (address) {
require(isData(self));
(uint256 rStartPos, uint256 len) = _decode(self);
require(len == 20);
address data;
assembly {
data := div(mload(rStartPos), exp(256, 12))
}
return data;
}
// Get the payload offset.
function _payloadOffset(RLPItem memory self) private pure returns (uint) {
if(self._unsafe_length == 0)
return 0;
uint b0;
uint memPtr = self._unsafe_memPtr;
assembly {
b0 := byte(0, mload(memPtr))
}
if(b0 < DATA_SHORT_START)
return 0;
if(b0 < DATA_LONG_START || (b0 >= LIST_SHORT_START && b0 < LIST_LONG_START))
return 1;
if(b0 < LIST_SHORT_START)
return b0 - DATA_LONG_OFFSET + 1;
return b0 - LIST_LONG_OFFSET + 1;
}
// Get the full length of an RLP item.
function _itemLength(uint memPtr) private pure returns (uint len) {
uint b0;
assembly {
b0 := byte(0, mload(memPtr))
}
if (b0 < DATA_SHORT_START)
len = 1;
else if (b0 < DATA_LONG_START)
len = b0 - DATA_SHORT_START + 1;
else if (b0 < LIST_SHORT_START) {
assembly {
let bLen := sub(b0, 0xB7) // bytes length (DATA_LONG_OFFSET)
let dLen := div(mload(add(memPtr, 1)), exp(256, sub(32, bLen))) // data length
len := add(1, add(bLen, dLen)) // total length
}
}
else if (b0 < LIST_LONG_START)
len = b0 - LIST_SHORT_START + 1;
else {
assembly {
let bLen := sub(b0, 0xF7) // bytes length (LIST_LONG_OFFSET)
let dLen := div(mload(add(memPtr, 1)), exp(256, sub(32, bLen))) // data length
len := add(1, add(bLen, dLen)) // total length
}
}
}
// Get start position and length of the data.
function _decode(RLPItem memory self) private pure returns (uint memPtr, uint len) {
require(isData(self));
uint b0;
uint start = self._unsafe_memPtr;
assembly {
b0 := byte(0, mload(start))
}
if (b0 < DATA_SHORT_START) {
memPtr = start;
len = 1;
return (memPtr, len);
}
if (b0 < DATA_LONG_START) {
len = self._unsafe_length - 1;
memPtr = start + 1;
} else {
uint bLen;
assembly {
bLen := sub(b0, 0xB7) // DATA_LONG_OFFSET
}
len = self._unsafe_length - 1 - bLen;
memPtr = start + bLen + 1;
}
return (memPtr, len);
}
// Assumes that enough memory has been allocated to store in target.
function _copyToBytes(uint btsPtr, bytes memory tgt, uint btsLen) private pure {
// Exploiting the fact that 'tgt' was the last thing to be allocated,
// we can write entire words, and just overwrite any excess.
assembly {
let words := div(add(btsLen, 31), 32)
let rOffset := btsPtr
let wOffset := add(tgt, 0x20)
for { let i := 0 } eq(i, words) { i := add(i, 1) }
{
let offset := mul(i, 0x20)
mstore(add(wOffset, offset), mload(add(rOffset, offset)))
i := add(i, 1)
}
mstore(add(tgt, add(0x20, mload(tgt))), 0)
}
}
// Check that an RLP item is valid.
function _validate(RLPItem memory self) private pure returns (bool ret) {
// Check that RLP is well-formed.
uint b0;
uint b1;
uint memPtr = self._unsafe_memPtr;
assembly {
b0 := byte(0, mload(memPtr))
b1 := byte(1, mload(memPtr))
}
if(b0 == DATA_SHORT_START + 1 && b1 < DATA_SHORT_START)
return false;
return true;
}
}
contract MerklePatriciaProof {
/*
* @dev Verifies a merkle patricia proof.
* @param value The terminating value in the trie.
* @param encodedPath The path in the trie leading to value.
* @param rlpParentNodes The rlp encoded stack of nodes.
* @param root The root hash of the trie.
* @return The boolean validity of the proof.
*/
function verify(bytes memory value, bytes memory encodedPath, bytes memory rlpParentNodes, bytes32 root) public pure returns (bool) {
RLP.RLPItem memory item = RLP.toRLPItem(rlpParentNodes);
RLP.RLPItem[] memory parentNodes = RLP.toList(item);
bytes memory currentNode;
RLP.RLPItem[] memory currentNodeList;
bytes32 nodeKey = root;
uint pathPtr = 0;
bytes memory path = _getNibbleArray(encodedPath);
if(path.length == 0) {return false;}
for (uint i=0; i<parentNodes.length; i++) {
require(pathPtr <= path.length, "Path overflow");
currentNode = RLP.toBytes(parentNodes[i]);
require(nodeKey == keccak256(currentNode), "node doesn't match key");
currentNodeList = RLP.toList(parentNodes[i]);
if(currentNodeList.length == 17) {
if(pathPtr == path.length) {
require(keccak256(RLP.toBytes(currentNodeList[16])) == keccak256(value), "terminating normal node hash doesn't match value hash");
return true;
}
uint8 nextPathNibble = uint8(path[pathPtr]);
require(nextPathNibble <= 16, "nibble too long");
nodeKey = RLP.toBytes32(currentNodeList[nextPathNibble]);
pathPtr += 1;
} else if(currentNodeList.length == 2) {
pathPtr += _nibblesToTraverse(RLP.toData(currentNodeList[0]), path, pathPtr);
if(pathPtr == path.length) {//leaf node
require(keccak256(RLP.toData(currentNodeList[1])) == keccak256(value), "leaf array hash doesn't match value hash");
return true;
}
//extension node
require(_nibblesToTraverse(RLP.toData(currentNodeList[0]), path, pathPtr) != 0, "invalid extension node");
nodeKey = RLP.toBytes32(currentNodeList[1]);
} else {
return false;
}
}
}
function _nibblesToTraverse(bytes memory encodedPartialPath, bytes memory path, uint pathPtr) private pure returns (uint) {
uint len;
// encodedPartialPath has elements that are each two hex characters (1 byte), but partialPath
// and slicedPath have elements that are each one hex character (1 nibble)
bytes memory partialPath = _getNibbleArray(encodedPartialPath);
bytes memory slicedPath = new bytes(partialPath.length);
// pathPtr counts nibbles in path
// partialPath.length is a number of nibbles
for(uint i=pathPtr; i<pathPtr+partialPath.length; i++) {
byte pathNibble = path[i];
slicedPath[i-pathPtr] = pathNibble;
}
if(keccak256(partialPath) == keccak256(slicedPath)) {
len = partialPath.length;
} else {
len = 0;
}
return len;
}
// bytes b must be hp encoded
function _getNibbleArray(bytes memory b) private pure returns (bytes memory) {
bytes memory nibbles;
if(b.length>0) {
uint8 offset;
uint8 hpNibble = uint8(_getNthNibbleOfBytes(0,b));
if(hpNibble == 1 || hpNibble == 3) {
nibbles = new bytes(b.length*2-1);
byte oddNibble = _getNthNibbleOfBytes(1,b);
nibbles[0] = oddNibble;
offset = 1;
} else {
nibbles = new bytes(b.length*2-2);
offset = 0;
}
for(uint i=offset; i<nibbles.length; i++) {
nibbles[i] = _getNthNibbleOfBytes(i-offset+2,b);
}
}
return nibbles;
}
function _getNthNibbleOfBytes(uint n, bytes memory str) private pure returns (byte) {
return byte(n%2==0 ? uint8(str[n/2])/0x10 : uint8(str[n/2])%0x10);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment