Skip to content

Instantly share code, notes, and snippets.

@d4l3k
Created June 11, 2020 02:40
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save d4l3k/ef2edb288608d2037abfd57e9fb138b9 to your computer and use it in GitHub Desktop.
Save d4l3k/ef2edb288608d2037abfd57e9fb138b9 to your computer and use it in GitHub Desktop.
package main
import (
"bytes"
"encoding/binary"
"flag"
"io"
"io/ioutil"
"log"
"os"
"./fbs/ap"
caffepb "./proto"
"github.com/gogo/protobuf/proto"
"github.com/pkg/errors"
)
var fbsFile = flag.String("fbs", "share/vision/fisheye_int8.fbs", "flatbuffer file to laod")
func main() {
log.SetFlags(log.Lshortfile | log.Flags())
flag.Parse()
if err := run2(); err != nil {
log.Fatalf("%+v", err)
}
}
func getDims(w ap.Weights) []uint32 {
var dims []uint32
for i := 0; i < w.DimsLength(); i++ {
dims = append(dims, w.Dims(i))
}
return dims
}
func getWeights(w ap.Tensor) []float32 {
var weights []float32
for i := 0; i < w.DataLength(); i++ {
weights = append(weights, w.Data(i))
}
return weights
}
func run2() error {
buf, err := ioutil.ReadFile(*fbsFile)
if err != nil {
return err
}
root := ap.GetRootAsRoot(buf, 0)
log.Printf("root %+v", root.LayersLength())
for i := 0; i < root.LayersLength(); i++ {
var layer ap.Layer
if !root.Layers(&layer, i) {
return errors.Errorf("failed to load layer %d", i)
}
log.Printf("layer %d %s: %d weights", i, layer.Name(), layer.WeightsLength())
for j := 0; j < layer.WeightsLength(); j++ {
var weights ap.Weights
if !layer.Weights(&weights, j) {
return errors.Errorf("failed to load weights %d", j)
}
log.Printf("weight %d::%d %+v %v", i, j, getDims(weights), weights.A())
var tensor ap.Tensor
if weights.Tensor(&tensor) == nil {
return errors.Errorf("failed to load tensor")
}
log.Printf(" - %v", getWeights(tensor))
}
}
return nil
}
type Reader struct {
reader io.ReadSeeker
offset int64
length int64
}
func (r *Reader) Len() int64 {
return r.length
}
func (r *Reader) ReadOffset() (int64, error) {
curOffset := r.offset
var offset uint32
if err := binary.Read(r.reader, binary.LittleEndian, &offset); err != nil {
return 0, errors.Wrapf(err, "reading %d", curOffset)
}
r.offset += 4
return curOffset + int64(offset), nil
}
func (r *Reader) ReadSOffset() (int64, error) {
curOffset := r.offset
var offset uint32
if err := binary.Read(r.reader, binary.LittleEndian, &offset); err != nil {
return 0, errors.Wrapf(err, "reading %d", curOffset)
}
r.offset += 4
soffset := 1<<32 - int64(offset)
return curOffset + soffset, nil
}
func (r *Reader) ReadUint16() (int, error) {
curOffset := r.offset
var v uint16
if err := binary.Read(r.reader, binary.LittleEndian, &v); err != nil {
return 0, errors.Wrapf(err, "reading %d", curOffset)
}
r.offset += 2
return int(v), nil
}
func (r *Reader) ReadUint32() (int, error) {
curOffset := r.offset
var v uint32
if err := binary.Read(r.reader, binary.LittleEndian, &v); err != nil {
return 0, errors.Wrapf(err, "reading %d", curOffset)
}
r.offset += 4
return int(v), nil
}
func (r *Reader) ReadFloat32() (float32, error) {
curOffset := r.offset
var v float32
if err := binary.Read(r.reader, binary.LittleEndian, &v); err != nil {
return 0, errors.Wrapf(err, "reading %d", curOffset)
}
r.offset += 4
return v, nil
}
func (r *Reader) Seek(offset int64) error {
if _, err := r.reader.Seek(offset, io.SeekStart); err != nil {
return errors.Wrapf(err, "seeking %d", offset)
}
r.offset = offset
return nil
}
func (r *Reader) Offset() int64 {
return r.offset
}
func (r *Reader) PrintDebug(n int) {
offset := r.Offset()
buf := make([]byte, n)
n, err := r.reader.Read(buf)
if err != nil {
log.Fatalf("%+v", err)
}
r.offset += int64(n)
buf = buf[:n]
log.Printf("%d: %+v |%s|", offset, buf, buf)
if err := r.Seek(offset); err != nil {
log.Fatalf("%+v", err)
}
}
func loadNet() (*caffepb.NetParameter, error) {
in, err := ioutil.ReadFile("share/vision/fisheye.prototxt")
if err != nil {
return nil, err
}
var net caffepb.NetParameter
if err := proto.UnmarshalText(string(in), &net); err != nil {
return nil, err
}
return &net, nil
}
type vtable struct {
Position int64
VTableSize int
TableSize int
Entries []int
}
func (r *Reader) ReadVTable(offset int64) (*vtable, error) {
curOffset := r.Offset()
if err := r.Seek(offset); err != nil {
return nil, err
}
vtableLength, err := r.ReadUint16()
if err != nil {
return nil, err
}
tableLength, err := r.ReadUint16()
if err != nil {
return nil, err
}
entryCount := (vtableLength - 4) / 2
var entries []int
for i := 0; i < entryCount; i++ {
offset, err := r.ReadUint16()
if err != nil {
return nil, err
}
entries = append(entries, offset)
}
if err := r.Seek(curOffset); err != nil {
return nil, err
}
table := vtable{
Position: offset,
VTableSize: vtableLength,
TableSize: tableLength,
Entries: entries,
}
return &table, nil
}
func (r *Reader) ReadTable() error {
log.Printf("loading table %d", r.Offset())
vtableAddr, err := r.ReadSOffset()
if err != nil {
return err
}
vtable, err := r.ReadVTable(vtableAddr)
if err != nil {
return err
}
log.Printf("vtable %#v", vtable)
//r.PrintDebug(100)
if vtable.Position == 1866322 { // ptr to vector
log.Printf("loading ptr to vector...")
r.PrintDebug(128)
vec, err := r.ReadOffset()
if err != nil {
return err
}
log.Printf("offset %d", vec)
if err := r.Seek(vec); err != nil {
return err
}
vecLen, err := r.ReadUint32()
if err != nil {
return err
}
log.Printf("vec length = %d", vecLen)
var subtables []int64
for i := 0; i < vecLen; i++ {
subtable, err := r.ReadOffset()
if err != nil {
return err
}
log.Printf("subtable %d", subtable)
if subtable >= r.Len() {
log.Printf("invalid subtable hmm %d", subtable)
if i == 0 {
continue
} else {
return errors.Errorf("too many invalid")
}
}
subtables = append(subtables, subtable)
}
subtables = subtables[1:]
for _, subtable := range subtables {
log.Printf("loading subtable %d", subtable)
if err := r.Seek(subtable); err != nil {
return err
}
if err := r.ReadTable(); err != nil {
return err
}
}
return nil
} else if vtable.Position == 1866070 { // layer
log.Printf("loading layer entry...")
dataAddr, err := r.ReadOffset()
if err != nil {
return err
}
nameAddr, err := r.ReadOffset()
if err != nil {
return err
}
if err := r.Seek(nameAddr); err != nil {
return err
}
name, err := r.ReadString()
if err != nil {
return err
}
log.Printf("name = %q", name)
if err := r.Seek(dataAddr); err != nil {
return err
}
subtableCount, err := r.ReadUint32()
if err != nil {
return err
}
log.Printf("subtable count = %d", subtableCount)
var subtables []int64
for i := 0; i < subtableCount; i++ {
subtable, err := r.ReadOffset()
if err != nil {
return err
}
subtables = append(subtables, subtable)
}
for _, subtable := range subtables {
if err := r.Seek(subtable); err != nil {
return err
}
if err := r.ReadTable(); err != nil {
return err
}
}
return nil
} else if vtable.Position == 1866292 {
log.Printf("loading weights entry...")
datatable, err := r.ReadOffset() // table of vec of bytes
if err != nil {
return err
}
if _, err := r.ReadUint32(); err != nil { // unknown (always seems to be 1)
return err
}
dimsAddr, err := r.ReadOffset()
if err != nil {
return err
}
if err := r.Seek(dimsAddr); err != nil {
return err
}
dims, err := r.ReadUint32s()
if err != nil {
return err
}
log.Printf("dims = %+v", dims)
if err := r.Seek(datatable); err != nil {
return err
}
vtableAddr, err := r.ReadSOffset()
if err != nil {
return err
}
vtable, err := r.ReadVTable(vtableAddr)
if err != nil {
return err
}
if vtable.Position != 1866322 {
return errors.Errorf("got unexpected table")
}
dataAddr, err := r.ReadOffset()
if err != nil {
return err
}
if err := r.Seek(dataAddr); err != nil {
return err
}
data, err := r.ReadFloat32s()
if err != nil {
return err
}
log.Printf("read data (%d) = %v", len(data), data)
r.PrintDebug(100)
return nil
} else {
return errors.Errorf("unknown vtable %#v", vtable)
}
}
func (r *Reader) ReadUint32s() ([]int, error) {
length, err := r.ReadUint32()
if err != nil {
return nil, err
}
var nums []int
for i := 0; i < length; i++ {
v, err := r.ReadUint32()
if err != nil {
return nil, err
}
nums = append(nums, v)
}
return nums, nil
}
func (r *Reader) ReadFloat32s() ([]float32, error) {
length, err := r.ReadUint32()
if err != nil {
return nil, err
}
var nums []float32
for i := 0; i < length; i++ {
v, err := r.ReadFloat32()
if err != nil {
return nil, err
}
nums = append(nums, v)
}
return nums, nil
}
func (r *Reader) ReadBytes() ([]byte, error) {
length, err := r.ReadUint32()
if err != nil {
return nil, err
}
buf := make([]byte, length)
n, err := r.reader.Read(buf)
if err != nil {
return nil, err
}
if n != length {
return nil, errors.Errorf("failed to read entire string: expected %d, got %d", length, n)
}
return buf, nil
}
func (r *Reader) ReadString() (string, error) {
buf, err := r.ReadBytes()
if err != nil {
return "", err
}
return string(buf), nil
}
func run() error {
net, err := loadNet()
if err != nil {
return err
}
var layers []string
for _, layer := range net.Layer {
layers = append(layers, *layer.Name)
}
log.Printf("layers (%d): %+v", len(layers), layers)
f, err := os.Open(*fbsFile)
if err != nil {
return err
}
defer f.Close()
data, err := ioutil.ReadAll(f)
if err != nil {
return err
}
r := Reader{
reader: f,
offset: 0,
length: int64(len(data)),
}
if err := r.Seek(0); err != nil {
return err
}
log.Printf("finding tables")
vtables := map[int64]vtable{}
vtablesCounts := map[int64]int{}
var tables []int64
for {
offset := r.Offset()
position, err := r.ReadSOffset()
if errors.Is(err, io.EOF) {
break
} else if err != nil {
return err
}
if position >= 0 && position < int64(len(data)) {
entry, err := r.ReadVTable(position)
if err != nil {
return err
}
tables = append(tables, offset)
vtables[position] = *entry
vtablesCounts[position] += 1
log.Printf("candidate %d: %#v", offset, entry)
}
}
log.Printf("found %d distinct vtables", len(vtables))
log.Printf("found %d distinct tables", len(tables))
log.Printf("found table counts %+v", vtablesCounts)
log.Printf("loading root table")
if err := r.Seek(0); err != nil {
return err
}
rootTableOffset, err := r.ReadOffset()
if err != nil {
return err
}
if err := r.Seek(rootTableOffset); err != nil {
return err
}
if err := r.ReadTable(); err != nil {
return err
}
return nil
if err := r.Seek(0); err != nil {
return err
}
// find layer names in flatbuffer
layerLocs := map[string]int{}
for _, layer := range layers {
search := 0
for {
match := bytes.Index(data[search:], []byte(layer))
if match < 0 {
return errors.Errorf("failed to find %q", layer)
}
match += search
search = match + 1
length := match - 4
if err := r.Seek(int64(length)); err != nil {
return err
}
elements, err := r.ReadUint32()
if err != nil {
return err
}
if elements != len(layer) {
log.Printf("# elements doesn't match str: %d != len(%q), %s", elements, layer, data[match:match+elements])
continue
}
log.Printf("found %q = %d", layer, match)
layerLocs[layer] = length
break
}
}
log.Printf("found %d layerLocs", len(layerLocs))
// Find references to layer strings
layerRefs := map[string]int64{}
for layer, target := range layerLocs {
if err := r.Seek(0); err != nil {
return err
}
for {
off := r.Offset()
match, err := r.ReadOffset()
if err != nil {
return err
}
if match == int64(target) {
log.Printf("found %q ref at %d", layer, off)
layerRefs[layer] = match
break
}
}
}
// Find VTables for table
for layer, ref := range layerRefs {
found := false
for i := 0; i < 100; i += 4 {
attempt := ref - int64(i)
if err := r.Seek(attempt); err != nil {
return err
}
soffset, err := r.ReadSOffset()
if err != nil {
return err
}
if soffset < int64(len(data)) {
log.Printf("%q (layerref %d): %d soffset %d", layer, ref, attempt, soffset)
if _, err := r.ReadVTable(soffset); err != nil {
return err
}
r.PrintDebug(64)
found = true
break
}
}
if !found {
return errors.Errorf("failed to find %q", layer)
}
}
if err := r.Seek(rootTableOffset); err != nil {
return err
}
log.Printf("root table offset %d", rootTableOffset)
if err := r.Seek(rootTableOffset); err != nil {
return err
}
tableOffset := r.Offset()
vTableOffset, err := r.ReadSOffset()
if err != nil {
return err
}
log.Printf("vtable offset %d", vTableOffset)
if err := r.Seek(vTableOffset); err != nil {
return err
}
vtableLength, err := r.ReadUint16()
if err != nil {
return err
}
log.Printf("vtable length %d", vtableLength)
tableLength, err := r.ReadUint16()
if err != nil {
return err
}
log.Printf("table length %d", tableLength)
var entries []int
for i := 0; i < (vtableLength/2 - 1); i++ {
offset, err := r.ReadUint16()
if err != nil {
return err
}
entries = append(entries, offset)
}
log.Printf("entries %+v", entries)
if err := r.Seek(tableOffset + int64(entries[0])); err != nil {
return err
}
val1, err := r.ReadUint16()
if err != nil {
return err
}
val2, err := r.ReadUint16()
if err != nil {
return err
}
log.Printf("table entries %d %d", val1, val2)
r.PrintDebug(300)
return nil
}
@knowpwr
Copy link

knowpwr commented Sep 25, 2022

When I do go run decode.go I receive the following errors:

decode.go:12:2: "./fbs/ap" is relative, but relative import paths are not suppor
ted in module mode
decode.go:14:2: "./proto" is relative, but relative import paths are not support
ed in module mode
decode.go:15:2: no required module provides package github.com/gogo/protobuf/pro
to: go.mod file not found in current directory or any parent directory; see 'go
help modules'
decode.go:16:2: no required module provides package github.com/pkg/errors: go.mo
d file not found in current directory or any parent directory; see 'go help modu
les'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment