Skip to content

Instantly share code, notes, and snippets.

@tkf
Created February 10, 2018 01:13
Show Gist options
  • Save tkf/3668ccf9aa704e5f1f321629ea71250f to your computer and use it in GitHub Desktop.
Save tkf/3668ccf9aa704e5f1f321629ea71250f to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"fused_tangent! (generic function with 2 methods)"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"using Base.Cartesian: @nexprs\n",
"\n",
"type RNN{TM, TV}\n",
" W::TM\n",
" b::TV\n",
" s::TV\n",
"\n",
" function RNN(W::TM, b::TV) where {TM <: AbstractMatrix,\n",
" TV <: AbstractVector}\n",
" (n,) = size(b)\n",
" @assert size(W) == (n, n)\n",
" return new{TM, TV}(W, b, similar(b))\n",
" end\n",
"end\n",
"\n",
"@inline function phase_dynamics!(du, u, rnn, t)\n",
" rnn.s .= tanh.(u .+ rnn.b)\n",
" A_mul_B!(du, rnn.W, rnn.s)\n",
"end\n",
"\n",
"@inline function separated_tangent!(du, u, rnn, t)\n",
" @views phase_dynamics!(du[:, 1], u[:, 1], rnn, t)\n",
"\n",
" Y1 = @view du[:, 2:end]\n",
" Y0 = @view u[:, 2:end]\n",
" slopes = (rnn.s .= 1 .- rnn.s.^2)\n",
" n, m = size(Y1)\n",
"\n",
" Y1 .= 0\n",
" @inbounds for k in 1:m\n",
" for j in 1:n\n",
" @views Y1[:, k] .+= rnn.W[:, j] .* (slopes[j] * Y0[j, k])\n",
" end\n",
" end\n",
"end\n",
"\n",
"# I thought I'd benchmark this function first but it is not faster\n",
"# than separated_tangent!. Presumably, this is because @simd works\n",
"# only for the inner-most loop.\n",
"@inline function fused_tangent!(du, u, rnn, t)\n",
" n, l = size(du)\n",
" du .= 0\n",
" @inbounds for j in 1:n\n",
" sj = tanh(u[j, 1] + rnn.b[j])\n",
" slope = 1 - sj^2\n",
" @simd for i = 1:n\n",
" du[i, 1] += rnn.W[i, j] * sj\n",
" for k in 1:l-1\n",
" du[i, k+1] += rnn.W[i, j] * slope * u[j, k+1]\n",
" end\n",
" end\n",
" end\n",
"end\n",
"\n",
"@inline @generated function fused_tangent!(du, u, rnn, t,\n",
" ::Type{Val{m}}) where {m}\n",
" quote\n",
" n, l = size(du)\n",
" @assert l - 1 == $m\n",
"\n",
" du .= 0\n",
" @inbounds for j in 1:n\n",
" sj = tanh(u[j, 1] + rnn.b[j])\n",
" slope = 1 - sj^2\n",
" @simd for i = 1:n\n",
" du[i, 1] += rnn.W[i, j] * sj\n",
" @nexprs $m k->(du[i, k+1] += rnn.W[i, j] * slope * u[j, k+1])\n",
" end\n",
" end\n",
" end\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"using Base.Test\n",
"using BenchmarkTools"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"st_trials = []\n",
"vft_trials = []\n",
"ft_trials = []\n",
"\n",
"for n in 2.^(5:11)\n",
" k = 2\n",
" rnn = RNN(randn(n, n), randn(n))\n",
" u0 = randn(n, 1 + k)\n",
" u1 = similar(u0)\n",
" u2 = similar(u0)\n",
" u3 = similar(u0)\n",
"\n",
" separated_tangent!(u1, u0, rnn, 0)\n",
" fused_tangent!(u2, u0, rnn, 0, Val{k})\n",
" fused_tangent!(u3, u0, rnn, 0)\n",
" @test u1 ≈ u2\n",
" @test u1 ≈ u3\n",
"\n",
" push!(st_trials, @benchmark separated_tangent!($u1, $u0, $rnn, 0))\n",
" push!(vft_trials, @benchmark fused_tangent!($u2, $u0, $rnn, 0, Val{$k}))\n",
" # push!(ft_trials, @benchmark fused_tangent!($u3, $u0, $rnn, 0))\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"using Plots"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/html": [
"<?xml version=\"1.0\" encoding=\"utf-8\"?>\n",
"<svg xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"600\" height=\"400\" viewBox=\"0 0 600 400\">\n",
"<defs>\n",
" <clipPath id=\"clip7100\">\n",
" <rect x=\"0\" y=\"0\" width=\"600\" height=\"400\"/>\n",
" </clipPath>\n",
"</defs>\n",
"<polygon clip-path=\"url(#clip7100)\" points=\"\n",
"0,400 600,400 600,0 0,0 \n",
" \" fill=\"#ffffff\" fill-opacity=\"1\"/>\n",
"<defs>\n",
" <clipPath id=\"clip7101\">\n",
" <rect x=\"120\" y=\"0\" width=\"421\" height=\"400\"/>\n",
" </clipPath>\n",
"</defs>\n",
"<polygon clip-path=\"url(#clip7100)\" points=\"\n",
"47.9701,360.065 580.315,360.065 580.315,11.811 47.9701,11.811 \n",
" \" fill=\"#ffffff\" fill-opacity=\"1\"/>\n",
"<defs>\n",
" <clipPath id=\"clip7102\">\n",
" <rect x=\"47\" y=\"11\" width=\"533\" height=\"349\"/>\n",
" </clipPath>\n",
"</defs>\n",
"<polyline clip-path=\"url(#clip7102)\" style=\"stroke:#000000; stroke-width:0.5; stroke-opacity:0.1; fill:none\" points=\"\n",
" 171.55,360.065 171.55,11.811 \n",
" \"/>\n",
"<polyline clip-path=\"url(#clip7102)\" style=\"stroke:#000000; stroke-width:0.5; stroke-opacity:0.1; fill:none\" points=\"\n",
" 303.58,360.065 303.58,11.811 \n",
" \"/>\n",
"<polyline clip-path=\"url(#clip7102)\" style=\"stroke:#000000; stroke-width:0.5; stroke-opacity:0.1; fill:none\" points=\"\n",
" 435.61,360.065 435.61,11.811 \n",
" \"/>\n",
"<polyline clip-path=\"url(#clip7102)\" style=\"stroke:#000000; stroke-width:0.5; stroke-opacity:0.1; fill:none\" points=\"\n",
" 567.64,360.065 567.64,11.811 \n",
" \"/>\n",
"<polyline clip-path=\"url(#clip7102)\" style=\"stroke:#000000; stroke-width:0.5; stroke-opacity:0.1; fill:none\" points=\"\n",
" 47.9701,293.536 580.315,293.536 \n",
" \"/>\n",
"<polyline clip-path=\"url(#clip7102)\" style=\"stroke:#000000; stroke-width:0.5; stroke-opacity:0.1; fill:none\" points=\"\n",
" 47.9701,223.273 580.315,223.273 \n",
" \"/>\n",
"<polyline clip-path=\"url(#clip7102)\" style=\"stroke:#000000; stroke-width:0.5; stroke-opacity:0.1; fill:none\" points=\"\n",
" 47.9701,153.01 580.315,153.01 \n",
" \"/>\n",
"<polyline clip-path=\"url(#clip7102)\" style=\"stroke:#000000; stroke-width:0.5; stroke-opacity:0.1; fill:none\" points=\"\n",
" 47.9701,82.7468 580.315,82.7468 \n",
" \"/>\n",
"<polyline clip-path=\"url(#clip7102)\" style=\"stroke:#000000; stroke-width:0.5; stroke-opacity:0.1; fill:none\" points=\"\n",
" 47.9701,12.4836 580.315,12.4836 \n",
" \"/>\n",
"<polyline clip-path=\"url(#clip7100)\" style=\"stroke:#000000; stroke-width:1; stroke-opacity:1; fill:none\" points=\"\n",
" 47.9701,360.065 580.315,360.065 \n",
" \"/>\n",
"<polyline clip-path=\"url(#clip7100)\" style=\"stroke:#000000; stroke-width:1; stroke-opacity:1; fill:none\" points=\"\n",
" 47.9701,360.065 47.9701,11.811 \n",
" \"/>\n",
"<polyline clip-path=\"url(#clip7100)\" style=\"stroke:#000000; stroke-width:1; stroke-opacity:1; fill:none\" points=\"\n",
" 171.55,360.065 171.55,354.842 \n",
" \"/>\n",
"<polyline clip-path=\"url(#clip7100)\" style=\"stroke:#000000; stroke-width:1; stroke-opacity:1; fill:none\" points=\"\n",
" 303.58,360.065 303.58,354.842 \n",
" \"/>\n",
"<polyline clip-path=\"url(#clip7100)\" style=\"stroke:#000000; stroke-width:1; stroke-opacity:1; fill:none\" points=\"\n",
" 435.61,360.065 435.61,354.842 \n",
" \"/>\n",
"<polyline clip-path=\"url(#clip7100)\" style=\"stroke:#000000; stroke-width:1; stroke-opacity:1; fill:none\" points=\"\n",
" 567.64,360.065 567.64,354.842 \n",
" \"/>\n",
"<polyline clip-path=\"url(#clip7100)\" style=\"stroke:#000000; stroke-width:1; stroke-opacity:1; fill:none\" points=\"\n",
" 47.9701,293.536 55.9553,293.536 \n",
" \"/>\n",
"<polyline clip-path=\"url(#clip7100)\" style=\"stroke:#000000; stroke-width:1; stroke-opacity:1; fill:none\" points=\"\n",
" 47.9701,223.273 55.9553,223.273 \n",
" \"/>\n",
"<polyline clip-path=\"url(#clip7100)\" style=\"stroke:#000000; stroke-width:1; stroke-opacity:1; fill:none\" points=\"\n",
" 47.9701,153.01 55.9553,153.01 \n",
" \"/>\n",
"<polyline clip-path=\"url(#clip7100)\" style=\"stroke:#000000; stroke-width:1; stroke-opacity:1; fill:none\" points=\"\n",
" 47.9701,82.7468 55.9553,82.7468 \n",
" \"/>\n",
"<polyline clip-path=\"url(#clip7100)\" style=\"stroke:#000000; stroke-width:1; stroke-opacity:1; fill:none\" points=\"\n",
" 47.9701,12.4836 55.9553,12.4836 \n",
" \"/>\n",
"<g clip-path=\"url(#clip7100)\">\n",
"<text style=\"fill:#000000; fill-opacity:1; font-family:Arial,Helvetica Neue,Helvetica,sans-serif; font-size:12; text-anchor:middle;\" transform=\"rotate(0, 171.55, 373.865)\" x=\"171.55\" y=\"373.865\">500</text>\n",
"</g>\n",
"<g clip-path=\"url(#clip7100)\">\n",
"<text style=\"fill:#000000; fill-opacity:1; font-family:Arial,Helvetica Neue,Helvetica,sans-serif; font-size:12; text-anchor:middle;\" transform=\"rotate(0, 303.58, 373.865)\" x=\"303.58\" y=\"373.865\">1000</text>\n",
"</g>\n",
"<g clip-path=\"url(#clip7100)\">\n",
"<text style=\"fill:#000000; fill-opacity:1; font-family:Arial,Helvetica Neue,Helvetica,sans-serif; font-size:12; text-anchor:middle;\" transform=\"rotate(0, 435.61, 373.865)\" x=\"435.61\" y=\"373.865\">1500</text>\n",
"</g>\n",
"<g clip-path=\"url(#clip7100)\">\n",
"<text style=\"fill:#000000; fill-opacity:1; font-family:Arial,Helvetica Neue,Helvetica,sans-serif; font-size:12; text-anchor:middle;\" transform=\"rotate(0, 567.64, 373.865)\" x=\"567.64\" y=\"373.865\">2000</text>\n",
"</g>\n",
"<g clip-path=\"url(#clip7100)\">\n",
"<text style=\"fill:#000000; fill-opacity:1; font-family:Arial,Helvetica Neue,Helvetica,sans-serif; font-size:12; text-anchor:end;\" transform=\"rotate(0, 41.9701, 298.036)\" x=\"41.9701\" y=\"298.036\">0.7</text>\n",
"</g>\n",
"<g clip-path=\"url(#clip7100)\">\n",
"<text style=\"fill:#000000; fill-opacity:1; font-family:Arial,Helvetica Neue,Helvetica,sans-serif; font-size:12; text-anchor:end;\" transform=\"rotate(0, 41.9701, 227.773)\" x=\"41.9701\" y=\"227.773\">0.8</text>\n",
"</g>\n",
"<g clip-path=\"url(#clip7100)\">\n",
"<text style=\"fill:#000000; fill-opacity:1; font-family:Arial,Helvetica Neue,Helvetica,sans-serif; font-size:12; text-anchor:end;\" transform=\"rotate(0, 41.9701, 157.51)\" x=\"41.9701\" y=\"157.51\">0.9</text>\n",
"</g>\n",
"<g clip-path=\"url(#clip7100)\">\n",
"<text style=\"fill:#000000; fill-opacity:1; font-family:Arial,Helvetica Neue,Helvetica,sans-serif; font-size:12; text-anchor:end;\" transform=\"rotate(0, 41.9701, 87.2468)\" x=\"41.9701\" y=\"87.2468\">1.0</text>\n",
"</g>\n",
"<g clip-path=\"url(#clip7100)\">\n",
"<text style=\"fill:#000000; fill-opacity:1; font-family:Arial,Helvetica Neue,Helvetica,sans-serif; font-size:12; text-anchor:end;\" transform=\"rotate(0, 41.9701, 16.9836)\" x=\"41.9701\" y=\"16.9836\">1.1</text>\n",
"</g>\n",
"<g clip-path=\"url(#clip7100)\">\n",
"<text style=\"fill:#000000; fill-opacity:1; font-family:Arial,Helvetica Neue,Helvetica,sans-serif; font-size:16; text-anchor:middle;\" transform=\"rotate(0, 314.143, 397.6)\" x=\"314.143\" y=\"397.6\">system size</text>\n",
"</g>\n",
"<g clip-path=\"url(#clip7100)\">\n",
"<text style=\"fill:#000000; fill-opacity:1; font-family:Arial,Helvetica Neue,Helvetica,sans-serif; font-size:16; text-anchor:middle;\" transform=\"rotate(-90, 14.4, 185.938)\" x=\"14.4\" y=\"185.938\">relative time</text>\n",
"</g>\n",
"<polyline clip-path=\"url(#clip7102)\" style=\"stroke:#009af9; stroke-width:1; stroke-opacity:1; fill:none\" points=\"\n",
" 47.9701,38.555 56.4201,11.811 73.3199,270.473 107.12,256.738 174.719,169.773 309.918,358.117 580.315,360.065 \n",
" \"/>\n",
"<polygon clip-path=\"url(#clip7100)\" points=\"\n",
"489.799,62.931 562.315,62.931 562.315,32.691 489.799,32.691 \n",
" \" fill=\"#ffffff\" fill-opacity=\"1\"/>\n",
"<polyline clip-path=\"url(#clip7100)\" style=\"stroke:#000000; stroke-width:1; stroke-opacity:1; fill:none\" points=\"\n",
" 489.799,62.931 562.315,62.931 562.315,32.691 489.799,32.691 489.799,62.931 \n",
" \"/>\n",
"<polyline clip-path=\"url(#clip7100)\" style=\"stroke:#009af9; stroke-width:1; stroke-opacity:1; fill:none\" points=\"\n",
" 495.799,47.811 531.799,47.811 \n",
" \"/>\n",
"<g clip-path=\"url(#clip7100)\">\n",
"<text style=\"fill:#000000; fill-opacity:1; font-family:Arial,Helvetica Neue,Helvetica,sans-serif; font-size:12; text-anchor:start;\" transform=\"rotate(0, 537.799, 52.311)\" x=\"537.799\" y=\"52.311\">y1</text>\n",
"</g>\n",
"</svg>\n"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"st_mtime = [est.time for est in map(minimum, st_trials)]\n",
"vft_mtime = [est.time for est in map(minimum, vft_trials)]\n",
"\n",
"plt = plot(2.^(5:11), vft_mtime ./ st_mtime, xlabel=\"system size\", ylabel=\"relative time\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Julia Version 0.6.2\n",
"Commit d386e40c17 (2017-12-13 18:08 UTC)\n",
"Platform Info:\n",
" OS: Linux (x86_64-pc-linux-gnu)\n",
" CPU: Intel(R) Core(TM) i7-4500U CPU @ 1.80GHz\n",
" WORD_SIZE: 64\n",
" BLAS: libopenblas (USE64BITINT DYNAMIC_ARCH NO_AFFINITY Haswell)\n",
" LAPACK: libopenblas64_\n",
" LIBM: libopenlibm\n",
" LLVM: libLLVM-3.9.1 (ORCJIT, haswell)\n"
]
}
],
"source": [
"versioninfo()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Base.JLOptions().check_bounds = 0\n"
]
},
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"@show Base.JLOptions().check_bounds # 0: unspecified; 1: yes; 2: no"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Base.JLOptions().can_inline = 1\n"
]
},
{
"data": {
"text/plain": [
"1"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"@show Base.JLOptions().can_inline # 0: no; 1: yes"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Architecture: x86_64\n",
"CPU op-mode(s): 32-bit, 64-bit\n",
"Byte Order: Little Endian\n",
"CPU(s): 4\n",
"On-line CPU(s) list: 0-3\n",
"Thread(s) per core: 2\n",
"Core(s) per socket: 2\n",
"Socket(s): 1\n",
"NUMA node(s): 1\n",
"Vendor ID: GenuineIntel\n",
"CPU family: 6\n",
"Model: 69\n",
"Model name: Intel(R) Core(TM) i7-4500U CPU @ 1.80GHz\n",
"Stepping: 1\n",
"CPU MHz: 2699.853\n",
"CPU max MHz: 3000.0000\n",
"CPU min MHz: 800.0000\n",
"BogoMIPS: 4788.93\n",
"Virtualization: VT-x\n",
"L1d cache: 32K\n",
"L1i cache: 32K\n",
"L2 cache: 256K\n",
"L3 cache: 4096K\n",
"NUMA node0 CPU(s): 0-3\n",
"Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc aperfmperf eagerfpu pni pclmulqdq dtes64 monitor ds_cpl vmx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm epb invpcid_single retpoline kaiser tpr_shadow vnmi flexpriority ept vpid fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid xsaveopt dtherm ida arat pln pts\n"
]
}
],
"source": [
";lscpu"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 0.6.2",
"language": "julia",
"name": "julia-0.6"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "0.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment