Created
December 5, 2012 14:58
-
-
Save takkkun/4216136 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
trait Vector[V <: Vector[V]] extends Product { | |
def +(vector: V): V | |
def *(vector: V): V | |
def *(coefficient: Double): V | |
def toList: List[Double] | |
def innerProduct(vector: V) = (this * vector).toList.reduceLeft(_ + _) | |
} | |
case class Vector1D(_1: Double) extends Vector[Vector1D] { | |
def +(vector: Vector1D) = Vector1D(_1 + vector._1) | |
def *(vector: Vector1D) = Vector1D(_1 * vector._1) | |
def *(coefficient: Double) = Vector1D(_1 * coefficient) | |
def toList = List(_1) | |
} | |
case class Vector2D(_1: Double, _2: Double) extends Vector[Vector2D] { | |
def +(vector: Vector2D) = Vector2D(_1 + vector._1, _2 + vector._2) | |
def *(vector: Vector2D) = Vector2D(_1 * vector._1, _2 * vector._2) | |
def *(coefficient: Double) = Vector2D(_1 * coefficient, _2 * coefficient) | |
def toList = List(_1, _2) | |
} | |
case class Vector3D(_1: Double, _2: Double, _3: Double) extends Vector[Vector3D] { | |
def +(vector: Vector3D) = Vector3D(_1 + vector._1, _2 + vector._2, _3 + vector._3) | |
def *(vector: Vector3D) = Vector3D(_1 * vector._1, _2 * vector._2, _3 * vector._3) | |
def *(coefficient: Double) = Vector3D(_1 * coefficient, _2 * coefficient, _3 * coefficient) | |
def toList = List(_1, _2, _3) | |
} | |
case class AugmentedVector[V <: Vector[V]](head: Double, vector: V) { | |
def +(other: AugmentedVector[V]) = AugmentedVector(head + other.head, vector + other.vector) | |
def *(other: AugmentedVector[V]) = AugmentedVector(head * other.head, vector * other.vector) | |
def *(coefficient: Double) = AugmentedVector(head * coefficient, vector * coefficient) | |
def innerProduct(other: AugmentedVector[V]) = head * other.head + vector.innerProduct(other.vector) | |
override def toString = (head :: vector.productIterator.toList).mkString(productPrefix + "(", ",", ")") | |
} | |
case class Class[V <: Vector[V]](name: String, patterns: List[V]) { | |
def augment = patterns.map(new AugmentedVector(1.0, _) with Pattern).toSeq | |
trait Pattern { | |
def patternOf = Class.this | |
} | |
} | |
case class Perceptron[V <: Vector[V]](class1: Class[V], class2: Class[V]) { | |
import scala.util.Random.shuffle | |
def train(initialWeight: V, learningRate: Double) = { | |
val patterns = class1.augment ++ class2.augment | |
val (weight, count) = adjust(patterns, AugmentedVector(0.0, initialWeight), learningRate, 0, class1, class2, patterns) | |
(new Discriminator(class1, class2, weight), count) | |
} | |
private def adjust( | |
tryingPatterns: Seq[AugmentedVector[V] with Class[V]#Pattern], | |
weight: AugmentedVector[V], | |
keisu: Double, | |
count: Int, | |
class1: Class[V], | |
class2: Class[V], | |
allPatterns: Seq[AugmentedVector[V] with Class[V]#Pattern] | |
): (AugmentedVector[V], Int) = | |
tryingPatterns match { | |
case Seq(pattern, patterns@_*) => { | |
print(pattern) | |
val (newPatterns, newWeight) = weight innerProduct pattern match { | |
case g if pattern.patternOf == class1 && g >= 0 || pattern.patternOf == class2 && g < 0 => { | |
println(" => discriminated") | |
(patterns, weight) | |
} | |
case g if pattern.patternOf == class1 && g < 0 => { | |
val newWeight = updateWeight(pattern, weight, keisu, 1) | |
println(" => could not discriminate") | |
println("retry by " + newWeight) | |
(shuffle(allPatterns.filter(pattern !=)), newWeight) | |
} | |
case g if pattern.patternOf == class2 && g >= 0 => { | |
val newWeight = updateWeight(pattern, weight, keisu, -1) | |
println(" => could not discriminate") | |
println("retry by " + newWeight) | |
(shuffle(allPatterns.filter(pattern !=)), newWeight) | |
} | |
} | |
val newCount = if (weight == newWeight) count else count + 1 | |
adjust(newPatterns, newWeight, keisu, newCount, class1, class2, allPatterns) | |
} | |
case Nil => (weight, count) | |
} | |
private def updateWeight(pattern: AugmentedVector[V], weight: AugmentedVector[V], keisu: Double, label: Int) = | |
weight + pattern * (keisu * label) | |
} | |
case class Discriminator[V <: Vector[V]](class1: Class[V], class2: Class[V], weight: AugmentedVector[V]) { | |
def apply(pattern: V): Option[Class[V]] = | |
weight innerProduct AugmentedVector(1.0, pattern) match { | |
case g if g > 0 => Some(class1) | |
case g if g < 0 => Some(class2) | |
case _ => None | |
} | |
} | |
object Perceptron { | |
implicit def scalarToVector1D(scalar: Double) = Vector1D(scalar) | |
implicit def pointToVector2D(point: Tuple2[Double, Double]) = Vector2D(point._1, point._2) | |
def main(args: Array[String]) { | |
val c1 = Class[Vector1D]("class1", List( | |
5.0, | |
4.0, | |
3.6, | |
7.0 | |
)) | |
val c2 = Class[Vector1D]("class2", List( | |
3.5, | |
2.0, | |
3.55, | |
1.0, | |
-1.0 | |
)) | |
val perceptron = Perceptron(c1, c2) | |
val (discriminator, count) = perceptron.train(Vector1D(0.0), args(0).toDouble) | |
println("Weight: " + discriminator.weight) | |
println("Count: " + count) | |
// discriminator(Vector1D(args(0).toDouble)) match { | |
// case Some(klass) => println(klass.name) | |
// case None => println("don't identify") | |
// } | |
// val c1 = Class[Vector2D]("class1", List( | |
// (-2.0, 2.0) | |
// )) | |
// val c2 = Class[Vector2D]("class2", List( | |
// (2.0, -2.0) | |
// )) | |
// val perceptron = Perceptron(c1, c2) | |
// val identify = perceptron.train(Vector2(0.0, 0.0)) | |
// identify(Vector2(args(0).toDouble, args(1).toDouble)) match { | |
// case Some(klass) => println(klass.name) | |
// case None => println("don't identify") | |
// } | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment