Skip to content

Instantly share code, notes, and snippets.

@owulveryck
Last active December 6, 2024 13:55
Show Gist options
  • Save owulveryck/19a5ba9553ff8209b3b4227b5325041b to your computer and use it in GitHub Desktop.
Save owulveryck/19a5ba9553ff8209b3b4227b5325041b to your computer and use it in GitHub Desktop.
Linear regression on iris dataset with Gorgonia and gota
sepal_length sepal_width petal_length petal_width species
5.1 3.5 1.4 0.2 setosa
4.9 3.0 1.4 0.2 setosa
4.7 3.2 1.3 0.2 setosa
4.6 3.1 1.5 0.2 setosa
5.0 3.6 1.4 0.2 setosa
5.4 3.9 1.7 0.4 setosa
4.6 3.4 1.4 0.3 setosa
5.0 3.4 1.5 0.2 setosa
4.4 2.9 1.4 0.2 setosa
4.9 3.1 1.5 0.1 setosa
5.4 3.7 1.5 0.2 setosa
4.8 3.4 1.6 0.2 setosa
4.8 3.0 1.4 0.1 setosa
4.3 3.0 1.1 0.1 setosa
5.8 4.0 1.2 0.2 setosa
5.7 4.4 1.5 0.4 setosa
5.4 3.9 1.3 0.4 setosa
5.1 3.5 1.4 0.3 setosa
5.7 3.8 1.7 0.3 setosa
5.1 3.8 1.5 0.3 setosa
5.4 3.4 1.7 0.2 setosa
5.1 3.7 1.5 0.4 setosa
4.6 3.6 1.0 0.2 setosa
5.1 3.3 1.7 0.5 setosa
4.8 3.4 1.9 0.2 setosa
5.0 3.0 1.6 0.2 setosa
5.0 3.4 1.6 0.4 setosa
5.2 3.5 1.5 0.2 setosa
5.2 3.4 1.4 0.2 setosa
4.7 3.2 1.6 0.2 setosa
4.8 3.1 1.6 0.2 setosa
5.4 3.4 1.5 0.4 setosa
5.2 4.1 1.5 0.1 setosa
5.5 4.2 1.4 0.2 setosa
4.9 3.1 1.5 0.1 setosa
5.0 3.2 1.2 0.2 setosa
5.5 3.5 1.3 0.2 setosa
4.9 3.1 1.5 0.1 setosa
4.4 3.0 1.3 0.2 setosa
5.1 3.4 1.5 0.2 setosa
5.0 3.5 1.3 0.3 setosa
4.5 2.3 1.3 0.3 setosa
4.4 3.2 1.3 0.2 setosa
5.0 3.5 1.6 0.6 setosa
5.1 3.8 1.9 0.4 setosa
4.8 3.0 1.4 0.3 setosa
5.1 3.8 1.6 0.2 setosa
4.6 3.2 1.4 0.2 setosa
5.3 3.7 1.5 0.2 setosa
5.0 3.3 1.4 0.2 setosa
7.0 3.2 4.7 1.4 versicolor
6.4 3.2 4.5 1.5 versicolor
6.9 3.1 4.9 1.5 versicolor
5.5 2.3 4.0 1.3 versicolor
6.5 2.8 4.6 1.5 versicolor
5.7 2.8 4.5 1.3 versicolor
6.3 3.3 4.7 1.6 versicolor
4.9 2.4 3.3 1.0 versicolor
6.6 2.9 4.6 1.3 versicolor
5.2 2.7 3.9 1.4 versicolor
5.0 2.0 3.5 1.0 versicolor
5.9 3.0 4.2 1.5 versicolor
6.0 2.2 4.0 1.0 versicolor
6.1 2.9 4.7 1.4 versicolor
5.6 2.9 3.6 1.3 versicolor
6.7 3.1 4.4 1.4 versicolor
5.6 3.0 4.5 1.5 versicolor
5.8 2.7 4.1 1.0 versicolor
6.2 2.2 4.5 1.5 versicolor
5.6 2.5 3.9 1.1 versicolor
5.9 3.2 4.8 1.8 versicolor
6.1 2.8 4.0 1.3 versicolor
6.3 2.5 4.9 1.5 versicolor
6.1 2.8 4.7 1.2 versicolor
6.4 2.9 4.3 1.3 versicolor
6.6 3.0 4.4 1.4 versicolor
6.8 2.8 4.8 1.4 versicolor
6.7 3.0 5.0 1.7 versicolor
6.0 2.9 4.5 1.5 versicolor
5.7 2.6 3.5 1.0 versicolor
5.5 2.4 3.8 1.1 versicolor
5.5 2.4 3.7 1.0 versicolor
5.8 2.7 3.9 1.2 versicolor
6.0 2.7 5.1 1.6 versicolor
5.4 3.0 4.5 1.5 versicolor
6.0 3.4 4.5 1.6 versicolor
6.7 3.1 4.7 1.5 versicolor
6.3 2.3 4.4 1.3 versicolor
5.6 3.0 4.1 1.3 versicolor
5.5 2.5 4.0 1.3 versicolor
5.5 2.6 4.4 1.2 versicolor
6.1 3.0 4.6 1.4 versicolor
5.8 2.6 4.0 1.2 versicolor
5.0 2.3 3.3 1.0 versicolor
5.6 2.7 4.2 1.3 versicolor
5.7 3.0 4.2 1.2 versicolor
5.7 2.9 4.2 1.3 versicolor
6.2 2.9 4.3 1.3 versicolor
5.1 2.5 3.0 1.1 versicolor
5.7 2.8 4.1 1.3 versicolor
6.3 3.3 6.0 2.5 virginica
5.8 2.7 5.1 1.9 virginica
7.1 3.0 5.9 2.1 virginica
6.3 2.9 5.6 1.8 virginica
6.5 3.0 5.8 2.2 virginica
7.6 3.0 6.6 2.1 virginica
4.9 2.5 4.5 1.7 virginica
7.3 2.9 6.3 1.8 virginica
6.7 2.5 5.8 1.8 virginica
7.2 3.6 6.1 2.5 virginica
6.5 3.2 5.1 2.0 virginica
6.4 2.7 5.3 1.9 virginica
6.8 3.0 5.5 2.1 virginica
5.7 2.5 5.0 2.0 virginica
5.8 2.8 5.1 2.4 virginica
6.4 3.2 5.3 2.3 virginica
6.5 3.0 5.5 1.8 virginica
7.7 3.8 6.7 2.2 virginica
7.7 2.6 6.9 2.3 virginica
6.0 2.2 5.0 1.5 virginica
6.9 3.2 5.7 2.3 virginica
5.6 2.8 4.9 2.0 virginica
7.7 2.8 6.7 2.0 virginica
6.3 2.7 4.9 1.8 virginica
6.7 3.3 5.7 2.1 virginica
7.2 3.2 6.0 1.8 virginica
6.2 2.8 4.8 1.8 virginica
6.1 3.0 4.9 1.8 virginica
6.4 2.8 5.6 2.1 virginica
7.2 3.0 5.8 1.6 virginica
7.4 2.8 6.1 1.9 virginica
7.9 3.8 6.4 2.0 virginica
6.4 2.8 5.6 2.2 virginica
6.3 2.8 5.1 1.5 virginica
6.1 2.6 5.6 1.4 virginica
7.7 3.0 6.1 2.3 virginica
6.3 3.4 5.6 2.4 virginica
6.4 3.1 5.5 1.8 virginica
6.0 3.0 4.8 1.8 virginica
6.9 3.1 5.4 2.1 virginica
6.7 3.1 5.6 2.4 virginica
6.9 3.1 5.1 2.3 virginica
5.8 2.7 5.1 1.9 virginica
6.8 3.2 5.9 2.3 virginica
6.7 3.3 5.7 2.5 virginica
6.7 3.0 5.2 2.3 virginica
6.3 2.5 5.0 1.9 virginica
6.5 3.0 5.2 2.0 virginica
6.2 3.4 5.4 2.3 virginica
5.9 3.0 5.1 1.8 virginica
package main
import (
"fmt"
"log"
"math"
"os"
"github.com/go-gota/gota/dataframe"
"github.com/go-gota/gota/series"
"gonum.org/v1/gonum/mat"
"gorgonia.org/gorgonia"
"gorgonia.org/tensor"
)
// https://www.kaggle.com/amarpandey/implementing-linear-regression-on-iris-dataset/notebook
//
func main() {
xT, yT := getXY()
g := gorgonia.NewGraph()
x := gorgonia.NodeFromAny(g, xT, gorgonia.WithName("x"))
y := gorgonia.NodeFromAny(g, yT, gorgonia.WithName("y"))
theta := gorgonia.NewVector(
g,
gorgonia.Float64,
gorgonia.WithName("theta"),
gorgonia.WithShape(xT.Shape()[1]),
gorgonia.WithInit(gorgonia.Uniform(0, 1)))
pred := must(gorgonia.Mul(x, theta))
// Gorgonia might delete values from nodes so we are going to save it
// and print it out later
var predicted gorgonia.Value
gorgonia.Read(pred, &predicted)
squaredError := must(gorgonia.Square(must(gorgonia.Sub(pred, y))))
cost := must(gorgonia.Mean(squaredError))
if _, err := gorgonia.Grad(cost, theta); err != nil {
log.Fatalf("Failed to backpropagate: %v", err)
}
machine := gorgonia.NewTapeMachine(g, gorgonia.BindDualValues(theta))
defer machine.Close()
model := []gorgonia.ValueGrad{theta}
solver := gorgonia.NewVanillaSolver(gorgonia.WithLearnRate(0.001))
iter := 100000
var err error
for i := 0; i < iter; i++ {
if err = machine.RunAll(); err != nil {
fmt.Printf("Error during iteration: %v: %v\n", i, err)
break
}
if err = solver.Step(model); err != nil {
log.Fatal(err)
}
fmt.Printf("theta: %2.2f Iter: %v Cost: %2.3f Accuracy: %2.2f \r",
theta.Value(),
i,
cost.Value(),
accuracy(predicted.Data().([]float64), y.Value().Data().([]float64)))
machine.Reset() // Reset is necessary in a loop like this
}
fmt.Println("")
}
func accuracy(prediction, y []float64) float64 {
var ok float64
for i := 0; i < len(prediction); i++ {
if math.Round(prediction[i]-y[i]) == 0 {
ok += 1.0
}
}
return ok / float64(len(y))
}
func getXY() (*tensor.Dense, *tensor.Dense) {
f, err := os.Open("iris.csv")
if err != nil {
log.Fatal(err)
}
defer f.Close()
df := dataframe.ReadCSV(f)
xDF := df.Drop("species")
toValue := func(s series.Series) series.Series {
records := s.Records()
floats := make([]float64, len(records))
for i, r := range records {
switch r {
case "setosa":
floats[i] = 1
case "virginica":
floats[i] = 2
case "versicolor":
floats[i] = 3
default:
log.Fatalf("unknown iris: %v\n", r)
}
}
return series.Floats(floats)
}
yDF := df.Select("species").Capply(toValue)
numRows, _ := xDF.Dims()
xDF = xDF.Mutate(series.New(one(numRows), series.Float, "bias"))
fmt.Println(xDF.Describe())
fmt.Println(yDF.Describe())
xT := tensor.FromMat64(mat.DenseCopyOf(&matrix{xDF}))
yT := tensor.FromMat64(mat.DenseCopyOf(&matrix{yDF}))
// Get rid of the last dimension to create a vector
yT.Reshape(numRows)
return xT, yT
}
type matrix struct {
dataframe.DataFrame
}
func (m matrix) At(i, j int) float64 {
return m.Elem(i, j).Float()
}
func (m matrix) T() mat.Matrix {
return mat.Transpose{Matrix: m}
}
func must(n *gorgonia.Node, err error) *gorgonia.Node {
if err != nil {
panic(err)
}
return n
}
func one(size int) []float64 {
one := make([]float64, size)
for i := 0; i < size; i++ {
one[i] = 1.0
}
return one
}
package main
import (
"fmt"
"log"
"os"
"github.com/go-gota/gota/dataframe"
"github.com/go-gota/gota/series"
"gonum.org/v1/gonum/mat"
)
func main() {
fa := mat.Formatted(getThetaNormal(), mat.Prefix(" "), mat.Squeeze())
fmt.Printf("ϴ: %v\n", fa)
}
func getXYMat() (*mat.Dense, *mat.Dense) {
f, err := os.Open("iris.csv")
if err != nil {
log.Fatal(err)
}
defer f.Close()
df := dataframe.ReadCSV(f)
xDF := df.Drop("species")
toValue := func(s series.Series) series.Series {
records := s.Records()
floats := make([]float64, len(records))
for i, r := range records {
switch r {
case "setosa":
floats[i] = 1
case "virginica":
floats[i] = 2
case "versicolor":
floats[i] = 3
default:
log.Fatalf("unknown iris: %v\n", r)
}
}
return series.Floats(floats)
}
yDF := df.Select("species").Capply(toValue)
numRows, _ := xDF.Dims()
xDF = xDF.Mutate(series.New(one(numRows), series.Float, "bias"))
fmt.Println(xDF.Describe())
fmt.Println(yDF.Describe())
return mat.DenseCopyOf(&matrix{xDF}), mat.DenseCopyOf(&matrix{yDF})
}
func one(size int) []float64 {
one := make([]float64, size)
for i := 0; i < size; i++ {
one[i] = 1.0
}
return one
}
func getThetaNormal() *mat.Dense {
x, y := getXYMat()
xt := mat.DenseCopyOf(x).T()
var xtx mat.Dense
xtx.Mul(xt, x)
var invxtx mat.Dense
invxtx.Inverse(&xtx)
var xty mat.Dense
xty.Mul(xt, y)
var output mat.Dense
output.Mul(&invxtx, &xty)
return &output
}
type matrix struct {
dataframe.DataFrame
}
func (m matrix) At(i, j int) float64 {
return m.Elem(i, j).Float()
}
func (m matrix) T() mat.Matrix {
return mat.Transpose{Matrix: m}
}
[7x6] DataFrame
column sepal_length sepal_width petal_length petal_width bias
0: mean 5.843333 3.054000 3.758667 1.198667 1.000000
1: stddev 0.828066 0.433594 1.764420 0.763161 0.000000
2: min 4.300000 2.000000 1.000000 0.100000 1.000000
3: 25% 5.100000 2.800000 1.600000 0.300000 1.000000
4: 50% 5.800000 3.000000 4.300000 1.300000 1.000000
5: 75% 6.400000 3.300000 5.100000 1.800000 1.000000
6: max 7.900000 4.400000 6.900000 2.500000 1.000000
<string> <float> <float> <float> <float> <float>
[7x2] DataFrame
column species
0: mean 2.000000
1: stddev 0.819232
2: min 1.000000
3: 25% 1.000000
4: 50% 2.000000
5: 75% 3.000000
6: max 3.000000
<string> <float>
ϴ: ⎡-0.08718768910924979⎤
⎢ -0.6831785613306529⎥
⎢ 0.44128274494996056⎥
⎢-0.41983988087491575⎥
⎣ 3.4405073828555714⎦
[7x6] DataFrame
column sepal_length sepal_width petal_length petal_width bias
0: mean 5.843333 3.054000 3.758667 1.198667 1.000000
1: stddev 0.828066 0.433594 1.764420 0.763161 0.000000
2: min 4.300000 2.000000 1.000000 0.100000 1.000000
3: 25% 5.100000 2.800000 1.600000 0.300000 1.000000
4: 50% 5.800000 3.000000 4.300000 1.300000 1.000000
5: 75% 6.400000 3.300000 5.100000 1.800000 1.000000
6: max 7.900000 4.400000 6.900000 2.500000 1.000000
<string> <float> <float> <float> <float> <float>
[7x2] DataFrame
column species
0: mean 2.000000
1: stddev 0.819232
2: min 1.000000
3: 25% 1.000000
4: 50% 2.000000
5: 75% 3.000000
6: max 3.000000
<string> <float>
theta: [-0.00 -0.63 0.43 -0.45 2.86] Iter: 99999 Cost: 0.289 Accuracy: 0.58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment