Created
April 30, 2018 23:20
-
-
Save dnlcrl/9e9ee03105a8dc60a5ae7dced0837ee4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[ 0., 1., 2.],\n", | |
" [ 3., 4., 5.]])" | |
] | |
}, | |
"execution_count": 1, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import torch\n", | |
"a = torch.arange(6).reshape(2, 3)\n", | |
"a" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"TRANSPOSE" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"2.73 µs ± 75.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[ 0., 3.],\n", | |
" [ 1., 4.],\n", | |
" [ 2., 5.]])" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%timeit torch.einsum('ij->ji', [a])\n", | |
"torch.einsum('ij->ji', [a])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1.22 µs ± 29.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[ 0., 3.],\n", | |
" [ 1., 4.],\n", | |
" [ 2., 5.]])" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%timeit a.t()\n", | |
"a.t()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"SUM" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"7.31 µs ± 208 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(15.)" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%timeit torch.einsum('ij->', [a])\n", | |
"torch.einsum('ij->', [a])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1.58 µs ± 44.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(15.)" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%timeit a.sum()\n", | |
"a.sum()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"COLUMN SUM" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"5.42 µs ± 49.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([ 3., 5., 7.])" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%timeit torch.einsum('ij->j', [a])\n", | |
"torch.einsum('ij->j', [a])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"2.42 µs ± 64.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([ 3., 5., 7.])" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%timeit torch.sum(a, 0)\n", | |
"torch.sum(a, 0)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"ROW SUM" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"4.81 µs ± 27.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([ 3., 12.])" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%timeit torch.einsum('ij->i', [a])\n", | |
"torch.einsum('ij->i', [a])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1.92 µs ± 103 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([ 3., 12.])" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%timeit torch.sum(a, 1)\n", | |
"torch.sum(a, 1)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"MATRIX-MATRIX MULTIPLICATION" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"a = torch.arange(6).reshape(2, 3)\n", | |
"b = torch.arange(15).reshape(3, 5)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"23.5 µs ± 715 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[ 25., 28., 31., 34., 37.],\n", | |
" [ 70., 82., 94., 106., 118.]])" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%timeit torch.einsum('ik,kj->ij', [a, b])\n", | |
"torch.einsum('ik,kj->ij', [a, b])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1.77 µs ± 12.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[ 25., 28., 31., 34., 37.],\n", | |
" [ 70., 82., 94., 106., 118.]])" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%timeit torch.mm(a, b)\n", | |
"torch.mm(a, b)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"MATRIX-VECTOR MULTIPLICATION" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"a = torch.arange(6).reshape(2, 3)\n", | |
"b = torch.arange(3)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"19.3 µs ± 799 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([ 5., 14.])" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%timeit torch.einsum('ik,k->i', [a, b])\n", | |
"torch.einsum('ik,k->i', [a, b])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1.51 µs ± 20.2 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([ 5., 14.])" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%timeit torch.mv(a,b)\n", | |
"torch.mv(a,b)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"DOT PRODUCT" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"vector" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"a = torch.arange(3)\n", | |
"b = torch.arange(3,6) # -- a vector of length 3 containing [3, 4, 5]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"17.1 µs ± 913 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(14.)" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%timeit torch.einsum('i,i->', [a, b])\n", | |
"torch.einsum('i,i->', [a, b])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1.92 µs ± 102 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(14.)" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%timeit torch.dot(a,b)\n", | |
"torch.dot(a,b)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"matrix" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"a = torch.arange(6).reshape(2, 3)\n", | |
"b = torch.arange(6,12).reshape(2, 3)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"18.6 µs ± 1.32 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(145.)" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%timeit torch.einsum('ij,ij->', [a, b])\n", | |
"torch.einsum('ij,ij->', [a, b])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"6.13 µs ± 85.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(145.)" | |
] | |
}, | |
"execution_count": 21, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%timeit torch.dot(a.view(-1),b.view(-1))\n", | |
"torch.dot(a.view(-1),b.view(-1))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"HADAMARD PRODUCT" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"a = torch.arange(6).reshape(2, 3)\n", | |
"b = torch.arange(6,12).reshape(2, 3)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"5.56 µs ± 72.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[ 0., 7., 16.],\n", | |
" [ 27., 40., 55.]])" | |
] | |
}, | |
"execution_count": 23, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%timeit torch.einsum('ij,ij->ij', [a, b])\n", | |
"torch.einsum('ij,ij->ij', [a, b])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1.31 µs ± 14.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[ 0., 7., 16.],\n", | |
" [ 27., 40., 55.]])" | |
] | |
}, | |
"execution_count": 24, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%timeit a*b\n", | |
"a*b" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"OUTER PRODUCT" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"a = torch.arange(3)\n", | |
"b = torch.arange(3,7) # -- a vector of length 4 containing [3, 4, 5, 6]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"10.8 µs ± 622 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[ 0., 0., 0., 0.],\n", | |
" [ 3., 4., 5., 6.],\n", | |
" [ 6., 8., 10., 12.]])" | |
] | |
}, | |
"execution_count": 26, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%timeit torch.einsum('i,j->ij', [a, b])\n", | |
"torch.einsum('i,j->ij', [a, b])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1.64 µs ± 51 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[ 0., 0., 0., 0.],\n", | |
" [ 3., 4., 5., 6.],\n", | |
" [ 6., 8., 10., 12.]])" | |
] | |
}, | |
"execution_count": 27, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%timeit torch.ger(a, b)\n", | |
"torch.ger(a, b)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"BATCH MATRIX MULTIPLICATION" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"a = torch.randn(3,2,5)\n", | |
"b = torch.randn(3,5,3)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"24 µs ± 422 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[-3.6068, 3.6341, 3.4859],\n", | |
" [ 2.3148, 2.5504, 3.8194]],\n", | |
"\n", | |
" [[ 2.3448, 2.5390, -0.1359],\n", | |
" [ 3.4580, 3.4026, 0.0316]],\n", | |
"\n", | |
" [[-2.1875, -3.7540, 4.1446],\n", | |
" [ 1.5737, -0.2249, -0.2547]]])" | |
] | |
}, | |
"execution_count": 29, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%timeit torch.einsum('ijk,ikl->ijl', [a, b])\n", | |
"torch.einsum('ijk,ikl->ijl', [a, b])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"4.81 µs ± 150 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[-3.6068, 3.6341, 3.4859],\n", | |
" [ 2.3148, 2.5504, 3.8194]],\n", | |
"\n", | |
" [[ 2.3448, 2.5390, -0.1359],\n", | |
" [ 3.4580, 3.4026, 0.0316]],\n", | |
"\n", | |
" [[-2.1875, -3.7540, 4.1446],\n", | |
" [ 1.5737, -0.2249, -0.2547]]])" | |
] | |
}, | |
"execution_count": 30, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%timeit a.bmm(b)\n", | |
"a.bmm(b)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"TENSOR MULTIPLICATION" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"a = torch.randn(2,3,5,7)\n", | |
"b = torch.randn(11,13,3,17,5)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"210 µs ± 27.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([2, 7, 11, 13, 17])" | |
] | |
}, | |
"execution_count": 32, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%timeit torch.einsum('pqrs,tuqvr->pstuv', [a, b]).shape\n", | |
"torch.einsum('pqrs,tuqvr->pstuv', [a, b]).shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"83.7 µs ± 5.52 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([2, 7, 11, 13, 17])" | |
] | |
}, | |
"execution_count": 33, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%timeit torch.mm(a.transpose(1,3).transpose(2,3).reshape(2*7, 3*5), \\\n", | |
"b.transpose(4, 3).transpose(1, 3).transpose(0, 2).reshape(3*5, 11*13*17)).reshape(2, 7, 11, 13, 17).shape\n", | |
"\n", | |
"torch.mm(a.transpose(1,3).transpose(2,3).reshape(2*7, 3*5), \\\n", | |
"b.transpose(4, 3).transpose(1, 3).transpose(0, 2).reshape(3*5, 11*13*17)).reshape(2, 7, 11, 13, 17).shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(1, dtype=torch.uint8)" | |
] | |
}, | |
"execution_count": 34, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"(torch.einsum('pqrs,tuqvr->pstuv', [a, b]) == torch.mm(a.transpose(1,3).transpose(2,3).reshape(2*7, 3*5), \\\n", | |
"b.transpose(4, 3).transpose(1, 3).transpose(0, 2).reshape(3*5, 11*13*17)).reshape(2, 7, 11, 13, 17)).all()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"BILINEAR TRANSFORMATION" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"a = torch.randn(2,3)\n", | |
"b = torch.randn(5,3,7)\n", | |
"c = torch.randn(2,7)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 36, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"51.3 µs ± 2.25 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[-7.8614, 1.6084, 1.8052, 2.3681, 1.1696],\n", | |
" [ 5.7942, -1.5822, -4.0773, 1.1712, 0.1531]])" | |
] | |
}, | |
"execution_count": 36, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%timeit torch.einsum('ik,jkl,il->ij', [a, b, c])\n", | |
"torch.einsum('ik,jkl,il->ij', [a, b, c])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 37, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"37 µs ± 1.26 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[-7.8614, 1.6084, 1.8052, 2.3681, 1.1696],\n", | |
" [ 5.7942, -1.5822, -4.0773, 1.1712, 0.1531]])" | |
] | |
}, | |
"execution_count": 37, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%timeit torch.mm(torch.mm(a, b.transpose(0, 1).reshape(3, -1)).reshape(-1, 7), c.transpose(1, 0))\\\n", | |
".view(-1).gather(0, torch.stack([torch.range(0, 9, 2), torch.range(11, 19, 2)]).view(-1).long()).reshape(2, 5)\n", | |
"\n", | |
"torch.mm(torch.mm(a, b.transpose(0, 1).reshape(3, -1)).reshape(-1, 7), c.transpose(1, 0))\\\n", | |
".view(-1).gather(0, torch.stack([torch.range(0, 9, 2), torch.range(11, 19, 2)]).view(-1).long()).reshape(2, 5)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"using .arange" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 44, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"34.8 µs ± 929 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[-7.8614, 1.6084, 1.8052, 2.3681, 1.1696],\n", | |
" [ 5.7942, -1.5822, -4.0773, 1.1712, 0.1531]])" | |
] | |
}, | |
"execution_count": 44, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%timeit torch.mm(torch.mm(a, b.transpose(0, 1).reshape(3, -1)).reshape(-1, 7), c.transpose(1, 0))\\\n", | |
".view(-1).gather(0, torch.stack([torch.arange(0, 10, 2), torch.arange(11, 20, 2)]).view(-1).long()).reshape(2, 5)\n", | |
"\n", | |
"torch.mm(torch.mm(a, b.transpose(0, 1).reshape(3, -1)).reshape(-1, 7), c.transpose(1, 0))\\\n", | |
".view(-1).gather(0, torch.stack([torch.arange(0, 10, 2), torch.arange(11, 20, 2)]).view(-1).long()).reshape(2, 5)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
source: https://rockt.github.io/2018/04/30/einsum#fn.1