Skip to content

Instantly share code, notes, and snippets.

@Intrepidd
Created March 20, 2020 13:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Intrepidd/9d6be0882d73e13dfb83240fd6ba0190 to your computer and use it in GitHub Desktop.
Save Intrepidd/9d6be0882d73e13dfb83240fd6ba0190 to your computer and use it in GitHub Desktop.
require 'ruby-fann'
class FannTrainer
def perform(data)
presenter = FannTrainerPresenter.new(data)
train = RubyFann::TrainData.new(inputs: presenter.inputs, desired_outputs: presenter.desired_outputs)
fann = RubyFann::Standard.new(num_inputs: FannTrainerPresenter::INPUTS_COUNT, hidden_neurons: [80], num_outputs: FannTrainerPresenter::OUTPUTS_COUNT)
# 1000 max_epochs, 10 errors between reports and 0.1 desired MSE (mean-squared-error)
fann.train_on_data(train, 3000, 1, 0.001)
fann.save('./memoji.net')
end
end
class FannTrainerPresenter
INPUTS_COUNT = 200
OUTPUTS_COUNT = 50
def initialize(raw_data)
@raw_data = raw_data.reject { |m| m['reactions'].blank? }
@raw_data = raw_data.reject { |m| (m['reactions'] & top_reactions).blank? }
end
def inputs
@raw_data.map do |message|
input_from_body(message['body'])
end
end
def desired_outputs
@raw_data.map do |message|
top_reactions.map do |top_reaction|
message['reactions'].include?(top_reaction) ? 1 : 0
end
end
end
def input_from_body(str)
words = clean(str)
top_words.map do |top_word|
words.count { |w| w == top_word }
end
end
def emojis_from_output(output)
top_reactions.inject({}) do |hsh, reaction|
hsh[reaction] = output[top_reactions.index(reaction)]
hsh
end.sort_by { |_k, v| v }
end
private
def clean(str)
tokens = PragmaticTokenizer::Tokenizer.new(punctuation: :none).tokenize(str.downcase)
tokens = filter_en.filter(tokens)
tokens = filter_fr.filter(tokens)
tokens.map { |t| stemmer.stem(t) }.uniq
end
def stemmer
@stemmer ||= Lingua::Stemmer.new(:language => "en")
end
def filter_en
@filter ||= Stopwords::Snowball::Filter.new "en"
end
def filter_fr
@filter ||= Stopwords::Snowball::Filter.new "fr"
end
def top_words
@top_words ||= begin
hash = Digest::SHA1.hexdigest(@raw_data.to_json + "#{INPUTS_COUNT} #{OUTPUTS_COUNT}")
filename = Rails.root.join("data", "cache-#{hash}.json")
if File.exist?(filename)
JSON.parse(File.read(filename))
else
counts = @raw_data.map { |d| clean(d['body']) }.flatten.group_by(&:itself).transform_values(&:count)
counts.sort_by { |_k, v| v }.last(INPUTS_COUNT).map(&:first).tap do |result|
File.write(filename, JSON.dump(result))
end
end
end
end
def top_reactions
@top_reactions ||= begin
counts = @raw_data.map { |d| d['reactions'] }.flatten.group_by(&:itself).transform_values(&:count)
counts.sort_by { |_k, v| v }.last(OUTPUTS_COUNT).map(&:first)
end
end
end
namespace :train do
task perform: :environment do
data = JSON.parse(File.read("#{Rails.root}/data/slack_reactions.json"))
test_data = data.pop(data.length * 0.1)
File.write(Rails.root.join("data", "test.json"), JSON.dump(test_data))
FannTrainer.new.perform(data)
end
task check: :environment do
fann = RubyFann::Standard.new(filename: './memoji.net')
presenter = FannTrainerPresenter.new(JSON.parse(File.read("#{Rails.root}/data/slack_reactions.json")))
test_data = JSON.parse(File.read("#{Rails.root}/data/test.json"))
runs = []
test_data.each do |data|
output = fann.run(presenter.input_from_body(data['body']))
emojis = presenter.emojis_from_output(output)
success = (data['reactions'] & emojis.reverse[0..5].select { |(_emoji, percent)| percent > 0.5 }.map(&:first)).present?
runs << success
if success
puts "#{data['body']} -> #{(emojis.reverse.map(&:first) & data['reactions']).first}"
end
end
success_count = runs.count { |r| r == true }
failure_count = runs.count { |r| r == false }
puts "Success: #{success_count}"
puts "Fail: #{failure_count}"
puts "% : #{success_count.to_f / runs.count * 100}"
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment