Skip to content

Instantly share code, notes, and snippets.

@arcnavier
Last active March 4, 2019 18:07
Show Gist options
  • Save arcnavier/cc780b9d6e72803c318c9a1d1b0a0ee9 to your computer and use it in GitHub Desktop.
Save arcnavier/cc780b9d6e72803c318c9a1d1b0a0ee9 to your computer and use it in GitHub Desktop.
import ds.Document;
import ds.Query;
import scorer.AScorer;
import scorer.BM25Scorer;
import scorer.BaselineScorer;
import utils.IndexUtils;
import utils.LoadHandler;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class Test {
public static void main(String[] args) {
Test t = new Test();
// t.testAScorer();
t.testBM25Scorer();
}
private static final String INDEX_DIR = "index";
private IndexUtils utils;
private Map<Query, Map<String, Document>> queryDict = null;
private String sigPath = "data/signal.train";
// Test data
private Query q, q2;
private Document d;
private List<Integer> list, list2;
private Test() {
utils = new IndexUtils(INDEX_DIR);
try {
queryDict = LoadHandler.loadTrainData(sigPath);
} catch (Exception e) {
System.out.println("Error while reading: " + sigPath);
e.printStackTrace();
}
q = new Query("Hello Indian Mifan Mifan");
q2 = new Query("stanford robert crown school");
// getDocFreq test case for title tf
d = new Document("http://liblog.law.stanford.edu/computer-labs-in-the-law-library-faq/");
d.title = "stanford stanford robert stanford crown library";
// getDocFreq test case for body tf
list = new ArrayList<>();
list2 = new ArrayList<>();
list.add(1);
list.add(2);
list.add(3);
list.add(4);
d.body_hits = new HashMap<>();
d.body_hits.put("stanford",list);
d.body_hits.put("robert",list);
list2.add(1);
d.body_hits.put("crown",list2);
d.body_length = 5;
}
private void printTermFrequency(Map<String, Double> termFreq, String title) {
System.out.println();
System.out.println(title + " term frequency table");
for (Map.Entry<String, Double> term: termFreq.entrySet()){
System.out.printf("%s %.4f\n", term.getKey(), term.getValue());
}
}
public void testAScorer() {
AScorer scorer = new BaselineScorer();
System.out.println("====== Test AScorer ======");
Map<String, Map<String, Double>> docTermFreqs = scorer.getDocTermFreqs(d, q2);
Map<String, Double> queryTermFreqs = scorer.getQueryFreqs(q);
// DEBUG: print term - raw term frequency table
printTermFrequency(docTermFreqs.get("title"), "Title");
printTermFrequency(docTermFreqs.get("body"), "Body");
printTermFrequency(queryTermFreqs, "Query");
}
public void testBM25Scorer() {
System.out.println("====== Test BM25Scorer ======");
BM25Scorer scorer = new BM25Scorer(utils, queryDict);
Map<String, Map<String, Double>> docTermFreqs = scorer.getDocTermFreqs(d, q2);
Map<String, Double> queryTermFreqs = scorer.getQueryFreqs(q);
System.out.println("Title average length: " + scorer.getAvgLengths().get("title"));
System.out.println("Body average length: " + scorer.getAvgLengths().get("body"));
printTermFrequency(docTermFreqs.get("title"), "BM25 Title");
printTermFrequency(docTermFreqs.get("body"), "BM25 Body");
scorer.normalizeTFs(docTermFreqs, d, q);
printTermFrequency(docTermFreqs.get("title"), "Normalized BM25 Title");
printTermFrequency(docTermFreqs.get("body"), "Normalized BM25 Body");
}
}
import ds.Document;
import ds.Query;
import scorer.AScorer;
import scorer.BM25Scorer;
import scorer.BaselineScorer;
import utils.IndexUtils;
import utils.LoadHandler;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class Test {
public static void main(String[] args) {
Test t = new Test();
// t.testAScorer();
t.testBM25Scorer();
}
private static final String INDEX_DIR = "index";
private IndexUtils utils;
private Map<Query, Map<String, Document>> queryDict = null;
private String sigPath = "data/signal.train";
// Test data
private Query q, q2;
private Document d;
private List<Integer> list, list2;
private Test() {
utils = new IndexUtils(INDEX_DIR);
try {
queryDict = LoadHandler.loadTrainData(sigPath);
} catch (Exception e) {
System.out.println("Error while reading: " + sigPath);
e.printStackTrace();
}
q = new Query("Hello Indian Mifan Mifan");
q2 = new Query("stanford robert crown school");
// getDocFreq test case for title tf
d = new Document("http://liblog.law.stanford.edu/computer-labs-in-the-law-library-faq/");
d.title = "stanford stanford robert stanford crown library";
// getDocFreq test case for body tf
list = new ArrayList<>();
list2 = new ArrayList<>();
list.add(1);
list.add(2);
list.add(3);
list.add(4);
d.body_hits = new HashMap<>();
d.body_hits.put("stanford",list);
d.body_hits.put("robert",list);
list2.add(1);
d.body_hits.put("crown",list2);
d.body_length = 5;
}
private void printTermFrequency(Map<String, Double> termFreq, String title) {
System.out.println();
System.out.println(title + " term frequency table");
for (Map.Entry<String, Double> term: termFreq.entrySet()){
System.out.printf("%s %.4f\n", term.getKey(), term.getValue());
}
}
public void testAScorer() {
AScorer scorer = new BaselineScorer();
System.out.println("====== Test AScorer ======");
Map<String, Map<String, Double>> docTermFreqs = scorer.getDocTermFreqs(d, q2);
Map<String, Double> queryTermFreqs = scorer.getQueryFreqs(q);
// DEBUG: print term - raw term frequency table
printTermFrequency(docTermFreqs.get("title"), "Title");
printTermFrequency(docTermFreqs.get("body"), "Body");
printTermFrequency(queryTermFreqs, "Query");
}
public void testBM25Scorer() {
System.out.println("====== Test BM25Scorer ======");
BM25Scorer scorer = new BM25Scorer(utils, queryDict);
Map<String, Map<String, Double>> docTermFreqs = scorer.getDocTermFreqs(d, q2);
Map<String, Double> queryTermFreqs = scorer.getQueryFreqs(q);
System.out.println("Title average length: " + scorer.getAvgLengths().get("title"));
System.out.println("Body average length: " + scorer.getAvgLengths().get("body"));
printTermFrequency(docTermFreqs.get("title"), "BM25 Title");
printTermFrequency(docTermFreqs.get("body"), "BM25 Body");
scorer.normalizeTFs(docTermFreqs, d, q);
printTermFrequency(docTermFreqs.get("title"), "Normalized BM25 Title");
printTermFrequency(docTermFreqs.get("body"), "Normalized BM25 Body");
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment