Created
June 4, 2018 04:57
-
-
Save RottenFruits/03845be50ee02bcedddff58425c9ddaa to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# MF" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"MatrixFactorization" | |
] | |
}, | |
"execution_count": 1, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"module MatrixFactorization\n", | |
"\n", | |
"mutable struct MatrixFactorizationModel\n", | |
" K::Int64\n", | |
" alpha::Float64\n", | |
" beta::Float64\n", | |
" n_user::Int64\n", | |
" n_item::Int64\n", | |
" R::Array\n", | |
" user_factors::Array\n", | |
" item_factors::Array\n", | |
"end\n", | |
"\n", | |
"#学習\n", | |
"function fit(model::MatrixFactorizationModel, n_iter::Int64)\n", | |
" model.user_factors = rand(Float64, model.n_user, model.K)\n", | |
" model.item_factors = rand(Float64, model.n_item, model.K)\n", | |
" \n", | |
" for i = 1:n_iter\n", | |
" sgd(model)\n", | |
" end\n", | |
" \n", | |
"end\n", | |
"\n", | |
"#確率的最急降下法\n", | |
"function sgd(model::MatrixFactorizationModel)\n", | |
" samples = model.R[shuffle(1:end), :]\n", | |
" \n", | |
" for i in 1:size(samples)[1]\n", | |
" user = samples[i, :][1]\n", | |
" item = samples[i, :][2]\n", | |
" \n", | |
" err = samples[i, :][3] - dot(model.user_factors[user], model.item_factors[item])\n", | |
" \n", | |
" model.user_factors[user] += model.alpha * (err * model.item_factors[item] - model.beta * model.user_factors[user])\n", | |
" model.item_factors[item] += model.alpha * (err * model.user_factors[user] - model.beta * model.item_factors[item]) \n", | |
" end\n", | |
" \n", | |
"end\n", | |
"\n", | |
"\n", | |
"#予測\n", | |
"function predict(model::MatrixFactorizationModel, X::Array)\n", | |
" rate = zeros(size(X)[1])\n", | |
" \n", | |
" for i = 1:size(X)[1]\n", | |
" user = X[i, :][1]\n", | |
" item = X[i, :][2]\n", | |
" rate[i] = dot(model.user_factors[user], model.item_factors[item])\n", | |
" end\n", | |
" \n", | |
" return rate\n", | |
"end\n", | |
"\n", | |
"\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# main" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"using DataFrames \n", | |
"using CSV" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<table class=\"data-frame\"><thead><tr><th></th><th>Column1</th><th>Column2</th><th>Column3</th><th>Column4</th></tr></thead><tbody><tr><th>1</th><td>196</td><td>242</td><td>3</td><td>881250949</td></tr><tr><th>2</th><td>186</td><td>302</td><td>3</td><td>891717742</td></tr><tr><th>3</th><td>22</td><td>377</td><td>1</td><td>878887116</td></tr><tr><th>4</th><td>244</td><td>51</td><td>2</td><td>880606923</td></tr><tr><th>5</th><td>166</td><td>346</td><td>1</td><td>886397596</td></tr><tr><th>6</th><td>298</td><td>474</td><td>4</td><td>884182806</td></tr><tr><th>7</th><td>115</td><td>265</td><td>2</td><td>881171488</td></tr><tr><th>8</th><td>253</td><td>465</td><td>5</td><td>891628467</td></tr><tr><th>9</th><td>305</td><td>451</td><td>3</td><td>886324817</td></tr><tr><th>10</th><td>6</td><td>86</td><td>3</td><td>883603013</td></tr><tr><th>11</th><td>62</td><td>257</td><td>2</td><td>879372434</td></tr><tr><th>12</th><td>286</td><td>1014</td><td>5</td><td>879781125</td></tr><tr><th>13</th><td>200</td><td>222</td><td>5</td><td>876042340</td></tr><tr><th>14</th><td>210</td><td>40</td><td>3</td><td>891035994</td></tr><tr><th>15</th><td>224</td><td>29</td><td>3</td><td>888104457</td></tr><tr><th>16</th><td>303</td><td>785</td><td>3</td><td>879485318</td></tr><tr><th>17</th><td>122</td><td>387</td><td>5</td><td>879270459</td></tr><tr><th>18</th><td>194</td><td>274</td><td>2</td><td>879539794</td></tr><tr><th>19</th><td>291</td><td>1042</td><td>4</td><td>874834944</td></tr><tr><th>20</th><td>234</td><td>1184</td><td>2</td><td>892079237</td></tr><tr><th>21</th><td>119</td><td>392</td><td>4</td><td>886176814</td></tr><tr><th>22</th><td>167</td><td>486</td><td>4</td><td>892738452</td></tr><tr><th>23</th><td>299</td><td>144</td><td>4</td><td>877881320</td></tr><tr><th>24</th><td>291</td><td>118</td><td>2</td><td>874833878</td></tr><tr><th>25</th><td>308</td><td>1</td><td>4</td><td>887736532</td></tr><tr><th>26</th><td>95</td><td>546</td><td>2</td><td>879196566</td></tr><tr><th>27</th><td>38</td><td>95</td><td>5</td><td>892430094</td></tr><tr><th>28</th><td>102</td><td>768</td><td>2</td><td>883748450</td></tr><tr><th>29</th><td>63</td><td>277</td><td>4</td><td>875747401</td></tr><tr><th>30</th><td>160</td><td>234</td><td>5</td><td>876861185</td></tr><tr><th>⋮</th><td>⋮</td><td>⋮</td><td>⋮</td><td>⋮</td></tr></tbody></table>" | |
], | |
"text/plain": [ | |
"100000×4 DataFrames.DataFrame\n", | |
"│ Row │ Column1 │ Column2 │ Column3 │ Column4 │\n", | |
"├────────┼─────────┼─────────┼─────────┼───────────┤\n", | |
"│ 1 │ 196 │ 242 │ 3 │ 881250949 │\n", | |
"│ 2 │ 186 │ 302 │ 3 │ 891717742 │\n", | |
"│ 3 │ 22 │ 377 │ 1 │ 878887116 │\n", | |
"│ 4 │ 244 │ 51 │ 2 │ 880606923 │\n", | |
"│ 5 │ 166 │ 346 │ 1 │ 886397596 │\n", | |
"│ 6 │ 298 │ 474 │ 4 │ 884182806 │\n", | |
"│ 7 │ 115 │ 265 │ 2 │ 881171488 │\n", | |
"│ 8 │ 253 │ 465 │ 5 │ 891628467 │\n", | |
"│ 9 │ 305 │ 451 │ 3 │ 886324817 │\n", | |
"│ 10 │ 6 │ 86 │ 3 │ 883603013 │\n", | |
"│ 11 │ 62 │ 257 │ 2 │ 879372434 │\n", | |
"⋮\n", | |
"│ 99989 │ 421 │ 498 │ 4 │ 892241344 │\n", | |
"│ 99990 │ 495 │ 1091 │ 4 │ 888637503 │\n", | |
"│ 99991 │ 806 │ 421 │ 4 │ 882388897 │\n", | |
"│ 99992 │ 676 │ 538 │ 4 │ 892685437 │\n", | |
"│ 99993 │ 721 │ 262 │ 3 │ 877137285 │\n", | |
"│ 99994 │ 913 │ 209 │ 2 │ 881367150 │\n", | |
"│ 99995 │ 378 │ 78 │ 3 │ 880056976 │\n", | |
"│ 99996 │ 880 │ 476 │ 3 │ 880175444 │\n", | |
"│ 99997 │ 716 │ 204 │ 5 │ 879795543 │\n", | |
"│ 99998 │ 276 │ 1090 │ 1 │ 874795795 │\n", | |
"│ 99999 │ 13 │ 225 │ 2 │ 882399156 │\n", | |
"│ 100000 │ 12 │ 203 │ 3 │ 879959583 │" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"df = CSV.read(\"data/ml-100k/u.data\", header = false, delim = '\\t')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"1682" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"#配列化\n", | |
"df = Array(df[:, 1:3])\n", | |
"\n", | |
"#ユニークユーザー、ユニークアイテム\n", | |
"user = length(unique(df[:, 1]))\n", | |
"item = length(unique(df[:, 2]))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"20001×3 Array{Union{Int64, Missings.Missing},2}:\n", | |
" 165 326 5\n", | |
" 319 301 4\n", | |
" 429 32 4\n", | |
" 191 286 4\n", | |
" 7 204 5\n", | |
" 622 1039 5\n", | |
" 867 270 5\n", | |
" 627 197 5\n", | |
" 746 157 4\n", | |
" 83 685 4\n", | |
" 354 753 5\n", | |
" 416 54 5\n", | |
" 621 222 4\n", | |
" ⋮ \n", | |
" 653 423 2\n", | |
" 60 21 3\n", | |
" 157 1283 2\n", | |
" 386 825 4\n", | |
" 49 93 5\n", | |
" 265 245 4\n", | |
" 406 596 3\n", | |
" 303 129 5\n", | |
" 256 526 3\n", | |
" 236 692 4\n", | |
" 328 132 5\n", | |
" 316 515 4" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"#シャッフル\n", | |
"df = df[shuffle(1:end), :]\n", | |
"\n", | |
"#学習データとテストデータ分割\n", | |
"N = size(df)[1]\n", | |
"train_size = Int64(N * 0.8)\n", | |
"train_df = df[1:train_size, :]\n", | |
"test_df = df[train_size:N, :]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"#学習\n", | |
"MF = MatrixFactorization.MatrixFactorizationModel(20, 0.01, 0.5, user, item, train_df, [], [])\n", | |
"MatrixFactorization.fit(MF, 10)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"1.0777975155497033" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"#精度\n", | |
"pred = MatrixFactorization.predict(MF, test_df)\n", | |
"\n", | |
"#rmse\n", | |
"sqrt(mean((pred - test_df[:, 3]).^2))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## time" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"main (generic function with 1 method)" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"function main()\n", | |
" MF = MatrixFactorization.MatrixFactorizationModel(20, 0.01, 0.5, user, item, train_df, [], [])\n", | |
" MatrixFactorization.fit(MF, 50)\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" 7.633972 seconds (155.94 M allocations: 3.576 GiB, 32.54% gc time)\n" | |
] | |
} | |
], | |
"source": [ | |
"@time main()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Julia 0.6.1", | |
"language": "julia", | |
"name": "julia-0.6" | |
}, | |
"language_info": { | |
"file_extension": ".jl", | |
"mimetype": "application/julia", | |
"name": "julia", | |
"version": "0.6.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment