This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'ruby-fann' | |
class FannTrainer | |
def perform(data) | |
presenter = FannTrainerPresenter.new(data) | |
train = RubyFann::TrainData.new(inputs: presenter.inputs, desired_outputs: presenter.desired_outputs) | |
fann = RubyFann::Standard.new(num_inputs: FannTrainerPresenter::INPUTS_COUNT, hidden_neurons: [80], num_outputs: FannTrainerPresenter::OUTPUTS_COUNT) | |
# 1000 max_epochs, 10 errors between reports and 0.1 desired MSE (mean-squared-error) | |
fann.train_on_data(train, 3000, 1, 0.001) | |
fann.save('./memoji.net') | |
end | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class FannTrainerPresenter | |
INPUTS_COUNT = 200 | |
OUTPUTS_COUNT = 50 | |
def initialize(raw_data) | |
@raw_data = raw_data.reject { |m| m['reactions'].blank? } | |
@raw_data = raw_data.reject { |m| (m['reactions'] & top_reactions).blank? } | |
end | |
def inputs | |
@raw_data.map do |message| | |
input_from_body(message['body']) | |
end | |
end | |
def desired_outputs | |
@raw_data.map do |message| | |
top_reactions.map do |top_reaction| | |
message['reactions'].include?(top_reaction) ? 1 : 0 | |
end | |
end | |
end | |
def input_from_body(str) | |
words = clean(str) | |
top_words.map do |top_word| | |
words.count { |w| w == top_word } | |
end | |
end | |
def emojis_from_output(output) | |
top_reactions.inject({}) do |hsh, reaction| | |
hsh[reaction] = output[top_reactions.index(reaction)] | |
hsh | |
end.sort_by { |_k, v| v } | |
end | |
private | |
def clean(str) | |
tokens = PragmaticTokenizer::Tokenizer.new(punctuation: :none).tokenize(str.downcase) | |
tokens = filter_en.filter(tokens) | |
tokens = filter_fr.filter(tokens) | |
tokens.map { |t| stemmer.stem(t) }.uniq | |
end | |
def stemmer | |
@stemmer ||= Lingua::Stemmer.new(:language => "en") | |
end | |
def filter_en | |
@filter ||= Stopwords::Snowball::Filter.new "en" | |
end | |
def filter_fr | |
@filter ||= Stopwords::Snowball::Filter.new "fr" | |
end | |
def top_words | |
@top_words ||= begin | |
hash = Digest::SHA1.hexdigest(@raw_data.to_json + "#{INPUTS_COUNT} #{OUTPUTS_COUNT}") | |
filename = Rails.root.join("data", "cache-#{hash}.json") | |
if File.exist?(filename) | |
JSON.parse(File.read(filename)) | |
else | |
counts = @raw_data.map { |d| clean(d['body']) }.flatten.group_by(&:itself).transform_values(&:count) | |
counts.sort_by { |_k, v| v }.last(INPUTS_COUNT).map(&:first).tap do |result| | |
File.write(filename, JSON.dump(result)) | |
end | |
end | |
end | |
end | |
def top_reactions | |
@top_reactions ||= begin | |
counts = @raw_data.map { |d| d['reactions'] }.flatten.group_by(&:itself).transform_values(&:count) | |
counts.sort_by { |_k, v| v }.last(OUTPUTS_COUNT).map(&:first) | |
end | |
end | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
namespace :train do | |
task perform: :environment do | |
data = JSON.parse(File.read("#{Rails.root}/data/slack_reactions.json")) | |
test_data = data.pop(data.length * 0.1) | |
File.write(Rails.root.join("data", "test.json"), JSON.dump(test_data)) | |
FannTrainer.new.perform(data) | |
end | |
task check: :environment do | |
fann = RubyFann::Standard.new(filename: './memoji.net') | |
presenter = FannTrainerPresenter.new(JSON.parse(File.read("#{Rails.root}/data/slack_reactions.json"))) | |
test_data = JSON.parse(File.read("#{Rails.root}/data/test.json")) | |
runs = [] | |
test_data.each do |data| | |
output = fann.run(presenter.input_from_body(data['body'])) | |
emojis = presenter.emojis_from_output(output) | |
success = (data['reactions'] & emojis.reverse[0..5].select { |(_emoji, percent)| percent > 0.5 }.map(&:first)).present? | |
runs << success | |
if success | |
puts "#{data['body']} -> #{(emojis.reverse.map(&:first) & data['reactions']).first}" | |
end | |
end | |
success_count = runs.count { |r| r == true } | |
failure_count = runs.count { |r| r == false } | |
puts "Success: #{success_count}" | |
puts "Fail: #{failure_count}" | |
puts "% : #{success_count.to_f / runs.count * 100}" | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment