Skip to content

Instantly share code, notes, and snippets.

@soonraah
Created June 5, 2016 11:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save soonraah/5001a9eafa7880fb84716618bc0f62de to your computer and use it in GitHub Desktop.
Save soonraah/5001a9eafa7880fb84716618bc0f62de to your computer and use it in GitHub Desktop.
A base class of online learning for binary linear classification
package mlp.onlineml.classification.binary
import breeze.linalg.{DenseMatrix, DenseVector}
/**
* A base class of binary linear classification
*
* @param w weight vector
* @param sigma covariance matrix
*/
abstract class LinearClassifier protected (val w: DenseVector[Double], val sigma: DenseMatrix[Double]) {
protected def e(x: DenseVector[Double]): Double
protected def alpha(x: DenseVector[Double], y: Label): Double
protected def beta(x: DenseVector[Double], y: Label): Double
/**
* Online training
*
* @param x training sample
* @param y label for training sample
* @return updated classifier
*/
def train(x: DenseVector[Double], y: Label): LinearClassifier = {
require(x.length == w.length)
if (y.value * (w.t * x) < e(x)) {
create(
w + y.value * alpha(x, y) * (sigma * x),
sigma - beta(x, y) * (sigma * x * x.t * sigma)
)
} else {
create(w, sigma)
}
}
def classify(x: DenseVector[Double]): Label = {
if (w.t * x > 0) Label(true) else Label(false)
}
protected def create(w: DenseVector[Double], sigma: DenseMatrix[Double]): LinearClassifier
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment