-
-
Save mattn/5ae333847399209d75f1d3e52631f002 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"fmt" | |
"io/ioutil" | |
"log" | |
"math/rand" | |
"os" | |
"time" | |
tf "github.com/tensorflow/tensorflow/tensorflow/go" | |
) | |
type model_t struct { | |
graph *tf.Graph | |
session *tf.Session | |
input tf.Output | |
target tf.Output | |
output tf.Output | |
initOp *tf.Operation | |
trainOp *tf.Operation | |
saveOp *tf.Operation | |
restoreOp *tf.Operation | |
checkpointFile tf.Output | |
} | |
func createModel(graph_def_filename string) (*model_t, error) { | |
model := &model_t{} | |
model.graph = tf.NewGraph() | |
var err error | |
// create the session. | |
sessionOpts := &tf.SessionOptions{} | |
model.session, err = tf.NewSession(model.graph, sessionOpts) | |
if err != nil { | |
return nil, err | |
} | |
b, err := ioutil.ReadFile(graph_def_filename) | |
if err != nil { | |
return nil, err | |
} | |
err = model.graph.Import(b, "") | |
if err != nil { | |
return nil, err | |
} | |
model.input.Op = model.graph.Operation("input") | |
model.input.Index = 0 | |
model.target.Op = model.graph.Operation("target") | |
model.target.Index = 0 | |
model.output.Op = model.graph.Operation("output") | |
model.output.Index = 0 | |
model.initOp = model.graph.Operation("init") | |
model.trainOp = model.graph.Operation("train") | |
model.saveOp = model.graph.Operation("save/control_dependency") | |
model.restoreOp = model.graph.Operation("save/restore_all") | |
model.checkpointFile.Op = model.graph.Operation("save/Const") | |
model.checkpointFile.Index = 0 | |
return model, nil | |
} | |
func directoryExists(filename string) bool { | |
_, err := os.Stat(filename) | |
return err == nil | |
} | |
func createCheckpoint(model *model_t, checkpoint_prefix string, save bool) error { | |
t, err := tf.NewTensor(checkpoint_prefix) | |
if err != nil { | |
return err | |
} | |
var op *tf.Operation | |
if save { | |
op = model.saveOp | |
} else { | |
op = model.restoreOp | |
} | |
_, err = model.session.Run( | |
map[tf.Output]*tf.Tensor{model.checkpointFile: t}, | |
nil, | |
[]*tf.Operation{op}) | |
if err != nil { | |
return err | |
} | |
return nil | |
} | |
func predict(model *model_t, batch []float32) error { | |
b := make([][1][1]float32, len(batch)) | |
for i, v := range batch { | |
b[i][0][0] = v | |
} | |
t, err := tf.NewTensor(b) | |
if err != nil { | |
return err | |
} | |
result, err := model.session.Run( | |
map[tf.Output]*tf.Tensor{model.input: t}, | |
[]tf.Output{model.output}, | |
nil) | |
if err != nil { | |
return err | |
} | |
predictions := result[0].Value().([][][]float32) | |
println("Predictions:") | |
for i := 0; i < len(predictions); i++ { | |
fmt.Printf("\t x = %f, predicted y = %f\n", batch[i], predictions[i][0][0]) | |
} | |
return nil | |
} | |
func train(model *model_t) error { | |
var inputs [10][1][1]float32 | |
var targets [10][1][1]float32 | |
for i := 0; i < len(inputs); i++ { | |
inputs[i][0][0] = rand.Float32() | |
targets[i][0][0] = 3.0*inputs[i][0][0] + 2.0 | |
} | |
x, err := tf.NewTensor(inputs) | |
if err != nil { | |
return err | |
} | |
y, err := tf.NewTensor(targets) | |
if err != nil { | |
return err | |
} | |
_, err = model.session.Run( | |
map[tf.Output]*tf.Tensor{ | |
model.input: x, | |
model.target: y, | |
}, | |
nil, | |
[]*tf.Operation{model.trainOp}) | |
return err | |
} | |
func initializeModel(model *model_t) error { | |
_, err := model.session.Run( | |
nil, | |
nil, | |
[]*tf.Operation{model.initOp}) | |
return err | |
} | |
func main() { | |
rand.Seed(time.Now().UnixNano()) | |
const graph_def_filename = "graph.pb" | |
const checkpoint_prefix = "./checkpoints/checkpoint" | |
exists := directoryExists("checkpoints") | |
println("Loading graph") | |
model, err := createModel(graph_def_filename) | |
if err != nil { | |
log.Fatal(err) | |
} | |
if exists { | |
println("Restoring weights from checkpoint (remove the checkpoints directory to reset)") | |
err = createCheckpoint(model, checkpoint_prefix, false) | |
if err != nil { | |
log.Fatal(err) | |
} | |
} else { | |
// initialize_model | |
err = initializeModel(model) | |
if err != nil { | |
log.Fatal(err) | |
} | |
} | |
data := []float32{1.0, 2.0, 3.0} | |
println("Initial predictions") | |
err = predict(model, data) | |
if err != nil { | |
log.Fatal(err) | |
} | |
println("Training for a few steps") | |
for i := 0; i < 200; i++ { | |
err = train(model) | |
if err != nil { | |
log.Fatal(err) | |
} | |
} | |
println("Updated predictions") | |
err = predict(model, data) | |
if err != nil { | |
log.Fatal(err) | |
} | |
err = createCheckpoint(model, checkpoint_prefix, true) | |
if err != nil { | |
log.Fatal(err) | |
} | |
model.session.Close() | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment