Skip to content

Instantly share code, notes, and snippets.

@owulveryck owulveryck/iris.csv
Last active Oct 31, 2019

Embed
What would you like to do?
Linear regression on iris dataset with Gorgonia and gota
sepal_length sepal_width petal_length petal_width species
5.1 3.5 1.4 0.2 setosa
4.9 3.0 1.4 0.2 setosa
4.7 3.2 1.3 0.2 setosa
4.6 3.1 1.5 0.2 setosa
5.0 3.6 1.4 0.2 setosa
5.4 3.9 1.7 0.4 setosa
4.6 3.4 1.4 0.3 setosa
5.0 3.4 1.5 0.2 setosa
4.4 2.9 1.4 0.2 setosa
4.9 3.1 1.5 0.1 setosa
5.4 3.7 1.5 0.2 setosa
4.8 3.4 1.6 0.2 setosa
4.8 3.0 1.4 0.1 setosa
4.3 3.0 1.1 0.1 setosa
5.8 4.0 1.2 0.2 setosa
5.7 4.4 1.5 0.4 setosa
5.4 3.9 1.3 0.4 setosa
5.1 3.5 1.4 0.3 setosa
5.7 3.8 1.7 0.3 setosa
5.1 3.8 1.5 0.3 setosa
5.4 3.4 1.7 0.2 setosa
5.1 3.7 1.5 0.4 setosa
4.6 3.6 1.0 0.2 setosa
5.1 3.3 1.7 0.5 setosa
4.8 3.4 1.9 0.2 setosa
5.0 3.0 1.6 0.2 setosa
5.0 3.4 1.6 0.4 setosa
5.2 3.5 1.5 0.2 setosa
5.2 3.4 1.4 0.2 setosa
4.7 3.2 1.6 0.2 setosa
4.8 3.1 1.6 0.2 setosa
5.4 3.4 1.5 0.4 setosa
5.2 4.1 1.5 0.1 setosa
5.5 4.2 1.4 0.2 setosa
4.9 3.1 1.5 0.1 setosa
5.0 3.2 1.2 0.2 setosa
5.5 3.5 1.3 0.2 setosa
4.9 3.1 1.5 0.1 setosa
4.4 3.0 1.3 0.2 setosa
5.1 3.4 1.5 0.2 setosa
5.0 3.5 1.3 0.3 setosa
4.5 2.3 1.3 0.3 setosa
4.4 3.2 1.3 0.2 setosa
5.0 3.5 1.6 0.6 setosa
5.1 3.8 1.9 0.4 setosa
4.8 3.0 1.4 0.3 setosa
5.1 3.8 1.6 0.2 setosa
4.6 3.2 1.4 0.2 setosa
5.3 3.7 1.5 0.2 setosa
5.0 3.3 1.4 0.2 setosa
7.0 3.2 4.7 1.4 versicolor
6.4 3.2 4.5 1.5 versicolor
6.9 3.1 4.9 1.5 versicolor
5.5 2.3 4.0 1.3 versicolor
6.5 2.8 4.6 1.5 versicolor
5.7 2.8 4.5 1.3 versicolor
6.3 3.3 4.7 1.6 versicolor
4.9 2.4 3.3 1.0 versicolor
6.6 2.9 4.6 1.3 versicolor
5.2 2.7 3.9 1.4 versicolor
5.0 2.0 3.5 1.0 versicolor
5.9 3.0 4.2 1.5 versicolor
6.0 2.2 4.0 1.0 versicolor
6.1 2.9 4.7 1.4 versicolor
5.6 2.9 3.6 1.3 versicolor
6.7 3.1 4.4 1.4 versicolor
5.6 3.0 4.5 1.5 versicolor
5.8 2.7 4.1 1.0 versicolor
6.2 2.2 4.5 1.5 versicolor
5.6 2.5 3.9 1.1 versicolor
5.9 3.2 4.8 1.8 versicolor
6.1 2.8 4.0 1.3 versicolor
6.3 2.5 4.9 1.5 versicolor
6.1 2.8 4.7 1.2 versicolor
6.4 2.9 4.3 1.3 versicolor
6.6 3.0 4.4 1.4 versicolor
6.8 2.8 4.8 1.4 versicolor
6.7 3.0 5.0 1.7 versicolor
6.0 2.9 4.5 1.5 versicolor
5.7 2.6 3.5 1.0 versicolor
5.5 2.4 3.8 1.1 versicolor
5.5 2.4 3.7 1.0 versicolor
5.8 2.7 3.9 1.2 versicolor
6.0 2.7 5.1 1.6 versicolor
5.4 3.0 4.5 1.5 versicolor
6.0 3.4 4.5 1.6 versicolor
6.7 3.1 4.7 1.5 versicolor
6.3 2.3 4.4 1.3 versicolor
5.6 3.0 4.1 1.3 versicolor
5.5 2.5 4.0 1.3 versicolor
5.5 2.6 4.4 1.2 versicolor
6.1 3.0 4.6 1.4 versicolor
5.8 2.6 4.0 1.2 versicolor
5.0 2.3 3.3 1.0 versicolor
5.6 2.7 4.2 1.3 versicolor
5.7 3.0 4.2 1.2 versicolor
5.7 2.9 4.2 1.3 versicolor
6.2 2.9 4.3 1.3 versicolor
5.1 2.5 3.0 1.1 versicolor
5.7 2.8 4.1 1.3 versicolor
6.3 3.3 6.0 2.5 virginica
5.8 2.7 5.1 1.9 virginica
7.1 3.0 5.9 2.1 virginica
6.3 2.9 5.6 1.8 virginica
6.5 3.0 5.8 2.2 virginica
7.6 3.0 6.6 2.1 virginica
4.9 2.5 4.5 1.7 virginica
7.3 2.9 6.3 1.8 virginica
6.7 2.5 5.8 1.8 virginica
7.2 3.6 6.1 2.5 virginica
6.5 3.2 5.1 2.0 virginica
6.4 2.7 5.3 1.9 virginica
6.8 3.0 5.5 2.1 virginica
5.7 2.5 5.0 2.0 virginica
5.8 2.8 5.1 2.4 virginica
6.4 3.2 5.3 2.3 virginica
6.5 3.0 5.5 1.8 virginica
7.7 3.8 6.7 2.2 virginica
7.7 2.6 6.9 2.3 virginica
6.0 2.2 5.0 1.5 virginica
6.9 3.2 5.7 2.3 virginica
5.6 2.8 4.9 2.0 virginica
7.7 2.8 6.7 2.0 virginica
6.3 2.7 4.9 1.8 virginica
6.7 3.3 5.7 2.1 virginica
7.2 3.2 6.0 1.8 virginica
6.2 2.8 4.8 1.8 virginica
6.1 3.0 4.9 1.8 virginica
6.4 2.8 5.6 2.1 virginica
7.2 3.0 5.8 1.6 virginica
7.4 2.8 6.1 1.9 virginica
7.9 3.8 6.4 2.0 virginica
6.4 2.8 5.6 2.2 virginica
6.3 2.8 5.1 1.5 virginica
6.1 2.6 5.6 1.4 virginica
7.7 3.0 6.1 2.3 virginica
6.3 3.4 5.6 2.4 virginica
6.4 3.1 5.5 1.8 virginica
6.0 3.0 4.8 1.8 virginica
6.9 3.1 5.4 2.1 virginica
6.7 3.1 5.6 2.4 virginica
6.9 3.1 5.1 2.3 virginica
5.8 2.7 5.1 1.9 virginica
6.8 3.2 5.9 2.3 virginica
6.7 3.3 5.7 2.5 virginica
6.7 3.0 5.2 2.3 virginica
6.3 2.5 5.0 1.9 virginica
6.5 3.0 5.2 2.0 virginica
6.2 3.4 5.4 2.3 virginica
5.9 3.0 5.1 1.8 virginica
package main
import (
"fmt"
"log"
"math"
"os"
"github.com/go-gota/gota/dataframe"
"github.com/go-gota/gota/series"
"gonum.org/v1/gonum/mat"
"gorgonia.org/gorgonia"
"gorgonia.org/tensor"
)
// https://www.kaggle.com/amarpandey/implementing-linear-regression-on-iris-dataset/notebook
//
func main() {
xT, yT := getXY()
g := gorgonia.NewGraph()
x := gorgonia.NodeFromAny(g, xT, gorgonia.WithName("x"))
y := gorgonia.NodeFromAny(g, yT, gorgonia.WithName("y"))
theta := gorgonia.NewVector(
g,
gorgonia.Float64,
gorgonia.WithName("theta"),
gorgonia.WithShape(xT.Shape()[1]),
gorgonia.WithInit(gorgonia.Uniform(0, 1)))
pred := must(gorgonia.Mul(x, theta))
// Gorgonia might delete values from nodes so we are going to save it
// and print it out later
var predicted gorgonia.Value
gorgonia.Read(pred, &predicted)
squaredError := must(gorgonia.Square(must(gorgonia.Sub(pred, y))))
cost := must(gorgonia.Mean(squaredError))
if _, err := gorgonia.Grad(cost, theta); err != nil {
log.Fatalf("Failed to backpropagate: %v", err)
}
machine := gorgonia.NewTapeMachine(g, gorgonia.BindDualValues(theta))
defer machine.Close()
model := []gorgonia.ValueGrad{theta}
solver := gorgonia.NewVanillaSolver(gorgonia.WithLearnRate(0.001))
iter := 100000
var err error
for i := 0; i < iter; i++ {
if err = machine.RunAll(); err != nil {
fmt.Printf("Error during iteration: %v: %v\n", i, err)
break
}
if err = solver.Step(model); err != nil {
log.Fatal(err)
}
fmt.Printf("theta: %2.2f Iter: %v Cost: %2.3f Accuracy: %2.2f \r",
theta.Value(),
i,
cost.Value(),
accuracy(predicted.Data().([]float64), y.Value().Data().([]float64)))
machine.Reset() // Reset is necessary in a loop like this
}
fmt.Println("")
}
func accuracy(prediction, y []float64) float64 {
var ok float64
for i := 0; i < len(prediction); i++ {
if math.Round(prediction[i]-y[i]) == 0 {
ok += 1.0
}
}
return ok / float64(len(y))
}
func getXY() (*tensor.Dense, *tensor.Dense) {
f, err := os.Open("iris.csv")
if err != nil {
log.Fatal(err)
}
defer f.Close()
df := dataframe.ReadCSV(f)
xDF := df.Drop("species")
toValue := func(s series.Series) series.Series {
records := s.Records()
floats := make([]float64, len(records))
for i, r := range records {
switch r {
case "setosa":
floats[i] = 1
case "virginica":
floats[i] = 2
case "versicolor":
floats[i] = 3
default:
log.Fatalf("unknown iris: %v\n", r)
}
}
return series.Floats(floats)
}
yDF := df.Select("species").Capply(toValue)
numRows, _ := xDF.Dims()
xDF = xDF.Mutate(series.New(one(numRows), series.Float, "bias"))
fmt.Println(xDF.Describe())
fmt.Println(yDF.Describe())
xT := tensor.FromMat64(mat.DenseCopyOf(&matrix{xDF}))
yT := tensor.FromMat64(mat.DenseCopyOf(&matrix{yDF}))
// Get rid of the last dimension to create a vector
yT.Reshape(numRows)
return xT, yT
}
type matrix struct {
dataframe.DataFrame
}
func (m matrix) At(i, j int) float64 {
return m.Elem(i, j).Float()
}
func (m matrix) T() mat.Matrix {
return mat.Transpose{Matrix: m}
}
func must(n *gorgonia.Node, err error) *gorgonia.Node {
if err != nil {
panic(err)
}
return n
}
func one(size int) []float64 {
one := make([]float64, size)
for i := 0; i < size; i++ {
one[i] = 1.0
}
return one
}
package main
import (
"fmt"
"log"
"os"
"github.com/go-gota/gota/dataframe"
"github.com/go-gota/gota/series"
"gonum.org/v1/gonum/mat"
)
func main() {
fa := mat.Formatted(getThetaNormal(), mat.Prefix(" "), mat.Squeeze())
fmt.Printf("ϴ: %v\n", fa)
}
func getXYMat() (*mat.Dense, *mat.Dense) {
f, err := os.Open("iris.csv")
if err != nil {
log.Fatal(err)
}
defer f.Close()
df := dataframe.ReadCSV(f)
xDF := df.Drop("species")
toValue := func(s series.Series) series.Series {
records := s.Records()
floats := make([]float64, len(records))
for i, r := range records {
switch r {
case "setosa":
floats[i] = 1
case "virginica":
floats[i] = 2
case "versicolor":
floats[i] = 3
default:
log.Fatalf("unknown iris: %v\n", r)
}
}
return series.Floats(floats)
}
yDF := df.Select("species").Capply(toValue)
numRows, _ := xDF.Dims()
xDF = xDF.Mutate(series.New(one(numRows), series.Float, "bias"))
fmt.Println(xDF.Describe())
fmt.Println(yDF.Describe())
return mat.DenseCopyOf(&matrix{xDF}), mat.DenseCopyOf(&matrix{yDF})
}
func one(size int) []float64 {
one := make([]float64, size)
for i := 0; i < size; i++ {
one[i] = 1.0
}
return one
}
func getThetaNormal() *mat.Dense {
x, y := getXYMat()
xt := mat.DenseCopyOf(x).T()
var xtx mat.Dense
xtx.Mul(xt, x)
var invxtx mat.Dense
invxtx.Inverse(&xtx)
var xty mat.Dense
xty.Mul(xt, y)
var output mat.Dense
output.Mul(&invxtx, &xty)
return &output
}
type matrix struct {
dataframe.DataFrame
}
func (m matrix) At(i, j int) float64 {
return m.Elem(i, j).Float()
}
func (m matrix) T() mat.Matrix {
return mat.Transpose{Matrix: m}
}
[7x6] DataFrame
column sepal_length sepal_width petal_length petal_width bias
0: mean 5.843333 3.054000 3.758667 1.198667 1.000000
1: stddev 0.828066 0.433594 1.764420 0.763161 0.000000
2: min 4.300000 2.000000 1.000000 0.100000 1.000000
3: 25% 5.100000 2.800000 1.600000 0.300000 1.000000
4: 50% 5.800000 3.000000 4.300000 1.300000 1.000000
5: 75% 6.400000 3.300000 5.100000 1.800000 1.000000
6: max 7.900000 4.400000 6.900000 2.500000 1.000000
<string> <float> <float> <float> <float> <float>
[7x2] DataFrame
column species
0: mean 2.000000
1: stddev 0.819232
2: min 1.000000
3: 25% 1.000000
4: 50% 2.000000
5: 75% 3.000000
6: max 3.000000
<string> <float>
ϴ: ⎡-0.08718768910924979⎤
⎢ -0.6831785613306529⎥
⎢ 0.44128274494996056⎥
⎢-0.41983988087491575⎥
⎣ 3.4405073828555714⎦
[7x6] DataFrame
column sepal_length sepal_width petal_length petal_width bias
0: mean 5.843333 3.054000 3.758667 1.198667 1.000000
1: stddev 0.828066 0.433594 1.764420 0.763161 0.000000
2: min 4.300000 2.000000 1.000000 0.100000 1.000000
3: 25% 5.100000 2.800000 1.600000 0.300000 1.000000
4: 50% 5.800000 3.000000 4.300000 1.300000 1.000000
5: 75% 6.400000 3.300000 5.100000 1.800000 1.000000
6: max 7.900000 4.400000 6.900000 2.500000 1.000000
<string> <float> <float> <float> <float> <float>
[7x2] DataFrame
column species
0: mean 2.000000
1: stddev 0.819232
2: min 1.000000
3: 25% 1.000000
4: 50% 2.000000
5: 75% 3.000000
6: max 3.000000
<string> <float>
theta: [-0.00 -0.63 0.43 -0.45 2.86] Iter: 99999 Cost: 0.289 Accuracy: 0.58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.