Created
August 23, 2017 20:59
-
-
Save salamanders/88dced678deff7cf3dcd7096a2c7dde5 to your computer and use it in GitHub Desktop.
JSAT Workout in Kotlin
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import com.google.common.collect.ImmutableSet | |
import com.google.common.collect.Sets | |
import com.google.common.reflect.ClassPath | |
import jsat.classifiers.ClassificationModelEvaluation | |
import jsat.classifiers.Classifier | |
import jsat.classifiers.OneVSAll | |
import jsat.classifiers.svm.SupportVectorLearner | |
import jsat.classifiers.trees.DecisionStump | |
import jsat.distributions.kernels.RBFKernel | |
import jsat.exceptions.FailedToFitException | |
import jsat.io.LIBSVMLoader | |
import jsat.linear.distancemetrics.EuclideanDistance | |
import kotlinx.coroutines.experimental.* | |
import java.io.Serializable | |
import java.lang.reflect.Constructor | |
import java.lang.reflect.Modifier | |
import java.nio.file.Files | |
import java.nio.file.Path | |
import java.nio.file.Paths | |
import java.util.* | |
import java.util.stream.Stream | |
import kotlin.streams.toList | |
data class MLWorkout2( | |
var constructor: Constructor<Classifier>, | |
var dataPath: Path = Paths.get("."), | |
var params: Set<Serializable> = setOf(), | |
var trainingTime: Long = 0, | |
private var dataSize: Int = 0 | |
) { | |
val classifier: Classifier by lazy { | |
assert(params.size == constructor.parameterCount) | |
constructor.newInstance(*params.toTypedArray()) | |
} | |
val errorRate: Double by lazy { | |
val startTime = System.currentTimeMillis() | |
if (classifier is SupportVectorLearner) { | |
(classifier as SupportVectorLearner).cacheMode = SupportVectorLearner.CacheMode.FULL // Small dataset, so we can do this | |
} | |
try { | |
val ds = LIBSVMLoader.loadC(dataPath.toFile()) | |
dataSize = ds.sampleSize * ds.numFeatures | |
try { | |
val cme = ClassificationModelEvaluation(classifier, ds) | |
cme.evaluateCrossValidation(10) | |
trainingTime = System.currentTimeMillis() - startTime | |
cme.errorRate | |
} catch (ex: FailedToFitException) { | |
val cme = ClassificationModelEvaluation(OneVSAll(classifier), ds) | |
cme.evaluateCrossValidation(10) | |
trainingTime = System.currentTimeMillis() - startTime | |
println("Successful OneVSAll") | |
cme.errorRate | |
} | |
} catch (ex: Throwable) { | |
when (ex) { | |
is NullPointerException, | |
is UnsupportedOperationException, | |
is ArithmeticException -> { | |
System.err.println("Expected: $ex") | |
} | |
else -> { | |
System.err.println("Unexpected: $ex") | |
ex.printStackTrace() | |
} | |
} | |
-1.0 | |
} | |
} | |
override fun toString() = "${classifier.javaClass.simpleName}\t${params.toString().replace(Regex("\\s"), "")}\t${dataPath.fileName}\t$dataSize" | |
} | |
fun main(args: Array<String>) { | |
val dataFiles: List<Path> = Files.walk(Paths.get("./data")).filter { path -> path.toString().endsWith(".libsvm") }.toList() | |
assert(dataFiles.isNotEmpty()) | |
val workouts = ClassPath.from(Thread.currentThread().contextClassLoader) | |
.getTopLevelClassesRecursive("jsat.classifiers") | |
.stream() | |
.map { it.load() } | |
.filter { !Modifier.isAbstract(it.modifiers) } | |
.filter { Classifier::class.java.isAssignableFrom(it) } | |
.map { it.constructors } | |
.flatMap { Arrays.stream(it) as Stream<Constructor<Classifier>> } | |
.map { MLWorkout2(it) } | |
.map { mlw -> dataFiles.stream().map { mlw.copy(dataPath = it) } } | |
.flatMap { it } | |
.map { mlw -> | |
Sets.powerSet(ImmutableSet.of( | |
EuclideanDistance(), | |
RBFKernel(0.5), | |
DecisionStump(), | |
5, | |
179, | |
1.5 | |
)) | |
.stream() | |
.filter { possibleParams -> possibleParams.size == mlw.constructor.parameterCount } | |
.map { params -> mlw.copy(params = params) } | |
} | |
.flatMap { mlw -> mlw } | |
.filter { mlw -> | |
try { | |
mlw.classifier | |
true | |
} catch (ex: Exception) { | |
false | |
} | |
} | |
.toList() | |
Collections.shuffle(workouts) | |
println("Number of workouts: ${workouts.size}") | |
runBlocking { | |
val jobs: List<Deferred<MLWorkout2>> = workouts.map { | |
async(CommonPool, start = CoroutineStart.DEFAULT) { | |
it.errorRate | |
it | |
} | |
}.onEach { job -> | |
job.invokeOnCompletion { | |
val workout = job.getCompleted() | |
println("$workout\t${workout.trainingTime}\t${workout.errorRate}") | |
} | |
} | |
println("Workouts mapped to Jobs and running") | |
// Wait in order for everyone to be done - either until timeout, or all finished | |
// TODO: Run the rest with a longer timeout | |
try { | |
// 5 hrs | |
withTimeout(1000 * 60 * 60 * 5) { | |
jobs.forEach { | |
it.join() | |
// val workout = it.await() | |
} | |
} | |
} catch (e: CancellationException) { | |
println("Hit the timeout.") | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment