Skip to content

Instantly share code, notes, and snippets.

@djsutherland
Created June 6, 2017 23:11
Show Gist options
  • Save djsutherland/45d05765c6291e79489c8c31e1ca99ba to your computer and use it in GitHub Desktop.
Save djsutherland/45d05765c6291e79489c8c31e1ca99ba to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import ctypes\n",
"import numpy as np\n",
"from numpy.ctypeslib import ndpointer"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"mkl = ctypes.cdll['libmkl_rt.dylib']"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# cblas constants, according to\n",
"# https://github.com/nicholas-moreles/blaspy/blob/master/blaspy/helpers.py\n",
"ROW_MAJOR = 101\n",
"NO_TRANS = 111"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"cblas_dgemm_batch = mkl.cblas_dgemm_batch\n",
"cblas_dgemm_batch.argtypes = [\n",
" ctypes.c_int, ndpointer(dtype=np.int32), ndpointer(dtype=np.int32), # layout, transA, transB\n",
" ndpointer(dtype=np.int32), ndpointer(dtype=np.int32), ndpointer(dtype=np.int32), # m, n, k\n",
" ndpointer(dtype=np.float64), # alpha\n",
" ctypes.POINTER(ctypes.POINTER(ctypes.c_double)), ndpointer(dtype=np.int32), # a, lda\n",
" ctypes.POINTER(ctypes.POINTER(ctypes.c_double)), ndpointer(dtype=np.int32), # b, ldb\n",
" ndpointer(dtype=np.float64), # beta\n",
" ctypes.POINTER(ctypes.POINTER(ctypes.c_double)), ndpointer(dtype=np.int32), # c, ldc\n",
" ctypes.c_int, ndpointer(np.int32), # group_count, group_size\n",
"]\n",
"cblas_dgemm_batch.restype = None"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sz = 10000\n",
"m, n, k = 9, 9, 9\n",
"A = np.random.randn(sz, m, k)\n",
"B = np.random.randn(sz, k, n)\n",
"C = np.empty((sz, m, n), dtype=np.float64)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# set up arguments\n",
"i = lambda x: np.array([x], dtype=np.int32)\n",
"d = lambda x: np.array([x], dtype=np.float64)\n",
"\n",
"t = ctypes.POINTER(ctypes.c_double) * sz\n",
"a_array = t()\n",
"b_array = t()\n",
"c_array = t()\n",
"for idx in range(sz):\n",
" a_array[idx] = A[idx].ctypes.data_as(ctypes.POINTER(ctypes.c_double))\n",
" b_array[idx] = B[idx].ctypes.data_as(ctypes.POINTER(ctypes.c_double))\n",
" c_array[idx] = C[idx].ctypes.data_as(ctypes.POINTER(ctypes.c_double))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The slowest run took 45.63 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
"1000 loops, best of 3: 1.22 ms per loop\n"
]
}
],
"source": [
"%%timeit\n",
"cblas_dgemm_batch(\n",
" ROW_MAJOR, i(NO_TRANS), i(NO_TRANS),\n",
" i(m), i(n), i(k),\n",
" d(1),\n",
" a_array, i(k),\n",
" b_array, i(n),\n",
" d(0),\n",
" c_array, i(n),\n",
" 1, i(sz)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 loop, best of 3: 183 ms per loop\n"
]
}
],
"source": [
"%%timeit\n",
"C = [np.ctypeslib.as_array(ary, (m, n)) for ary in c_array]"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"C = [np.ctypeslib.as_array(ary, (m, n)) for ary in c_array]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.allclose(C, A @ B)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10 loops, best of 3: 37.4 ms per loop\n"
]
}
],
"source": [
"%timeit A @ B"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For reference, here's the equivalent call to `dgemm`."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cblas_dgemm = mkl.cblas_dgemm\n",
"cblas_dgemm.argtypes = [\n",
" ctypes.c_int, ctypes.c_int, ctypes.c_int, # layout, transA, transB\n",
" ctypes.c_int, ctypes.c_int, ctypes.c_int, # m, n, k\n",
" ctypes.c_double, # alpha\n",
" ndpointer(dtype=np.float64, ndim=2), ctypes.c_int, # a, lda\n",
" ndpointer(dtype=np.float64, ndim=2), ctypes.c_int, # b, ldb\n",
" ctypes.c_double, # beta\n",
" ndpointer(dtype=np.float64, ndim=2, flags='WRITEABLE'), ctypes.c_int, # c, ldc\n",
"]\n",
"cblas_dgemm.restype = None\n",
"\n",
"A = np.random.randn(9, 9)\n",
"B = np.random.randn(9, 9)\n",
"C = np.empty((A.shape[0], B.shape[0]))\n",
"\n",
"cblas_dgemm(\n",
" ROW_MAJOR, NO_TRANS, NO_TRANS,\n",
" A.shape[0], B.shape[1], A.shape[1],\n",
" 1,\n",
" A, A.shape[0],\n",
" B, B.shape[0],\n",
" 0,\n",
" C, C.shape[0]\n",
")\n",
"\n",
"np.allclose(C, A @ B)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment