Skip to content

Instantly share code, notes, and snippets.

@mattn
Created October 17, 2023 13:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mattn/aea612c6c42e413f923f5b3d3ab94953 to your computer and use it in GitHub Desktop.
Save mattn/aea612c6c42e413f923f5b3d3ab94953 to your computer and use it in GitHub Desktop.
package main
import (
"bufio"
"embed"
"errors"
"flag"
"fmt"
"image"
_ "image/jpeg"
_ "image/png"
"io/fs"
"log"
"net/http"
"os"
"sort"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"github.com/mattn/go-tflite"
"github.com/nfnt/resize"
)
//go:embed static
var assets embed.FS
func loadLabels(filename string) ([]string, error) {
labels := []string{}
f, err := os.Open(filename)
if err != nil {
return nil, err
}
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
labels = append(labels, scanner.Text())
}
return labels, nil
}
func main() {
var model_path, label_path string
flag.StringVar(&model_path, "model", "nsfw.tflite", "path to model file")
flag.StringVar(&label_path, "label", "labels.txt", "path to label file")
flag.Parse()
port := os.Getenv("PORT")
if port == "" {
port = "8080"
}
labels, err := loadLabels(label_path)
if err != nil {
log.Fatal(err)
}
model := tflite.NewModelFromFile(model_path)
if model == nil {
log.Fatal("cannot load model")
}
defer model.Delete()
options := tflite.NewInterpreterOptions()
options.SetNumThread(4)
options.SetErrorReporter(func(msg string, user_data interface{}) {
fmt.Println(msg)
}, nil)
defer options.Delete()
interpreter := tflite.NewInterpreter(model, options)
if interpreter == nil {
log.Println("cannot create interpreter")
return
}
defer interpreter.Delete()
status := interpreter.AllocateTensors()
if status != tflite.OK {
log.Println("allocate failed")
return
}
input := interpreter.GetInputTensor(0)
wanted_height := input.Dim(1)
wanted_width := input.Dim(2)
wanted_channels := input.Dim(3)
wanted_type := input.Type()
e := echo.New()
e.Use(middleware.CORS())
e.Use(middleware.Logger())
e.POST("/", func(c echo.Context) error {
imf, err := c.FormFile("image")
if err != nil {
return err
}
f, err := imf.Open()
if err != nil {
return err
}
defer f.Close()
img, _, err := image.Decode(f)
if err != nil {
return err
}
resized := resize.Resize(uint(wanted_width), uint(wanted_height), img, resize.NearestNeighbor)
bounds := resized.Bounds()
dx, dy := bounds.Dx(), bounds.Dy()
if wanted_type == tflite.UInt8 {
bb := make([]byte, wanted_width*wanted_height*wanted_channels)
for y := 0; y < dy; y++ {
for x := 0; x < dx; x++ {
col := resized.At(x, y)
r, g, b, _ := col.RGBA()
bb[(y*dx+x)*3+0] = byte(float64(r) / 255.0)
bb[(y*dx+x)*3+1] = byte(float64(g) / 255.0)
bb[(y*dx+x)*3+2] = byte(float64(b) / 255.0)
}
}
copy(input.UInt8s(), bb)
} else if wanted_type == tflite.Float32 {
ff := make([]float32, wanted_width*wanted_height*wanted_channels)
for y := 0; y < dy; y++ {
for x := 0; x < dx; x++ {
col := resized.At(x, y)
r, g, b, _ := col.RGBA()
ff[(y*dx+x)*3+0] = float32(r) / 65535.0
ff[(y*dx+x)*3+1] = float32(g) / 65535.0
ff[(y*dx+x)*3+2] = float32(b) / 65535.0
}
}
copy(input.Float32s(), ff)
} else {
return errors.New("is not wanted type")
}
status = interpreter.Invoke()
if status != tflite.OK {
return errors.New("invoke failed")
}
output := interpreter.GetOutputTensor(0)
output_size := output.Dim(output.NumDims() - 1)
b := make([]byte, output_size)
type result struct {
index int
Score float64 `json:"score"`
Label string `json:"label"`
}
status = output.CopyToBuffer(&b[0])
if status != tflite.OK {
return errors.New("output failed")
}
results := []result{}
for i := 0; i < output_size; i++ {
score := float64(b[i]) / 255.0
if score < 0.2 {
continue
}
results = append(results, result{Score: score, index: i, Label: labels[i]})
}
sort.Slice(results, func(i, j int) bool {
return results[i].Score > results[j].Score
})
//json.NewEncoder(w).Encode(results)
return c.JSON(http.StatusOK, results)
})
sub, _ := fs.Sub(assets, "static")
e.GET("/*", echo.WrapHandler(http.FileServer(http.FS(sub))))
e.Logger.Fatal(e.Start(":" + port))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment