Skip to content

Instantly share code, notes, and snippets.

@ByungSunBae
Created January 29, 2019 13:54
Show Gist options
  • Save ByungSunBae/1c31629a3b1119a27ae559fad616a30e to your computer and use it in GitHub Desktop.
Save ByungSunBae/1c31629a3b1119a27ae559fad616a30e to your computer and use it in GitHub Desktop.
tensorflow simple operation in golang (key point : sum of elements of vector)
package main
import (
"fmt"
"github.com/kniren/gota/dataframe"
tf "github.com/tensorflow/tensorflow/tensorflow/go"
"github.com/tensorflow/tensorflow/tensorflow/go/op"
"os"
)
func errcheck(e error) {
if e != nil {
panic(e)
}
}
func main() {
// example.csv is dummy data that made in python.
// So, you can generate data in python or R.
file, err := os.Open("example.csv")
errcheck(err)
defer file.Close()
example := dataframe.ReadCSV(file)
//fmt.Println(example)
root := op.NewScope()
x_pl := op.Placeholder(root.SubScope("feature"), tf.Float, op.PlaceholderShape(tf.MakeShape(-1, 1)))
y_pl := op.Placeholder(root.SubScope("target"), tf.Float, op.PlaceholderShape(tf.MakeShape(-1, 1)))
fmt.Println(x_pl.Op.Name(), y_pl.Op.Name())
multiplication := op.Mul(root, x_pl, y_pl)
// I spend time because of entering axis of op.Sum.
axis := op.Const(root, []int32{0, 1})
sum_vl := op.Sum(root, multiplication, axis)
graph, err := root.Finalize()
errcheck(err)
var sess *tf.Session
sess, err = tf.NewSession(graph, &tf.SessionOptions{})
errcheck(err)
x_col_tmp := example.Col("x").Float()
x_col := [400][1]float32{}
for idx, val := range x_col_tmp {
x_col[idx][0] = float32(val)
}
y_col_tmp := example.Col("y").Float()
y_col := [400][1]float32{}
for idx, val := range y_col_tmp {
y_col[idx][0] = float32(val)
}
var x_val, y_val *tf.Tensor
if x_val, err = tf.NewTensor(x_col); err != nil {
panic(err.Error())
}
if y_val, err = tf.NewTensor(y_col); err != nil {
panic(err.Error())
}
var results []*tf.Tensor
if results, err = sess.Run(map[tf.Output]*tf.Tensor{
x_pl: x_val,
y_pl: y_val,
}, []tf.Output{sum_vl}, nil); err != nil {
panic(err.Error())
}
for _, result := range results {
fmt.Println(result.Value().(float32))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment