Last active
June 26, 2016 09:40
-
-
Save gomao9/cd0889436e6ed7ea2589a90b7f5b9007 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'jubatus/classifier/client' | |
require 'open-uri' | |
require 'json' | |
require 'kconv' | |
require 'rubimas' | |
URL = 'path/to/training_data.json' | |
HOST = "127.0.0.1" | |
PORT = 9199 | |
NAME = "" | |
SAMPLE_TEXT = %w( | |
う~ん…この衣装、小さくて頭が入らないですね~。…あ、私のじゃないんですか~。 | |
プロデューサーさん、もっと笑顔でど~んって待ち構えましょう~。焦ってもお仕事はこないです~。 | |
今日はレッスンよりみんなでお散歩しましょ~。 | |
あ…っ! 突然ですけど、私、お天気おねえさんのお仕事がやりたいです~。 | |
えぇ~? 今の話…アイドルのことを話してたんですか~? | |
なんとーもうこんな時間ですか? びっくりですね~。 | |
今日はオーディションですか~? もちろん覚えてましたよ~。 | |
早口言葉も得意ですよ~。ゆっくり丁寧に喋れば、絶対にカミませんから~。 | |
今日は一日、発声練習に集中しますね~。ノドが渇くから、たくさんお水を用意しないと…! | |
私もウデには覚えがあるので、社長が相手でも負けませんよ? なんの話って…将棋の話ですよ~。 | |
プロデューサーさんにサンドイッチを作ってきましたよ~美味しいですか? | |
プロデューサーさん、私、トークの役に立つ小話をたくさん集めたので、今日はトークの練習しましょう~。 | |
私がTVに映る時は、隣で一緒にみてくださいね~。 | |
今日はちょっとヒマですね~。プロデューサーさんも、こっちでゆっくりお茶でも飲みませんか? | |
もう~、プロデューサーさんのお話はちゃ~んと聞いてますよ? 昨日の時代劇の話で…違いましたか~。 | |
美希ちゃんのお昼寝を見ていると、私もなんだか…プロデューサーさん肩かしてください…Zzz。 | |
) | |
def text_normalize(text) | |
Kconv.kconv(text, 'utf-8').gsub(/[\r\n]/,"").gsub('-', 'ー').gsub(/^『/, "") | |
end | |
def document_tag_pairs | |
hashes = JSON.parse(open(URL).read).flatten | |
pairs = hashes.map{|hash| [text_normalize(hash['text']), hash['idol_no']] } | |
pairs.uniq{|text, _| text } | |
end | |
def train(client) | |
train_data = document_tag_pairs.map do |document, tag| | |
[tag, Jubatus::Common::Datum.new(text: document)] | |
end | |
client.train(train_data) | |
end | |
def test(client) | |
test_data = SAMPLE_TEXTS.map do |text| | |
text = text_normalize(text) | |
Jubatus::Common::Datum.new(text: text) | |
end | |
results = client.classify(test_data) | |
results.zip(test_data).each do |result, input| | |
p input.string_values.to_h['text'] | |
puts format_result(result) | |
puts | |
end | |
end | |
def format_result(result) | |
max = result.max_by(3){|r| r.score } | |
max.map do |r| | |
name = 765.pro.find_by_id(r.label.to_i).name.shorten | |
space = 8 - name.length * 2 | |
"%s:%4.2f" % [name + (' ' * space), r.score] | |
end | |
end | |
def client | |
@client ||= Jubatus::Classifier::Client::Classifier.new(HOST, PORT, NAME) | |
end | |
case ARGV.first | |
when 'train' | |
train(client) | |
when 'test' | |
test(client) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment