Created
August 7, 2018 05:56
-
-
Save ytbilly3636/ee364dbb9ada0768d603576d922aafb1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding:utf-8 -*- | |
''' | |
Requirements | |
* chainer | |
* matplotlib | |
* MeCab | |
* OpenJTalk | |
* requests_oauthlib | |
* Twitter API | |
''' | |
import sys | |
import six | |
import numpy as np | |
import chainer | |
from chainer import links as L | |
from chainer import functions as F | |
from chainer import Chain | |
from chainer import Variable | |
from chainer import serializers | |
import matplotlib.pyplot as plt | |
import MeCab | |
import json | |
from requests_oauthlib import OAuth1Session | |
import subprocess | |
import warnings | |
warnings.filterwarnings("ignore") | |
# open jtalk | |
def jtalk(t): | |
# TODO: meiのhtsvoiceを別途取得 | |
open_jtalk=['open_jtalk'] | |
mech=['-x','/var/lib/mecab/dic/open-jtalk/naist-jdic'] | |
htsvoice=['-m','/usr/share/hts-voice/mei/mei_normal.htsvoice'] | |
speed=['-r','1.0'] | |
outwav=['-ow','open_jtalk.wav'] | |
cmd=open_jtalk+mech+htsvoice+speed+outwav | |
c = subprocess.Popen(cmd,stdin=subprocess.PIPE) | |
c.stdin.write(t) | |
c.stdin.close() | |
c.wait() | |
aplay = ['aplay','-q','open_jtalk.wav'] | |
wr = subprocess.Popen(aplay) | |
# LSTM | |
class LSTMNetwork(Chain): | |
def __init__(self, num_id, num_hid=100): | |
super(LSTMNetwork, self).__init__( | |
em1 = L.EmbedID(num_id, num_hid), | |
ls1 = L.LSTM(num_hid, num_hid), | |
ln1 = L.Linear(num_hid, num_id) | |
) | |
def __call__(self, x): | |
h_em1 = F.relu(self.em1(x)) | |
h_ls1 = F.relu(self.ls1(h_em1)) | |
h_ln1 = self.ln1(h_ls1) | |
return h_ln1 | |
def reset(self): | |
self.ls1.reset_state() | |
return | |
# Dataset (Twitter API + MeCab) | |
class Tweet2Words(object): | |
def __init__(self): | |
# TODO: CK, CS, AT, ASを取得 | |
CK = '' | |
CS = '' | |
AT = '' | |
AS = '' | |
self.session = OAuth1Session(CK, CS, AT, AS) | |
self.mecab = MeCab.Tagger('-Ochasen') | |
# 指定したアカウントの指定した数のツイートを取得 | |
def __fetchTweets(self, count, screen_name='620haruka_y'): | |
print('Fetching {} tweets ...'.format(count)) | |
url = 'https://api.twitter.com/1.1/statuses/user_timeline.json' | |
params={ | |
'screen_name':screen_name, | |
'count':count, | |
'exclude_replies':False, | |
'include_rts':True | |
} | |
res = self.session.get(url, params=params) | |
if res.status_code != 200: | |
sys.exit('ERROR: Could not fetch tweets') | |
tweets = json.loads(res.text) | |
print('{} tweets were fetched actually'.format(len(tweets))) | |
print('Done') | |
return tweets | |
# 複数の文章に含まれる単語の辞書を作成する | |
def __createDict(self, sentences): | |
all_sentences = ' '.join(sentences) | |
node = self.mecab.parseToNode(all_sentences) | |
self.dict = [] | |
while node: | |
# word:単語 | |
word = node.surface | |
# wordが既に辞書に登録されていなければ追加 | |
if not word in self.dict: | |
self.dict.append(word) | |
node = node.next | |
# 例外対策 | |
self.dict.append(None) | |
# 辞書の保存 | |
np.save('dict_' + str(len(self.dict)) + '.npy', self.dict) | |
# 上記2つの関数を使って単語の時系列データを作成 | |
def __createDataset(self, size): | |
tweets = self.__fetchTweets(count=size) | |
tweets = [tweet['text'].encode('utf-8') for tweet in tweets] | |
self.__createDict(tweets) | |
print('Creating dataset...') | |
self.sentences = [] | |
for tweet in tweets: | |
sentence = [] | |
node = self.mecab.parseToNode(tweet) | |
while node: | |
word = node.surface | |
if word in self.dict: | |
sentence.append(self.dict.index(word)) | |
else: | |
sentence.append(len(self.dict) - 1) | |
node = node.next | |
self.sentences.append(sentence) | |
print('Done') | |
# データセットを返す | |
def getDataset(self, size): | |
self.__createDataset(size=size) | |
return self.sentences | |
# 辞書のサイズを返す | |
def lenDict(self): | |
return len(self.dict) | |
# 辞書のインデックスから単語を返す | |
def index2word(self, index): | |
return self.dict[index] | |
# Trainer | |
class Trainer: | |
def __init__(self, size): | |
self.t2w = Tweet2Words() | |
self.data = self.t2w.getDataset(size=size) | |
self.net = LSTMNetwork(num_id=self.t2w.lenDict()) | |
self.cls = L.Classifier(self.net) | |
self.opt = chainer.optimizers.Adam() | |
self.opt.setup(self.cls) | |
# 1ツイートの学習 | |
def train_sequence(self, seq): | |
self.net.reset() | |
sum_loss = 0.0 | |
for i in six.moves.range(len(seq) - 1): | |
self.cls.zerograds() | |
# x:ある単語、t:ある単語の次に来る単語 | |
x = Variable(np.asarray([seq[i]], dtype=np.int32)) | |
t = Variable(np.asarray([seq[i+1]], dtype=np.int32)) | |
# xが入力されたらtが出力されるように学習 | |
loss = self.cls(x, t) | |
sum_loss += loss.data | |
loss.backward() | |
self.opt.update() | |
return sum_loss | |
# 1ツイートのテスト:最初の数単語を入力して残りの単語を推定させる | |
def test_sequence(self, seq_part, times_est): | |
self.net.reset() | |
res_seq = [] | |
# 最初の数単語の入力 | |
for i in six.moves.range(len(seq_part)): | |
x = Variable(np.asarray([seq_part[i]], dtype=np.int32)) | |
y = self.net(x) | |
res_seq.append(seq_part[i]) | |
res_seq.append(np.argmax(y.data)) | |
# 残りの単語の推定 | |
for i in six.moves.range(times_est): | |
x = Variable(np.asarray([np.argmax(y.data)], dtype=np.int32)) | |
y = self.net(x) | |
res_seq.append(np.argmax(y.data)) | |
return res_seq | |
# ツイートを再現できるか確認する | |
def validation(self, pre_words): | |
correct = 0.0 | |
trial = 5 if len(self.data) > 5 else len(self.data) | |
for i, d in enumerate(self.data): | |
sentence_est = self.test_sequence(d[0:pre_words], len(d) - pre_words - 1) | |
if sentence_est == d: | |
correct += 1.0 | |
# 試しに5ツイートほど出力してみる | |
if i > trial: | |
continue | |
# ツイートの出力 | |
sys.stdout.write(str(i) + ':') | |
voice = '' | |
for word_est in sentence_est: | |
sys.stdout.write(self.t2w.index2word(word_est) + ' ') | |
voice += self.t2w.index2word(word_est) | |
sys.stdout.write('\n') | |
# ツイートを喋らせてみる | |
if i == 0: | |
jtalk(voice) | |
# 再現度 = 一致したツイート / すべてのツイート、と定義した | |
return correct / len(self.data) | |
# 学習ループの実行 | |
def run(self, pre_words): | |
# 75 epochくらい | |
max_iter = 75 * len(self.data) | |
perm = np.random.permutation(max_iter) % len(self.data) | |
# matplotlib用 | |
grx = np.arange(max_iter) | |
gry = np.zeros(max_iter) | |
plt.clf() | |
plt.plot(grx, gry) | |
plt.pause(0.001) | |
for i in six.moves.range(max_iter): | |
# 100ツイートの学習ごとにテストを行う | |
if i % 100 == 0: | |
print('---- Trial {} ----'.format(i)) | |
score = self.validation(pre_words) | |
print('------------------') | |
print('Score: {}'.format(score)) | |
# 1ツイートの学習と損失関数のプロット | |
loss = self.train_sequence(self.data[perm[i]]) | |
gry[i] = loss | |
plt.clf() | |
plt.plot(grx, gry) | |
plt.pause(0.001) | |
# モデルの保存 | |
serializers.save_npz('model_' + str(len(self.data)) + '.mod', self.net) | |
# グラフの保存 | |
plt.savefig('haruka_' + str(size) + '.png') | |
# デモ用 | |
def demo(path2mod, path2dict): | |
_dict = np.load(path2dict) | |
net = LSTMNetwork(num_id=len(_dict)) | |
serializers.load_npz(path2mod, net) | |
net.reset() | |
# ランダムな単語を入れて | |
r = np.random.randint(0, len(_dict), size=(1, 1)).astype(np.int32) | |
txt = '' | |
txt += _dict[r[0][0]] | |
# 10回ループを回す | |
for i in six.moves.range(10): | |
x = Variable(r) | |
y = net(x) | |
r = np.asarray([np.argmax(y.data)]).astype(np.int32) | |
txt += _dict[r[0]] | |
# 喋らせてみる | |
print(txt) | |
jtalk(txt) | |
if __name__ == '__main__': | |
argv = sys.argv | |
if len(argv) > 2: | |
demo(path2mod=argv[1], path2dict=argv[2]) | |
sys.exit() | |
for s in six.moves.range(10, 50, 10): | |
size = s | |
t = Trainer(size=size) | |
t.run(pre_words=6) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment