Skip to content

Instantly share code, notes, and snippets.

@squiidz
Created May 28, 2019 20:08
Show Gist options
  • Save squiidz/fe3ba9ac4b79fc3f7e541eec30da85d8 to your computer and use it in GitHub Desktop.
Save squiidz/fe3ba9ac4b79fc3f7e541eec30da85d8 to your computer and use it in GitHub Desktop.
MerkleTree
package merkletree
import (
"bytes"
"crypto/sha256"
"errors"
"fmt"
)
//Content represents the data that is stored and verified by the tree. A type that
//implements this interface can be used as an item in the tree.
type Content interface {
CalculateHash() ([]byte, error)
Equals(other Content) (bool, error)
}
//MerkleTree is the container for the tree. It holds a pointer to the root of the tree,
//a list of pointers to the leaf nodes, and the merkle root.
type MerkleTree struct {
Root *Node
merkleRoot []byte
Leafs []*Node
}
//Node represents a node, root, or leaf in the tree. It stores pointers to its immediate
//relationships, a hash, the content stored if it is a leaf, and other metadata.
type Node struct {
Parent *Node
Left *Node
Right *Node
leaf bool
dup bool
Hash []byte
C Content
}
//verifyNode walks down the tree until hitting a leaf, calculating the hash at each level
//and returning the resulting hash of Node n.
func (n *Node) verifyNode() ([]byte, error) {
if n.leaf {
return n.C.CalculateHash()
}
rightBytes, err := n.Right.verifyNode()
if err != nil {
return nil, err
}
leftBytes, err := n.Left.verifyNode()
if err != nil {
return nil, err
}
h := sha256.New()
if _, err := h.Write(append(leftBytes, rightBytes...)); err != nil {
return nil, err
}
return h.Sum(nil), nil
}
//calculateNodeHash is a helper function that calculates the hash of the node.
func (n *Node) calculateNodeHash() ([]byte, error) {
if n.leaf {
return n.C.CalculateHash()
}
h := sha256.New()
if _, err := h.Write(append(n.Left.Hash, n.Right.Hash...)); err != nil {
return nil, err
}
return h.Sum(nil), nil
}
//NewTree creates a new Merkle Tree using the content cs.
func NewTree(cs []Content) (*MerkleTree, error) {
root, leafs, err := buildWithContent(cs)
if err != nil {
return nil, err
}
t := &MerkleTree{
Root: root,
merkleRoot: root.Hash,
Leafs: leafs,
}
return t, nil
}
// GetMerklePath: Get Merkle path and indexes(left leaf or right leaf)
func (m *MerkleTree) GetMerklePath(content Content) ([][]byte, []int64, error) {
for _, current := range m.Leafs {
ok, err := current.C.Equals(content)
if err != nil {
return nil, nil, err
}
if ok {
currentParent := current.Parent
var merklePath [][]byte
var index []int64
for currentParent != nil {
if bytes.Equal(currentParent.Left.Hash, current.Hash) {
merklePath = append(merklePath, currentParent.Right.Hash)
index = append(index, 1) // right leaf
} else {
merklePath = append(merklePath, currentParent.Left.Hash)
index = append(index, 0) // left leaf
}
current = currentParent
currentParent = currentParent.Parent
}
return merklePath, index, nil
}
}
return nil, nil, nil
}
//buildWithContent is a helper function that for a given set of Contents, generates a
//corresponding tree and returns the root node, a list of leaf nodes, and a possible error.
//Returns an error if cs contains no Contents.
func buildWithContent(cs []Content) (*Node, []*Node, error) {
if len(cs) == 0 {
return nil, nil, errors.New("error: cannot construct tree with no content")
}
var leafs []*Node
for _, c := range cs {
hash, err := c.CalculateHash()
if err != nil {
return nil, nil, err
}
leafs = append(leafs, &Node{
Hash: hash,
C: c,
leaf: true,
})
}
if len(leafs)%2 == 1 {
duplicate := &Node{
Hash: leafs[len(leafs)-1].Hash,
C: leafs[len(leafs)-1].C,
leaf: true,
dup: true,
}
leafs = append(leafs, duplicate)
}
root, err := buildIntermediate(leafs)
if err != nil {
return nil, nil, err
}
return root, leafs, nil
}
//buildIntermediate is a helper function that for a given list of leaf nodes, constructs
//the intermediate and root levels of the tree. Returns the resulting root node of the tree.
func buildIntermediate(nl []*Node) (*Node, error) {
var nodes []*Node
for i := 0; i < len(nl); i += 2 {
h := sha256.New()
var left, right int = i, i + 1
if i+1 == len(nl) {
right = i
}
chash := append(nl[left].Hash, nl[right].Hash...)
if _, err := h.Write(chash); err != nil {
return nil, err
}
n := &Node{
Left: nl[left],
Right: nl[right],
Hash: h.Sum(nil),
}
nodes = append(nodes, n)
nl[left].Parent = n
nl[right].Parent = n
if len(nl) == 2 {
return n, nil
}
}
return buildIntermediate(nodes)
}
//MerkleRoot returns the unverified Merkle Root (hash of the root node) of the tree.
func (m *MerkleTree) MerkleRoot() []byte {
return m.merkleRoot
}
//RebuildTree is a helper function that will rebuild the tree reusing only the content that
//it holds in the leaves.
func (m *MerkleTree) RebuildTree() error {
var cs []Content
for _, c := range m.Leafs {
cs = append(cs, c.C)
}
root, leafs, err := buildWithContent(cs)
if err != nil {
return err
}
m.Root = root
m.Leafs = leafs
m.merkleRoot = root.Hash
return nil
}
//RebuildTreeWith replaces the content of the tree and does a complete rebuild; while the root of
//the tree will be replaced the MerkleTree completely survives this operation. Returns an error if the
//list of content cs contains no entries.
func (m *MerkleTree) RebuildTreeWith(cs []Content) error {
root, leafs, err := buildWithContent(cs)
if err != nil {
return err
}
m.Root = root
m.Leafs = leafs
m.merkleRoot = root.Hash
return nil
}
//VerifyTree verify tree validates the hashes at each level of the tree and returns true if the
//resulting hash at the root of the tree matches the resulting root hash; returns false otherwise.
func (m *MerkleTree) VerifyTree() (bool, error) {
calculatedMerkleRoot, err := m.Root.verifyNode()
if err != nil {
return false, err
}
if bytes.Compare(m.merkleRoot, calculatedMerkleRoot) == 0 {
return true, nil
}
return false, nil
}
//VerifyContent indicates whether a given content is in the tree and the hashes are valid for that content.
//Returns true if the expected Merkle Root is equivalent to the Merkle root calculated on the critical path
//for a given content. Returns true if valid and false otherwise.
func (m *MerkleTree) VerifyContent(content Content) (bool, error) {
for _, l := range m.Leafs {
ok, err := l.C.Equals(content)
if err != nil {
return false, err
}
if ok {
currentParent := l.Parent
for currentParent != nil {
h := sha256.New()
rightBytes, err := currentParent.Right.calculateNodeHash()
if err != nil {
return false, err
}
leftBytes, err := currentParent.Left.calculateNodeHash()
if err != nil {
return false, err
}
if _, err := h.Write(append(leftBytes, rightBytes...)); err != nil {
return false, err
}
if bytes.Compare(h.Sum(nil), currentParent.Hash) != 0 {
return false, nil
}
currentParent = currentParent.Parent
}
return true, nil
}
}
return false, nil
}
//String returns a string representation of the tree. Only leaf nodes are included
//in the output.
func (m *MerkleTree) String() string {
s := ""
for _, l := range m.Leafs {
s += fmt.Sprint(l)
s += "\n"
}
return s
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment