Skip to content

Instantly share code, notes, and snippets.

@RottenFruits
Created June 4, 2018 04:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save RottenFruits/03845be50ee02bcedddff58425c9ddaa to your computer and use it in GitHub Desktop.
Save RottenFruits/03845be50ee02bcedddff58425c9ddaa to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# MF"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"MatrixFactorization"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"module MatrixFactorization\n",
"\n",
"mutable struct MatrixFactorizationModel\n",
" K::Int64\n",
" alpha::Float64\n",
" beta::Float64\n",
" n_user::Int64\n",
" n_item::Int64\n",
" R::Array\n",
" user_factors::Array\n",
" item_factors::Array\n",
"end\n",
"\n",
"#学習\n",
"function fit(model::MatrixFactorizationModel, n_iter::Int64)\n",
" model.user_factors = rand(Float64, model.n_user, model.K)\n",
" model.item_factors = rand(Float64, model.n_item, model.K)\n",
" \n",
" for i = 1:n_iter\n",
" sgd(model)\n",
" end\n",
" \n",
"end\n",
"\n",
"#確率的最急降下法\n",
"function sgd(model::MatrixFactorizationModel)\n",
" samples = model.R[shuffle(1:end), :]\n",
" \n",
" for i in 1:size(samples)[1]\n",
" user = samples[i, :][1]\n",
" item = samples[i, :][2]\n",
" \n",
" err = samples[i, :][3] - dot(model.user_factors[user], model.item_factors[item])\n",
" \n",
" model.user_factors[user] += model.alpha * (err * model.item_factors[item] - model.beta * model.user_factors[user])\n",
" model.item_factors[item] += model.alpha * (err * model.user_factors[user] - model.beta * model.item_factors[item]) \n",
" end\n",
" \n",
"end\n",
"\n",
"\n",
"#予測\n",
"function predict(model::MatrixFactorizationModel, X::Array)\n",
" rate = zeros(size(X)[1])\n",
" \n",
" for i = 1:size(X)[1]\n",
" user = X[i, :][1]\n",
" item = X[i, :][2]\n",
" rate[i] = dot(model.user_factors[user], model.item_factors[item])\n",
" end\n",
" \n",
" return rate\n",
"end\n",
"\n",
"\n",
"end"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# main"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"using DataFrames \n",
"using CSV"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/html": [
"<table class=\"data-frame\"><thead><tr><th></th><th>Column1</th><th>Column2</th><th>Column3</th><th>Column4</th></tr></thead><tbody><tr><th>1</th><td>196</td><td>242</td><td>3</td><td>881250949</td></tr><tr><th>2</th><td>186</td><td>302</td><td>3</td><td>891717742</td></tr><tr><th>3</th><td>22</td><td>377</td><td>1</td><td>878887116</td></tr><tr><th>4</th><td>244</td><td>51</td><td>2</td><td>880606923</td></tr><tr><th>5</th><td>166</td><td>346</td><td>1</td><td>886397596</td></tr><tr><th>6</th><td>298</td><td>474</td><td>4</td><td>884182806</td></tr><tr><th>7</th><td>115</td><td>265</td><td>2</td><td>881171488</td></tr><tr><th>8</th><td>253</td><td>465</td><td>5</td><td>891628467</td></tr><tr><th>9</th><td>305</td><td>451</td><td>3</td><td>886324817</td></tr><tr><th>10</th><td>6</td><td>86</td><td>3</td><td>883603013</td></tr><tr><th>11</th><td>62</td><td>257</td><td>2</td><td>879372434</td></tr><tr><th>12</th><td>286</td><td>1014</td><td>5</td><td>879781125</td></tr><tr><th>13</th><td>200</td><td>222</td><td>5</td><td>876042340</td></tr><tr><th>14</th><td>210</td><td>40</td><td>3</td><td>891035994</td></tr><tr><th>15</th><td>224</td><td>29</td><td>3</td><td>888104457</td></tr><tr><th>16</th><td>303</td><td>785</td><td>3</td><td>879485318</td></tr><tr><th>17</th><td>122</td><td>387</td><td>5</td><td>879270459</td></tr><tr><th>18</th><td>194</td><td>274</td><td>2</td><td>879539794</td></tr><tr><th>19</th><td>291</td><td>1042</td><td>4</td><td>874834944</td></tr><tr><th>20</th><td>234</td><td>1184</td><td>2</td><td>892079237</td></tr><tr><th>21</th><td>119</td><td>392</td><td>4</td><td>886176814</td></tr><tr><th>22</th><td>167</td><td>486</td><td>4</td><td>892738452</td></tr><tr><th>23</th><td>299</td><td>144</td><td>4</td><td>877881320</td></tr><tr><th>24</th><td>291</td><td>118</td><td>2</td><td>874833878</td></tr><tr><th>25</th><td>308</td><td>1</td><td>4</td><td>887736532</td></tr><tr><th>26</th><td>95</td><td>546</td><td>2</td><td>879196566</td></tr><tr><th>27</th><td>38</td><td>95</td><td>5</td><td>892430094</td></tr><tr><th>28</th><td>102</td><td>768</td><td>2</td><td>883748450</td></tr><tr><th>29</th><td>63</td><td>277</td><td>4</td><td>875747401</td></tr><tr><th>30</th><td>160</td><td>234</td><td>5</td><td>876861185</td></tr><tr><th>&vellip;</th><td>&vellip;</td><td>&vellip;</td><td>&vellip;</td><td>&vellip;</td></tr></tbody></table>"
],
"text/plain": [
"100000×4 DataFrames.DataFrame\n",
"│ Row │ Column1 │ Column2 │ Column3 │ Column4 │\n",
"├────────┼─────────┼─────────┼─────────┼───────────┤\n",
"│ 1 │ 196 │ 242 │ 3 │ 881250949 │\n",
"│ 2 │ 186 │ 302 │ 3 │ 891717742 │\n",
"│ 3 │ 22 │ 377 │ 1 │ 878887116 │\n",
"│ 4 │ 244 │ 51 │ 2 │ 880606923 │\n",
"│ 5 │ 166 │ 346 │ 1 │ 886397596 │\n",
"│ 6 │ 298 │ 474 │ 4 │ 884182806 │\n",
"│ 7 │ 115 │ 265 │ 2 │ 881171488 │\n",
"│ 8 │ 253 │ 465 │ 5 │ 891628467 │\n",
"│ 9 │ 305 │ 451 │ 3 │ 886324817 │\n",
"│ 10 │ 6 │ 86 │ 3 │ 883603013 │\n",
"│ 11 │ 62 │ 257 │ 2 │ 879372434 │\n",
"⋮\n",
"│ 99989 │ 421 │ 498 │ 4 │ 892241344 │\n",
"│ 99990 │ 495 │ 1091 │ 4 │ 888637503 │\n",
"│ 99991 │ 806 │ 421 │ 4 │ 882388897 │\n",
"│ 99992 │ 676 │ 538 │ 4 │ 892685437 │\n",
"│ 99993 │ 721 │ 262 │ 3 │ 877137285 │\n",
"│ 99994 │ 913 │ 209 │ 2 │ 881367150 │\n",
"│ 99995 │ 378 │ 78 │ 3 │ 880056976 │\n",
"│ 99996 │ 880 │ 476 │ 3 │ 880175444 │\n",
"│ 99997 │ 716 │ 204 │ 5 │ 879795543 │\n",
"│ 99998 │ 276 │ 1090 │ 1 │ 874795795 │\n",
"│ 99999 │ 13 │ 225 │ 2 │ 882399156 │\n",
"│ 100000 │ 12 │ 203 │ 3 │ 879959583 │"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = CSV.read(\"data/ml-100k/u.data\", header = false, delim = '\\t')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"1682"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#配列化\n",
"df = Array(df[:, 1:3])\n",
"\n",
"#ユニークユーザー、ユニークアイテム\n",
"user = length(unique(df[:, 1]))\n",
"item = length(unique(df[:, 2]))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"20001×3 Array{Union{Int64, Missings.Missing},2}:\n",
" 165 326 5\n",
" 319 301 4\n",
" 429 32 4\n",
" 191 286 4\n",
" 7 204 5\n",
" 622 1039 5\n",
" 867 270 5\n",
" 627 197 5\n",
" 746 157 4\n",
" 83 685 4\n",
" 354 753 5\n",
" 416 54 5\n",
" 621 222 4\n",
" ⋮ \n",
" 653 423 2\n",
" 60 21 3\n",
" 157 1283 2\n",
" 386 825 4\n",
" 49 93 5\n",
" 265 245 4\n",
" 406 596 3\n",
" 303 129 5\n",
" 256 526 3\n",
" 236 692 4\n",
" 328 132 5\n",
" 316 515 4"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#シャッフル\n",
"df = df[shuffle(1:end), :]\n",
"\n",
"#学習データとテストデータ分割\n",
"N = size(df)[1]\n",
"train_size = Int64(N * 0.8)\n",
"train_df = df[1:train_size, :]\n",
"test_df = df[train_size:N, :]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"#学習\n",
"MF = MatrixFactorization.MatrixFactorizationModel(20, 0.01, 0.5, user, item, train_df, [], [])\n",
"MatrixFactorization.fit(MF, 10)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"1.0777975155497033"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#精度\n",
"pred = MatrixFactorization.predict(MF, test_df)\n",
"\n",
"#rmse\n",
"sqrt(mean((pred - test_df[:, 3]).^2))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## time"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"main (generic function with 1 method)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function main()\n",
" MF = MatrixFactorization.MatrixFactorizationModel(20, 0.01, 0.5, user, item, train_df, [], [])\n",
" MatrixFactorization.fit(MF, 50)\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" 7.633972 seconds (155.94 M allocations: 3.576 GiB, 32.54% gc time)\n"
]
}
],
"source": [
"@time main()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 0.6.1",
"language": "julia",
"name": "julia-0.6"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "0.6.1"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment