Skip to content

Instantly share code, notes, and snippets.

@mattn
Last active May 9, 2019 02:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mattn/f5aa1b96753c76075f2666235919a8f3 to your computer and use it in GitHub Desktop.
Save mattn/f5aa1b96753c76075f2666235919a8f3 to your computer and use it in GitHub Desktop.
package main
import (
"flag"
"fmt"
"image/png"
"io/ioutil"
"log"
"math"
"os"
"sort"
"time"
"github.com/nfnt/resize"
"github.com/owulveryck/onnx-go"
"github.com/owulveryck/onnx-go/backend/x/gorgonnx"
"gorgonia.org/tensor"
)
const (
height = 1200
width = 1200
)
var emotionTable = []string{
"neutral",
"happiness",
"surprise",
"sadness",
"anger",
"disgust",
"fear",
"contempt",
}
func main() {
model := flag.String("model", "model.onnx", "path to the model file")
input := flag.String("input", "file.png", "path to the input file")
flag.Parse()
backend := gorgonnx.NewGraph()
// Create a model and set the execution backend
m := onnx.NewModel(backend)
// read the onnx model
b, err := ioutil.ReadFile(*model)
if err != nil {
log.Fatal(err)
}
// Decode it into the model
err = m.UnmarshalBinary(b)
if err != nil {
log.Fatal(err)
}
// Set the first input, the number depends of the model
// TODO
f, err := os.Open(*input)
if err != nil {
log.Fatal(err)
}
defer f.Close()
img, err := png.Decode(f)
if err != nil {
log.Fatal(err)
}
inputT := tensor.New(tensor.WithShape(1, 3, height, width), tensor.Of(tensor.Float32))
w := img.Bounds().Dx()
h := img.Bounds().Dy()
img = resize.Resize(uint(w), uint(h), img, resize.Bilinear)
for y := 0; y < h; y++ {
for x := 0; x < w; x++ {
r, g, b, _ := img.At(x, y).RGBA()
inputT.SetAt((float32(r)/65536-0.485)/0.229, 0, 0, y, x)
inputT.SetAt((float32(g)/65536-0.456)/0.224, 0, 1, y, x)
inputT.SetAt((float32(b)/65536-0.406)/0.225, 0, 2, y, x)
}
}
m.SetInput(0, inputT)
start := time.Now()
err = backend.Run()
if err != nil {
log.Fatal(err)
}
fmt.Printf("Computation time: %v\n", time.Since(start))
output, err := m.GetOutputTensors()
if err != nil {
log.Fatal(err)
}
fmt.Println(len(output))
fmt.Println(output[0].Data().([]float32))
fmt.Println(output[1].Data().(int64))
fmt.Println(output[2].Data().(float32))
}
func softmax(input []float32) []float32 {
var sumExp float64
output := make([]float32, len(input))
for i := 0; i < len(input); i++ {
sumExp += math.Exp(float64(input[i]))
}
for i := 0; i < len(input); i++ {
output[i] = float32(math.Exp(float64(input[i]))) / float32(sumExp)
}
return output
}
func classify(input []float32) emotions {
result := make(emotions, len(input))
for i := 0; i < len(input); i++ {
result[i] = emotion{
emotion: emotionTable[i],
weight: input[i],
}
}
sort.Sort(sort.Reverse(result))
return result
}
type emotions []emotion
type emotion struct {
emotion string
weight float32
}
func (e emotions) Len() int { return len(e) }
func (e emotions) Swap(i, j int) { e[i], e[j] = e[j], e[i] }
func (e emotions) Less(i, j int) bool { return e[i].weight < e[j].weight }
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment