|
package main |
|
|
|
import ( |
|
"bufio" |
|
"flag" |
|
"fmt" |
|
"image/color" |
|
"log" |
|
"os" |
|
|
|
"go-hep.org/x/hep/fit" |
|
"gonum.org/v1/plot" |
|
"gonum.org/v1/plot/plotter" |
|
"gonum.org/v1/plot/vg/draw" |
|
) |
|
|
|
var iterations int |
|
|
|
func main() { |
|
flag.IntVar(&iterations, "n", 1000, "number of iterations") |
|
flag.Parse() |
|
|
|
xys, err := readData("data.txt") |
|
if err != nil { |
|
log.Fatalf("could not read data.txt: %v", err) |
|
} |
|
_ = xys |
|
|
|
err = plotData("out.png", xys) |
|
if err != nil { |
|
log.Fatalf("could not plot data: %v", err) |
|
} |
|
} |
|
|
|
type xy struct{ x, y float64 } |
|
|
|
func readData(path string) (plotter.XYs, error) { |
|
f, err := os.Open(path) |
|
if err != nil { |
|
return nil, err |
|
} |
|
defer f.Close() |
|
|
|
var xys plotter.XYs |
|
s := bufio.NewScanner(f) |
|
for s.Scan() { |
|
var x, y float64 |
|
_, err := fmt.Sscanf(s.Text(), "%f,%f", &x, &y) |
|
if err != nil { |
|
log.Printf("discarding bad data point %q: %v", s.Text(), err) |
|
continue |
|
} |
|
xys = append(xys, struct{ X, Y float64 }{x, y}) |
|
} |
|
if err := s.Err(); err != nil { |
|
return nil, fmt.Errorf("could not scan: %v", err) |
|
} |
|
return xys, nil |
|
} |
|
|
|
func plotData(path string, xys plotter.XYs) error { |
|
f, err := os.Create(path) |
|
if err != nil { |
|
return fmt.Errorf("could not create %s: %v", path, err) |
|
} |
|
|
|
p, err := plot.New() |
|
if err != nil { |
|
return fmt.Errorf("could not create plot: %v", err) |
|
} |
|
|
|
// create scatter with all data points |
|
s, err := plotter.NewScatter(xys) |
|
if err != nil { |
|
return fmt.Errorf("could not create scatter: %v", err) |
|
} |
|
s.GlyphStyle.Shape = draw.CrossGlyph{} |
|
s.Color = color.RGBA{R: 255, A: 255} |
|
p.Add(s) |
|
|
|
x, c := linearRegression(xys, 0.01) |
|
x, c = minimize(xys) |
|
|
|
// create fake linear regression result |
|
l, err := plotter.NewLine(plotter.XYs{ |
|
{3, 3*x + c}, {20, 20*x + c}, |
|
}) |
|
if err != nil { |
|
return fmt.Errorf("could not create line: %v", err) |
|
} |
|
p.Add(l) |
|
|
|
wt, err := p.WriterTo(256, 256, "png") |
|
if err != nil { |
|
return fmt.Errorf("could not create writer: %v", err) |
|
} |
|
_, err = wt.WriteTo(f) |
|
if err != nil { |
|
return fmt.Errorf("could not write to %s: %v", path, err) |
|
} |
|
|
|
if err := f.Close(); err != nil { |
|
return fmt.Errorf("could not close %s: %v", path, err) |
|
} |
|
return nil |
|
} |
|
|
|
func minimize(xys plotter.XYs) (m, c float64) { |
|
xs := make([]float64, len(xys)) |
|
ys := make([]float64, len(xys)) |
|
for i, xy := range xys { |
|
xs[i] = xy.X |
|
ys[i] = xy.Y |
|
} |
|
|
|
res, err := fit.Curve1D( |
|
fit.Func1D{ |
|
F: func(x float64, ps []float64) float64 { |
|
return ps[0] + ps[1]*x |
|
}, |
|
Ps: []float64{c, m}, |
|
X: xs, |
|
Y: ys, |
|
}, |
|
nil, nil, |
|
) |
|
if err != nil { |
|
log.Fatal(err) |
|
} |
|
m = res.X[1] |
|
c = res.X[0] |
|
|
|
fmt.Printf("cost(%.2f, %.2f) = %.2f\n", m, c, res.F/float64(len(xys))) |
|
return m, c |
|
} |
|
|
|
func linearRegression(xys plotter.XYs, alpha float64) (m, c float64) { |
|
for i := 0; i < iterations; i++ { |
|
dm, dc := computeGradient(xys, m, c) |
|
m += -dm * alpha |
|
c += -dc * alpha |
|
fmt.Printf("cost(%.2f, %.2f) = %.2f\n", m, c, computeCost(xys, m, c)) |
|
} |
|
|
|
fmt.Printf("cost(%.2f, %.2f) = %.2f\n", m, c, computeCost(xys, m, c)) |
|
|
|
return m, c |
|
} |
|
|
|
func computeCost(xys plotter.XYs, m, c float64) float64 { |
|
// cost = 1/N * sum((y - (m*x+c))^2) |
|
s := 0.0 |
|
for _, xy := range xys { |
|
d := xy.Y - (xy.X*m + c) |
|
s += d * d |
|
} |
|
return s / float64(len(xys)) |
|
} |
|
|
|
func computeGradient(xys plotter.XYs, m, c float64) (dm, dc float64) { |
|
// cost = 1/N * sum((y - (m*x+c))^2) |
|
// cost/dm = 2/N * sum(-x * (y - (m*x+c))) |
|
// cost/dc = 2/N * sum(-(y - (m*x+c))) |
|
for _, xy := range xys { |
|
d := xy.Y - (xy.X*m + c) |
|
dm += -xy.X * d |
|
dc += -d |
|
} |
|
n := float64(len(xys)) |
|
return 2 / n * dm, 2 / n * dc |
|
} |