Skip to content

Instantly share code, notes, and snippets.

@mattn

mattn/main.go Secret

Created May 15, 2019 16:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mattn/5ae333847399209d75f1d3e52631f002 to your computer and use it in GitHub Desktop.
Save mattn/5ae333847399209d75f1d3e52631f002 to your computer and use it in GitHub Desktop.
package main
import (
"fmt"
"io/ioutil"
"log"
"math/rand"
"os"
"time"
tf "github.com/tensorflow/tensorflow/tensorflow/go"
)
type model_t struct {
graph *tf.Graph
session *tf.Session
input tf.Output
target tf.Output
output tf.Output
initOp *tf.Operation
trainOp *tf.Operation
saveOp *tf.Operation
restoreOp *tf.Operation
checkpointFile tf.Output
}
func createModel(graph_def_filename string) (*model_t, error) {
model := &model_t{}
model.graph = tf.NewGraph()
var err error
// create the session.
sessionOpts := &tf.SessionOptions{}
model.session, err = tf.NewSession(model.graph, sessionOpts)
if err != nil {
return nil, err
}
b, err := ioutil.ReadFile(graph_def_filename)
if err != nil {
return nil, err
}
err = model.graph.Import(b, "")
if err != nil {
return nil, err
}
model.input.Op = model.graph.Operation("input")
model.input.Index = 0
model.target.Op = model.graph.Operation("target")
model.target.Index = 0
model.output.Op = model.graph.Operation("output")
model.output.Index = 0
model.initOp = model.graph.Operation("init")
model.trainOp = model.graph.Operation("train")
model.saveOp = model.graph.Operation("save/control_dependency")
model.restoreOp = model.graph.Operation("save/restore_all")
model.checkpointFile.Op = model.graph.Operation("save/Const")
model.checkpointFile.Index = 0
return model, nil
}
func directoryExists(filename string) bool {
_, err := os.Stat(filename)
return err == nil
}
func createCheckpoint(model *model_t, checkpoint_prefix string, save bool) error {
t, err := tf.NewTensor(checkpoint_prefix)
if err != nil {
return err
}
var op *tf.Operation
if save {
op = model.saveOp
} else {
op = model.restoreOp
}
_, err = model.session.Run(
map[tf.Output]*tf.Tensor{model.checkpointFile: t},
nil,
[]*tf.Operation{op})
if err != nil {
return err
}
return nil
}
func predict(model *model_t, batch []float32) error {
b := make([][1][1]float32, len(batch))
for i, v := range batch {
b[i][0][0] = v
}
t, err := tf.NewTensor(b)
if err != nil {
return err
}
result, err := model.session.Run(
map[tf.Output]*tf.Tensor{model.input: t},
[]tf.Output{model.output},
nil)
if err != nil {
return err
}
predictions := result[0].Value().([][][]float32)
println("Predictions:")
for i := 0; i < len(predictions); i++ {
fmt.Printf("\t x = %f, predicted y = %f\n", batch[i], predictions[i][0][0])
}
return nil
}
func train(model *model_t) error {
var inputs [10][1][1]float32
var targets [10][1][1]float32
for i := 0; i < len(inputs); i++ {
inputs[i][0][0] = rand.Float32()
targets[i][0][0] = 3.0*inputs[i][0][0] + 2.0
}
x, err := tf.NewTensor(inputs)
if err != nil {
return err
}
y, err := tf.NewTensor(targets)
if err != nil {
return err
}
_, err = model.session.Run(
map[tf.Output]*tf.Tensor{
model.input: x,
model.target: y,
},
nil,
[]*tf.Operation{model.trainOp})
return err
}
func initializeModel(model *model_t) error {
_, err := model.session.Run(
nil,
nil,
[]*tf.Operation{model.initOp})
return err
}
func main() {
rand.Seed(time.Now().UnixNano())
const graph_def_filename = "graph.pb"
const checkpoint_prefix = "./checkpoints/checkpoint"
exists := directoryExists("checkpoints")
println("Loading graph")
model, err := createModel(graph_def_filename)
if err != nil {
log.Fatal(err)
}
if exists {
println("Restoring weights from checkpoint (remove the checkpoints directory to reset)")
err = createCheckpoint(model, checkpoint_prefix, false)
if err != nil {
log.Fatal(err)
}
} else {
// initialize_model
err = initializeModel(model)
if err != nil {
log.Fatal(err)
}
}
data := []float32{1.0, 2.0, 3.0}
println("Initial predictions")
err = predict(model, data)
if err != nil {
log.Fatal(err)
}
println("Training for a few steps")
for i := 0; i < 200; i++ {
err = train(model)
if err != nil {
log.Fatal(err)
}
}
println("Updated predictions")
err = predict(model, data)
if err != nil {
log.Fatal(err)
}
err = createCheckpoint(model, checkpoint_prefix, true)
if err != nil {
log.Fatal(err)
}
model.session.Close()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment