Last active
November 4, 2018 10:04
-
-
Save mikeyee/dedb7f938cadb3e55074a8d01e52013c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import numpy as np\n", | |
"from emo_utils import *\n", | |
"import emoji\n", | |
"import matplotlib.pyplot as plt\n", | |
"import re\n", | |
"import jieba\n", | |
"\n%matplotlib inline" | |
], | |
"outputs": [], | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": false, | |
"outputHidden": false, | |
"inputHidden": false | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#讀入訓練及測試用的留言庫\n", | |
"\n", | |
"X_train, Y_train = read_csv('data/train_emoji_ch.csv')\n", | |
"X_test, Y_test = read_csv('data/tesss_ch.csv')" | |
], | |
"outputs": [], | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false, | |
"outputHidden": false, | |
"inputHidden": false | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#測試emoji的參數\n", | |
"label_to_emoji(7)" | |
], | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 3, | |
"data": { | |
"text/plain": [ | |
"'😡'" | |
] | |
}, | |
"metadata": {} | |
} | |
], | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": false, | |
"outputHidden": false, | |
"inputHidden": false | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#計算訓練留言庫中最長的句子及分割字詞示範\n", | |
"\n", | |
"text=max(X_train, key=len)\n", | |
"text = list(jieba.cut(text,cut_all=False))\n", | |
"maxLen = len(text)\n", | |
"\n\n", | |
"maxLen=18 #後面再切碎字串,這裡用最大的\n", | |
"print(maxLen)\n", | |
"text" | |
], | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Building prefix dict from the default dictionary ...\n", | |
"Loading model from cache /var/folders/bv/2q78w3811yg71lmsppx6lxrw0000gn/T/jieba.cache\n", | |
"Loading model cost 0.947 seconds.\n", | |
"Prefix dict has been built succesfully.\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"18\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"execution_count": 4, | |
"data": { | |
"text/plain": [ | |
"['我', '的', '程式', '明明', '可以', '執行', '但', '導師', '仍給', '我', '零分']" | |
] | |
}, | |
"metadata": {} | |
} | |
], | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false, | |
"outputHidden": false, | |
"inputHidden": false | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#閱讀訓練留言庫的留言\n", | |
"index = 7\n", | |
"print(X_train[index], label_to_emoji(Y_train[index]))" | |
], | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"這功課太多了 😞\n" | |
] | |
} | |
], | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": false, | |
"outputHidden": false, | |
"inputHidden": false | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"Y_oh_train = convert_to_one_hot(Y_train, C = 8)\n", | |
"Y_oh_test = convert_to_one_hot(Y_test, C = 8)" | |
], | |
"outputs": [], | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": false, | |
"outputHidden": false, | |
"inputHidden": false | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#將訓練留言的正確答案化為one hot 陣列\n", | |
"index = 0\n", | |
"print(Y_train[index], \"is converted into one hot\", Y_oh_train[index])" | |
], | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"3 is converted into one hot [0. 0. 0. 1. 0. 0. 0. 0.]\n" | |
] | |
} | |
], | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": false, | |
"outputHidden": false, | |
"inputHidden": false | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#讀入facebook預先訓練好的中文文字向量\n", | |
"word_to_index, index_to_word, word_to_vec_map = read_glove_vecs('data/wiki.zh.vec')" | |
], | |
"outputs": [], | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": false, | |
"outputHidden": false, | |
"inputHidden": false | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#讀入文字向量中單詞的數據\n", | |
"word = \"差\"\n", | |
"index = 235052\n", | |
"print(\"the index of\", word, \"in the vocabulary is\", word_to_index[word])\n", | |
"print(\"the\", str(index) + \"th word in the vocabulary is\", index_to_word[index])\n", | |
"len(word_to_vec_map[word])" | |
], | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"the index of 差 in the vocabulary is 235052\n", | |
"the 235052th word in the vocabulary is 差\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"execution_count": 9, | |
"data": { | |
"text/plain": [ | |
"300" | |
] | |
}, | |
"metadata": {} | |
} | |
], | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": false, | |
"outputHidden": false, | |
"inputHidden": false | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#加總留言字詞向量\n", | |
"\n", | |
"def sentence_to_avg(sentence, word_to_vec_map):\n", | |
" \"\"\"\n", | |
" Converts a sentence (string) into a list of words (strings). Extracts the GloVe representation of each word\n", | |
" and averages its value into a single vector encoding the meaning of the sentence.\n", | |
"\n", | |
" Arguments:\n", | |
" sentence -- string, one training example from X\n", | |
" word_to_vec_map -- dictionary mapping every word in a vocabulary into its 50-dimensional vector representation\n", | |
"\n", | |
" Returns:\n", | |
" avg -- average vector encoding information about the sentence, numpy-array of shape (50,)\n", | |
" \"\"\"\n", | |
"\n\n", | |
" words = list(jieba.cut(sentence,cut_all=False))\n", | |
" temps=[]\n", | |
" for w in words:\n", | |
" try:\n", | |
" word_to_vec_map[w]\n", | |
" temp=w\n", | |
" except:\n", | |
" temp=list(w)\n", | |
" temps.append(temp)\n", | |
" flat_list=[]\n", | |
" _ = [flat_list.extend(item) if isinstance(item, list) else flat_list.append(item) for item in temps if item]\n", | |
" while ' ' in flat_list:\n", | |
" flat_list.remove(' ')\n", | |
" words=flat_list\n", | |
"\n", | |
" avg = np.zeros((300,))\n", | |
"\n", | |
" for w in words:\n", | |
" avg += word_to_vec_map[w]\n", | |
" avg = avg/len(words)\n", | |
"\n\n return avg" | |
], | |
"outputs": [], | |
"execution_count": 10, | |
"metadata": { | |
"collapsed": false, | |
"outputHidden": false, | |
"inputHidden": false | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#測試把留言的字詞向量加總及平均\n", | |
"\n", | |
"avg = sentence_to_avg(\"這是我一生最差勁的一天 \", word_to_vec_map)\n", | |
"print(\"avg = \", avg)" | |
], | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"avg = [-1.13861875e-01 1.04477125e+00 -5.89520000e-01 3.24456250e-01\n", | |
" 7.38658750e-01 -4.64008750e-01 -5.60848750e-01 2.58263750e-02\n", | |
" 4.23486250e-01 -1.34124125e-01 9.67411250e-01 1.42290625e-01\n", | |
" 1.35261125e+00 -9.27590000e-01 2.08043500e-01 -6.24710000e-01\n", | |
" 9.57933750e-01 -4.19742250e-01 -1.12124625e+00 -6.56251250e-01\n", | |
" -9.99676250e-01 1.06239875e+00 1.18810875e+00 8.31875000e-03\n", | |
" 5.75907500e-01 -4.11583750e-01 9.39983750e-01 1.99731000e-01\n", | |
" 7.24997500e-01 2.47968175e-01 7.64845000e-01 -1.00287625e+00\n", | |
" -9.05686250e-01 -5.49693750e-01 -8.29576250e-01 1.19548875e+00\n", | |
" 1.05869750e+00 -5.47840000e-01 1.93496250e-01 1.26827750e+00\n", | |
" 7.34667500e-01 1.07951875e+00 1.17510625e+00 -1.09544750e+00\n", | |
" -3.72637425e-01 -4.29343750e-01 -7.19981250e-01 -3.67023750e-01\n", | |
" 5.13111375e-01 -8.94367500e-01 1.25980625e+00 -3.23177875e-01\n", | |
" -3.98428750e-01 -3.33925750e-01 4.74337750e-02 -9.53198750e-01\n", | |
" -1.10110875e+00 -1.26326375e+00 -3.35262500e+00 3.98726250e-01\n", | |
" 1.14972250e+00 -1.73711250e-01 8.78507500e-01 1.17542625e+00\n", | |
" 4.49652500e-01 8.36718750e-01 -9.26055000e-01 -7.36686250e-01\n", | |
" -5.71937500e-01 -3.44220563e-01 1.33897750e+00 -8.14458750e-01\n", | |
" 1.17143875e+00 -6.51095000e-01 -8.13095000e-01 4.27236625e-01\n", | |
" -1.26325313e-01 -8.62573750e-01 -5.51665000e-01 -1.07435750e+00\n", | |
" -3.31313750e-01 2.76122500e-03 7.34642500e-01 -7.46127500e-01\n", | |
" -6.92805000e-01 -9.29848750e-01 2.51736750e-01 5.15750000e-01\n", | |
" 8.19151250e-02 -1.02911625e+00 -6.78373750e-01 7.21015000e-01\n", | |
" -2.91189750e-01 -2.18004375e-01 -9.07598750e-01 1.01732000e+00\n", | |
" 6.72615000e-01 7.39838750e-01 3.13162500e-02 1.33909000e-01\n", | |
" 5.16993750e-01 7.64466250e-01 -5.07446250e-01 8.62035000e-01\n", | |
" -5.14123750e-01 6.94323750e-01 -6.69157500e-02 -6.91896250e-01\n", | |
" -8.54900000e-01 8.69995000e-02 3.28381250e-01 -7.18932500e-01\n", | |
" -1.10166787e-01 5.39598750e-01 2.46171625e-01 4.52513750e-01\n", | |
" -5.27893750e-01 2.25719250e-01 -8.03552500e-01 -9.48873750e-01\n", | |
" -8.62896250e-01 9.29968750e-01 6.87881250e-01 -7.01957500e-02\n", | |
" -6.60365000e-01 -8.36058750e-01 8.01417500e-01 8.33808750e-01\n", | |
" -8.51356250e-01 3.72643750e-01 8.79275000e-01 3.20965000e-01\n", | |
" -9.38807500e-01 9.33356250e-01 -1.07136500e+00 4.09523750e-01\n", | |
" -4.92283750e-01 -2.83116625e-01 -7.83887500e-01 -4.53477500e-01\n", | |
" 1.90866250e-02 7.95182500e-01 1.12467125e+00 -1.14985250e+00\n", | |
" -8.80171250e-01 3.48727375e-01 -1.42132000e+00 -4.00085000e-02\n", | |
" -7.22596250e-01 -2.97139738e-01 5.76632500e-01 7.57280000e-01\n", | |
" -2.39040000e-01 1.53809625e-01 1.01213625e+00 -7.82800000e-01\n", | |
" 3.56789875e-01 -4.20626250e-01 -4.67568000e-02 6.97733750e-01\n", | |
" 5.33088750e-01 2.95906500e-01 3.64505750e-01 -1.32033875e-01\n", | |
" -4.67241125e-02 3.72525000e-01 -2.80784750e-01 8.81593750e-01\n", | |
" 5.25787500e-01 3.54515000e-01 5.13618750e-01 6.84813750e-01\n", | |
" 1.07395750e+00 -1.06314563e-01 5.69411125e-02 8.81600000e-02\n", | |
" 6.70105000e-01 7.66377500e-01 -1.22160150e-01 -4.73730000e-01\n", | |
" -6.71396250e-01 -5.60345000e-01 2.17737125e-01 1.59381812e-01\n", | |
" -2.97697750e-01 8.97773750e-01 -5.75956375e-02 -6.32732500e-01\n", | |
" -4.58300000e-01 8.10427500e-01 -3.30479000e-01 -8.91516250e-02\n", | |
" -5.06517875e-01 -4.16750000e-01 -2.51331125e-01 -6.24966250e-01\n", | |
" -8.69408750e-01 -8.17258750e-01 3.27987500e-03 4.15771250e-01\n", | |
" 3.45661250e-02 -5.34625000e-01 -6.72638750e-02 -6.20112500e-01\n", | |
" -9.17050000e-03 1.18512750e-01 3.74020000e-01 -4.08962500e-01\n", | |
" -1.78093500e-01 4.28253750e-01 -9.24425000e-01 2.97356375e-01\n", | |
" -5.55003750e-01 5.52548750e-01 -6.87260000e-01 1.47368000e-01\n", | |
" 4.36581250e-01 -2.34388250e-01 1.02227375e+00 6.80735000e-01\n", | |
" -5.53666250e-01 7.15455000e-01 3.31547375e-01 -2.03136875e-01\n", | |
" -9.04048750e-01 7.41556250e-01 -1.52710375e-01 -3.47773750e-01\n", | |
" -9.16613750e-01 2.28735000e-02 -7.97330000e-01 -4.90311250e-01\n", | |
" 3.36850750e-01 -5.14460000e-01 -4.94883625e-02 1.04841000e+00\n", | |
" 3.49311625e-01 -9.82771250e-01 -2.46683125e-01 -5.05893750e-01\n", | |
" 7.70962500e-01 -7.81355000e-01 5.95733750e-01 -1.43723687e-01\n", | |
" -3.12608750e-01 -5.29705000e-03 -4.71833750e-01 -1.94240000e-01\n", | |
" -6.85587500e-01 4.58510000e-01 -3.72441250e-01 -1.05118125e-01\n", | |
" -8.62170000e-01 1.70604925e-01 -5.23317500e-01 9.52750000e-04\n", | |
" 4.80752500e-01 6.40448750e-01 -1.55436250e+00 -5.59417500e-02\n", | |
" 3.91768750e-01 -2.51005000e-01 -4.59052500e-01 -3.99057500e-02\n", | |
" -3.77393892e-01 -3.03079500e-01 -6.06856250e-01 -9.01596250e-01\n", | |
" -1.79545000e+00 -1.00723125e+00 7.72862500e-03 3.44258750e-01\n", | |
" 1.05848500e+00 -6.01907500e-01 -3.07066625e-02 -2.38434125e-01\n", | |
" -4.34837500e-02 8.50942500e-01 7.36421250e-02 -4.72792500e-01\n", | |
" 6.98252500e-01 1.26219250e+00 1.41322875e-01 6.67311250e-01\n", | |
" -2.27989913e-01 -2.26622500e-01 5.46474750e-01 3.89776250e-01\n", | |
" -9.58172500e-01 6.50233500e-02 -9.52265000e-01 -1.01913250e+00\n", | |
" -5.33670000e-01 -4.93423750e-01 -2.16531000e-02 -8.96391250e-01\n", | |
" 2.50648500e-01 3.73640362e-01 1.06248250e+00 2.75682000e-01]\n" | |
] | |
} | |
], | |
"execution_count": 11, | |
"metadata": { | |
"collapsed": false, | |
"outputHidden": false, | |
"inputHidden": false | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#模型參數 learning_rate 及 num_iterations均可以修改\n", | |
"\n", | |
"def model(X, Y, word_to_vec_map, learning_rate = 0.005, num_iterations = 1000):\n", | |
" \"\"\"\n", | |
" Model to train word vector representations in numpy.\n", | |
"\n", | |
" Arguments:\n", | |
" X -- input data, numpy array of sentences as strings, of shape (m, 1)\n", | |
" Y -- labels, numpy array of integers between 0 and 7, numpy-array of shape (m, 1)\n", | |
" word_to_vec_map -- dictionary mapping every word in a vocabulary into its 50-dimensional vector representation\n", | |
" learning_rate -- learning_rate for the stochastic gradient descent algorithm\n", | |
" num_iterations -- number of iterations\n", | |
"\n", | |
" Returns:\n", | |
" pred -- vector of predictions, numpy-array of shape (m, 1)\n", | |
" W -- weight matrix of the softmax layer, of shape (n_y, n_h)\n", | |
" b -- bias of the softmax layer, of shape (n_y,)\n", | |
" \"\"\"\n", | |
" np.random.seed(1)\n", | |
"\n\n", | |
" m = Y.shape[0] # number of training examples\n", | |
" n_y = 8 # number of classes \n", | |
" n_h = 300 # dimensions of the GloVe vectors \n", | |
"\n\n", | |
" W = np.random.randn(n_y, n_h) / np.sqrt(n_h)\n", | |
" b = np.zeros((n_y,))\n", | |
"\n\n", | |
" Y_oh = convert_to_one_hot(Y, C = n_y) \n", | |
"\n\n", | |
" for t in range(num_iterations): # Loop over the number of iterations\n", | |
" for i in range(m): # Loop over the training examples\n", | |
" # Average the word vectors of the words from the j'th training example\n", | |
" avg = sentence_to_avg(X[i], word_to_vec_map)\n", | |
" # Forward propagate the avg through the softmax layer\n", | |
" z = np.dot(W, avg) + b\n", | |
" a = softmax(z)\n", | |
" # Compute cost using the j'th training label's one hot representation and \"A\" (the output of the softmax)\n", | |
" cost = -np.sum(Y_oh[i] * np.log(a))\n", | |
"\n", | |
" # Compute gradients \n", | |
" dz = a - Y_oh[i]\n", | |
" dW = np.dot(dz.reshape(n_y,1), avg.reshape(1, n_h))\n", | |
" db = dz\n", | |
"\n", | |
" # Update parameters with Stochastic Gradient Descent\n", | |
" W = W - learning_rate * dW\n", | |
" b = b - learning_rate * db\n", | |
" \n", | |
" if t % 100 == 0:\n", | |
" print(\"Epoch: \" + str(t) + \" --- cost = \" + str(cost))\n", | |
" pred = predict(X, Y, W, b, word_to_vec_map)\n", | |
"\n return pred, W, b" | |
], | |
"outputs": [], | |
"execution_count": 12, | |
"metadata": { | |
"collapsed": false, | |
"outputHidden": false, | |
"inputHidden": false | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#訓練模型\n", | |
"pred, W, b = model(X_train, Y_train, word_to_vec_map)\n", | |
"print(pred)" | |
], | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Epoch: 0 --- cost = 2.838557314655528\n", | |
"Accuracy: 0.059602649006622516\n", | |
"Epoch: 100 --- cost = 1.208787957239081\n", | |
"Accuracy: 0.5960264900662252\n", | |
"Epoch: 200 --- cost = 0.619723973651847\n", | |
"Accuracy: 0.7218543046357616\n", | |
"Epoch: 300 --- cost = 0.4027164591549059\n", | |
"Accuracy: 0.7682119205298014\n", | |
"Epoch: 400 --- cost = 0.3002130588259839\n", | |
"Accuracy: 0.8013245033112583\n", | |
"Epoch: 500 --- cost = 0.23925132641522678\n", | |
"Accuracy: 0.8211920529801324\n", | |
"Epoch: 600 --- cost = 0.19674781080695078\n", | |
"Accuracy: 0.8609271523178808\n", | |
"Epoch: 700 --- cost = 0.1649123704857175\n", | |
"Accuracy: 0.8807947019867549\n", | |
"Epoch: 800 --- cost = 0.14047599147368353\n", | |
"Accuracy: 0.9139072847682119\n", | |
"Epoch: 900 --- cost = 0.12151579557825612\n", | |
"Accuracy: 0.9403973509933775\n", | |
"[[3.]\n", | |
" [2.]\n", | |
" [3.]\n", | |
" [0.]\n", | |
" [0.]\n", | |
" [3.]\n", | |
" [2.]\n", | |
" [3.]\n", | |
" [1.]\n", | |
" [3.]\n", | |
" [3.]\n", | |
" [1.]\n", | |
" [3.]\n", | |
" [2.]\n", | |
" [3.]\n", | |
" [2.]\n", | |
" [3.]\n", | |
" [1.]\n", | |
" [2.]\n", | |
" [7.]\n", | |
" [0.]\n", | |
" [2.]\n", | |
" [2.]\n", | |
" [4.]\n", | |
" [2.]\n", | |
" [2.]\n", | |
" [0.]\n", | |
" [7.]\n", | |
" [4.]\n", | |
" [2.]\n", | |
" [0.]\n", | |
" [3.]\n", | |
" [2.]\n", | |
" [2.]\n", | |
" [3.]\n", | |
" [4.]\n", | |
" [2.]\n", | |
" [2.]\n", | |
" [0.]\n", | |
" [2.]\n", | |
" [3.]\n", | |
" [0.]\n", | |
" [3.]\n", | |
" [2.]\n", | |
" [4.]\n", | |
" [7.]\n", | |
" [7.]\n", | |
" [4.]\n", | |
" [2.]\n", | |
" [1.]\n", | |
" [1.]\n", | |
" [1.]\n", | |
" [2.]\n", | |
" [0.]\n", | |
" [3.]\n", | |
" [4.]\n", | |
" [4.]\n", | |
" [2.]\n", | |
" [1.]\n", | |
" [2.]\n", | |
" [0.]\n", | |
" [3.]\n", | |
" [2.]\n", | |
" [2.]\n", | |
" [0.]\n", | |
" [0.]\n", | |
" [3.]\n", | |
" [2.]\n", | |
" [1.]\n", | |
" [2.]\n", | |
" [2.]\n", | |
" [4.]\n", | |
" [3.]\n", | |
" [3.]\n", | |
" [2.]\n", | |
" [4.]\n", | |
" [0.]\n", | |
" [0.]\n", | |
" [0.]\n", | |
" [3.]\n", | |
" [7.]\n", | |
" [3.]\n", | |
" [2.]\n", | |
" [0.]\n", | |
" [1.]\n", | |
" [2.]\n", | |
" [7.]\n", | |
" [0.]\n", | |
" [2.]\n", | |
" [2.]\n", | |
" [2.]\n", | |
" [3.]\n", | |
" [2.]\n", | |
" [2.]\n", | |
" [2.]\n", | |
" [4.]\n", | |
" [1.]\n", | |
" [3.]\n", | |
" [3.]\n", | |
" [4.]\n", | |
" [1.]\n", | |
" [1.]\n", | |
" [3.]\n", | |
" [1.]\n", | |
" [0.]\n", | |
" [4.]\n", | |
" [0.]\n", | |
" [3.]\n", | |
" [3.]\n", | |
" [4.]\n", | |
" [4.]\n", | |
" [4.]\n", | |
" [7.]\n", | |
" [2.]\n", | |
" [7.]\n", | |
" [0.]\n", | |
" [4.]\n", | |
" [4.]\n", | |
" [0.]\n", | |
" [3.]\n", | |
" [3.]\n", | |
" [7.]\n", | |
" [7.]\n", | |
" [3.]\n", | |
" [3.]\n", | |
" [2.]\n", | |
" [5.]\n", | |
" [5.]\n", | |
" [5.]\n", | |
" [5.]\n", | |
" [5.]\n", | |
" [5.]\n", | |
" [6.]\n", | |
" [7.]\n", | |
" [6.]\n", | |
" [5.]\n", | |
" [5.]\n", | |
" [1.]\n", | |
" [5.]\n", | |
" [7.]\n", | |
" [7.]\n", | |
" [7.]\n", | |
" [7.]\n", | |
" [7.]\n", | |
" [7.]\n", | |
" [7.]\n", | |
" [3.]\n", | |
" [0.]\n", | |
" [6.]\n", | |
" [6.]\n", | |
" [2.]]\n" | |
] | |
} | |
], | |
"execution_count": 13, | |
"metadata": { | |
"collapsed": false, | |
"outputHidden": false, | |
"inputHidden": false | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print(\"Training set:\")\n", | |
"pred_train = predict(X_train, Y_train, W, b, word_to_vec_map)\n", | |
"print('Test set:')\n", | |
"pred_test = predict(X_test, Y_test, W, b, word_to_vec_map)" | |
], | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Training set:\n", | |
"Accuracy: 0.9536423841059603\n", | |
"Test set:\n", | |
"Accuracy: 0.5882352941176471\n" | |
] | |
} | |
], | |
"execution_count": 14, | |
"metadata": { | |
"collapsed": false, | |
"outputHidden": false, | |
"inputHidden": false | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#輸入多個留言看結果\n", | |
"X_my_sentences = np.array([\"政府可恥\", \"我憎恨你\", \"這位同學是白癡\", \"曼聯快搶攻\", \"真想吃薯條\", \"我可不可以去廁所\"])\n", | |
"Y_my_labels = np.array([[7], [7], [2], [1], [4],[6]])\n", | |
"\n", | |
"pred = predict(X_my_sentences, Y_my_labels , W, b, word_to_vec_map)\n", | |
"print_predictions(X_my_sentences, pred)" | |
], | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Accuracy: 0.6666666666666666\n", | |
"\n", | |
"政府可恥 😡\n", | |
"我憎恨你 😞\n", | |
"這位同學是白癡 😄\n", | |
"曼聯快搶攻 ⚾\n", | |
"真想吃薯條 🍴\n", | |
"我可不可以去廁所 😡\n" | |
] | |
} | |
], | |
"execution_count": 15, | |
"metadata": { | |
"collapsed": false, | |
"outputHidden": false, | |
"inputHidden": false | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#輸入單詞看結果\n", | |
"\n", | |
"my_sentence=\"我訂了三文魚\"\n", | |
"X_my_sentences = np.array([my_sentence])\n", | |
"#Y_my_labels = np.array([[1]])\n", | |
"pred = predict(X_my_sentences, Y_my_labels , W, b, word_to_vec_map)\n", | |
"print_predictions(X_my_sentences, pred)" | |
], | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Accuracy: 0.16666666666666666\n", | |
"\n", | |
"我訂了三文魚 🍴\n" | |
] | |
} | |
], | |
"execution_count": 16, | |
"metadata": { | |
"collapsed": false, | |
"outputHidden": false, | |
"inputHidden": false | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print(Y_test.shape)\n", | |
"print(' '+ label_to_emoji(0)+ ' ' + label_to_emoji(1) + ' ' + label_to_emoji(2)+ ' ' + label_to_emoji(3)+' ' + label_to_emoji(4)+' ' + label_to_emoji(5)+' ' + label_to_emoji(6)+' ' + label_to_emoji(7))\n", | |
"print(pd.crosstab(Y_test, pred_test.reshape(51,), rownames=['Actual'], colnames=['Predicted'], margins=True))\n", | |
"plot_confusion_matrix(Y_test, pred_test)" | |
], | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"(51,)\n", | |
" ❤️ ⚾ 😄 😞 🍴 😭 💩 😡\n", | |
"Predicted 0.0 1.0 2.0 3.0 4.0 7.0 All\n", | |
"Actual \n", | |
"0 6 0 0 0 1 0 7\n", | |
"1 1 4 0 0 0 1 6\n", | |
"2 5 0 7 2 0 4 18\n", | |
"3 1 1 1 8 0 3 14\n", | |
"4 0 0 1 0 5 0 6\n", | |
"All 13 5 9 10 6 8 51\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/pandas/core/generic.py:7460: RuntimeWarning: '<' not supported between instances of 'str' and 'float', sort order is undefined for incomparable objects\n", | |
" return_indexers=True)\n" | |
] | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 336x288 with 2 Axes>" | |
], | |
"image/png": [ | |
"iVBORw0KGgoAAAANSUhEUgAAASwAAAD+CAYAAACA/DjlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAGh1JREFUeJzt3X2cXVV97/HPdzITQkwQwgQcSWK4PEWKimVKvWIpBEUQCil4LbEoKrfBtrwKPpSH3nurXm8vohUtr6o1FmosjxakIBUwrzSAsRBJeIYQSBAupIEwQIAghAR+94+9Bg6Tk5kzM2c/Tb7v1+u85ux99lm/dfac+c1aa6+9tyICM7M66Ci7AmZmrXLCMrPacMIys9pwwjKz2nDCMrPacMIys9pwwjKz2nDCMrPa6Cy7AnmStD+wESAiVpRUh46IeK2AOAcCXcDmiFiad7yGuKXs4xLjFr6fJSk8wxsYwy0sSUcCPwX+DPgXSZ8uKO5Rkr4i6RxJOxeUrD4MXAMcBVwq6VRJkwqIW9Y+LituKfsZGJ/ij9m/15ZFxJh6AAImAT8Djknr3gesAj6bc+zfBX4NfBz4B+CXwPuBrhw/63bAD4GPpXX7AwuBLwITx9I+Ljlu4fs5xdkLuAJ4R1ruyCtWHR5jLmNHZgOwDNhBUldE3AqcAJwp6VM5ht8P+HlEXBIRnwWuBM4ADoD2/4dMn3UjsAJ4t6RJEXEncDrwESCXlkdZ+7jkuIXv5+QJ4FHgHEnTI+K1bbmlNZY/+BPAYcD2ABGxDPgEcKqk3XOKeRuwvaRZKeZ5wBLgW5J2jPy6h3cDOwN7SOqMiPuAvwQ+L+k9OcWEcvZx4XEbEkRh+1nSuyRdFREvAF8GHgG+ua0nrTH3oSUJICK+C0wEvifprem/8RKyL11eA5hPAJuBD0nqTvX4W+Be4JScYhIR1wEbgL8A9kstgOXA9WTdmbziFrqPJY0rMq6kaZLG9/+jKXg/PwKEpMtT0jqHrOu7TSctpX5xrUnaB5hC1lV4LSJebXjtUuBl4Fayo6KfB34/Ih5vU+xxA+K9F/gqcANwY0TcI+msVK+vtyHensCOwL0R8fKA184FJpMdPXsM+AJwUEQ80oa4vwV0AysiYl3jkas897GkDwC7R8Q/p+XxEfFKAXE/TNay+eOIeDglxU3ptTz389si4on0fDvgn4DtIuJ4SZOBs4GZwF+1I17d1D5hSToO+L/AmvRYBvwwIp5v2OYzwNuB9wBfTk350cbdOyIeTM/HRcSr/X/EKWmdQpZYAjgQmBMR94wy5tFkn/Vpstbc30TEvQP+mA4F3g3sDXwnIu4fTcxU5pHAucDDZIf050XEmgFx27qPU+thIrCUrPVyfkT8Q3ptQn+yzul3ezjwdbLf39URcVpa//o/p5z28yzgfuDvyP4xzJf0FuDbwNSImJOS1leBHch+D5tHG7dWihjZz+tB9sdzOdl/N4DjgW8AfwO8tcn227Up7tHAb4BLGtaNSz870s9usiM8HydrIYw25vvJBn3fm5a/C1zY8HrHgO072/RZDwEeBA5My1cBH2z8zHns44byziBrwfwI+Nwg27Xrd/tBsq7Xb6Xv18+Bg/Pez6msaWRjnmemuD8im0IxA/gWcGXabgeyBNa2/VyXx1joA+9Alhgg+2O6luyLNheyiX6Sfju9/spog6X/eKeSHSF6RdJFAJG1sDrjjYH1zRHxUGRHDH892rjJuRFxR3r+JWBK6jYQ2ZjG76RWGMCrTUsYvieBUyLiV5LeRjZ141RJ3wc+CSDpgHbu4wE2A9OBBcCBks6TdE6K+/4c4o4DPhlZS+0twEqy5NU/gfO19J06Km3frv1MZF3ZXwG/TXb08TrgT8gS1wXAdEnnR8TzEfFUu+I2IymG8bi+hfIekXSPpDslLUvrpkhaKOmh9HOnocqpdcKKrDtyHnCcpN9LyWIJcCdwsKTtgYOA/0zbj7r/GxEvAp8BLiGbgzOhIWltBkhHjE6UNKH/IEAbLAV+ksofRzYv6B1kCRtJ04BZZF3itnzWVM6KiFicFk8GvhsRc4BbgCMlzQQOpo37eICrgSciYhHZZ/tTsq4awO+0O25E3BAR/6HsDIX1wL8BX5L0rogISV3AHsDt7Yzb8D05i2wYoRtYS9btfAj4X2Qtv++2I16LdWrpkeraikMjYv+I6E3LZwGLImIvYFFaHlzZTbzRPoAJZC2e+by56X4jsEcB8Xcmm291UVp+N1nXdJccY3aSTaBclJZPBL4JTC54318H7J1zjLeTDTz/Cdkf7l+TtaI/ThqDLeBz/m+ywe7+7n4uccnG6saTjVFdDDxANvYJWS9ipwJ/t9HR0dHSA1jWQnmPAN0D1q0EetLzHmDlUOXU/lzCiHhZ0sVk/5XOTgOXG4GpZIeg847/tKRTgG9IWknWaj04ItblGHMzsEHSY6l7dDjw6cgOf+ei8ahgWj4e2AXILSZARPynpMfIWhh/HhE/TQPeqxrrk7O7gM+RjY++llfcVG7/MMNNZIP5/5peeyiPmINpX+cAyP4+fy4pgO9HxHxg14hYm15/Ath1qEJqn7AAIuJZST8gO8JyCtmh7hMj4smC4vdJuhs4EvhQwy8hF6n70AX8Xvp5WN5f6P4/0jRmdiLZFII/yvuzJj8gO1q3PC3fFAWco9kvIq6U9Edkg+KPFBBvZZoKM1PSxIj4Td4xB5JER0drI0avvvpqd/+4VDI/JaRGH4jsyPIuwEJJDzS+GBGRktmgxkTCAohsbs5iSTdni8V9odNg4UeAw2OUUxda0fCf+KvAbQX/932NbGzluIhYWUTAiHgMeKy/lVfw77Y/5seKipncChxXcMw3GUYLqy/eGJdqKiLWpJ/rJF1FNtXnSUk9EbFWUg8wZK+k9vOwqqJxblCBMX3ZkTGsrNYVQEdHR4wfP76lbTdu3Lh8sISVjqx3RMQL6flCsnHBw4CnI+JrqUU5JSLOGCzWmGlhla3oZJViOlmNYWUlq35tHMPaFbgqlddJNn/xekm3AT+WdDLZCd5DtmLdwjKzLXR0dMSECRNa2vall14atIXVTm5hmdkWGuZYVYoTlpk1VcWEVeuZ7q2QNG9biOm4YztuGTE7OjpaehRap0KjlaOML3Upf0iOO6bjlpEkWz01pzDuEprZFqo6hlWLo4RTpkyJadOmjei9zzzzDFOmTBnRe1udhzLQU089xdSpU0f03tGoY9zRfP/6+vro7m71vNs3G80f42g+78aNG0f0vtF8j9esWcMzzzwzrA/c2dkZkydPbmnb9evX+yhho2nTpnHttdcWHnfGjBmFx9zWbNq0qZS4XV1dpcRdvXp14THnzJkzovdVsYVVi4RlZsVzwjKz2nDCMrNaGM7VGorkhGVmTbmFZWa14YRlZrXhhGVmtVDViaNOWGbWlBOWmdVGFRNWKcctJR0haaWkVenSqGZWMb5aA6/fBPQ7ZHeY2ReYK2nfouthZlvX6pUaim6FldHCOpDsnnIPpzvdXAYcW0I9zGwQTliZ3YDHGpYfT+vMrEKqmLAqO+ierrA4D2C33ZzPzIrmQffMGmB6w/K0tO5NImJ+RPRGRO9IrwNkZiPnFlbmNmAvSbuTJaoTgI+XUA8z2wqf/JxExGZJpwI3AOOACyPivqLrYWaDq2KXsJQxrIj4GfCzMmKbWWucsMysNpywzKwWfPKzmdWKE5aZ1YYTlpnVhqc1mFkteAzLzGrFCcvMasMJy8xqwwnLzGrDCWuEurq66OnpKTxuX19f4TEBuru7S4lbhq6urlLibtq0qZS4ZXzekSSePAbd09WGlwFrIuLodAGEy4CdgeXAJ9JFPbeqesctzawScrim+2nAioblc4FvRcSewLPAyUPWaVifwMy2Ge28HpakacBRwD+mZQGzgSvSJguAOUOVU4suoZkVr81dwm8DZwCT0/LOwPqI2JyWW7pUultYZraFYd41p1vSsobHvAFlHQ2si4jlo62XW1hm1tQwWlh9EdE7yOsHAcdI+ggwAdgB+DtgR0mdqZXV9FLpA7mFZWZNtWsMKyLOjohpETGT7JLo/x4RfwwsBj6aNjsJuHqospywzKypAm5CcSbweUmryMa0LhjqDe4SmtkW8roJRUTcCNyYnj9MdmPlljlhmVlTnuluZrVRxYRVyhiWpAslrZN0bxnxzWxoVbyRalmD7j8Ejigptpm1oIoJq6z7Et4saWYZsc1saL7iqJnVihPWMKTp/fMAZsyYUXJtzLY9VbwJRfVqlETE/IjojYjeben6UGZV4TEsM6uFqo5hlTWt4VLgFmAfSY9LGvLCXWZWLLewkoiYW0ZcM2tdFVtY7hKaWVNOWGZWC3md/DxaTlhm1pRbWGZWG05YZlYbTlhmVhtOWGZWC1WdOOqEZWZNOWGZWW14WoOZ1YK7hKOwefNm+vr6Co/b09NTeEyA1atXlxK3jM87ceLEwmMCPPfcc6XEXbJkSeExN2zYMKL3OWGZWW04YZlZbThhmVltOGGZWS140N3MasXTGsysNtzCMrPacMIys1rwGJaZ1YoTlpnVRhUTVuGHASRNl7RY0v2S7pN0WtF1MLOh+TZfmc3AFyLidkmTgeWSFkbE/SXUxcya8E0okohYC6xNz1+QtALYDXDCMqsQdwkHkDQTeC+wtMx6mNmW2tUllDRB0q8k3ZWGgb6S1u8uaamkVZIulzR+qLJKS1iSJgFXAqdHxPNNXp8naZmkZU8//XTxFTTbxrVxDGsjMDsi3gPsDxwh6X3AucC3ImJP4Fng5KEKKiVhSeoiS1YXR8RPmm0TEfMjojcienfeeediK2hmbUtYkem/KFdXegQwG7girV8AzBmqrDKOEgq4AFgREecVHd/MhtZqsmp1nEvSOEl3AuuAhcBqYH1EbE6bPE42lj2oMlpYBwGfAGZLujM9PlJCPcxsEMNIWN39wzfpMW9gWRHxakTsD0wDDgRmjaROZRwlXAJU7/CDmb3JMKY19EVEbysbRsR6SYuB/wrsKKkztbKmAWuGrFOrNTKzbUsbjxJOlbRjer498CFgBbAY+Gja7CTg6qHK8qk5ZraFNs9i7wEWSBpH1kj6cURcK+l+4DJJ/we4g2xse1BOWGbWVLsSVkTcTTbfcuD6h8nGs1rmhGVmTVVxpvtWE5akn5LNlWgqIo7JpUZmVgm1SljA3xZWCzOrlNqd/BwRNxVZETOrlrq1sACQtBdwDrAvMKF/fUT8lxzrZWYlq2LCaqXN90/A98iuY3Uo8CPgojwrZWblq+sF/LaPiEWSFBGPAl+WtBz465zr9rrOzk66u7uLCve6TZs2FR4TYMaMGaXEXb16deExZ80a0Rkao9bV1VVK3DK+x52dI5sMUMUWViufZKOkDuAhSaeSTZ+flG+1zKxMZbSeWtFKl/A0YCLwF8ABZCcun5RnpcysfLXsEkbEbenpBuDT+VbHzKqiVtMa+qUzq7eYQBoRs3OpkZlVQhW7hK2MYX2x4fkE4HiyI4ZmNkZVdQyrlS7h8gGrfinpVznVx8wqopYJS9KUhsUOsoH3t+ZWIzOrhFomLGA52RiWyLqCv6aFu1uYWb3VNWG9MyJeblwhabuc6mNmFVHFhNXKccv/aLLulnZXxMyqo/9qDa08ijTY9bDeRnbbne0lvZc3bhyxA9lEUjMbw6rYwhqsS/hh4FNkd7P4Jm8krOeBvxppQEkTgJuB7VL8KyLiSyMtz8zyUauEFRELyC4cf3xEXNnGmP23rd6Q7gC9RNJ1EXFrG2OY2ShVMWG10gE9oP8WPQCSdkp3uRiRQW5bbWYV0e47P7dLKwnryIhY378QEc8Co7pT88DbVkfE0tGUZ2btV9eENa5xGkO6EeKopjUMvG21pP0GbiNpXv+tr/v6+kYTzsxGoK4J62JgkaSTJf13YCGwoB3BU8ttMXBEk9fmR0RvRPSWcdEzs21draY19IuIcyXdBXyQbKzpBuAdIw0oaSqwKSLWN9y2+tyRlmdm7Vfbk5+TJ8mS1X8jOzVnNEcNm962ehTlmVkOapWwJO0NzE2PPuByQBFx6GgCbu221WZWLbVKWMADwC+AoyNiFYCkzxVSKzMrXRUT1mAjZscBa4HFkn4g6TDemO1uZmNcrY4SRsS/RsQJwCyyI3mnA7tI+p6kw4uqoJkVr7YTRyPixYi4JCL+gGze1B3AmbnXzMxKVcVpDcOKFhHPpvlRh+VVITOrhiq2sEZ2S1gzG/OqOOjuhGVmW6jqxNHq3SnRzCqhXV1CSdMlLZZ0v6T7JJ2W1k+RtFDSQ+nnTkOV5YRlZk21cQxrM/CFiNgXeB/w55L2Bc4CFkXEXsCitDyoWnQJJdHV1VV2NQqzadOmUuLOmjWr8Jhr164tPCZAT09PKXHL2McTJkwY0fvadQQwItaSzekkIl6QtILs8uvHAoekzRYANzLEDIRaJCwzK1ZeY1iSZpKdmrcU2DUlM4AngF2Her8Tlpk1NYyE1S1pWcPy/IiY36S8SWQXTjg9Ip5vLD8iQtKQVx52wjKzpoaRsPoioneIsrrIktXFEfGTtPpJST0RsVZSD9kViAflQXcza6qNRwkFXACsiIjzGl66BjgpPT8JuHqostzCMrOm2jiGdRDwCeAeZfdygOxWgV8DfizpZOBR4GNDFeSEZWZbaOege0QsYetXehnWaX5OWGbWVNEnNrfCCcvMmqriqTlOWGa2haqeS+iEZWZNVTFhldZJVXb35zsk+Y45ZhXk62G92WnACmCHEutgZlvhFlYiaRpwFPCPZcQ3s6G5hfWGbwNnAJNLim9mg5BUyWkNhddI0tHAuohYPsR28yQtk7TsqaeeKqh2Ztavii2sMlLoQcAxkh4BLgNmS7po4EbpZhe9EdE7derUoutots1zwgIi4uyImBYRM4ETgH+PiBOLroeZDa6KCcvzsMxsC5442kRE3Eh2WVQzqxgnLDOrDScsM6uNKk5rcMIysy14DMvMasUJy8xqwwnLzGrDCcvMasMJy8xqwYPuZlYrntZgZrXhFtYIvfDCCyxevLjwuLNmzSo8JkBfX18pcTdt2lR4zO7u7sJjAjzwwAOlxH3nO99ZStyRcMIys1rwGJaZ1YoTlpnVhhOWmdWGjxKaWS14DMvMasUJy8xqwwnLzGrDCcvMasMJy8xqoaqD7rket5Q0R1JImpWWZ0q6Nz0/RNK1ecY3s5Hr6Oho6VFonXIufy6wJP00sxpp141UJV0oaV1/YyWtmyJpoaSH0s+dWqlTbglL0iTgA8DJZHd4NrMaaeOdn38IHDFg3VnAoojYC1iUloeUZwvrWOD6iHgQeFrSATnGMrM2ajVZtZKwIuJm4JkBq48FFqTnC4A5rdQrz4Q1F7gsPb+MYXYLJc2TtEzSsueee67tlTOzwbWxhdXMrhGxNj1/Ati1lTflcpRQ0hRgNvAuSQGMAwL4TqtlRMR8YD7APvvsE3nU08y2bhjJqFvSsobl+envtyURESlPDCmvaQ0fBf45Ik7pXyHpJmB6TvHMrM2GkbD6IqJ3mMU/KaknItZK6gHWtfKmvLqEc4GrBqy7Ejg7p3hm1kaS8p7WcA1wUnp+EnB1K2/KpYUVEYc2WXc+cH7D8o3AjXnEN7PRa9fEUUmXAoeQdR0fB74EfA34saSTgUeBj7VSlme6m1lT7UpYEbG1A26HDbcsJywza6qKp+Y4YZlZU05YZlYLVT352QnLzJpywjKz2vBNKMysNtzCMrNa8BiWmdWKE5aZ1YYT1ghNnjyZQw/d4myfMaunp6fsKlhOIoq/8Iik5SN8X7urMmq1SFhmVjwnLDOrhf6rNVSNE5aZNeUWlpnVhhOWmdWGE5aZ1YInjppZrThhmVltOGGZWW14WoOZ1YLHsMysVpywzKw2nLDMrDacsMysNpywhkHSPGBeWtwgaeUIi+oG+tpTq0rHdNyxHXc0Md8x3Df45Odhioj5wPzRliNpWUT0tqFKlY7puGM7bkkxiwzXksomLDMrlxOWmdVCVedhVa+T2n6j7lbWJOagcSW9KulOSfdK+hdJE0caRNIhkq5Nz48B/t8g2+4o6c9GEOPLkr44xGaV289jKWZ/0hrqUaQxn7DSWNiYj9lC3JciYv+I2A94Bfhs44vKDPv7EBHXRMRxg2yyIzDshNVi7Cru5zET0wnLquIXwJ6SZkpaKelHwL3AdEmHS7pF0u2pJTYJQNIRkh6QdDvweoKS9ClJf5+e7yrpKkl3pcf7ga8Be6TW3TfSdn8p6TZJd0v6SkNZ/0PSg5KWAPsUtjesqSomLI9hbWMkdQJHAtenVXsBJ0XErZK6gf8JfDAiXpR0JvB5SV8HfgDMBlYBl2+l+POBmyLiDyWNAyYBZwH7RcT+Kf7hKeaBgIBrJB0MvAicAOxP9r28HRjR3V5s9Dytwcq2vaQ70/NfABcAbwcejYhb0/r3AfsCv0z/OccDtwCzgF9HxEMAki7ijTlyjWYDnwSIiFeB5yTtNGCbw9PjjrQ8iSyBTQauiojfpBjXjOrT2qhVcdDdCWvb8VJ/K6df+kK+2LgKWBgRcwds96b3jZKAcyLi+wNinN7GGNYGVUxY1WvzWZluBQ6StCeApLdI2ht4AJgpaY+03dytvH8R8KfpveMkvRV4gaz11O8G4DMNY2O7SdoFuBmYI2l7SZOBP2jzZ7NhquIYlhOWvS4ingI+BVwq6W5SdzAiXibrAv5bGnRft5UiTgMOlXQP2fjTvhHxNFkX815J34iInwOXALek7a4AJkfE7WRjY3cB1wG35fZBrSXtTFjpoM1KSasknTXiOpVx62wzq7be3t5YunRpS9t2dnYuH+y0oXQA5kHgQ8DjZP+M5kbE/cOtl1tYZtZUG1tYBwKrIuLhiHgFuAw4diR18qC7mTXVxmkNuwGPNSw/DvzuSApywjKzLSxfvvyGNC+vFRMkLWtYnp/XzHwnLDPbQkQc0cbi1gDTG5anpXXD5jEsM8vbbcBeknaXNJ7sjIYRTQx2C8vMchURmyWdSjYHbxxwYUTcN5KyPK3BzGrDXUIzqw0nLDOrDScsM6sNJywzqw0nLDOrDScsM6sNJywzqw0nLDOrjf8P7wgeTQb6vxIAAAAASUVORK5CYII=\n" | |
] | |
}, | |
"metadata": {} | |
} | |
], | |
"execution_count": 17, | |
"metadata": { | |
"collapsed": false, | |
"outputHidden": false, | |
"inputHidden": false | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#輸入gensim,看文字向量的關連性\n", | |
"\n", | |
"from gensim.models.keyedvectors import KeyedVectors\n", | |
"word_vectors = KeyedVectors.load_word2vec_format(\"data/wiki.zh.vec\", binary = False)" | |
], | |
"outputs": [], | |
"execution_count": 18, | |
"metadata": { | |
"collapsed": false, | |
"outputHidden": false, | |
"inputHidden": false | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"word_vectors.most_similar('小貓', topn = 10)" | |
], | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/gensim/matutils.py:737: FutureWarning: Conversion of the second argument of issubdtype from `int` to `np.signedinteger` is deprecated. In future, it will be treated as `np.int64 == np.dtype(int).type`.\n", | |
" if np.issubdtype(vec.dtype, np.int):\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"execution_count": 19, | |
"data": { | |
"text/plain": [ | |
"[('愛犬', 0.9518550038337708),\n", | |
" ('花貓', 0.9500274062156677),\n", | |
" ('養貓', 0.9474766850471497),\n", | |
" ('貓', 0.947167158126831),\n", | |
" ('貓咪', 0.9470548033714294),\n", | |
" ('小臉', 0.9433143734931946),\n", | |
" ('猴兒', 0.9423208236694336),\n", | |
" ('貓叫', 0.9421137571334839),\n", | |
" ('小輩', 0.9417970180511475),\n", | |
" ('柴犬', 0.9417864680290222)]" | |
] | |
}, | |
"metadata": {} | |
} | |
], | |
"execution_count": 19, | |
"metadata": { | |
"collapsed": false, | |
"outputHidden": false, | |
"inputHidden": false | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"word_vectors.similarity(\"電視\", \"劉德華\")" | |
], | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/gensim/matutils.py:737: FutureWarning: Conversion of the second argument of issubdtype from `int` to `np.signedinteger` is deprecated. In future, it will be treated as `np.int64 == np.dtype(int).type`.\n", | |
" if np.issubdtype(vec.dtype, np.int):\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"execution_count": 20, | |
"data": { | |
"text/plain": [ | |
"0.85784405" | |
] | |
}, | |
"metadata": {} | |
} | |
], | |
"execution_count": 20, | |
"metadata": { | |
"collapsed": false, | |
"outputHidden": false, | |
"inputHidden": false | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [], | |
"outputs": [], | |
"execution_count": 21, | |
"metadata": { | |
"collapsed": false, | |
"outputHidden": false, | |
"inputHidden": false | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [], | |
"outputs": [], | |
"execution_count": 21, | |
"metadata": { | |
"collapsed": false, | |
"outputHidden": false, | |
"inputHidden": false | |
} | |
} | |
], | |
"metadata": { | |
"kernel_info": { | |
"name": "python3" | |
}, | |
"language_info": { | |
"name": "python", | |
"version": "3.6.5", | |
"mimetype": "text/x-python", | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"pygments_lexer": "ipython3", | |
"nbconvert_exporter": "python", | |
"file_extension": ".py" | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"language": "python", | |
"display_name": "Python 3" | |
}, | |
"nteract": { | |
"version": "0.12.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment