Last active
July 4, 2021 18:33
-
-
Save smola/18d3477cfe66063a933c7b26f9feb1a6 to your computer and use it in GitHub Desktop.
Cross validation for github/linguist
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'parallel' | |
require 'linguist' | |
include Linguist | |
all = false | |
if ARGV[0] == '--all' | |
all = true | |
ARGV.shift | |
end | |
$skip_extensions = Set.new() | |
if not all | |
# Skip extensions with catch-all rule | |
Heuristics.all.each do |h| | |
rules = h.instance_variable_get(:@rules) | |
if rules[-1]['pattern'].is_a? AlwaysMatch | |
$skip_extensions |= Set.new(h.extensions) | |
end | |
end | |
end | |
$samples = [] | |
Samples.each do |sample| | |
sample[:data] = File.read(sample[:path]) | |
$samples << sample | |
end | |
def eval(sample) | |
if $skip_extensions.include? sample[:extname] | |
return [] | |
end | |
languages = Language.find_by_filename(sample[:path]).map(&:name) | |
if languages.length == 1 | |
return [] | |
end | |
languages = Language.find_by_extension(sample[:path]).map(&:name) | |
if languages.length <= 1 | |
return [] | |
end | |
# Test only languages with at least 2 samples | |
n_samples = 0 | |
$samples.each do |other_sample| | |
if other_sample[:language] == sample[:language] | |
n_samples += 1 | |
end | |
end | |
if n_samples <= 1 | |
#puts "Needs more samples: #{sample[:language]}, #{sample[:extname]}" | |
return [] | |
end | |
train_samples = [] | |
$samples.each do |train_sample| | |
next if sample == train_sample | |
next if not languages.include? train_sample[:language] | |
train_samples.push(train_sample) | |
end | |
languages = Set.new(train_samples.map { |s| s[:language] }).to_a | |
if languages.length <= 1 | |
return [] | |
end | |
db = {} | |
train_samples.each do |train_sample| | |
data = train_sample[:data] | |
Classifier.train! db, train_sample[:language], data | |
end | |
if Classifier.respond_to? :finalize_train! | |
Classifier.finalize_train! db | |
end | |
results = Classifier.classify(db, sample[:data], languages) | |
if sample[:language] == results.first[0] | |
["#{sample[:path]} GOOD"] | |
else | |
["#{sample[:path]} BAD(#{results.first[0]})"] | |
end | |
end | |
results = Parallel.flat_map($samples) do |sample| | |
eval(sample) | |
end | |
results.each do |res| | |
puts res | |
end | |
next if languages.length <= 1 | |
# Test only languages with at least 2 samples | |
n_samples = 0 | |
Samples.each do |other_sample| | |
if other_sample[:language] == sample[:language] | |
n_samples += 1 | |
end | |
end | |
next if n_samples <= 1 | |
train_samples = [] | |
Samples.each do |train_sample| | |
next if sample == train_sample | |
next if not languages.include? train_sample[:language] | |
train_samples.push(train_sample) | |
end | |
languages = Set.new(train_samples.map { |s| s[:language] }).to_a | |
next if languages.length <= 1 | |
db = {} | |
train_samples.each do |train_sample| | |
data = File.read(train_sample[:path]) | |
Classifier.train! db, train_sample[:language], data | |
end | |
results = Classifier.classify(db, File.read(sample[:path]), languages) | |
if sample[:language] == results.first[0] | |
puts "#{sample[:path]} GOOD" | |
else | |
puts "#{sample[:path]} BAD(#{results.first[0]})" | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment