Created
June 8, 2018 02:22
-
-
Save tbenst/de3f61ac9956778a2ca6d5db94b526ef to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Using matplotlib backend: Qt5Agg\n", | |
"Populating the interactive namespace from numpy and matplotlib\n" | |
] | |
} | |
], | |
"source": [ | |
"%pylab" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import torch as T\n", | |
"from torch.autograd import Variable\n", | |
"import torch.nn.functional as F\n", | |
"from torch.autograd import Variable" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"filters = T.rand(8,300,300,2).cuda()\n", | |
"img = np.random.rand(300,300).astype(np.float32)\n", | |
"def conv_only(img, w, shape):\n", | |
" fimg = T.zeros(*w.shape[1:]).cuda()\n", | |
" fimg[:,:,0] = T.from_numpy(img).cuda()\n", | |
" conv = T.ifft(T.fft(fimg,2)*w,2)\n", | |
" return conv\n", | |
"\n", | |
"def find_position(img, w, shape):\n", | |
" fimg = T.zeros(*w.shape[1:]).cuda()\n", | |
" fimg[:,:,0] = T.from_numpy(img).cuda()\n", | |
" conv = T.ifft(T.fft(fimg,2)*w,2)\n", | |
" idx = conv.argmax()\n", | |
" return np.unravel_index(idx, shape)\n", | |
"\n", | |
"def find_position_2(img, w, shape):\n", | |
" fimg = T.zeros(*w.shape[1:]).cuda()\n", | |
" fimg[:,:,0] = T.from_numpy(img).cuda()\n", | |
" conv = T.ifft(T.fft(fimg,2)*w,2)\n", | |
" conv = conv.cpu().numpy()\n", | |
" idx = np.argmax(conv)\n", | |
" return np.unravel_index(idx, shape)\n", | |
"\n", | |
"shape = [*filters.shape]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# initialize CuBLAS\n", | |
"_ = conv_only(img,filters, shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1.17 ms ± 61.2 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit -r 3 -n 3\n", | |
"conv_only(img,filters, shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"27.1 ms ± 828 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit -r 3 -n 3\n", | |
"find_position(img,filters, shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"3.25 ms ± 285 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit -r 3 -n 3\n", | |
"find_position_2(img,filters, shape)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment