Skip to content

Instantly share code, notes, and snippets.

@xccds
Created November 25, 2015 02:56
Show Gist options
  • Save xccds/d82f2b8de4112c209f8c to your computer and use it in GitHub Desktop.
Save xccds/d82f2b8de4112c209f8c to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"### 基于Deep LSTM的中文分词\n",
" - 步骤1:读入有标注的训练语料库,处理成keras需要的数据格式。\n",
" - 步骤2:根据训练数据建模,使用deep LSTM方法\n",
" - 步骤3:读入无标注的检验语料库,用两层LSTM模型进行分词标注,用更多层效果会更好一点\n",
" - 步骤4:检查最终的效果 F值0.949"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- 步骤1:训练数据读取和转换"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using gpu device 0: GeForce GTX TITAN X (CNMeM is disabled)\n"
]
}
],
"source": [
"import sys\n",
"import os\n",
"import nltk \n",
"import codecs\n",
"import pandas as pd\n",
"import numpy as np\n",
"from os import path\n",
"from nltk.probability import FreqDist \n",
"from gensim.models import word2vec\n",
"from cPickle import load, dump"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# 根据微软语料库计算词向量\n",
"\n",
"# 读单个文本\n",
"def load_file(input_file):\n",
" input_data = codecs.open(input_file, 'r', 'utf-8')\n",
" input_text = input_data.read()\n",
" return input_text\n",
"\n",
"# 读取目录下文本成为一个字符串\n",
"def load_dir(input_dir):\n",
" files = list_dir(input_dir)\n",
" seg_files_path = [path.join(input_dir, f) for f in files]\n",
" output = []\n",
" for txt in seg_files_path:\n",
" output.append(load_file(txt))\n",
" return '\\n'.join(output)\n",
"\n",
"# nltk 输入文本,输出词频表\n",
"def freq_func(input_txt):\n",
" corpus = nltk.Text(input_txt) \n",
" fdist = FreqDist(corpus) \n",
" w = fdist.keys() \n",
" v = fdist.values() \n",
" freqdf = pd.DataFrame({'word':w,'freq':v}) \n",
" freqdf.sort('freq',ascending =False, inplace=True)\n",
" freqdf['idx'] = np.arange(len(v))\n",
" return freqdf\n",
"\n",
"# word2vec建模\n",
"def trainW2V(corpus, epochs=20, num_features = 100,sg=1,\\\n",
" min_word_count = 1, num_workers = 4,\\\n",
" context = 4, sample = 1e-5, negative = 5):\n",
" w2v = word2vec.Word2Vec(workers = num_workers,\n",
" sample = sample,\n",
" size = num_features,\n",
" min_count=min_word_count,\n",
" window = context)\n",
" np.random.shuffle(corpus)\n",
" w2v.build_vocab(corpus) \n",
" for epoch in range(epochs):\n",
" print('epoch' + str(epoch))\n",
" np.random.shuffle(corpus)\n",
" w2v.train(corpus)\n",
" w2v.alpha *= 0.9 \n",
" w2v.min_alpha = w2v.alpha \n",
" print(\"word2vec DONE.\")\n",
" return w2v\n",
" \n",
"\n",
"def save_w2v(w2v, idx2word):\n",
" # 保存词向量lookup矩阵,按idx位置存放\n",
" init_weight_wv = []\n",
" for i in range(len(idx2word)):\n",
" init_weight_wv.append(w2v[idx2word[i]])\n",
" return init_weight_wv\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch0\n",
"epoch1\n",
"epoch2\n",
"epoch3\n",
"epoch4\n",
"epoch5\n",
"epoch6\n",
"epoch7\n",
"epoch8\n",
"epoch9\n",
"epoch10\n",
"epoch11\n",
"epoch12\n",
"epoch13\n",
"epoch14\n",
"epoch15\n",
"epoch16\n",
"epoch17\n",
"epoch18\n",
"epoch19\n",
"word2vec DONE.\n"
]
}
],
"source": [
"input_file = 'icwb2-data/training/msr_training.utf8'\n",
"input_text = load_file(input_file) # 读入全部文本\n",
"txtwv = [line.split() for line in input_text.split('\\n') if line != ''] # 为词向量准备的文本格式\n",
"txtnltk = [w for w in input_text.split()] # 为计算词频准备的文本格式\n",
"freqdf = freq_func(txtnltk) # 计算词频表\n",
"maxfeatures = freqdf.shape[0] # 词汇个数\n",
"# 建立两个映射字典\n",
"word2idx = dict((c, i) for c, i in zip(freqdf.word, freqdf.idx))\n",
"idx2word = dict((i, c) for c, i in zip(freqdf.word, freqdf.idx))\n",
"# word2vec\n",
"w2v = trainW2V(txtwv)\n",
"# 存向量\n",
"init_weight_wv = save_w2v(w2v,idx2word)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# 定义'U'为未登陆新字, 'P'为两头padding用途,并增加两个相应的向量表示\n",
"char_num = len(init_weight_wv)\n",
"idx2word[char_num] = u'U'\n",
"word2idx[u'U'] = char_num\n",
"idx2word[char_num+1] = u'P'\n",
"word2idx[u'P'] = char_num+1\n",
"\n",
"init_weight_wv.append(np.random.randn(100,))\n",
"init_weight_wv.append(np.zeros(100,))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# 读取数据,将格式进行转换为带四种标签 S B M E\n",
"output_file = 'icwb2-data/training/msr_training.tagging.utf8'"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import codecs\n",
"import sys\n",
"\n",
"def character_tagging(input_file, output_file):\n",
" input_data = codecs.open(input_file, 'r', 'utf-8')\n",
" output_data = codecs.open(output_file, 'w', 'utf-8')\n",
" for line in input_data.readlines():\n",
" word_list = line.strip().split()\n",
" for word in word_list:\n",
" if len(word) == 1:\n",
" output_data.write(word + \"/S \")\n",
" else:\n",
" output_data.write(word[0] + \"/B \")\n",
" for w in word[1:len(word)-1]:\n",
" output_data.write(w + \"/M \")\n",
" output_data.write(word[len(word)-1] + \"/E \")\n",
" output_data.write(\"\\n\")\n",
" input_data.close()\n",
" output_data.close()\n",
"\n",
"character_tagging(input_file, output_file)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# 分离word 和 label\n",
"with open(output_file) as f:\n",
" lines = f.readlines()\n",
" train_line = [[w[0] for w in line.decode('utf-8').split()] for line in lines]\n",
" train_label = [w[2] for line in lines for w in line.decode('utf-8').split()]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# 文档转数字list\n",
"import numpy as np\n",
"def sent2num(sentence, word2idx = word2idx, context = 7):\n",
" predict_word_num = []\n",
" for w in sentence:\n",
" # 文本中的字如果在词典中则转为数字,如果不在则设置为'U\n",
" if w in word2idx:\n",
" predict_word_num.append(word2idx[w])\n",
" else:\n",
" predict_word_num.append(word2idx[u'U'])\n",
" # 首尾padding\n",
" num = len(predict_word_num)\n",
" pad = int((context-1)*0.5)\n",
" for i in range(pad):\n",
" predict_word_num.insert(0,word2idx[u'P'] )\n",
" predict_word_num.append(word2idx[u'P'] )\n",
" train_x = []\n",
" for i in range(num):\n",
" train_x.append(predict_word_num[i:i+context])\n",
" return train_x"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"[[88120, 88120, 88120, 9, 21, 107, 1976],\n",
" [88120, 88120, 9, 21, 107, 1976, 26],\n",
" [88120, 9, 21, 107, 1976, 26, 1116],\n",
" [9, 21, 107, 1976, 26, 1116, 1397],\n",
" [21, 107, 1976, 26, 1116, 1397, 7],\n",
" [107, 1976, 26, 1116, 1397, 7, 10],\n",
" [1976, 26, 1116, 1397, 7, 10, 538],\n",
" [26, 1116, 1397, 7, 10, 538, 2300],\n",
" [1116, 1397, 7, 10, 538, 2300, 2353],\n",
" [1397, 7, 10, 538, 2300, 2353, 378],\n",
" [7, 10, 538, 2300, 2353, 378, 0],\n",
" [10, 538, 2300, 2353, 378, 0, 46],\n",
" [538, 2300, 2353, 378, 0, 46, 2118],\n",
" [2300, 2353, 378, 0, 46, 2118, 18],\n",
" [2353, 378, 0, 46, 2118, 18, 2610],\n",
" [378, 0, 46, 2118, 18, 2610, 1],\n",
" [0, 46, 2118, 18, 2610, 1, 1172],\n",
" [46, 2118, 18, 2610, 1, 1172, 2183],\n",
" [2118, 18, 2610, 1, 1172, 2183, 78],\n",
" [18, 2610, 1, 1172, 2183, 78, 7],\n",
" [2610, 1, 1172, 2183, 78, 7, 16],\n",
" [1, 1172, 2183, 78, 7, 16, 141],\n",
" [1172, 2183, 78, 7, 16, 141, 70],\n",
" [2183, 78, 7, 16, 141, 70, 110],\n",
" [78, 7, 16, 141, 70, 110, 1],\n",
" [7, 16, 141, 70, 110, 1, 2300],\n",
" [16, 141, 70, 110, 1, 2300, 2353],\n",
" [141, 70, 110, 1, 2300, 2353, 378],\n",
" [70, 110, 1, 2300, 2353, 378, 0],\n",
" [110, 1, 2300, 2353, 378, 0, 87],\n",
" [1, 2300, 2353, 378, 0, 87, 3986],\n",
" [2300, 2353, 378, 0, 87, 3986, 1962],\n",
" [2353, 378, 0, 87, 3986, 1962, 7],\n",
" [378, 0, 87, 3986, 1962, 7, 462],\n",
" [0, 87, 3986, 1962, 7, 462, 180],\n",
" [87, 3986, 1962, 7, 462, 180, 91],\n",
" [3986, 1962, 7, 462, 180, 91, 1962],\n",
" [1962, 7, 462, 180, 91, 1962, 1],\n",
" [7, 462, 180, 91, 1962, 1, 1384],\n",
" [462, 180, 91, 1962, 1, 1384, 37],\n",
" [180, 91, 1962, 1, 1384, 37, 1],\n",
" [91, 1962, 1, 1384, 37, 1, 43],\n",
" [1962, 1, 1384, 37, 1, 43, 463],\n",
" [1, 1384, 37, 1, 43, 463, 1380],\n",
" [1384, 37, 1, 43, 463, 1380, 2],\n",
" [37, 1, 43, 463, 1380, 2, 88120],\n",
" [1, 43, 463, 1380, 2, 88120, 88120],\n",
" [43, 463, 1380, 2, 88120, 88120, 88120]]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 输入字符list,输出数字list\n",
"sent2num(train_line[0])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# 将所有训练文本转成数字list\n",
"train_word_num = []\n",
"for line in train_line:\n",
" train_word_num.extend(sent2num(line))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4050469\n",
"4050469\n"
]
}
],
"source": [
"print len(train_word_num)\n",
"print len(train_label)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from cPickle import load\n",
"#dump(train_word_num, open('train_word_num.pickle', 'wb'))\n",
"#train_word_num = load(open('train_word_num.pickle','rb'))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"nb_classes = len(np.unique(train_label))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- 步骤3:训练模型"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from __future__ import absolute_import\n",
"from __future__ import print_function\n",
"\n",
"from keras.preprocessing import sequence\n",
"from keras.optimizers import SGD, RMSprop, Adagrad\n",
"from keras.utils import np_utils\n",
"from keras.models import Sequential,Graph\n",
"from keras.layers.core import Dense, Dropout, Activation, TimeDistributedDense\n",
"from keras.layers.embeddings import Embedding\n",
"from keras.layers.recurrent import LSTM, GRU,SimpleRNN\n",
"from keras.layers.core import Reshape, Flatten ,Dropout\n",
"from keras.regularizers import l1,l2\n",
"from keras.layers.convolutional import Convolution2D, MaxPooling2D,MaxPooling1D"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{u'M': 2, u'S': 3, u'B': 0, u'E': 1}\n",
"{0: u'B', 1: u'E', 2: u'M', 3: u'S'}\n"
]
}
],
"source": [
"# 建立两个字典\n",
"label_dict = dict(zip(np.unique(train_label), range(4)))\n",
"num_dict = {n:l for l,n in label_dict.iteritems()}\n",
"print(label_dict)\n",
"print(num_dict)\n",
"# 将目标变量转为数字\n",
"train_label = [label_dict[y] for y in train_label]"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# 切分数据集\n",
"from sklearn.cross_validation import train_test_split\n",
"train_word_num = np.array(train_word_num)\n",
"train_X, test_X, train_y, test_y = train_test_split(train_word_num, train_label , train_size=0.9, random_state=1)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"Y_train = np_utils.to_categorical(train_y, nb_classes)\n",
"Y_test = np_utils.to_categorical(test_y, nb_classes)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3645422 train sequences\n",
"405047 test sequences\n"
]
}
],
"source": [
"print(len(train_X), 'train sequences')\n",
"print(len(test_X), 'test sequences')"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# 初始字向量格式准备\n",
"init_weight = [np.array(init_weight_wv)]"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"batch_size = 128"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"maxfeatures = init_weight[0].shape[0] # 词典大小\n",
"word_dim = 100\n",
"maxlen = 7\n",
"hidden_units = 100"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# stacking LSTM"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"stacking LSTM...\n"
]
}
],
"source": [
"print('stacking LSTM...')\n",
"model = Sequential()\n",
"model.add(Embedding(maxfeatures, word_dim,input_length=maxlen))\n",
"model.add(LSTM(output_dim=hidden_units, return_sequences =True))\n",
"model.add(LSTM(output_dim=hidden_units, return_sequences =False))\n",
"model.add(Dropout(0.5))\n",
"model.add(Dense(nb_classes))\n",
"model.add(Activation('softmax'))\n",
"model.compile(loss='categorical_crossentropy', optimizer='adam')"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train...\n",
"Train on 3645422 samples, validate on 405047 samples\n",
"Epoch 1/20\n",
"3645422/3645422 [==============================] - 649s - loss: 0.2642 - acc: 0.9054 - val_loss: 0.1585 - val_acc: 0.9446\n",
"Epoch 2/20\n",
"3645422/3645422 [==============================] - 649s - loss: 0.1341 - acc: 0.9549 - val_loss: 0.1166 - val_acc: 0.9603\n",
"Epoch 3/20\n",
"3645422/3645422 [==============================] - 649s - loss: 0.0993 - acc: 0.9671 - val_loss: 0.0976 - val_acc: 0.9673\n",
"Epoch 4/20\n",
"3645422/3645422 [==============================] - 649s - loss: 0.0813 - acc: 0.9733 - val_loss: 0.0870 - val_acc: 0.9711\n",
"Epoch 5/20\n",
"3645422/3645422 [==============================] - 649s - loss: 0.0703 - acc: 0.9769 - val_loss: 0.0818 - val_acc: 0.9728\n",
"Epoch 6/20\n",
"3645422/3645422 [==============================] - 649s - loss: 0.0630 - acc: 0.9792 - val_loss: 0.0760 - val_acc: 0.9749\n",
"Epoch 7/20\n",
"3645422/3645422 [==============================] - 649s - loss: 0.0579 - acc: 0.9810 - val_loss: 0.0723 - val_acc: 0.9764\n",
"Epoch 8/20\n",
"3645422/3645422 [==============================] - 649s - loss: 0.0541 - acc: 0.9821 - val_loss: 0.0698 - val_acc: 0.9772\n",
"Epoch 9/20\n",
"3645422/3645422 [==============================] - 649s - loss: 0.0512 - acc: 0.9832 - val_loss: 0.0697 - val_acc: 0.9774\n",
"Epoch 10/20\n",
"3645422/3645422 [==============================] - 649s - loss: 0.0493 - acc: 0.9838 - val_loss: 0.0678 - val_acc: 0.9781\n",
"Epoch 11/20\n",
"3645422/3645422 [==============================] - 649s - loss: 0.0474 - acc: 0.9844 - val_loss: 0.0664 - val_acc: 0.9788\n",
"Epoch 12/20\n",
"3645422/3645422 [==============================] - 649s - loss: 0.0458 - acc: 0.9849 - val_loss: 0.0659 - val_acc: 0.9787\n",
"Epoch 13/20\n",
"3645422/3645422 [==============================] - 649s - loss: 0.0446 - acc: 0.9853 - val_loss: 0.0647 - val_acc: 0.9797\n",
"Epoch 14/20\n",
"3645422/3645422 [==============================] - 649s - loss: 0.0435 - acc: 0.9857 - val_loss: 0.0627 - val_acc: 0.9798\n",
"Epoch 15/20\n",
"3645422/3645422 [==============================] - 648s - loss: 0.0426 - acc: 0.9860 - val_loss: 0.0625 - val_acc: 0.9799\n",
"Epoch 16/20\n",
"3645422/3645422 [==============================] - 649s - loss: 0.0420 - acc: 0.9862 - val_loss: 0.0615 - val_acc: 0.9804\n",
"Epoch 17/20\n",
"3645422/3645422 [==============================] - 649s - loss: 0.0411 - acc: 0.9864 - val_loss: 0.0601 - val_acc: 0.9804\n",
"Epoch 18/20\n",
"3645422/3645422 [==============================] - 649s - loss: 0.0406 - acc: 0.9866 - val_loss: 0.0614 - val_acc: 0.9810\n",
"Epoch 19/20\n",
"3645422/3645422 [==============================] - 649s - loss: 0.0401 - acc: 0.9867 - val_loss: 0.0624 - val_acc: 0.9804\n",
"Epoch 20/20\n",
"3645422/3645422 [==============================] - 649s - loss: 0.0398 - acc: 0.9869 - val_loss: 0.0618 - val_acc: 0.9807\n"
]
}
],
"source": [
"# train_X, test_X, Y_train, Y_test\n",
"print(\"Train...\")\n",
"result = model.fit(train_X, Y_train, batch_size=batch_size, \n",
" nb_epoch=20, validation_data = (test_X,Y_test), show_accuracy=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# bidirectional_lstm"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"graph = Graph()\n",
"graph.add_input(name='input', input_shape=(maxlen,), dtype=int)\n",
"graph.add_node(Embedding(maxfeatures, word_dim, input_length=maxlen),\n",
" name='embedding', input='input')\n",
"graph.add_node(LSTM(output_dim=hidden_units), name='forward', input='embedding')\n",
"graph.add_node(LSTM(output_dim=hidden_units, go_backwards =True), name='backward', input='embedding')\n",
"graph.add_node(Dropout(0.5), name='dropout', inputs=['forward', 'backward'])\n",
"graph.add_node(Dense(nb_classes, activation='softmax'), name='softmax', input='dropout')\n",
"graph.add_output(name='output', input='softmax')\n",
"graph.compile(loss = {'output': 'categorical_crossentropy'}, optimizer='adam')"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 3645422 samples, validate on 405047 samples\n",
"Epoch 1/20\n",
"3645422/3645422 [==============================] - 703s - loss: 0.2795 - val_loss: 0.1710\n",
"Epoch 2/20\n",
"3645422/3645422 [==============================] - 701s - loss: 0.1448 - val_loss: 0.1225\n",
"Epoch 3/20\n",
"3645422/3645422 [==============================] - 702s - loss: 0.1090 - val_loss: 0.1029\n",
"Epoch 4/20\n",
"3645422/3645422 [==============================] - 701s - loss: 0.0906 - val_loss: 0.0930\n",
"Epoch 5/20\n",
"3645422/3645422 [==============================] - 701s - loss: 0.0794 - val_loss: 0.0876\n",
"Epoch 6/20\n",
"3645422/3645422 [==============================] - 701s - loss: 0.0727 - val_loss: 0.0841\n",
"Epoch 7/20\n",
"3645422/3645422 [==============================] - 702s - loss: 0.0677 - val_loss: 0.0809\n",
"Epoch 8/20\n",
"3645422/3645422 [==============================] - 702s - loss: 0.0640 - val_loss: 0.0789\n",
"Epoch 9/20\n",
"3645422/3645422 [==============================] - 702s - loss: 0.0613 - val_loss: 0.0762\n",
"Epoch 10/20\n",
"3645422/3645422 [==============================] - 701s - loss: 0.0592 - val_loss: 0.0760\n",
"Epoch 11/20\n",
"3645422/3645422 [==============================] - 702s - loss: 0.0574 - val_loss: 0.0761\n",
"Epoch 12/20\n",
"3645422/3645422 [==============================] - 701s - loss: 0.0563 - val_loss: 0.0727\n",
"Epoch 13/20\n",
"3645422/3645422 [==============================] - 702s - loss: 0.0552 - val_loss: 0.0732\n",
"Epoch 14/20\n",
"3645422/3645422 [==============================] - 702s - loss: 0.0542 - val_loss: 0.0717\n",
"Epoch 15/20\n",
"3645422/3645422 [==============================] - 701s - loss: 0.0533 - val_loss: 0.0735\n",
"Epoch 16/20\n",
"3645422/3645422 [==============================] - 700s - loss: 0.0525 - val_loss: 0.0707\n",
"Epoch 17/20\n",
"3645422/3645422 [==============================] - 702s - loss: 0.0520 - val_loss: 0.0716\n",
"Epoch 18/20\n",
"3645422/3645422 [==============================] - 701s - loss: 0.0517 - val_loss: 0.0711\n",
"Epoch 19/20\n",
"3645422/3645422 [==============================] - 701s - loss: 0.0510 - val_loss: 0.0700\n",
"Epoch 20/20\n",
"3645422/3645422 [==============================] - 700s - loss: 0.0506 - val_loss: 0.0715\n"
]
}
],
"source": [
"result2 = graph.fit({'input':train_X, 'output':Y_train}, \n",
" batch_size=batch_size, \n",
" nb_epoch=20, validation_data = ({'input':test_X,'output':Y_test}))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(88121, 100)\n",
"(100, 100)\n",
"(100, 100)\n",
"(100,)\n",
"(100, 100)\n",
"(100, 100)\n",
"(100,)\n",
"(100, 100)\n",
"(100, 100)\n",
"(100,)\n",
"(100, 100)\n",
"(100, 100)\n",
"(100,)\n",
"(100, 100)\n",
"(100, 100)\n",
"(100,)\n",
"(100, 100)\n",
"(100, 100)\n",
"(100,)\n",
"(100, 100)\n",
"(100, 100)\n",
"(100,)\n",
"(100, 100)\n",
"(100, 100)\n",
"(100,)\n",
"(100, 4)\n",
"(4,)\n"
]
}
],
"source": [
"import theano\n",
"# 模型待学习参数如下:即W的规模\n",
"weight_len = len(model.get_weights())\n",
"for i in range(weight_len):\n",
" print(model.get_weights()[i].shape)\n",
"# lstm 分别是embeding, 4种不同的门 , dense_w, dense_b"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"(10, 4)"
]
},
"execution_count": 90,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"layer = theano.function([model.layers[0].input],model.layers[3].get_output(train=False),allow_input_downcast=True) \n",
"layer_out = layer(test_X[:10]) \n",
"layer_out.shape # 前10篇文章的第0层输出,经过reku计算"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- 步骤4:用test文本进行预测,评估效果"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"temp_txt = u'国家食药监总局发布通知称,酮康唑口服制剂因存在严重肝毒性不良反应,即日起停止生产销售使用。'\n",
"temp_txt = list(temp_txt)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"[[88120, 88120, 88120, 309, 223, 5082, 1522],\n",
" [88120, 88120, 309, 223, 5082, 1522, 6778],\n",
" [88120, 309, 223, 5082, 1522, 6778, 396],\n",
" [309, 223, 5082, 1522, 6778, 396, 1773],\n",
" [223, 5082, 1522, 6778, 396, 1773, 1053]]"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"temp_num = sent2num(temp_txt)\n",
"temp_num[:5]"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# 根据输入得到标注推断\n",
"def predict_num(input_num,input_txt, \\\n",
" model,\\\n",
" label_dict=label_dict,\\\n",
" num_dict=num_dict):\n",
" input_num = np.array(input_num)\n",
" predict_prob = model.predict_proba(input_num, verbose=False)\n",
" predict_lable = model.predict_classes(input_num, verbose=False)\n",
" for i , lable in enumerate(predict_lable[:-1]):\n",
" # 如果是首字 ,不可为E, M\n",
" if i==0:\n",
" predict_prob[i, label_dict[u'E']] = 0\n",
" predict_prob[i, label_dict[u'M']] = 0 \n",
" # 前字为B,后字不可为B,S\n",
" if lable == label_dict[u'B']:\n",
" predict_prob[i+1,label_dict[u'B']] = 0\n",
" predict_prob[i+1,label_dict[u'S']] = 0\n",
" # 前字为E,后字不可为M,E\n",
" if lable == label_dict[u'E']:\n",
" predict_prob[i+1,label_dict[u'M']] = 0\n",
" predict_prob[i+1,label_dict[u'E']] = 0\n",
" # 前字为M,后字不可为B,S\n",
" if lable == label_dict[u'M']:\n",
" predict_prob[i+1,label_dict[u'B']] = 0\n",
" predict_prob[i+1,label_dict[u'S']] = 0\n",
" # 前字为S,后字不可为M,E\n",
" if lable == label_dict[u'S']:\n",
" predict_prob[i+1,label_dict[u'M']] = 0\n",
" predict_prob[i+1,label_dict[u'E']] = 0\n",
" predict_lable[i+1] = predict_prob[i+1].argmax()\n",
" predict_lable_new = [num_dict[x] for x in predict_lable]\n",
" result = [w+'/' +l for w, l in zip(input_txt,predict_lable_new)]\n",
" return ' '.join(result) + '\\n'"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"国/B 家/M 食/M 药/M 监/M 总/M 局/E 发/B 布/E 通/B 知/E 称/S ,/S 酮/B 康/E 唑/B 口/E 服/S 制/B 剂/E 因/S 存/B 在/E 严/B 重/E 肝/S 毒/B 性/E 不/B 良/E 反/B 应/E ,/S 即/B 日/E 起/S 停/B 止/E 生/B 产/E 销/B 售/E 使/B 用/E 。/S\n",
"\n"
]
}
],
"source": [
"temp = predict_num(temp_num,temp_txt, model = model)\n",
"print(temp)"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"test_file = 'icwb2-data/testing/msr_test.utf8'\n",
"with open(test_file,'r') as f:\n",
" lines = f.readlines()\n",
" test_texts = [list(line.decode('utf-8').strip()) for line in lines]"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"test_output = []\n",
"for line in test_texts:\n",
" test_num = sent2num(line)\n",
" output_line = predict_num(test_num,input_txt=line,model = model)\n",
" test_output.append(output_line.encode('utf-8'))"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"with open('icwb2-data/testing/msr_test_output.utf8','w') as f:\n",
" f.writelines(test_output)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"input_file = 'icwb2-data/testing/msr_test_output.utf8'\n",
"output_file = 'icwb2-data/testing/msr_test.split.tag2word.utf8'"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import codecs\n",
"import sys\n",
"\n",
"def character_2_word(input_file, output_file):\n",
" input_data = codecs.open(input_file, 'r', 'utf-8')\n",
" output_data = codecs.open(output_file, 'w', 'utf-8')\n",
" # 4 tags for character tagging: B(Begin), E(End), M(Middle), S(Single)\n",
" for line in input_data.readlines():\n",
" char_tag_list = line.strip().split()\n",
" for char_tag in char_tag_list:\n",
" char_tag_pair = char_tag.split('/')\n",
" char = char_tag_pair[0]\n",
" tag = char_tag_pair[1]\n",
" if tag == 'B':\n",
" output_data.write(' ' + char)\n",
" elif tag == 'M':\n",
" output_data.write(char)\n",
" elif tag == 'E':\n",
" output_data.write(char + ' ')\n",
" else: # tag == 'S'\n",
" output_data.write(' ' + char + ' ')\n",
" output_data.write(\"\\n\")\n",
" input_data.close()\n",
" output_data.close()\n",
"\n",
"character_2_word(input_file, output_file)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### - 最终使用perl脚本检验的F值为0.949"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"! ./icwb2-data/scripts/score ./icwb2-data/gold/msr_training_words.utf8 ./icwb2-data/gold/msr_test_gold.utf8 ./icwb2-data/testing/msr_test.split.tag2word.utf8 > deep.score"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment