Skip to content

Instantly share code, notes, and snippets.

@matteo-grella
Created June 8, 2022 14:52
Show Gist options
  • Save matteo-grella/6c1da99f3248b997612531fd8d531350 to your computer and use it in GitHub Desktop.
Save matteo-grella/6c1da99f3248b997612531fd8d531350 to your computer and use it in GitHub Desktop.
package main
import (
"fmt"
"log"
"reflect"
"github.com/nlpodyssey/spago/ag"
"github.com/nlpodyssey/spago/gd"
"github.com/nlpodyssey/spago/gd/sgd"
"github.com/nlpodyssey/spago/initializers"
"github.com/nlpodyssey/spago/losses"
"github.com/nlpodyssey/spago/mat"
"github.com/nlpodyssey/spago/mat/float"
"github.com/nlpodyssey/spago/mat/rand"
"github.com/nlpodyssey/spago/nn"
)
const (
epochs = 100 // number of epochs
examples = 1000 // number of examples
)
type Linear struct {
nn.Module
W nn.Param `spago:"type:weights"`
B nn.Param `spago:"type:biases"`
}
func NewLinear[T float.DType](in, out int) *Linear {
return &Linear{
W: nn.NewParam(mat.NewEmptyDense[T](out, in)),
B: nn.NewParam(mat.NewEmptyVecDense[T](out)),
}
}
func (m *Linear) InitWithRandomWeights(seed uint64) *Linear {
initializers.XavierUniform(m.W.Value(), 1.0, rand.NewLockedRand(seed))
return m
}
func (m *Linear) Forward(x ag.Node) ag.Node {
return ag.Add(ag.Mul(m.W, x), m.B)
}
func main() {
m := NewLinear[float64](1, 1).InitWithRandomWeights(42)
optimizer := gd.NewOptimizer(m, sgd.New[float64](sgd.NewConfig(0.001, 0.9, true)))
normalize := func(x float64) float64 { return x / float64(examples) }
objective := func(x float64) float64 { return 3*x + 1 }
criterion := losses.MSE
learn := func(input, expected float64) float64 {
x, target := ag.Scalar(input), ag.Scalar(expected)
y := m.Forward(x)
loss := criterion(y, target, true)
defer ag.Backward(loss) // free the memory of the graph before return
return loss.Value().Scalar().F64()
}
for epoch := 0; epoch < epochs; epoch++ {
for i := 0; i < examples; i++ {
x := normalize(float64(i))
loss := learn(x, objective(x))
if i%100 == 0 {
fmt.Printf("Loss: %.6f\n", loss)
}
}
optimizer.Do()
}
fmt.Printf("\nW: %.2f | B: %.2f\n\n", m.W.Value().Scalar().F64(), m.B.Value().Scalar().F64())
fmt.Printf("%#v", m)
err := nn.DumpToFile(m, "model")
if err != nil {
log.Fatal(err)
}
m2, err := nn.LoadFromFile[*Linear]("model")
if err != nil {
log.Fatal(err)
}
fmt.Printf("\nW: %.2f | B: %.2f\n\n", m2.W.Value().Scalar().F64(), m2.B.Value().Scalar().F64())
fmt.Println(reflect.TypeOf(m2).Kind())
// Save the model to a file
err = nn.DumpToFile(m.W.Value(), "w")
if err != nil {
log.Fatal(err)
}
// Load the model from a file
w, err := nn.LoadFromFile[mat.Dense[float64]]("w")
if err != nil {
log.Fatal(err)
}
fmt.Printf("\nW: %.2f\n", w.Scalar().F64())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment