Created
November 1, 2018 21:16
-
-
Save aaronedell/ac14b5d21624ffeca87ef6b4b7b528a3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"context" | |
"encoding/base64" | |
"flag" | |
"image" | |
"image/color" | |
"image/draw" | |
"image/png" | |
"io/ioutil" | |
"log" | |
"os" | |
"os/signal" | |
"path/filepath" | |
"strconv" | |
"github.com/machinebox/mb/internal/stream/images" | |
"github.com/machinebox/sdk-go/classificationbox" | |
"github.com/pkg/errors" | |
"golang.org/x/image/font" | |
"golang.org/x/image/font/basicfont" | |
"golang.org/x/image/math/fixed" | |
) | |
// To split up a video run $ffmpeg -i path/tp/video.mp4 path/to/output/label%04d.jpg -hide_banner | |
// To put a video back together from frames: $ ffmpeg -i [label]%04d.png -c:v libx264 -r 30 -pix_fmt yuv420p out.mp4 | |
func main() { | |
ctx := context.Background() | |
// trap Ctrl+C and call cancel on the context | |
ctx, cancel := context.WithCancel(ctx) | |
c := make(chan os.Signal, 1) | |
signal.Notify(c, os.Interrupt) | |
defer func() { | |
signal.Stop(c) | |
cancel() | |
}() | |
go func() { | |
select { | |
case <-c: | |
cancel() | |
case <-ctx.Done(): | |
} | |
}() | |
if err := run(ctx); err != nil { | |
log.Fatalln(err) | |
} | |
} | |
func run(ctx context.Context) error { | |
var ( | |
src = flag.String("src", ".", "source of dataset") | |
modelID = flag.String("id", "1234567890", "Classificationbox modelID") | |
) | |
flag.Parse() | |
cb := classificationbox.New("http://localhost:8080") | |
frames, err := ioutil.ReadDir(*src) | |
if err != nil { | |
return err | |
} | |
absSrc, abserr := filepath.Abs(*src) | |
if abserr != nil { | |
absSrc = *src | |
} | |
for _, f := range frames { | |
absSrcLocation := filepath.Join(absSrc, f.Name()) | |
predictedClass, err := predictImage(ctx, cb, *modelID, absSrcLocation) | |
if err != nil { | |
return err | |
} | |
//log.Printf(predictedClass) | |
if _, err := os.Stat("tagged"); os.IsNotExist(err) { | |
os.Mkdir("tagged", 0777) | |
} | |
z, err := createTaggedimage(absSrcLocation, predictedClass, "tagged/"+f.Name()) | |
if err != nil { | |
log.Fatalln(err) | |
} | |
log.Println(z) | |
continue | |
} | |
return err | |
} | |
func createTaggedimage(src string, label string, outpath string) (string, error) { | |
file, err := os.Open(src) // For read access. | |
if err != nil { | |
log.Fatal(err) | |
} | |
img, err := images.Decode(file) | |
if err != nil { | |
log.Fatal(err) | |
} | |
b := img.Bounds() | |
m := image.NewRGBA(image.Rect(0, 0, b.Dx(), b.Dy())) | |
draw.Draw(m, m.Bounds(), img, b.Min, draw.Src) | |
addLabel(m, 20, 30, label) | |
f, err := os.Create(outpath) | |
if err != nil { | |
log.Fatalln(err) | |
} | |
defer f.Close() | |
if err := png.Encode(f, m); err != nil { | |
log.Fatalln(err) | |
} | |
return "", nil | |
} | |
func addLabel(m *image.RGBA, x, y int, label string) { | |
col := color.RGBA{200, 100, 0, 255} | |
point := fixed.Point26_6{fixed.Int26_6(x * 64), fixed.Int26_6(y * 64)} | |
d := &font.Drawer{ | |
Dst: m, | |
Src: image.NewUniform(col), | |
Face: basicfont.Face7x13, | |
Dot: point, | |
} | |
log.Println(label) | |
d.DrawString(label) | |
} | |
func predictImage(ctx context.Context, cb *classificationbox.Client, modelID string, image string) (string, error) { | |
base64, err := base64Image(image) | |
if err != nil { | |
return "", err | |
} | |
req := classificationbox.PredictRequest{ | |
Inputs: []classificationbox.Feature{ | |
classificationbox.FeatureImageBase64("image", base64), | |
}, | |
} | |
resp, err := cb.Predict(ctx, modelID, req) | |
if err != nil { | |
return "", errors.Wrap(err, "predict") | |
} | |
h := resp.Classes[0].Score * 100 | |
k := strconv.FormatFloat(h, 'f', 6, 64) | |
z := resp.Classes[0].ID + " - " + k + "%" | |
return z, nil | |
} | |
func base64Image(path string) (string, error) { | |
f, err := os.Open(path) | |
if err != nil { | |
return "", err | |
} | |
defer f.Close() | |
buf, err := ioutil.ReadAll(f) | |
if err != nil { | |
return "", err | |
} | |
return base64.StdEncoding.EncodeToString(buf), nil | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Of course - and sorry for the headache!