Skip to content

Instantly share code, notes, and snippets.

@salamanders
Created August 13, 2016 19:13
Show Gist options
  • Save salamanders/cd42f99b8483e8d0d89f6edfa5b43a10 to your computer and use it in GitHub Desktop.
Save salamanders/cd42f99b8483e8d0d89f6edfa5b43a10 to your computer and use it in GitHub Desktop.
Guava Table<Long,String,String> to JSAT DataSet (Classification or Regression)
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.logging.Logger;
import com.google.common.base.Preconditions;
import com.google.common.base.Predicates;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.Iterables;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import jsat.classifiers.CategoricalData;
/**
* Everything you could ever want to know about a column.
*
* @author benjaminhill@gmail
*
*/
public class ColumnInfo {
protected static final Logger LOG = Logger.getLogger(ColumnInfo.class.getName());
public static final Gson GSON = new GsonBuilder().setPrettyPrinting().create();
/**
* Utility
*
* @param from
* any collection of "sortable" objects
* @return sorted unique List of Strings
*/
private static List<String> collectionToSortedUniqueStringList(final Collection<Object> from) {
final Set<String> uniqueValues = new HashSet<>();
for (final Object o : from) {
if (null != o) {
uniqueValues.add(String.valueOf(o));
}
}
final List<String> uniqueValuesList = new ArrayList<>(uniqueValues);
Collections.sort(uniqueValuesList);
return uniqueValuesList;
}
/**
* Parse a table's column from unknown string-wrapped values into Longs, Doubles, and Strings. Take into account
* priority, and return each column with consistent types.
*
* @param rawTable
* @return Table with consistent columns
*/
private static SortedMap<Long, Object> parseColumn(final Map<Long, String> rawColumn) {
final SortedMap<Long, Object> parsedColumn = new TreeMap<>();
// Convert the entire column. Stinks if the last cell throws it to a higher, but meh, optimize later.
Class<?> currentWorstClass = Long.class;
for (final Entry<Long, String> cell : rawColumn.entrySet()) {
final Object parsedObject = parseToLowestObject(cell.getValue(), currentWorstClass);
if (null == parsedObject) {
continue;
}
parsedColumn.put(cell.getKey(), parsedObject);
if (!currentWorstClass.equals(parsedObject.getClass())) {
LOG.info("During column parse, bumping up lowest class from " + currentWorstClass.getSimpleName() + " to "
+ parsedObject.getClass().getSimpleName() + " because of '" + cell.getValue() + "'");
currentWorstClass = parsedObject.getClass();
}
}
// Now double check the column. Could optimize to only check upwards, but meh.
for (final Entry<Long, String> cell : rawColumn.entrySet()) {
final Object currentValue = parsedColumn.get(cell.getKey());
if (null == currentValue || currentWorstClass.equals(currentValue.getClass())) {
continue;
}
if (Double.class.equals(currentWorstClass)) {
// It can only be a Long
parsedColumn.put(cell.getKey(), ((Long) currentValue).doubleValue());
} else if (String.class.equals(currentWorstClass)) {
parsedColumn.put(cell.getKey(), String.valueOf(currentValue));
} else {
throw new RuntimeException(
"No idea how to turn a " + currentValue.getClass().getName() + " into a " + currentWorstClass.getName());
}
}
return parsedColumn;
}
/**
* Utility function to parse a string to the most basic possible type - Long, Double, or String. May return null.
*
* @param value
* @return
*/
private static Object parseToLowestObject(final String value, final Class<?> currentClass) {
if (null == value || value.isEmpty()) {
return null;
}
if (String.class.equals(currentClass)) {
return value;
}
if (Long.class.equals(currentClass)) {
try {
return Long.valueOf(value);
} catch (final NumberFormatException nfe) {
// ignore instead of nest
}
}
try {
return Double.valueOf(value);
} catch (final NumberFormatException nfe) {
// ignore instead of nest
}
return value;
}
private final CategoricalData categoricalData;
private final BiMap<String, Integer> lookup = HashBiMap.create();
private final String name;
private final SortedMap<Long, Object> parsedData;
private final Class<?> type;
/**
* Column will be parsed and saved as lowest common type (Long/Double/String)
*
* @param columnData
*/
public ColumnInfo(final String columnName, final Map<Long, String> columnData) {
this.name = columnName;
parsedData = parseColumn(columnData);
type = Iterables.find(parsedData.values(), Predicates.notNull()).getClass();
constructLabelLookups();
categoricalData = isLookup() ? constructJSATCategoricalData() : null;
}
/**
* Once we have created a per-column label-to-int mapping, create a mapping from column names to the CategoricalData
*
* @param columnAndValueToInt
* @return
*/
private CategoricalData constructJSATCategoricalData() {
Preconditions.checkState(!lookup.isEmpty(), "Tried to get CategoricalData for non-lookup column");
final CategoricalData cd = new CategoricalData(parsedData.size());
cd.setCategoryName(name);
for (final Entry<String, Integer> cell : lookup.entrySet()) {
cd.setOptionName(cell.getKey(), cell.getValue());
}
return cd;
}
/**
* Map from column->Value->Integer so we can keep all the lookups for the CategoricalData.
*
* Has a hidden threshold of >=20 unique Long values means "not a category"
*/
private void constructLabelLookups() {
if (String.class.equals(type) || Long.class.equals(type)) {
final List<String> uniqueValuesList = collectionToSortedUniqueStringList(parsedData.values());
// Bail if too many long values
if (Long.class.equals(type) && uniqueValuesList.size() > 20) {
return;
}
for (int i = 0; i < uniqueValuesList.size(); i++) {
lookup.put(uniqueValuesList.get(i), i);
}
}
}
public CategoricalData getCategoricalData() {
Preconditions.checkNotNull(categoricalData, "Tried to get CategoricalData for non-lookup column");
return categoricalData;
}
public String getName() {
return name;
}
public Class<?> getType() {
return type;
}
public boolean isLookup() {
return !lookup.isEmpty();
}
/**
* Takes into account lookup tables for strings
*
* @param rowId
* @return
*/
public Number getRowValue(final Number rowId) {
final Object value = parsedData.get(rowId);
if (!isLookup()) {
return (Number) value;
}
final Integer intValue = lookup.get(String.valueOf(value));
Preconditions.checkNotNull(intValue);
return intValue;
}
public String getKeyFromLookupId(final int id) {
return lookup.inverse().get(id);
}
public Set<Long> getAllRowKeys() {
return parsedData.keySet();
}
@Override
public int hashCode() {
return name.hashCode();
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null) {
return false;
}
if (!(obj instanceof ColumnInfo)) {
return false;
}
ColumnInfo other = (ColumnInfo) obj;
if (name == null) {
if (other.name != null) {
return false;
}
} else if (!name.equals(other.name)) {
return false;
}
return true;
}
@Override
public String toString() {
final Map<String, Object> tmp = new TreeMap<>();
tmp.put("name", name);
tmp.put("type", type.getSimpleName().substring(0, 1));
tmp.put("lookup", lookup);
tmp.put("sample", parsedData.subMap(0L, 25L));
return GSON.toJson(tmp);
}
}
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.SortedMap;
import java.util.SortedSet;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.logging.Logger;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Table;
import com.google.common.primitives.Ints;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import jsat.DataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.ClassificationDataSet;
import jsat.linear.DenseVector;
import jsat.regression.RegressionDataSet;
import jsat.utils.DoubleList;
/**
* Convert a Guava Table into a JSAT-friendly data source.
*
* Cribbed mainly from JSAT/JSAT/src/jsat/ARFFLoader.java and JSAT/src/jsat/io/CSV.java There has to be a better way to
* do this.
*
* @author benjaminhill@gmail
*
*/
public class TableDataLoader {
protected static final Logger LOG = Logger.getLogger(TableDataLoader.class.getName());
public static final Gson GSON = new GsonBuilder().create();
private final SortedMap<String, ColumnInfo> columns = new TreeMap<>();
public TableDataLoader(final Table<Long, String, String> rawTable) {
Preconditions.checkNotNull(rawTable);
for (final Entry<String, Map<Long, String>> column : rawTable.columnMap().entrySet()) {
final ColumnInfo ci = new ColumnInfo(column.getKey(), column.getValue());
columns.put(ci.getName(), ci);
LOG.info(ci.toString());
}
}
private void columnsToFeats(final ColumnInfo targetCi, final Number rowId, final DoubleList numericFeats,
final List<Integer> catFeats) {
for (final ColumnInfo ci : columns.values()) {
if (targetCi != ci) {
if (ci.isLookup()) {
catFeats.add(((Number) ci.getRowValue(rowId)).intValue());
} else {
numericFeats.add(((Number) ci.getRowValue(rowId)).doubleValue());
}
}
}
}
/**
*
* @param rawTable
* a Table of String-encoded data (with both nulls or empty strings allowed)
* @param outputColumnName
* which column you are trying to predict, may be Integer/String for Classification, Doubles for Regression
* @return
*/
public DataSet<?> getDataSet(final String outputColumnName) {
Preconditions.checkNotNull(outputColumnName);
Preconditions.checkState(columns.containsKey(outputColumnName), "output column doesn't exist.");
final ColumnInfo targetCi = columns.get(outputColumnName);
final List<CategoricalData> inputCDs = new ArrayList<>();
final SortedSet<Number> allRowKeys = new TreeSet<>();
for (final ColumnInfo ci : columns.values()) {
allRowKeys.addAll(ci.getAllRowKeys());
if (ci.isLookup() && targetCi != ci) {
inputCDs.add(ci.getCategoricalData());
}
}
// Count the numbers after removing the category
final int categoricalColumnCount = inputCDs.size();
final int numericalColumnCount = columns.size() - categoricalColumnCount - 1;
Preconditions.checkState(categoricalColumnCount >= 0);
Preconditions.checkState(numericalColumnCount >= 0);
LOG.info("Forking on output type (name:" + targetCi.getName() + ", isLookup:" + targetCi.isLookup() + ", cat:"
+ categoricalColumnCount + ", num:" + numericalColumnCount + ")");
if (targetCi.isLookup()) {
LOG.info("Classification");
return tableToDataSet_Classification(targetCi, inputCDs, allRowKeys, categoricalColumnCount,
numericalColumnCount);
}
LOG.info("Regression");
return tableToDataSet_Regression(targetCi, inputCDs, allRowKeys, categoricalColumnCount, numericalColumnCount);
}
/**
* All done with the prep, time to build the actual data set (Classification)
*
* @param parsedTable
* @param outputColumnName
* @param columnTypes
* @param columnAndValueToInt
* @param categoricalDatas
* @return
*/
private ClassificationDataSet tableToDataSet_Classification(final ColumnInfo targetCi,
final List<CategoricalData> inputCDs, final SortedSet<Number> allRowKeys, final int categoricalColumnCount,
final int numericalColumnCount) {
final CategoricalData targetCD = targetCi.getCategoricalData();
final ClassificationDataSet cds = new ClassificationDataSet(numericalColumnCount,
inputCDs.toArray(new CategoricalData[inputCDs.size()]), targetCD);
// Add all data points
for (final Number rowId : allRowKeys) {
final DoubleList numericFeats = new DoubleList(numericalColumnCount);
final List<Integer> catFeats = new ArrayList<>(categoricalColumnCount);
final int outputClass = ((Number) targetCi.getRowValue(rowId)).intValue();
columnsToFeats(targetCi, rowId, numericFeats, catFeats);
cds.addDataPoint(new DenseVector(numericFeats), Ints.toArray(catFeats), outputClass);
}
return cds;
}
/**
* All done with the prep, time to build the actual data set (Regression)
*
* @param parsedTable
* @param outputColumnName
* @param columnTypes
* @param columnAndValueToInt
* @param categoricalDatas
* @return
*/
private RegressionDataSet tableToDataSet_Regression(final ColumnInfo targetCi, final List<CategoricalData> inputCDs,
final SortedSet<Number> allRowKeys, final int categoricalColumnCount, final int numericalColumnCount) {
final RegressionDataSet rds = new RegressionDataSet(numericalColumnCount,
inputCDs.toArray(new CategoricalData[inputCDs.size()]));
// Add all data points
for (final Number rowId : allRowKeys) {
final DoubleList numericFeats = new DoubleList(numericalColumnCount);
final List<Integer> catFeats = new ArrayList<>(categoricalColumnCount);
final double outputValue = ((Number) targetCi.getRowValue(rowId)).doubleValue();
columnsToFeats(targetCi, rowId, numericFeats, catFeats);
try {
rds.addDataPoint(new DenseVector(numericFeats), Ints.toArray(catFeats), outputValue);
} catch (final RuntimeException re) {
LOG.severe("NO go:" + GSON.toJson(ImmutableMap.of("targetCi", targetCi.getName(), "rowId", rowId,
"numericFeats", numericFeats, "catFeats", catFeats)));
throw re;
}
}
return rds;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment