Created
March 18, 2018 06:10
-
-
Save denizyuret/ee6f262c9241378e0446a27c25be43dd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"195471971×16 BitArray{2}:\n", | |
" false false false false false false false false false false false false false false false false\n", | |
" false false false false false false false false false false false false false false false false\n", | |
" false false false false false false false false false false false false false false false false\n", | |
" false false false false false false false false false false false false false false false false\n", | |
" false false false false false false false false false false false false false false false false\n", | |
" false false false false false false false false false false false false false false false false\n", | |
" false false false false false false false false false false false false false false false false\n", | |
" false false false false false false false false false false false false false false false false\n", | |
" false false false false false false false false false false false false false false false false\n", | |
" false false false false false false false false false false false false false false false false\n", | |
" false false false false false false false false false false false false false false false false\n", | |
" false false false false false false false false false false false false false false false false\n", | |
" false false false false false false false false false false false false false false false false\n", | |
" ⋮ ⋮ ⋮ ⋮\n", | |
" false false false false false false false false false false false false false false false false\n", | |
" false false false false false false false false false false false false false false false false\n", | |
" false false false false false false false false false false false false false false false false\n", | |
" false false false false false false false false false false false false false false false false\n", | |
" false false false false false false false false false false false false false false false false\n", | |
" false false false false false false false false false false false false false false false false\n", | |
" false false false false false false false false false false false false false false false false\n", | |
" false false false false false false false false false false false false false false false false\n", | |
" false false false false false false false false false false false false false false false false\n", | |
" false false false false false false false false false false false false false false false false\n", | |
" false false false false false false false false false false false false false false false false\n", | |
" false false false false false false false false false false false false false false false false" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"using Knet,FileIO\n", | |
"ENV[\"COLUMNS\"]=120\n", | |
"Atype = KnetArray{Float32}\n", | |
"macro date(_x) :(println(\"$(now()) \"*$(string(_x)));flush(STDOUT);@time $(esc(_x))) end\n", | |
"data = load(\"levent.jld2\",\"data\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.009750594881963922" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"mean(data[:,1])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"390943944" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"sizeof(data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.338441114020383" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"mean(data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"((195471971, 15), (195471971,))" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"y=data[:,1]; x=data[:,2:end]; map(size, (x,y))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"window (generic function with 1 method)" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# create dataset from given x[N,X],y[N,1] matrices with 2r+1 window size\n", | |
"function window(x,y,r)\n", | |
" n = size(x,1)\n", | |
" shifted = []\n", | |
" for shift = -r:r\n", | |
" push!(shifted, x[(r+1+shift):(n-r+shift),:])\n", | |
" end\n", | |
" newx = hcat(shifted...)\n", | |
" newy = y[r+1:n-r]\n", | |
" (newx,newy)\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"balance (generic function with 1 method)" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# subsample the dataset so 50% of y is true, 50% is false\n", | |
"function balance(x,y)\n", | |
" t = find(y)\n", | |
" f = find(.!y)\n", | |
" if length(f) > length(t)\n", | |
" f = f[randperm(length(f))[1:length(t)]]\n", | |
" elseif length(t) > length(f)\n", | |
" t = t[randperm(length(t))[1:length(f)]]\n", | |
" end\n", | |
" i = sort(vcat(t,f))\n", | |
" x[i,:],y[i,:]\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(::gradfun) (generic function with 1 method)" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Loss functions\n", | |
"zeroone(w,data,model) = 1 - accuracy(w,data,model)\n", | |
"softmax(w,data,model) = mean(softmax(w,x,y,model) for (x,y) in data)\n", | |
"softmax(w,x,y,model; o...) = nll(model(w,x;o...),y)\n", | |
"softgrad = grad(softmax)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 63, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"train (generic function with 1 method)" | |
] | |
}, | |
"execution_count": 63, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Train model(w) with SGD and return a list containing w for every epoch\n", | |
"function train(w,data,predict; epochs=1,kw...)\n", | |
" weights = Any[deepcopy(w)]\n", | |
" o = optimizers(w,Adam)\n", | |
" for epoch in 1:epochs\n", | |
" for (x,y) in data\n", | |
" g = softgrad(w,x,y,predict;kw...)\n", | |
" update!(w,g,o) # w[i] = w[i] - lr * g[i]\n", | |
" end\n", | |
" push!(weights,deepcopy(w))\n", | |
" end\n", | |
" return weights\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 64, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Weight initialization for multiple layers: h=array of layer sizes\n", | |
"# Output is an array [w0,b0,w1,b1,...,wn,bn] where wi,bi is the weight matrix and bias vector for the i'th layer\n", | |
"function winit(h...; wtype=KnetArray{Float32}) # use winit(x,h1,h2,...,hn,y) for n hidden layer model\n", | |
" w = Any[]\n", | |
" for i=2:length(h)\n", | |
" push!(w, xavier(h[i],h[i-1]))\n", | |
" push!(w, zeros(h[i],1))\n", | |
" end\n", | |
" map(wtype, w)\n", | |
"end;" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 65, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"function mlp(w,x; pdrop=(0,0))\n", | |
" for i=1:2:length(w)\n", | |
" x = dropout(x, pdrop[i==1?1:2]) # apply one of two dropout rates\n", | |
" x = w[i]*mat(x) .+ w[i+1]\n", | |
" if i < length(w)-1; x = relu.(x); end\n", | |
" end\n", | |
" return x\n", | |
"end;" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 72, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" 8.487600 seconds (984 allocations: 3.719 GiB, 3.07% gc time)\n", | |
" 0.200014 seconds (60 allocations: 35.902 MiB, 51.32% gc time)\n", | |
" 10.626496 seconds (88 allocations: 5.098 GiB, 4.44% gc time)\n", | |
" 0.285115 seconds (60 allocations: 49.534 MiB, 1.36% gc time)\n", | |
" 12.992490 seconds (103 allocations: 6.476 GiB, 5.28% gc time)\n", | |
" 0.442815 seconds (60 allocations: 63.167 MiB, 0.74% gc time)\n", | |
" 17.211757 seconds (132 allocations: 9.234 GiB, 5.90% gc time)\n", | |
" 0.762351 seconds (60 allocations: 90.432 MiB, 0.90% gc time)\n", | |
" 25.543356 seconds (189 allocations: 14.748 GiB, 5.28% gc time)\n", | |
" 1.417267 seconds (60 allocations: 144.962 MiB, 0.23% gc time)\n", | |
" 42.517745 seconds (302 allocations: 25.778 GiB, 5.37% gc time)\n", | |
" 2.716464 seconds (60 allocations: 254.022 MiB, 0.36% gc time)\n", | |
" 76.806931 seconds (529 allocations: 47.836 GiB, 5.32% gc time)\n", | |
" 5.375362 seconds (61 allocations: 472.143 MiB, 0.12% gc time)\n" | |
] | |
} | |
], | |
"source": [ | |
"dtrn,dval,dtst = [],[],[]\n", | |
"ntrn,nval,ntst = 1000000,10000,10000\n", | |
"for r in -1:5\n", | |
" r = floor(Int,2.0^r)\n", | |
" @time xr,yr = balance(window(x,y,r)...)\n", | |
" @time xr,yr = xr',1+vec(yr)\n", | |
" m,n = size(xr)\n", | |
" i = randperm(n)\n", | |
" itrn = i[1:ntrn]\n", | |
" ival = i[1+ntrn:nval+ntrn]\n", | |
" itst = i[1+nval+ntrn:ntst+nval+ntrn]\n", | |
" push!(dtrn, minibatch(xr[:,itrn],yr[itrn],100,xtype=KnetArray{Float32}))\n", | |
" push!(dval, minibatch(xr[:,ival],yr[ival],100,xtype=KnetArray{Float32}))\n", | |
" push!(dtst, minibatch(xr[:,itst],yr[itst],100,xtype=KnetArray{Float32}))\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 73, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"m = size((first(dtrn[i]))[1], 1) = 15\n", | |
" 51.424511 seconds (45.89 M allocations: 2.600 GiB, 2.19% gc time)\n", | |
"trn[0.545653 0.632752 0.632661 0.632666 0.632666 0.632664 0.632664 0.632663 0.632663 0.632663 0.632653]\n", | |
"val[0.5411 0.6257 0.6258 0.6258 0.6258 0.6258 0.6258 0.6258 0.6258 0.6258 0.6258]\n", | |
"tst[0.5359 0.6259 0.625 0.625 0.625 0.625 0.625 0.625 0.625 0.625 0.625]\n", | |
"m = size((first(dtrn[i]))[1], 1) = 45\n", | |
" 51.639358 seconds (45.98 M allocations: 3.747 GiB, 2.15% gc time)\n", | |
"trn[0.503749 0.638616 0.638538 0.638637 0.638669 0.638673 0.638684 0.638691 0.638696 0.638692 0.638697]\n", | |
"val[0.5065 0.6357 0.6352 0.6358 0.6358 0.6359 0.6358 0.6358 0.6359 0.6359 0.6358]\n", | |
"tst[0.4998 0.6359 0.6356 0.6358 0.636 0.6359 0.6359 0.6359 0.6358 0.6358 0.6358]\n", | |
"m = size((first(dtrn[i]))[1], 1) = 75\n", | |
" 52.164389 seconds (48.09 M allocations: 4.932 GiB, 2.36% gc time)\n", | |
"trn[0.509977 0.644329 0.644335 0.644372 0.644399 0.644451 0.644456 0.644483 0.644495 0.644435 0.644427]\n", | |
"val[0.5082 0.6418 0.6416 0.6413 0.6413 0.6414 0.6415 0.6415 0.6418 0.6417 0.6417]\n", | |
"tst[0.5085 0.6462 0.6462 0.6461 0.6461 0.6462 0.6461 0.6458 0.646 0.646 0.6461]\n", | |
"m = size((first(dtrn[i]))[1], 1) = 135\n", | |
" 53.357136 seconds (48.09 M allocations: 7.234 GiB, 2.69% gc time)\n", | |
"trn[0.472325 0.651045 0.65119 0.65134 0.651395 0.651455 0.651458 0.651446 0.651465 0.651476 0.651489]\n", | |
"val[0.474 0.6489 0.6494 0.6492 0.6492 0.6489 0.649 0.6489 0.649 0.6491 0.6491]\n", | |
"tst[0.4771 0.6531 0.6526 0.6525 0.6529 0.653 0.6529 0.6529 0.6533 0.6532 0.6531]\n", | |
"m = size((first(dtrn[i]))[1], 1) = 255\n", | |
" 57.588248 seconds (48.09 M allocations: 11.846 GiB, 3.18% gc time)\n", | |
"trn[0.534265 0.661574 0.661669 0.661688 0.66173 0.66177 0.661772 0.661785 0.661784 0.661753 0.661733]\n", | |
"val[0.5308 0.6687 0.6689 0.6683 0.6684 0.6685 0.6686 0.6687 0.6685 0.6685 0.6684]\n", | |
"tst[0.5254 0.6722 0.6714 0.6719 0.6716 0.672 0.672 0.672 0.6719 0.6721 0.6718]\n", | |
"m = size((first(dtrn[i]))[1], 1) = 495\n", | |
" 67.429157 seconds (48.09 M allocations: 21.067 GiB, 3.86% gc time)\n", | |
"trn[0.490907 0.671939 0.672006 0.672068 0.672106 0.672126 0.672125 0.672099 0.672117 0.672132 0.672142]\n", | |
"val[0.5061 0.6611 0.6609 0.6609 0.6611 0.6612 0.6612 0.6611 0.6611 0.6614 0.6614]\n", | |
"tst[0.4951 0.6701 0.6704 0.6706 0.6706 0.6706 0.6707 0.6708 0.6707 0.6706 0.6708]\n", | |
"m = size((first(dtrn[i]))[1], 1) = 975\n", | |
" 89.599265 seconds (48.30 M allocations: 39.512 GiB, 4.46% gc time)\n", | |
"trn[0.489031 0.676754 0.676924 0.676961 0.676974 0.677004 0.67706 0.677117 0.677153 0.67717 0.677199]\n", | |
"val[0.4916 0.6845 0.6853 0.6852 0.6853 0.6853 0.6854 0.6855 0.6853 0.6856 0.6857]\n", | |
"tst[0.4859 0.6782 0.679 0.6791 0.6791 0.6796 0.6797 0.6791 0.6789 0.6791 0.6794]\n" | |
] | |
} | |
], | |
"source": [ | |
"for i=1:length(dtrn)\n", | |
" @show m = size(first(dtrn[i])[1],1)\n", | |
" @time weights = train(winit(m,2),dtrn[i],mlp,epochs=10)\n", | |
" println(:trn, [ accuracy(w,dtrn[i],mlp) for w in weights ]'); flush(STDOUT)\n", | |
" println(:val, [ accuracy(w,dval[i],mlp) for w in weights ]'); flush(STDOUT)\n", | |
" println(:tst, [ accuracy(w,dtst[i],mlp) for w in weights ]'); flush(STDOUT) \n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 74, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"m = size((first(dtrn[i]))[1], 1) = 15\n", | |
" 83.893496 seconds (78.96 M allocations: 3.756 GiB, 1.74% gc time)\n", | |
"trn[0.48936 0.633894 0.634274 0.634316 0.634331 0.637423 0.637309 0.63743 0.637312 0.637314 0.637432]\n", | |
"val[0.4876 0.628 0.6278 0.6278 0.6278 0.6356 0.6349 0.6355 0.6349 0.6349 0.6355]\n", | |
"tst[0.5008 0.6258 0.6259 0.626 0.6261 0.6364 0.6367 0.6367 0.6367 0.6367 0.6367]\n", | |
"m = size((first(dtrn[i]))[1], 1) = 45\n", | |
" 84.765788 seconds (79.10 M allocations: 4.904 GiB, 1.93% gc time)\n", | |
"trn[0.51415 0.640233 0.641429 0.64149 0.642206 0.64232 0.642009 0.642374 0.642378 0.642745 0.642801]\n", | |
"val[0.5258 0.6381 0.637 0.6377 0.6362 0.6364 0.6367 0.6367 0.6367 0.6369 0.6368]\n", | |
"tst[0.5158 0.6373 0.6376 0.6379 0.6396 0.6408 0.641 0.6395 0.6392 0.6406 0.642]\n", | |
"m = size((first(dtrn[i]))[1], 1) = 75\n", | |
" 85.588212 seconds (79.09 M allocations: 6.057 GiB, 2.07% gc time)\n", | |
"trn[0.454612 0.646706 0.647891 0.647904 0.648285 0.648426 0.648711 0.649018 0.648678 0.64903 0.648843]\n", | |
"val[0.4535 0.6391 0.6451 0.6448 0.6458 0.6452 0.6482 0.6487 0.6472 0.647 0.6466]\n", | |
"tst[0.45 0.6499 0.6528 0.6518 0.652 0.6551 0.6526 0.653 0.6494 0.6502 0.6498]\n", | |
"m = size((first(dtrn[i]))[1], 1) = 135\n", | |
" 86.032194 seconds (79.08 M allocations: 8.360 GiB, 2.32% gc time)\n", | |
"trn[0.492708 0.653582 0.65384 0.654691 0.654985 0.655462 0.655369 0.655432 0.655745 0.655578 0.655892]\n", | |
"val[0.5025 0.6553 0.6506 0.6515 0.6529 0.6534 0.6548 0.6524 0.6546 0.6544 0.6531]\n", | |
"tst[0.4979 0.6557 0.6549 0.6549 0.6569 0.6572 0.6568 0.6563 0.6566 0.6572 0.6562]\n", | |
"m = size((first(dtrn[i]))[1], 1) = 255\n", | |
" 87.855101 seconds (79.09 M allocations: 12.972 GiB, 2.76% gc time)\n", | |
"trn[0.526806 0.659916 0.660705 0.662868 0.664051 0.664277 0.664627 0.664244 0.664574 0.665462 0.665754]\n", | |
"val[0.526 0.6674 0.6671 0.6696 0.6717 0.6688 0.6692 0.6709 0.6711 0.67 0.6699]\n", | |
"tst[0.5155 0.6694 0.6732 0.6693 0.6723 0.6753 0.6746 0.673 0.674 0.6748 0.6751]\n", | |
"m = size((first(dtrn[i]))[1], 1) = 495\n", | |
" 93.570356 seconds (79.10 M allocations: 22.193 GiB, 3.43% gc time)\n", | |
"trn[0.498532 0.670728 0.669856 0.67231 0.672802 0.673276 0.674112 0.674328 0.67394 0.674245 0.673782]\n", | |
"val[0.511 0.6625 0.6633 0.6631 0.6631 0.6626 0.6629 0.6645 0.6643 0.6639 0.6622]\n", | |
"tst[0.4942 0.672 0.6704 0.6726 0.6715 0.674 0.6746 0.674 0.6752 0.6754 0.6761]\n", | |
"m = size((first(dtrn[i]))[1], 1) = 975\n", | |
"114.815320 seconds (79.29 M allocations: 40.639 GiB, 4.19% gc time)\n", | |
"trn[0.470306 0.685803 0.690103 0.690596 0.69056 0.690461 0.690734 0.690718 0.690593 0.690626 0.69041]\n", | |
"val[0.4717 0.6915 0.6984 0.6996 0.6991 0.6978 0.6966 0.6952 0.6964 0.6959 0.6953]\n", | |
"tst[0.4673 0.688 0.6896 0.6879 0.6897 0.6888 0.6896 0.6897 0.6883 0.6884 0.6885]\n" | |
] | |
} | |
], | |
"source": [ | |
"for i=1:length(dtrn)\n", | |
" @show m = size(first(dtrn[i])[1],1)\n", | |
" @time weights = train(winit(m,64,2),dtrn[i],mlp,epochs=10)\n", | |
" println(:trn, [ accuracy(w,dtrn[i],mlp) for w in weights ]'); flush(STDOUT)\n", | |
" println(:val, [ accuracy(w,dval[i],mlp) for w in weights ]'); flush(STDOUT)\n", | |
" println(:tst, [ accuracy(w,dtst[i],mlp) for w in weights ]'); flush(STDOUT) \n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Julia 0.6.2", | |
"language": "julia", | |
"name": "julia-0.6" | |
}, | |
"language_info": { | |
"file_extension": ".jl", | |
"mimetype": "application/julia", | |
"name": "julia", | |
"version": "0.6.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment