Skip to content

Instantly share code, notes, and snippets.

@yamaguchiyuto
Created March 20, 2017 00:41
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yamaguchiyuto/182dc366f291e66a2a902f5efa1809e1 to your computer and use it in GitHub Desktop.
Save yamaguchiyuto/182dc366f291e66a2a902f5efa1809e1 to your computer and use it in GitHub Desktop.
JTM experiments
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 40,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"from jtm import JTM"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"vocab = ['apple', 'banana', 'computer', 'mac', 'burger', 'ipad']\n",
"categories = ['PC', 'FOOD']\n",
"\n",
"X = [[],[]]\n",
"X[0].append([0,0,2,3,3,5]) # apple, apple, computer, mac, mac, ipad\n",
"X[1].append([0]) # PC\n",
"X[0].append([0,0,1,3,3,4]) # apple, apple, banana, mac, mac, burger\n",
"X[1].append([1]) # FOOD\n",
"X[0].append([2,2,5,5]) # computer, computer, ipad, ipad\n",
"X[1].append([0]) # PC\n",
"X[0].append([1,1,4,4]) # banana, banana, burger, burger\n",
"X[1].append([1]) # FOOD\n",
"X[0].append([0,3]) # apple, mac\n",
"X[1].append([0]) # PC\n",
"X[0].append([0,3]) # apple, mac\n",
"X[1].append([1]) # FOOD\n",
"\n",
"V = [len(vocab), len(categories)]"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# モデルの定義\n",
"K = 2\n",
"alpha=0.01\n",
"beta=0.01\n",
"max_iter=1000\n",
"model = JTM(K=K, alpha=alpha, beta=beta, max_iter=max_iter)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<jtm.JTM instance at 0x10cd1c638>"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# JTMのフィッティング\n",
"model.fit(X,V)"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[array([1, 1, 1, 1, 1, 1]), array([0, 0, 0, 0, 0, 0]), array([1, 1, 1, 1]), array([0, 0, 0, 0]), array([1, 1]), array([0, 0])]\n",
"[array([1]), array([0]), array([1]), array([0]), array([1]), array([0])]\n"
]
}
],
"source": [
"# トピック割り当ての結果\n",
"# 最後の2つのドキュメントは出現単語は同じだけどカテゴリの情報をうまく使って正しいトピックに割り当てられている\n",
"print model.Z[0]\n",
"print model.Z[1]"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<jtm.JTM instance at 0x10cd1c638>"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# LDAのフィッティング\n",
"# カテゴリの情報(side information)を与えなければJTMはLDAと等価\n",
"X = [X[0]]\n",
"V = [V[0]]\n",
"model.fit(X,V)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[array([1, 1, 0, 1, 1, 0]), array([1, 1, 1, 1, 1, 1]), array([0, 0, 0, 0]), array([1, 1, 1, 1]), array([1, 1]), array([1, 1])]\n"
]
}
],
"source": [
"# トピック割り当ての結果\n",
"# カテゴリ情報が与えられていないので、普通のLDAでは当然最後の2つのドキュメントを正しくトピック割り当て出来ない\n",
"print model.Z[0]"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.11"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment