Skip to content

Instantly share code, notes, and snippets.

@salamanders
Last active March 31, 2016 05:32
Show Gist options
  • Save salamanders/8e7054f62b53eb772895 to your computer and use it in GitHub Desktop.
Save salamanders/8e7054f62b53eb772895 to your computer and use it in GitHub Desktop.
Ultra-hacky attempt to give EdwardRaff/JSAT a workout by blindly trying every classifier
import java.io.File;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Modifier;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.TimeUnit;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import com.google.common.reflect.ClassPath;
import com.google.common.reflect.ClassPath.ClassInfo;
import com.google.common.util.concurrent.SimpleTimeLimiter;
import com.google.common.util.concurrent.TimeLimiter;
import com.google.common.util.concurrent.UncheckedTimeoutException;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.ClassificationModelEvaluation;
import jsat.classifiers.Classifier;
import jsat.classifiers.bayesian.NaiveBayes;
import jsat.classifiers.boosting.ModestAdaBoost;
import jsat.classifiers.svm.SupportVectorLearner;
import jsat.classifiers.svm.SupportVectorLearner.CacheMode;
import jsat.classifiers.trees.DecisionStump;
import jsat.distributions.kernels.KernelTrick;
import jsat.distributions.kernels.RBFKernel;
import jsat.exceptions.FailedToFitException;
import jsat.io.LIBSVMLoader;
import jsat.parameters.RandomSearch;
/**
* @author Benjamin Hill
*/
public class TryAllClassifiers {
private static final String DATA_FILE = "diabetes.libsvm"; // diabetes, mushrooms;
private static final Logger LOG = Logger.getLogger(TryAllClassifiers.class.getName());
private static final int GUESSED_SMALL_PARAM = 50;
private static final TimeLimiter TIME_LIMITER = new SimpleTimeLimiter();
private static final Class<? extends Classifier> WEAK_CLASSIFIER = DecisionStump.class;
/**
* Try to instantiate via hacky methods.
*
* @param classifierClass
* @return instantiated class
*/
private static Classifier buildClassifierInstance(Class<Classifier> classifierClass) {
// No arg
try {
return classifierClass.newInstance();
} catch (final InstantiationException | IllegalAccessException e) {
// no luck
}
// Single int
try {
return classifierClass.getConstructor(Integer.TYPE).newInstance(GUESSED_SMALL_PARAM);
} catch (final NoSuchMethodException e) {
// ignore
} catch (final InvocationTargetException | InstantiationException | IllegalAccessException e) {
System.err.println(
"Bad constructing:" + classifierClass.getCanonicalName() + " " + e.getClass() + " " + e.getMessage());
}
// Single Classifier - use NB
try {
return classifierClass.getConstructor(Classifier.class).newInstance(WEAK_CLASSIFIER.newInstance());
} catch (final NoSuchMethodException e) {
// ignore
} catch (final InvocationTargetException | InstantiationException | IllegalAccessException e) {
System.err.println("Bad constructing:" + classifierClass.getCanonicalName() + "\t" + e.getMessage());
}
// Classifier + Iterations
try {
return classifierClass.getConstructor(Classifier.class, Integer.TYPE).newInstance(WEAK_CLASSIFIER.newInstance(),
GUESSED_SMALL_PARAM);
} catch (final NoSuchMethodException e) {
// ignore
} catch (final InvocationTargetException | InstantiationException | IllegalAccessException e) {
System.err.println("Bad constructing:" + classifierClass.getCanonicalName() + "\t" + e.getMessage());
}
// KernelTrick
try {
return classifierClass.getConstructor(KernelTrick.class).newInstance(new RBFKernel(0.5));
} catch (final NoSuchMethodException e) {
// ignore
} catch (final InvocationTargetException | InstantiationException | IllegalAccessException e) {
System.err.println("Bad constructing:" + classifierClass.getCanonicalName() + "\t" + e.getMessage());
}
// KernelTrick + param
try {
return classifierClass.getConstructor(KernelTrick.class, Integer.TYPE).newInstance(new RBFKernel(0.5),
GUESSED_SMALL_PARAM);
} catch (final NoSuchMethodException e) {
// ignore
} catch (final InvocationTargetException | InstantiationException | IllegalAccessException e) {
System.err.println("Bad constructing:" + classifierClass.getCanonicalName() + "\t" + e.getMessage());
}
System.err.println("Unable to find a way to construct:" + classifierClass.getCanonicalName());
return null;
}
/**
* Attempts to instantiate as many models as possible
*
* @return
* @throws IOException
*/
@SuppressWarnings("unchecked")
private static Set<Classifier> getModels() throws IOException {
return ClassPath.from(Thread.currentThread().getContextClassLoader())
.getTopLevelClassesRecursive("jsat.classifiers").stream().map(ClassInfo::load) // ClassInfo to Class
.filter(Classifier.class::isAssignableFrom)
.filter(possibleModelClass -> !Modifier.isAbstract(possibleModelClass.getModifiers()))
.map(possibleModelClass -> (Class<Classifier>) possibleModelClass) // cast only
.map(TryAllClassifiers::buildClassifierInstance) // the hard work
.filter(inst -> inst != null) // successes only
.collect(Collectors.toSet());
}
/**
* A few minutes, nothing too long.
*
* @param model
* @return
*/
private static String trainModel(final Classifier model) {
return trainAndTest(model, 10, TimeUnit.MINUTES);
}
/**
* Trains the model, returns bare stats. Even tries autoAddParameters to improve!
*
* @param model
* @return
* @throws IOException
*/
private static String trainAndTest(final Classifier model, final int time, final TimeUnit unit) {
Callable<String> myCallable = () -> {
final long startTime = System.currentTimeMillis();
if (model instanceof SupportVectorLearner) {
((SupportVectorLearner) model).setCacheMode(CacheMode.FULL);// Small dataset, so we can do this
}
final ClassificationDataSet dataset = LIBSVMLoader.loadC(new File(DATA_FILE));
ClassificationModelEvaluation cme = new ClassificationModelEvaluation(model, dataset);
cme.evaluateCrossValidation(10);
final double originalErrorRate = cme.getErrorRate();
// TODO: tunedErrorRate
/*
* double tunedErrorRate = 0; try { final List<ClassificationDataSet> splits = dataset.randomSplit(0.75, 0.25);
* final ClassificationDataSet train = splits.get(0); final ClassificationDataSet test = splits.get(1); final
* RandomSearch search = new RandomSearch(model, 3); // this method adds parameters, and returns the number of
* parameters added if (search.autoAddParameters(train) > 0) { // that way we only do the search if there are any
* parameters to actually tune search.trainC(dataset); Classifier tunedModel = search.getTrainedClassifier(); cme
* = new ClassificationModelEvaluation(tunedModel, train); cme.evaluateTestSet(test); tunedErrorRate =
* cme.getErrorRate(); } } catch (final FailedToFitException ex) { // ignore, it doesn't like tuning. }
*/
final long elapsedTime = System.currentTimeMillis() - startTime;
return String.format("%s\t%d\t%.3f", model.getClass().getName(), elapsedTime, originalErrorRate);
};
try {
return TIME_LIMITER.callWithTimeout(myCallable, time, unit, true);
} catch (final InterruptedException | UncheckedTimeoutException e) {
return String.format("%s\tTIMEOUT", model.getClass().getName());
} catch (final Throwable ex) {
return String.format("%s\tERROR\t%s", model.getClass().getName(), ex.getMessage());
}
}
/**
* @param args
* @throws Exception
*/
public static void main(String[] args) throws Exception {
System.out.println(String.format("%s\t%s\t%s\t%s", "Model", "time", "errorRate", "tunedErrorRate"));
getModels().stream().parallel() // may overlap with the internal parallel training
.map(TryAllClassifiers::trainModel).forEach(System.out::println);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment