Skip to content

Instantly share code, notes, and snippets.

@tbenst
Created June 8, 2018 02:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tbenst/de3f61ac9956778a2ca6d5db94b526ef to your computer and use it in GitHub Desktop.
Save tbenst/de3f61ac9956778a2ca6d5db94b526ef to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using matplotlib backend: Qt5Agg\n",
"Populating the interactive namespace from numpy and matplotlib\n"
]
}
],
"source": [
"%pylab"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch as T\n",
"from torch.autograd import Variable\n",
"import torch.nn.functional as F\n",
"from torch.autograd import Variable"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"filters = T.rand(8,300,300,2).cuda()\n",
"img = np.random.rand(300,300).astype(np.float32)\n",
"def conv_only(img, w, shape):\n",
" fimg = T.zeros(*w.shape[1:]).cuda()\n",
" fimg[:,:,0] = T.from_numpy(img).cuda()\n",
" conv = T.ifft(T.fft(fimg,2)*w,2)\n",
" return conv\n",
"\n",
"def find_position(img, w, shape):\n",
" fimg = T.zeros(*w.shape[1:]).cuda()\n",
" fimg[:,:,0] = T.from_numpy(img).cuda()\n",
" conv = T.ifft(T.fft(fimg,2)*w,2)\n",
" idx = conv.argmax()\n",
" return np.unravel_index(idx, shape)\n",
"\n",
"def find_position_2(img, w, shape):\n",
" fimg = T.zeros(*w.shape[1:]).cuda()\n",
" fimg[:,:,0] = T.from_numpy(img).cuda()\n",
" conv = T.ifft(T.fft(fimg,2)*w,2)\n",
" conv = conv.cpu().numpy()\n",
" idx = np.argmax(conv)\n",
" return np.unravel_index(idx, shape)\n",
"\n",
"shape = [*filters.shape]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# initialize CuBLAS\n",
"_ = conv_only(img,filters, shape)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.17 ms ± 61.2 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)\n"
]
}
],
"source": [
"%%timeit -r 3 -n 3\n",
"conv_only(img,filters, shape)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"27.1 ms ± 828 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)\n"
]
}
],
"source": [
"%%timeit -r 3 -n 3\n",
"find_position(img,filters, shape)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3.25 ms ± 285 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)\n"
]
}
],
"source": [
"%%timeit -r 3 -n 3\n",
"find_position_2(img,filters, shape)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment