Skip to content

Instantly share code, notes, and snippets.

@salamanders
Created August 23, 2017 20:59
Show Gist options
  • Save salamanders/88dced678deff7cf3dcd7096a2c7dde5 to your computer and use it in GitHub Desktop.
Save salamanders/88dced678deff7cf3dcd7096a2c7dde5 to your computer and use it in GitHub Desktop.
JSAT Workout in Kotlin
import com.google.common.collect.ImmutableSet
import com.google.common.collect.Sets
import com.google.common.reflect.ClassPath
import jsat.classifiers.ClassificationModelEvaluation
import jsat.classifiers.Classifier
import jsat.classifiers.OneVSAll
import jsat.classifiers.svm.SupportVectorLearner
import jsat.classifiers.trees.DecisionStump
import jsat.distributions.kernels.RBFKernel
import jsat.exceptions.FailedToFitException
import jsat.io.LIBSVMLoader
import jsat.linear.distancemetrics.EuclideanDistance
import kotlinx.coroutines.experimental.*
import java.io.Serializable
import java.lang.reflect.Constructor
import java.lang.reflect.Modifier
import java.nio.file.Files
import java.nio.file.Path
import java.nio.file.Paths
import java.util.*
import java.util.stream.Stream
import kotlin.streams.toList
data class MLWorkout2(
var constructor: Constructor<Classifier>,
var dataPath: Path = Paths.get("."),
var params: Set<Serializable> = setOf(),
var trainingTime: Long = 0,
private var dataSize: Int = 0
) {
val classifier: Classifier by lazy {
assert(params.size == constructor.parameterCount)
constructor.newInstance(*params.toTypedArray())
}
val errorRate: Double by lazy {
val startTime = System.currentTimeMillis()
if (classifier is SupportVectorLearner) {
(classifier as SupportVectorLearner).cacheMode = SupportVectorLearner.CacheMode.FULL // Small dataset, so we can do this
}
try {
val ds = LIBSVMLoader.loadC(dataPath.toFile())
dataSize = ds.sampleSize * ds.numFeatures
try {
val cme = ClassificationModelEvaluation(classifier, ds)
cme.evaluateCrossValidation(10)
trainingTime = System.currentTimeMillis() - startTime
cme.errorRate
} catch (ex: FailedToFitException) {
val cme = ClassificationModelEvaluation(OneVSAll(classifier), ds)
cme.evaluateCrossValidation(10)
trainingTime = System.currentTimeMillis() - startTime
println("Successful OneVSAll")
cme.errorRate
}
} catch (ex: Throwable) {
when (ex) {
is NullPointerException,
is UnsupportedOperationException,
is ArithmeticException -> {
System.err.println("Expected: $ex")
}
else -> {
System.err.println("Unexpected: $ex")
ex.printStackTrace()
}
}
-1.0
}
}
override fun toString() = "${classifier.javaClass.simpleName}\t${params.toString().replace(Regex("\\s"), "")}\t${dataPath.fileName}\t$dataSize"
}
fun main(args: Array<String>) {
val dataFiles: List<Path> = Files.walk(Paths.get("./data")).filter { path -> path.toString().endsWith(".libsvm") }.toList()
assert(dataFiles.isNotEmpty())
val workouts = ClassPath.from(Thread.currentThread().contextClassLoader)
.getTopLevelClassesRecursive("jsat.classifiers")
.stream()
.map { it.load() }
.filter { !Modifier.isAbstract(it.modifiers) }
.filter { Classifier::class.java.isAssignableFrom(it) }
.map { it.constructors }
.flatMap { Arrays.stream(it) as Stream<Constructor<Classifier>> }
.map { MLWorkout2(it) }
.map { mlw -> dataFiles.stream().map { mlw.copy(dataPath = it) } }
.flatMap { it }
.map { mlw ->
Sets.powerSet(ImmutableSet.of(
EuclideanDistance(),
RBFKernel(0.5),
DecisionStump(),
5,
179,
1.5
))
.stream()
.filter { possibleParams -> possibleParams.size == mlw.constructor.parameterCount }
.map { params -> mlw.copy(params = params) }
}
.flatMap { mlw -> mlw }
.filter { mlw ->
try {
mlw.classifier
true
} catch (ex: Exception) {
false
}
}
.toList()
Collections.shuffle(workouts)
println("Number of workouts: ${workouts.size}")
runBlocking {
val jobs: List<Deferred<MLWorkout2>> = workouts.map {
async(CommonPool, start = CoroutineStart.DEFAULT) {
it.errorRate
it
}
}.onEach { job ->
job.invokeOnCompletion {
val workout = job.getCompleted()
println("$workout\t${workout.trainingTime}\t${workout.errorRate}")
}
}
println("Workouts mapped to Jobs and running")
// Wait in order for everyone to be done - either until timeout, or all finished
// TODO: Run the rest with a longer timeout
try {
// 5 hrs
withTimeout(1000 * 60 * 60 * 5) {
jobs.forEach {
it.join()
// val workout = it.await()
}
}
} catch (e: CancellationException) {
println("Hit the timeout.")
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment