Skip to content

Instantly share code, notes, and snippets.

@ytbilly3636
Created August 7, 2018 05:56
Show Gist options
  • Save ytbilly3636/ee364dbb9ada0768d603576d922aafb1 to your computer and use it in GitHub Desktop.
Save ytbilly3636/ee364dbb9ada0768d603576d922aafb1 to your computer and use it in GitHub Desktop.
# -*- coding:utf-8 -*-
'''
Requirements
* chainer
* matplotlib
* MeCab
* OpenJTalk
* requests_oauthlib
* Twitter API
'''
import sys
import six
import numpy as np
import chainer
from chainer import links as L
from chainer import functions as F
from chainer import Chain
from chainer import Variable
from chainer import serializers
import matplotlib.pyplot as plt
import MeCab
import json
from requests_oauthlib import OAuth1Session
import subprocess
import warnings
warnings.filterwarnings("ignore")
# open jtalk
def jtalk(t):
# TODO: meiのhtsvoiceを別途取得
open_jtalk=['open_jtalk']
mech=['-x','/var/lib/mecab/dic/open-jtalk/naist-jdic']
htsvoice=['-m','/usr/share/hts-voice/mei/mei_normal.htsvoice']
speed=['-r','1.0']
outwav=['-ow','open_jtalk.wav']
cmd=open_jtalk+mech+htsvoice+speed+outwav
c = subprocess.Popen(cmd,stdin=subprocess.PIPE)
c.stdin.write(t)
c.stdin.close()
c.wait()
aplay = ['aplay','-q','open_jtalk.wav']
wr = subprocess.Popen(aplay)
# LSTM
class LSTMNetwork(Chain):
def __init__(self, num_id, num_hid=100):
super(LSTMNetwork, self).__init__(
em1 = L.EmbedID(num_id, num_hid),
ls1 = L.LSTM(num_hid, num_hid),
ln1 = L.Linear(num_hid, num_id)
)
def __call__(self, x):
h_em1 = F.relu(self.em1(x))
h_ls1 = F.relu(self.ls1(h_em1))
h_ln1 = self.ln1(h_ls1)
return h_ln1
def reset(self):
self.ls1.reset_state()
return
# Dataset (Twitter API + MeCab)
class Tweet2Words(object):
def __init__(self):
# TODO: CK, CS, AT, ASを取得
CK = ''
CS = ''
AT = ''
AS = ''
self.session = OAuth1Session(CK, CS, AT, AS)
self.mecab = MeCab.Tagger('-Ochasen')
# 指定したアカウントの指定した数のツイートを取得
def __fetchTweets(self, count, screen_name='620haruka_y'):
print('Fetching {} tweets ...'.format(count))
url = 'https://api.twitter.com/1.1/statuses/user_timeline.json'
params={
'screen_name':screen_name,
'count':count,
'exclude_replies':False,
'include_rts':True
}
res = self.session.get(url, params=params)
if res.status_code != 200:
sys.exit('ERROR: Could not fetch tweets')
tweets = json.loads(res.text)
print('{} tweets were fetched actually'.format(len(tweets)))
print('Done')
return tweets
# 複数の文章に含まれる単語の辞書を作成する
def __createDict(self, sentences):
all_sentences = ' '.join(sentences)
node = self.mecab.parseToNode(all_sentences)
self.dict = []
while node:
# word:単語
word = node.surface
# wordが既に辞書に登録されていなければ追加
if not word in self.dict:
self.dict.append(word)
node = node.next
# 例外対策
self.dict.append(None)
# 辞書の保存
np.save('dict_' + str(len(self.dict)) + '.npy', self.dict)
# 上記2つの関数を使って単語の時系列データを作成
def __createDataset(self, size):
tweets = self.__fetchTweets(count=size)
tweets = [tweet['text'].encode('utf-8') for tweet in tweets]
self.__createDict(tweets)
print('Creating dataset...')
self.sentences = []
for tweet in tweets:
sentence = []
node = self.mecab.parseToNode(tweet)
while node:
word = node.surface
if word in self.dict:
sentence.append(self.dict.index(word))
else:
sentence.append(len(self.dict) - 1)
node = node.next
self.sentences.append(sentence)
print('Done')
# データセットを返す
def getDataset(self, size):
self.__createDataset(size=size)
return self.sentences
# 辞書のサイズを返す
def lenDict(self):
return len(self.dict)
# 辞書のインデックスから単語を返す
def index2word(self, index):
return self.dict[index]
# Trainer
class Trainer:
def __init__(self, size):
self.t2w = Tweet2Words()
self.data = self.t2w.getDataset(size=size)
self.net = LSTMNetwork(num_id=self.t2w.lenDict())
self.cls = L.Classifier(self.net)
self.opt = chainer.optimizers.Adam()
self.opt.setup(self.cls)
# 1ツイートの学習
def train_sequence(self, seq):
self.net.reset()
sum_loss = 0.0
for i in six.moves.range(len(seq) - 1):
self.cls.zerograds()
# x:ある単語、t:ある単語の次に来る単語
x = Variable(np.asarray([seq[i]], dtype=np.int32))
t = Variable(np.asarray([seq[i+1]], dtype=np.int32))
# xが入力されたらtが出力されるように学習
loss = self.cls(x, t)
sum_loss += loss.data
loss.backward()
self.opt.update()
return sum_loss
# 1ツイートのテスト:最初の数単語を入力して残りの単語を推定させる
def test_sequence(self, seq_part, times_est):
self.net.reset()
res_seq = []
# 最初の数単語の入力
for i in six.moves.range(len(seq_part)):
x = Variable(np.asarray([seq_part[i]], dtype=np.int32))
y = self.net(x)
res_seq.append(seq_part[i])
res_seq.append(np.argmax(y.data))
# 残りの単語の推定
for i in six.moves.range(times_est):
x = Variable(np.asarray([np.argmax(y.data)], dtype=np.int32))
y = self.net(x)
res_seq.append(np.argmax(y.data))
return res_seq
# ツイートを再現できるか確認する
def validation(self, pre_words):
correct = 0.0
trial = 5 if len(self.data) > 5 else len(self.data)
for i, d in enumerate(self.data):
sentence_est = self.test_sequence(d[0:pre_words], len(d) - pre_words - 1)
if sentence_est == d:
correct += 1.0
# 試しに5ツイートほど出力してみる
if i > trial:
continue
# ツイートの出力
sys.stdout.write(str(i) + ':')
voice = ''
for word_est in sentence_est:
sys.stdout.write(self.t2w.index2word(word_est) + ' ')
voice += self.t2w.index2word(word_est)
sys.stdout.write('\n')
# ツイートを喋らせてみる
if i == 0:
jtalk(voice)
# 再現度 = 一致したツイート / すべてのツイート、と定義した
return correct / len(self.data)
# 学習ループの実行
def run(self, pre_words):
# 75 epochくらい
max_iter = 75 * len(self.data)
perm = np.random.permutation(max_iter) % len(self.data)
# matplotlib用
grx = np.arange(max_iter)
gry = np.zeros(max_iter)
plt.clf()
plt.plot(grx, gry)
plt.pause(0.001)
for i in six.moves.range(max_iter):
# 100ツイートの学習ごとにテストを行う
if i % 100 == 0:
print('---- Trial {} ----'.format(i))
score = self.validation(pre_words)
print('------------------')
print('Score: {}'.format(score))
# 1ツイートの学習と損失関数のプロット
loss = self.train_sequence(self.data[perm[i]])
gry[i] = loss
plt.clf()
plt.plot(grx, gry)
plt.pause(0.001)
# モデルの保存
serializers.save_npz('model_' + str(len(self.data)) + '.mod', self.net)
# グラフの保存
plt.savefig('haruka_' + str(size) + '.png')
# デモ用
def demo(path2mod, path2dict):
_dict = np.load(path2dict)
net = LSTMNetwork(num_id=len(_dict))
serializers.load_npz(path2mod, net)
net.reset()
# ランダムな単語を入れて
r = np.random.randint(0, len(_dict), size=(1, 1)).astype(np.int32)
txt = ''
txt += _dict[r[0][0]]
# 10回ループを回す
for i in six.moves.range(10):
x = Variable(r)
y = net(x)
r = np.asarray([np.argmax(y.data)]).astype(np.int32)
txt += _dict[r[0]]
# 喋らせてみる
print(txt)
jtalk(txt)
if __name__ == '__main__':
argv = sys.argv
if len(argv) > 2:
demo(path2mod=argv[1], path2dict=argv[2])
sys.exit()
for s in six.moves.range(10, 50, 10):
size = s
t = Trainer(size=size)
t.run(pre_words=6)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment