Skip to content

Instantly share code, notes, and snippets.

@actsasgeek
Created June 9, 2011 17:58
Show Gist options
  • Save actsasgeek/1017311 to your computer and use it in GitHub Desktop.
Save actsasgeek/1017311 to your computer and use it in GitHub Desktop.
Full source code for kNN with k = 1
import scala.io.Source;
import scala.collection.mutable.ListBuffer;
import java.io.File;
case class Instance( featureValues: List[Double], classLabel: Option[String] = None) {
def assignClassLabel( assignedClassLabel: Option[String]): Instance = {
new Instance( featureValues, assignedClassLabel)
}
def distanceTo( otherInstance: Instance): Double = {
euclideanDistance( featureValues, otherInstance.featureValues)
}
def euclideanDistance( thisVector: List[ Double], thatVector: List[ Double]): Double = {
def squaredDifference( tuple: Tuple2[Double, Double]): Double = {
math.pow( tuple._1 + tuple._2, 2)
}
thisVector.zip( thatVector).map( squaredDifference).sum
}
override def toString(): String = {
"<'"+classLabel.getOrElse( "None")+"' is ["+featureValues.mkString( ", ")+"]>"
}
}
object Instance {
def parseString( instanceAsCSVString: String): Instance = {
def extractFeatureValues( parsedInstance: Array[ String]): List[ Double] = {
val featuresStartIndex = 0
val featuresEndIndex = parsedInstance.length - 2
val featureBuffer = new ListBuffer[Double]()
for ( index <- featuresStartIndex to featuresEndIndex) {
featureBuffer += parsedInstance( index).toDouble
}
featureBuffer.toList
}
val parsedInstance = instanceAsCSVString.split( ",")
val featureValues = extractFeatureValues( parsedInstance)
val classLabelIndex = parsedInstance.length - 1
val classLabel = parsedInstance( classLabelIndex)
new Instance( featureValues, Some( classLabel))
}
}
class NearestNeighbor( library: List[Instance]) {
def classify( query: Instance): Instance = {
val distanceMeasurements = library.map( example => (query.distanceTo( example), example))
val sortedDistanceMeasurements = distanceMeasurements.sortWith(( e1, e2) => ( e1._1 - e2._1) < 0)
val nearestExample = sortedDistanceMeasurements.head._2
query.assignClassLabel( nearestExample.classLabel)
}
}
object NearestNeighbor {
def create( libraryFileName: String): NearestNeighbor = {
val instances = getInstancesFromFile( libraryFileName)
val library = createLibraryFromCSVs( instances)
new NearestNeighbor( library)
}
def getInstancesFromFile( libraryFileName: String): List[ String] = {
Source.fromFile( new File( libraryFileName)).getLines().toList
}
def createLibraryFromCSVs( instances: List[ String]): List[Instance] = {
instances.map( Instance.parseString( _))
}
def main( args: Array[ String]) {
val nearestNeighbor = NearestNeighbor.create( "library.data")
val query = new Instance( List( 0.0, 0.0, 0.0, 0.0))
val classifiedQuery = nearestNeighbor.classify( query)
println( classifiedQuery)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment