Skip to content

Instantly share code, notes, and snippets.

@takkkun
Created December 5, 2012 14:58
Show Gist options
  • Save takkkun/4216136 to your computer and use it in GitHub Desktop.
Save takkkun/4216136 to your computer and use it in GitHub Desktop.
trait Vector[V <: Vector[V]] extends Product {
def +(vector: V): V
def *(vector: V): V
def *(coefficient: Double): V
def toList: List[Double]
def innerProduct(vector: V) = (this * vector).toList.reduceLeft(_ + _)
}
case class Vector1D(_1: Double) extends Vector[Vector1D] {
def +(vector: Vector1D) = Vector1D(_1 + vector._1)
def *(vector: Vector1D) = Vector1D(_1 * vector._1)
def *(coefficient: Double) = Vector1D(_1 * coefficient)
def toList = List(_1)
}
case class Vector2D(_1: Double, _2: Double) extends Vector[Vector2D] {
def +(vector: Vector2D) = Vector2D(_1 + vector._1, _2 + vector._2)
def *(vector: Vector2D) = Vector2D(_1 * vector._1, _2 * vector._2)
def *(coefficient: Double) = Vector2D(_1 * coefficient, _2 * coefficient)
def toList = List(_1, _2)
}
case class Vector3D(_1: Double, _2: Double, _3: Double) extends Vector[Vector3D] {
def +(vector: Vector3D) = Vector3D(_1 + vector._1, _2 + vector._2, _3 + vector._3)
def *(vector: Vector3D) = Vector3D(_1 * vector._1, _2 * vector._2, _3 * vector._3)
def *(coefficient: Double) = Vector3D(_1 * coefficient, _2 * coefficient, _3 * coefficient)
def toList = List(_1, _2, _3)
}
case class AugmentedVector[V <: Vector[V]](head: Double, vector: V) {
def +(other: AugmentedVector[V]) = AugmentedVector(head + other.head, vector + other.vector)
def *(other: AugmentedVector[V]) = AugmentedVector(head * other.head, vector * other.vector)
def *(coefficient: Double) = AugmentedVector(head * coefficient, vector * coefficient)
def innerProduct(other: AugmentedVector[V]) = head * other.head + vector.innerProduct(other.vector)
override def toString = (head :: vector.productIterator.toList).mkString(productPrefix + "(", ",", ")")
}
case class Class[V <: Vector[V]](name: String, patterns: List[V]) {
def augment = patterns.map(new AugmentedVector(1.0, _) with Pattern).toSeq
trait Pattern {
def patternOf = Class.this
}
}
case class Perceptron[V <: Vector[V]](class1: Class[V], class2: Class[V]) {
import scala.util.Random.shuffle
def train(initialWeight: V, learningRate: Double) = {
val patterns = class1.augment ++ class2.augment
val (weight, count) = adjust(patterns, AugmentedVector(0.0, initialWeight), learningRate, 0, class1, class2, patterns)
(new Discriminator(class1, class2, weight), count)
}
private def adjust(
tryingPatterns: Seq[AugmentedVector[V] with Class[V]#Pattern],
weight: AugmentedVector[V],
keisu: Double,
count: Int,
class1: Class[V],
class2: Class[V],
allPatterns: Seq[AugmentedVector[V] with Class[V]#Pattern]
): (AugmentedVector[V], Int) =
tryingPatterns match {
case Seq(pattern, patterns@_*) => {
print(pattern)
val (newPatterns, newWeight) = weight innerProduct pattern match {
case g if pattern.patternOf == class1 && g >= 0 || pattern.patternOf == class2 && g < 0 => {
println(" => discriminated")
(patterns, weight)
}
case g if pattern.patternOf == class1 && g < 0 => {
val newWeight = updateWeight(pattern, weight, keisu, 1)
println(" => could not discriminate")
println("retry by " + newWeight)
(shuffle(allPatterns.filter(pattern !=)), newWeight)
}
case g if pattern.patternOf == class2 && g >= 0 => {
val newWeight = updateWeight(pattern, weight, keisu, -1)
println(" => could not discriminate")
println("retry by " + newWeight)
(shuffle(allPatterns.filter(pattern !=)), newWeight)
}
}
val newCount = if (weight == newWeight) count else count + 1
adjust(newPatterns, newWeight, keisu, newCount, class1, class2, allPatterns)
}
case Nil => (weight, count)
}
private def updateWeight(pattern: AugmentedVector[V], weight: AugmentedVector[V], keisu: Double, label: Int) =
weight + pattern * (keisu * label)
}
case class Discriminator[V <: Vector[V]](class1: Class[V], class2: Class[V], weight: AugmentedVector[V]) {
def apply(pattern: V): Option[Class[V]] =
weight innerProduct AugmentedVector(1.0, pattern) match {
case g if g > 0 => Some(class1)
case g if g < 0 => Some(class2)
case _ => None
}
}
object Perceptron {
implicit def scalarToVector1D(scalar: Double) = Vector1D(scalar)
implicit def pointToVector2D(point: Tuple2[Double, Double]) = Vector2D(point._1, point._2)
def main(args: Array[String]) {
val c1 = Class[Vector1D]("class1", List(
5.0,
4.0,
3.6,
7.0
))
val c2 = Class[Vector1D]("class2", List(
3.5,
2.0,
3.55,
1.0,
-1.0
))
val perceptron = Perceptron(c1, c2)
val (discriminator, count) = perceptron.train(Vector1D(0.0), args(0).toDouble)
println("Weight: " + discriminator.weight)
println("Count: " + count)
// discriminator(Vector1D(args(0).toDouble)) match {
// case Some(klass) => println(klass.name)
// case None => println("don't identify")
// }
// val c1 = Class[Vector2D]("class1", List(
// (-2.0, 2.0)
// ))
// val c2 = Class[Vector2D]("class2", List(
// (2.0, -2.0)
// ))
// val perceptron = Perceptron(c1, c2)
// val identify = perceptron.train(Vector2(0.0, 0.0))
// identify(Vector2(args(0).toDouble, args(1).toDouble)) match {
// case Some(klass) => println(klass.name)
// case None => println("don't identify")
// }
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment