Created
June 18, 2009 21:16
-
-
Save r00k/132193 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ClassifyPublications | |
# TODO: test corpus should be confirmed AND rejected items; would require a classifyable object | |
# Results: | |
# | |
# (Testing done with 10-fold cross-validation on 31k citations) | |
# | |
# Training on title and abstract, classifying on title = 90.8% accuracy | |
# Training on title and abstract, classifying on title, with relevant and non grants trained on = 89.5% accuracy | |
# | |
require 'classifier' | |
attr_accessor :classifier | |
# Citations with which we train the classifier | |
FULL_CONFIRMED_CORPUS = Citation.confirmed + Grant.relevant | |
CONFIRMED_TRAINING_CORPUS_SIZE = 30000 | |
CONFIRMED_TRAINING_CORPUS = FULL_CONFIRMED_CORPUS[0...CONFIRMED_TRAINING_CORPUS_SIZE] | |
REJECTED_TRAINING_CORPUS = Citation.rejected + Grant.non_relevant | |
# Confirmed citations that are not trained on, but are used to test classifier accuracy | |
TEST_CORPUS_SIZE = 900 | |
TEST_CORPUS = (FULL_CONFIRMED_CORPUS - CONFIRMED_TRAINING_CORPUS)[0..TEST_CORPUS_SIZE] | |
def initialize | |
@classifier = Classifier::Bayes.new('Confirmed', 'Rejected') | |
end | |
# Train the classifier and then test its accuracy | |
def run | |
accuracy_results = [] | |
run_number = 1 | |
CONFIRMED_TRAINING_CORPUS.iterate_by_tenths do |tenth, rest| | |
print "Run number: #{run_number} " | |
train_confirmed(rest) | |
train_rejected | |
accuracy_results << check_classifier_accuracy(tenth) | |
run_number += 1 | |
end | |
puts "Average accuracy: #{accuracy_results.sum / accuracy_results.size.to_f}%" | |
end | |
def train_confirmed(training_corpus=CONFIRMED_TRAINING_CORPUS) | |
training_corpus.each do |citation| | |
@classifier.train_confirmed citation.title | |
@classifier.train_confirmed citation.abstract if citation.respond_to?(:abstract) && citation.abstract | |
end | |
end | |
def train_rejected(training_corpus=REJECTED_TRAINING_CORPUS) | |
training_corpus.each do |citation| | |
@classifier.train_rejected citation.title | |
@classifier.train_rejected(citation.abstract) if citation.respond_to?(:abstract) && citation.abstract | |
end | |
end | |
def check_classifier_accuracy(test_corpus=TEST_CORPUS) | |
results = [] | |
test_corpus.each do |citation| | |
results << @classifier.classify(citation.title) | |
end | |
output_results(results, test_corpus) | |
end | |
private | |
def output_results(results, test_corpus) | |
number_wrong = results.select { |c| c == 'Rejected' }.size | |
percent_correct = (((test_corpus.size - number_wrong)/(test_corpus.size.to_f))*100).round | |
puts "#{percent_correct}% correct" unless Rails.env.test? | |
percent_correct | |
end | |
# Time how long the run takes | |
def time_run(&block) | |
start = Time.now | |
yield | |
puts "Completed in #{(Time.now - start) / 60} minutes" unless RAILS_ENV == "test" | |
end | |
end | |
class Array | |
# Split an array into 10 pieces, yield 1/10th and the other 9/10ths. Do this 10 times, returning a distinct | |
# (no overlaps) 1/10th slice each time. | |
def iterate_by_tenths(&block) | |
self.each_slice((self.length / 10).ceil) do |tenth| | |
yield(tenth, self - tenth) | |
end | |
end | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require File.dirname(__FILE__) + "/../test_helper" | |
class ClassifyPublicationsTest < Test::Unit::TestCase | |
scenario :base | |
def setup | |
@classifier = ClassifyPublications.new | |
end | |
def test_check_classifier_accuracy | |
@classifier.train_confirmed [Citation.confirmed.first, Grant.relevant.first] | |
@classifier.train_rejected [Citation.rejected.first, Grant.non_relevant.first] | |
test_corpus = [Citation.confirmed.first] | |
assert_equal 100, @classifier.check_classifier_accuracy(test_corpus) | |
test_corpus = [Citation.rejected.first] | |
assert_equal 0, @classifier.check_classifier_accuracy(test_corpus) | |
test_corpus = [Citation.confirmed.first, Citation.rejected.first] | |
assert_equal 50, @classifier.check_classifier_accuracy(test_corpus) | |
end | |
def test_handles_nil_or_not_present_abstracts | |
# Test with nil abstracts | |
assert_nothing_raised do | |
cit = Citation.first | |
cit.abstract = nil | |
cit.save false | |
@classifier.train_confirmed([cit]) | |
@classifier.train_rejected([cit]) | |
end | |
# Test with grants (have no abstract attribute at all) | |
assert_nothing_raised do | |
g = Grant.first | |
@classifier.train_confirmed([g]) | |
@classifier.train_rejected([g]) | |
end | |
end | |
def test_iterate_by_tenths | |
array = (1..10).to_a | |
concatenated_tenths = [] | |
array.iterate_by_tenths do |tenth, rest| | |
assert_equal 1, tenth.size | |
assert_equal 9, rest.size | |
concatenated_tenths += tenth | |
end | |
assert_equal array, concatenated_tenths | |
end | |
def test_iterate_by_tenths_not_even | |
array = (1..25).to_a | |
concatenated_tenths = [] | |
array.iterate_by_tenths do |tenth, rest| | |
# Have to handle special case of last element | |
if tenth == [25] | |
assert_equal 1, tenth.size, tenth | |
assert_equal 24, rest.size, rest | |
else | |
assert_equal 3, tenth.size, tenth | |
assert_equal 22, rest.size, rest | |
end | |
concatenated_tenths += tenth | |
end | |
assert_equal array, concatenated_tenths | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment