Linear regression on iris dataset with Gorgonia and gota
sepal_length | sepal_width | petal_length | petal_width | species | |
---|---|---|---|---|---|
5.1 | 3.5 | 1.4 | 0.2 | setosa | |
4.9 | 3.0 | 1.4 | 0.2 | setosa | |
4.7 | 3.2 | 1.3 | 0.2 | setosa | |
4.6 | 3.1 | 1.5 | 0.2 | setosa | |
5.0 | 3.6 | 1.4 | 0.2 | setosa | |
5.4 | 3.9 | 1.7 | 0.4 | setosa | |
4.6 | 3.4 | 1.4 | 0.3 | setosa | |
5.0 | 3.4 | 1.5 | 0.2 | setosa | |
4.4 | 2.9 | 1.4 | 0.2 | setosa | |
4.9 | 3.1 | 1.5 | 0.1 | setosa | |
5.4 | 3.7 | 1.5 | 0.2 | setosa | |
4.8 | 3.4 | 1.6 | 0.2 | setosa | |
4.8 | 3.0 | 1.4 | 0.1 | setosa | |
4.3 | 3.0 | 1.1 | 0.1 | setosa | |
5.8 | 4.0 | 1.2 | 0.2 | setosa | |
5.7 | 4.4 | 1.5 | 0.4 | setosa | |
5.4 | 3.9 | 1.3 | 0.4 | setosa | |
5.1 | 3.5 | 1.4 | 0.3 | setosa | |
5.7 | 3.8 | 1.7 | 0.3 | setosa | |
5.1 | 3.8 | 1.5 | 0.3 | setosa | |
5.4 | 3.4 | 1.7 | 0.2 | setosa | |
5.1 | 3.7 | 1.5 | 0.4 | setosa | |
4.6 | 3.6 | 1.0 | 0.2 | setosa | |
5.1 | 3.3 | 1.7 | 0.5 | setosa | |
4.8 | 3.4 | 1.9 | 0.2 | setosa | |
5.0 | 3.0 | 1.6 | 0.2 | setosa | |
5.0 | 3.4 | 1.6 | 0.4 | setosa | |
5.2 | 3.5 | 1.5 | 0.2 | setosa | |
5.2 | 3.4 | 1.4 | 0.2 | setosa | |
4.7 | 3.2 | 1.6 | 0.2 | setosa | |
4.8 | 3.1 | 1.6 | 0.2 | setosa | |
5.4 | 3.4 | 1.5 | 0.4 | setosa | |
5.2 | 4.1 | 1.5 | 0.1 | setosa | |
5.5 | 4.2 | 1.4 | 0.2 | setosa | |
4.9 | 3.1 | 1.5 | 0.1 | setosa | |
5.0 | 3.2 | 1.2 | 0.2 | setosa | |
5.5 | 3.5 | 1.3 | 0.2 | setosa | |
4.9 | 3.1 | 1.5 | 0.1 | setosa | |
4.4 | 3.0 | 1.3 | 0.2 | setosa | |
5.1 | 3.4 | 1.5 | 0.2 | setosa | |
5.0 | 3.5 | 1.3 | 0.3 | setosa | |
4.5 | 2.3 | 1.3 | 0.3 | setosa | |
4.4 | 3.2 | 1.3 | 0.2 | setosa | |
5.0 | 3.5 | 1.6 | 0.6 | setosa | |
5.1 | 3.8 | 1.9 | 0.4 | setosa | |
4.8 | 3.0 | 1.4 | 0.3 | setosa | |
5.1 | 3.8 | 1.6 | 0.2 | setosa | |
4.6 | 3.2 | 1.4 | 0.2 | setosa | |
5.3 | 3.7 | 1.5 | 0.2 | setosa | |
5.0 | 3.3 | 1.4 | 0.2 | setosa | |
7.0 | 3.2 | 4.7 | 1.4 | versicolor | |
6.4 | 3.2 | 4.5 | 1.5 | versicolor | |
6.9 | 3.1 | 4.9 | 1.5 | versicolor | |
5.5 | 2.3 | 4.0 | 1.3 | versicolor | |
6.5 | 2.8 | 4.6 | 1.5 | versicolor | |
5.7 | 2.8 | 4.5 | 1.3 | versicolor | |
6.3 | 3.3 | 4.7 | 1.6 | versicolor | |
4.9 | 2.4 | 3.3 | 1.0 | versicolor | |
6.6 | 2.9 | 4.6 | 1.3 | versicolor | |
5.2 | 2.7 | 3.9 | 1.4 | versicolor | |
5.0 | 2.0 | 3.5 | 1.0 | versicolor | |
5.9 | 3.0 | 4.2 | 1.5 | versicolor | |
6.0 | 2.2 | 4.0 | 1.0 | versicolor | |
6.1 | 2.9 | 4.7 | 1.4 | versicolor | |
5.6 | 2.9 | 3.6 | 1.3 | versicolor | |
6.7 | 3.1 | 4.4 | 1.4 | versicolor | |
5.6 | 3.0 | 4.5 | 1.5 | versicolor | |
5.8 | 2.7 | 4.1 | 1.0 | versicolor | |
6.2 | 2.2 | 4.5 | 1.5 | versicolor | |
5.6 | 2.5 | 3.9 | 1.1 | versicolor | |
5.9 | 3.2 | 4.8 | 1.8 | versicolor | |
6.1 | 2.8 | 4.0 | 1.3 | versicolor | |
6.3 | 2.5 | 4.9 | 1.5 | versicolor | |
6.1 | 2.8 | 4.7 | 1.2 | versicolor | |
6.4 | 2.9 | 4.3 | 1.3 | versicolor | |
6.6 | 3.0 | 4.4 | 1.4 | versicolor | |
6.8 | 2.8 | 4.8 | 1.4 | versicolor | |
6.7 | 3.0 | 5.0 | 1.7 | versicolor | |
6.0 | 2.9 | 4.5 | 1.5 | versicolor | |
5.7 | 2.6 | 3.5 | 1.0 | versicolor | |
5.5 | 2.4 | 3.8 | 1.1 | versicolor | |
5.5 | 2.4 | 3.7 | 1.0 | versicolor | |
5.8 | 2.7 | 3.9 | 1.2 | versicolor | |
6.0 | 2.7 | 5.1 | 1.6 | versicolor | |
5.4 | 3.0 | 4.5 | 1.5 | versicolor | |
6.0 | 3.4 | 4.5 | 1.6 | versicolor | |
6.7 | 3.1 | 4.7 | 1.5 | versicolor | |
6.3 | 2.3 | 4.4 | 1.3 | versicolor | |
5.6 | 3.0 | 4.1 | 1.3 | versicolor | |
5.5 | 2.5 | 4.0 | 1.3 | versicolor | |
5.5 | 2.6 | 4.4 | 1.2 | versicolor | |
6.1 | 3.0 | 4.6 | 1.4 | versicolor | |
5.8 | 2.6 | 4.0 | 1.2 | versicolor | |
5.0 | 2.3 | 3.3 | 1.0 | versicolor | |
5.6 | 2.7 | 4.2 | 1.3 | versicolor | |
5.7 | 3.0 | 4.2 | 1.2 | versicolor | |
5.7 | 2.9 | 4.2 | 1.3 | versicolor | |
6.2 | 2.9 | 4.3 | 1.3 | versicolor | |
5.1 | 2.5 | 3.0 | 1.1 | versicolor | |
5.7 | 2.8 | 4.1 | 1.3 | versicolor | |
6.3 | 3.3 | 6.0 | 2.5 | virginica | |
5.8 | 2.7 | 5.1 | 1.9 | virginica | |
7.1 | 3.0 | 5.9 | 2.1 | virginica | |
6.3 | 2.9 | 5.6 | 1.8 | virginica | |
6.5 | 3.0 | 5.8 | 2.2 | virginica | |
7.6 | 3.0 | 6.6 | 2.1 | virginica | |
4.9 | 2.5 | 4.5 | 1.7 | virginica | |
7.3 | 2.9 | 6.3 | 1.8 | virginica | |
6.7 | 2.5 | 5.8 | 1.8 | virginica | |
7.2 | 3.6 | 6.1 | 2.5 | virginica | |
6.5 | 3.2 | 5.1 | 2.0 | virginica | |
6.4 | 2.7 | 5.3 | 1.9 | virginica | |
6.8 | 3.0 | 5.5 | 2.1 | virginica | |
5.7 | 2.5 | 5.0 | 2.0 | virginica | |
5.8 | 2.8 | 5.1 | 2.4 | virginica | |
6.4 | 3.2 | 5.3 | 2.3 | virginica | |
6.5 | 3.0 | 5.5 | 1.8 | virginica | |
7.7 | 3.8 | 6.7 | 2.2 | virginica | |
7.7 | 2.6 | 6.9 | 2.3 | virginica | |
6.0 | 2.2 | 5.0 | 1.5 | virginica | |
6.9 | 3.2 | 5.7 | 2.3 | virginica | |
5.6 | 2.8 | 4.9 | 2.0 | virginica | |
7.7 | 2.8 | 6.7 | 2.0 | virginica | |
6.3 | 2.7 | 4.9 | 1.8 | virginica | |
6.7 | 3.3 | 5.7 | 2.1 | virginica | |
7.2 | 3.2 | 6.0 | 1.8 | virginica | |
6.2 | 2.8 | 4.8 | 1.8 | virginica | |
6.1 | 3.0 | 4.9 | 1.8 | virginica | |
6.4 | 2.8 | 5.6 | 2.1 | virginica | |
7.2 | 3.0 | 5.8 | 1.6 | virginica | |
7.4 | 2.8 | 6.1 | 1.9 | virginica | |
7.9 | 3.8 | 6.4 | 2.0 | virginica | |
6.4 | 2.8 | 5.6 | 2.2 | virginica | |
6.3 | 2.8 | 5.1 | 1.5 | virginica | |
6.1 | 2.6 | 5.6 | 1.4 | virginica | |
7.7 | 3.0 | 6.1 | 2.3 | virginica | |
6.3 | 3.4 | 5.6 | 2.4 | virginica | |
6.4 | 3.1 | 5.5 | 1.8 | virginica | |
6.0 | 3.0 | 4.8 | 1.8 | virginica | |
6.9 | 3.1 | 5.4 | 2.1 | virginica | |
6.7 | 3.1 | 5.6 | 2.4 | virginica | |
6.9 | 3.1 | 5.1 | 2.3 | virginica | |
5.8 | 2.7 | 5.1 | 1.9 | virginica | |
6.8 | 3.2 | 5.9 | 2.3 | virginica | |
6.7 | 3.3 | 5.7 | 2.5 | virginica | |
6.7 | 3.0 | 5.2 | 2.3 | virginica | |
6.3 | 2.5 | 5.0 | 1.9 | virginica | |
6.5 | 3.0 | 5.2 | 2.0 | virginica | |
6.2 | 3.4 | 5.4 | 2.3 | virginica | |
5.9 | 3.0 | 5.1 | 1.8 | virginica |
package main | |
import ( | |
"fmt" | |
"log" | |
"math" | |
"os" | |
"github.com/go-gota/gota/dataframe" | |
"github.com/go-gota/gota/series" | |
"gonum.org/v1/gonum/mat" | |
"gorgonia.org/gorgonia" | |
"gorgonia.org/tensor" | |
) | |
// https://www.kaggle.com/amarpandey/implementing-linear-regression-on-iris-dataset/notebook | |
// | |
func main() { | |
xT, yT := getXY() | |
g := gorgonia.NewGraph() | |
x := gorgonia.NodeFromAny(g, xT, gorgonia.WithName("x")) | |
y := gorgonia.NodeFromAny(g, yT, gorgonia.WithName("y")) | |
theta := gorgonia.NewVector( | |
g, | |
gorgonia.Float64, | |
gorgonia.WithName("theta"), | |
gorgonia.WithShape(xT.Shape()[1]), | |
gorgonia.WithInit(gorgonia.Uniform(0, 1))) | |
pred := must(gorgonia.Mul(x, theta)) | |
// Gorgonia might delete values from nodes so we are going to save it | |
// and print it out later | |
var predicted gorgonia.Value | |
gorgonia.Read(pred, &predicted) | |
squaredError := must(gorgonia.Square(must(gorgonia.Sub(pred, y)))) | |
cost := must(gorgonia.Mean(squaredError)) | |
if _, err := gorgonia.Grad(cost, theta); err != nil { | |
log.Fatalf("Failed to backpropagate: %v", err) | |
} | |
machine := gorgonia.NewTapeMachine(g, gorgonia.BindDualValues(theta)) | |
defer machine.Close() | |
model := []gorgonia.ValueGrad{theta} | |
solver := gorgonia.NewVanillaSolver(gorgonia.WithLearnRate(0.001)) | |
iter := 100000 | |
var err error | |
for i := 0; i < iter; i++ { | |
if err = machine.RunAll(); err != nil { | |
fmt.Printf("Error during iteration: %v: %v\n", i, err) | |
break | |
} | |
if err = solver.Step(model); err != nil { | |
log.Fatal(err) | |
} | |
fmt.Printf("theta: %2.2f Iter: %v Cost: %2.3f Accuracy: %2.2f \r", | |
theta.Value(), | |
i, | |
cost.Value(), | |
accuracy(predicted.Data().([]float64), y.Value().Data().([]float64))) | |
machine.Reset() // Reset is necessary in a loop like this | |
} | |
fmt.Println("") | |
} | |
func accuracy(prediction, y []float64) float64 { | |
var ok float64 | |
for i := 0; i < len(prediction); i++ { | |
if math.Round(prediction[i]-y[i]) == 0 { | |
ok += 1.0 | |
} | |
} | |
return ok / float64(len(y)) | |
} | |
func getXY() (*tensor.Dense, *tensor.Dense) { | |
f, err := os.Open("iris.csv") | |
if err != nil { | |
log.Fatal(err) | |
} | |
defer f.Close() | |
df := dataframe.ReadCSV(f) | |
xDF := df.Drop("species") | |
toValue := func(s series.Series) series.Series { | |
records := s.Records() | |
floats := make([]float64, len(records)) | |
for i, r := range records { | |
switch r { | |
case "setosa": | |
floats[i] = 1 | |
case "virginica": | |
floats[i] = 2 | |
case "versicolor": | |
floats[i] = 3 | |
default: | |
log.Fatalf("unknown iris: %v\n", r) | |
} | |
} | |
return series.Floats(floats) | |
} | |
yDF := df.Select("species").Capply(toValue) | |
numRows, _ := xDF.Dims() | |
xDF = xDF.Mutate(series.New(one(numRows), series.Float, "bias")) | |
fmt.Println(xDF.Describe()) | |
fmt.Println(yDF.Describe()) | |
xT := tensor.FromMat64(mat.DenseCopyOf(&matrix{xDF})) | |
yT := tensor.FromMat64(mat.DenseCopyOf(&matrix{yDF})) | |
// Get rid of the last dimension to create a vector | |
yT.Reshape(numRows) | |
return xT, yT | |
} | |
type matrix struct { | |
dataframe.DataFrame | |
} | |
func (m matrix) At(i, j int) float64 { | |
return m.Elem(i, j).Float() | |
} | |
func (m matrix) T() mat.Matrix { | |
return mat.Transpose{Matrix: m} | |
} | |
func must(n *gorgonia.Node, err error) *gorgonia.Node { | |
if err != nil { | |
panic(err) | |
} | |
return n | |
} | |
func one(size int) []float64 { | |
one := make([]float64, size) | |
for i := 0; i < size; i++ { | |
one[i] = 1.0 | |
} | |
return one | |
} |
package main | |
import ( | |
"fmt" | |
"log" | |
"os" | |
"github.com/go-gota/gota/dataframe" | |
"github.com/go-gota/gota/series" | |
"gonum.org/v1/gonum/mat" | |
) | |
func main() { | |
fa := mat.Formatted(getThetaNormal(), mat.Prefix(" "), mat.Squeeze()) | |
fmt.Printf("ϴ: %v\n", fa) | |
} | |
func getXYMat() (*mat.Dense, *mat.Dense) { | |
f, err := os.Open("iris.csv") | |
if err != nil { | |
log.Fatal(err) | |
} | |
defer f.Close() | |
df := dataframe.ReadCSV(f) | |
xDF := df.Drop("species") | |
toValue := func(s series.Series) series.Series { | |
records := s.Records() | |
floats := make([]float64, len(records)) | |
for i, r := range records { | |
switch r { | |
case "setosa": | |
floats[i] = 1 | |
case "virginica": | |
floats[i] = 2 | |
case "versicolor": | |
floats[i] = 3 | |
default: | |
log.Fatalf("unknown iris: %v\n", r) | |
} | |
} | |
return series.Floats(floats) | |
} | |
yDF := df.Select("species").Capply(toValue) | |
numRows, _ := xDF.Dims() | |
xDF = xDF.Mutate(series.New(one(numRows), series.Float, "bias")) | |
fmt.Println(xDF.Describe()) | |
fmt.Println(yDF.Describe()) | |
return mat.DenseCopyOf(&matrix{xDF}), mat.DenseCopyOf(&matrix{yDF}) | |
} | |
func one(size int) []float64 { | |
one := make([]float64, size) | |
for i := 0; i < size; i++ { | |
one[i] = 1.0 | |
} | |
return one | |
} | |
func getThetaNormal() *mat.Dense { | |
x, y := getXYMat() | |
xt := mat.DenseCopyOf(x).T() | |
var xtx mat.Dense | |
xtx.Mul(xt, x) | |
var invxtx mat.Dense | |
invxtx.Inverse(&xtx) | |
var xty mat.Dense | |
xty.Mul(xt, y) | |
var output mat.Dense | |
output.Mul(&invxtx, &xty) | |
return &output | |
} | |
type matrix struct { | |
dataframe.DataFrame | |
} | |
func (m matrix) At(i, j int) float64 { | |
return m.Elem(i, j).Float() | |
} | |
func (m matrix) T() mat.Matrix { | |
return mat.Transpose{Matrix: m} | |
} |
[7x6] DataFrame | |
column sepal_length sepal_width petal_length petal_width bias | |
0: mean 5.843333 3.054000 3.758667 1.198667 1.000000 | |
1: stddev 0.828066 0.433594 1.764420 0.763161 0.000000 | |
2: min 4.300000 2.000000 1.000000 0.100000 1.000000 | |
3: 25% 5.100000 2.800000 1.600000 0.300000 1.000000 | |
4: 50% 5.800000 3.000000 4.300000 1.300000 1.000000 | |
5: 75% 6.400000 3.300000 5.100000 1.800000 1.000000 | |
6: max 7.900000 4.400000 6.900000 2.500000 1.000000 | |
<string> <float> <float> <float> <float> <float> | |
[7x2] DataFrame | |
column species | |
0: mean 2.000000 | |
1: stddev 0.819232 | |
2: min 1.000000 | |
3: 25% 1.000000 | |
4: 50% 2.000000 | |
5: 75% 3.000000 | |
6: max 3.000000 | |
<string> <float> | |
ϴ: ⎡-0.08718768910924979⎤ | |
⎢ -0.6831785613306529⎥ | |
⎢ 0.44128274494996056⎥ | |
⎢-0.41983988087491575⎥ | |
⎣ 3.4405073828555714⎦ |
[7x6] DataFrame | |
column sepal_length sepal_width petal_length petal_width bias | |
0: mean 5.843333 3.054000 3.758667 1.198667 1.000000 | |
1: stddev 0.828066 0.433594 1.764420 0.763161 0.000000 | |
2: min 4.300000 2.000000 1.000000 0.100000 1.000000 | |
3: 25% 5.100000 2.800000 1.600000 0.300000 1.000000 | |
4: 50% 5.800000 3.000000 4.300000 1.300000 1.000000 | |
5: 75% 6.400000 3.300000 5.100000 1.800000 1.000000 | |
6: max 7.900000 4.400000 6.900000 2.500000 1.000000 | |
<string> <float> <float> <float> <float> <float> | |
[7x2] DataFrame | |
column species | |
0: mean 2.000000 | |
1: stddev 0.819232 | |
2: min 1.000000 | |
3: 25% 1.000000 | |
4: 50% 2.000000 | |
5: 75% 3.000000 | |
6: max 3.000000 | |
<string> <float> | |
theta: [-0.00 -0.63 0.43 -0.45 2.86] Iter: 99999 Cost: 0.289 Accuracy: 0.58 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment