Skip to content

Instantly share code, notes, and snippets.

@rauljordan
Created March 7, 2022 05:53
Show Gist options
  • Save rauljordan/f6bc796403eeaa4438550318eb8e37f9 to your computer and use it in GitHub Desktop.
Save rauljordan/f6bc796403eeaa4438550318eb8e37f9 to your computer and use it in GitHub Desktop.
SSZ Generic Iterables Go 1.18 Beta
package main
import (
"encoding/binary"
"fmt"
)
const BYTES_PER_LENGTH_OFFSET = uint64(4)
type Marshaler interface {
MarshalSSZ() []byte
}
type SSZItem interface {
Marshaler
SSZBytesLength() uint64
}
type Uint64 uint64
func (u Uint64) SSZBytesLength() uint64 {
return 8
}
func (u Uint64) MarshalSSZ() []byte {
b := make([]byte, 8)
binary.LittleEndian.PutUint64(b, uint64(u))
return b
}
type Uint32 uint32
func (u Uint32) SSZBytesLength() uint64 {
return 4
}
func (u Uint32) MarshalSSZ() []byte {
b := make([]byte, 4)
binary.LittleEndian.PutUint32(b, uint32(u))
return b
}
type List[K SSZItem] []SSZItem
type Vector[K SSZItem] []SSZItem
type Iterable[K SSZItem] interface {
List[K] | Vector[K]
}
func (l List[K]) IsFixedLength() bool {
return false
}
func (v Vector[K]) IsFixedLength() bool {
return true
}
func (l List[K]) SSZFixedLength() uint64 {
return 0
}
func (v Vector[K]) SSZFixedLength() uint64 {
return uint64(len(v))
}
func (l List[K]) SSZBytesLength() uint64 {
return sumBytes(l) + BYTES_PER_LENGTH_OFFSET
}
func (v Vector[K]) SSZBytesLength() uint64 {
return sumBytes(v)
}
func (l List[K]) MarshalSSZ() []byte {
b := make([]byte, l.SSZBytesLength())
binary.LittleEndian.PutUint32(b, uint32(len(l)))
cursor := int(BYTES_PER_LENGTH_OFFSET)
return collectSSZBuffer(b, cursor, l)
}
func (v Vector[K]) MarshalSSZ() []byte {
b := make([]byte, v.SSZBytesLength())
cursor := 0
return collectSSZBuffer(b, cursor, v)
}
func sumBytes[K SSZItem](items []K) uint64 {
res := uint64(0)
for _, item := range items {
res += item.SSZBytesLength()
}
return res
}
func collectSSZBuffer[K SSZItem](buf []byte, cursor int, items []K) []byte {
i := 0
for cursor < len(buf) {
enc := items[i].MarshalSSZ()
copy(buf[cursor:cursor+len(enc)], enc)
cursor += len(enc)
i++
}
return buf
}
func Encode[K Marshaler](items K) []byte {
return items.MarshalSSZ()
}
func main() {
a := Vector[Uint64]{
Uint64(1),
Uint64(2),
Uint64(3),
}
encoded := Encode(a)
fmt.Printf("Vector[uint64] encoding %v, got %#x\n", a, encoded)
b := List[Uint64]{
Uint64(1),
Uint64(2),
Uint64(3),
}
encoded = Encode(b)
fmt.Printf("List[uint64] encoding %v, got %#x\n", b, encoded)
c := Vector[Uint32]{
Uint32(1),
Uint32(2),
Uint32(3),
}
encoded = Encode(c)
fmt.Printf("Vector[uint32] encoding %v, got %#x\n", c, encoded)
d := List[Uint32]{
Uint32(1),
Uint32(2),
Uint32(3),
}
encoded = Encode(d)
fmt.Printf("List[uint32] encoding %v, got %#x\n", c, encoded)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment