Skip to content

Instantly share code, notes, and snippets.

@mayhewsw
Created October 5, 2015 20:08
Show Gist options
  • Save mayhewsw/d7298d578540819264a2 to your computer and use it in GitHub Desktop.
Save mayhewsw/d7298d578540819264a2 to your computer and use it in GitHub Desktop.
package edu.illinois.cs.cogcomp.tutorial;
import edu.illinois.cs.cogcomp.lbjava.learn.Lexicon;
import edu.illinois.cs.cogcomp.lbjava.learn.SparseNetworkLearner;
import java.io.ByteArrayOutputStream;
import java.io.PrintStream;
import java.util.*;
import java.util.stream.Stream;
public class FeatureWeightCalculator {
public static <K, V extends Comparable<? super V>> Map<K, V> sortByValue( Map<K, V> map )
{
Map<K,V> result = new LinkedHashMap<>();
Stream<Map.Entry<K,V>> st = map.entrySet().stream();
st.sorted(Comparator.comparing(e -> e.getValue()))
.forEach(e -> result.put(e.getKey(), e.getValue()));
return result;
}
public static Map<String, Map<String, Double>> getFeatureWeights(SparseNetworkLearner c) {
ByteArrayOutputStream sout = new ByteArrayOutputStream();
c.write(new PrintStream(sout));
String s = sout.toString();
String[] lines = s.split("\n");
Lexicon lexicon = c.getLexicon();
// this maps a label to a hashmap of features.
Map<String, Map<String, Double>> featslist = new HashMap<>();
int i = 0;
while(true) {
if(lines[i].startsWith("label")){
String label = lines[i].split(":")[1].trim();
// skip two lines, and then read.
i += 3;
Map<String, Double> feats = new HashMap<>();
for(int k = i; k < i + lexicon.size(); k++) {
String line = lines[k];
String featid = lexicon.lookupKey(k - i).toStringNoPackage(); // .getStringIdentifier();
feats.put(featid, Double.parseDouble(line));
}
featslist.put(label, feats);
i += lexicon.size();
}
i++;
if(i >= lines.length){
break;
}
}
return featslist;
}
/**
* @param args
*/
public static void main(String[] args) {
CsvSentimentClassifier bc = new CsvSentimentClassifier();
Map<String, Map<String, Double>> maps = getFeatureWeights(bc);
for(String k : maps.keySet()){
Map m = maps.get(k);
maps.put(k, sortByValue(m));
}
Map target = maps.get("Neutral");
Object[] sarray = target.keySet().toArray();
int topk = 10;
for(int i = 0; i < topk; i++){
System.out.println(sarray[i]);
System.out.println(target.get(sarray[i]));
}
System.out.println();
for(int i = sarray.length-1; i > sarray.length - topk; i--){
System.out.println(sarray[i]);
System.out.println(target.get(sarray[i]));
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment