Skip to content

Instantly share code, notes, and snippets.

@unixpickle
Last active March 12, 2017 02:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save unixpickle/53f6308d6d6939c6b3a31a46faabb8f6 to your computer and use it in GitHub Desktop.
Save unixpickle/53f6308d6d6939c6b3a31a46faabb8f6 to your computer and use it in GitHub Desktop.
Linear bayes
// Use Bayes' rule to compute a probability distribution
// over linear models, then use the linear models on a
// simple classification problem.
package main
import (
"image"
"image/color"
"image/png"
"log"
"math"
"math/rand"
"os"
"github.com/unixpickle/essentials"
)
const (
ImageSize = 256
NumDist = 1000
Stddev = 3
)
type Sample struct {
X float64
Y float64
Class int
}
type Params struct {
X float64
Y float64
Bias float64
}
func SampleParams() *Params {
return &Params{
X: rand.NormFloat64() * Stddev,
Y: rand.NormFloat64() * Stddev,
Bias: rand.NormFloat64() * Stddev,
}
}
func (p *Params) Prob(sample Sample) float64 {
comb := p.X*sample.X + p.Y*sample.Y + p.Bias
if sample.Class == 1 {
return 1 / (1 + math.Exp(-comb))
} else {
return 1 / (1 + math.Exp(comb))
}
}
func (p *Params) TotalProb(samples []Sample) float64 {
total := 1.0
for _, x := range samples {
total *= p.Prob(x)
}
return total
}
type ParamDist map[*Params]float64
func NewParamDist(samples []Sample) ParamDist {
paramDist := map[*Params]float64{}
var sum float64
for i := 0; i < NumDist; i++ {
p := SampleParams()
prob := p.TotalProb(samples)
paramDist[p] = prob
sum += prob
}
for p := range paramDist {
paramDist[p] /= sum
}
return paramDist
}
func (p ParamDist) Classify(x, y float64) float64 {
var sum float64
for params, weight := range p {
comb := params.X*x + params.Y*y + params.Bias
sum += weight / (1 + math.Exp(-comb))
}
return sum
}
func main() {
samples := []Sample{
{1, 0, 1},
{1, 1, 1},
{3, 1, 1},
{5, 1, 1},
{2, 2, 1},
{1, 3, 1},
{3, 3, 1},
{2, 4, 1},
{2, 6, 1},
{8, 4, 0},
{8, 6, 0},
{6, 9, 0},
{9, 9, 0},
}
log.Println("Learning classifier distribution...")
dist := NewParamDist(samples)
log.Println("Coloring image...")
outImage := image.NewRGBA(image.Rect(0, 0, ImageSize, ImageSize))
for y := 0; y < ImageSize; y++ {
for x := 0; x < ImageSize; x++ {
prob := dist.Classify(float64(x)*10/(ImageSize-1),
float64(y)*10/(ImageSize-1))
color := color.RGBA{
R: uint8(0xff - 0xff*prob + 0.5),
G: 0x70,
B: uint8(0xff*prob + 0.5),
A: 0xff,
}
outImage.SetRGBA(x, y, color)
}
}
log.Println("Writing output.png...")
f, err := os.Create("output.png")
if err != nil {
essentials.Die(err)
}
defer f.Close()
if err := png.Encode(f, outImage); err != nil {
essentials.Die(err)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment