Last active
May 2, 2019 00:03
-
-
Save mkolod/9dfe13db0b93064093ba23894716f124 to your computer and use it in GitHub Desktop.
CUDA pointwise fuser
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## TODO:\n", | |
"\n", | |
"1. Graph API (keep the RPN API as well):\n", | |
" a = Tensor([1., 2., 3.])\n", | |
" b = Tensor.randn(3)\n", | |
" c = Sin(a)\n", | |
" d = Add(c, b)\n", | |
" e = Pow(d, 2.0)\n", | |
" \n", | |
"2. GraphViz to show compute graph\n", | |
"\n", | |
"3. Tensor placement on CPU and GPU\n", | |
"\n", | |
"4. Autograd" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pycuda.autoinit\n", | |
"import pycuda.driver as drv\n", | |
"import numpy as np\n", | |
"\n", | |
"from pycuda.compiler import SourceModule" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class StackElem:\n", | |
" def __init__(self):\n", | |
" pass\n", | |
" \n", | |
"class Assign(StackElem):\n", | |
" def __init__(self, value, name, dtype):\n", | |
" super(Assign, self).__init__()\n", | |
" self.value = value\n", | |
" self.name = name \n", | |
" self.dtype = dtype\n", | |
" \n", | |
" def forward(self, symbolic=True):\n", | |
" return self\n", | |
" \n", | |
" def __repr__(self):\n", | |
" return \" {} {} = {}\".format(self.dtype, self.name, self.value)\n", | |
" \n", | |
"class Constant(StackElem):\n", | |
" def __init__(self, value, name, dtype):\n", | |
" super(Constant, self).__init__()\n", | |
" self.value = value\n", | |
" self.name = name\n", | |
" self.dtype = dtype.__name__\n", | |
" \n", | |
" def forward(self, symbolic=True):\n", | |
" return self\n", | |
" \n", | |
" def __repr__(self):\n", | |
" return self.name\n", | |
" \n", | |
"class Tensor(StackElem):\n", | |
" def __init__(self, name, dtype, value=np.array([])):\n", | |
" super(Tensor, self).__init__()\n", | |
" self.value = value\n", | |
" self.name = name\n", | |
" self.dtype = dtype.__name__\n", | |
" \n", | |
" def forward(self, symbolic=True):\n", | |
" return self\n", | |
" \n", | |
" def __repr__(self):\n", | |
" return self.name\n", | |
" \n", | |
"class Op(StackElem):\n", | |
" def __init__(self):\n", | |
" super(StackElem, self).__init__()\n", | |
"\n", | |
"class BinaryOp(Op):\n", | |
" def __init__(self):\n", | |
" super(Op, self).__init__()\n", | |
" \n", | |
" def forward(self, x, y):\n", | |
" raise NotImplementedError\n", | |
" \n", | |
" def backward(self, x, y):\n", | |
" raise NotImplementedError\n", | |
" \n", | |
"class UnaryOp(Op):\n", | |
" def __init__(self):\n", | |
" super(UnaryOp, self).__init__()\n", | |
" \n", | |
" def forward(self, x):\n", | |
" raise NotImpoementedError\n", | |
" \n", | |
"class Plus(BinaryOp):\n", | |
" def __init__(self):\n", | |
" super(Plus, self).__init__()\n", | |
" \n", | |
" def forward(self, x, y, symbolic=True):\n", | |
" idx = \"[i]\" if isinstance(x, Tensor) else \"\"\n", | |
" if symbolic:\n", | |
" x = x.name\n", | |
" y = y.name\n", | |
" else:\n", | |
" x = x.value\n", | |
" y = y.value\n", | |
" return \"{}{} + {}[i];\".format(x, idx, y)\n", | |
" \n", | |
"class Pow(BinaryOp):\n", | |
" def __init__(self, fast_math=False):\n", | |
" super(Pow, self).__init__()\n", | |
" self.fast_math = fast_math\n", | |
" \n", | |
" def forward(self, x, y, symbolic=True):\n", | |
" op = \"__powf\" if self.fast_math else \"powf\"\n", | |
" idx1 = \"[i]\" if isinstance(x, Tensor) else \"\"\n", | |
" idx2 = \"[i]\" if isinstance(y, Tensor) else \"\"\n", | |
" if symbolic:\n", | |
" x = x.name\n", | |
" y = y.name\n", | |
" else:\n", | |
" x = x.value\n", | |
" y = y.value\n", | |
" return \"\"\"{}({}{}, {}{})\"\"\".format(op, x, idx1, y, idx2)\n", | |
"\n", | |
"class Sin(UnaryOp):\n", | |
" def __init__(self):\n", | |
" super(Sin, self).__init__()\n", | |
" \n", | |
" def forward(self, x, symbolic=True):\n", | |
" idx = \"[i]\" if isinstance(x, Tensor) else \"\"\n", | |
" if x.dtype == \"float\":\n", | |
" op = \"sinf\"\n", | |
" elif x.dtype == \"half\":\n", | |
" op = \"hsin\"\n", | |
" if symbolic:\n", | |
" x = x.name\n", | |
" else:\n", | |
" x = x.value\n", | |
" return \"{}({}{})\".format(op, x, idx)\n", | |
"\n", | |
"\n", | |
"class Tanh(UnaryOp):\n", | |
" def __init__(self):\n", | |
" super(Tanh, self).__init__()\n", | |
" \n", | |
" def forward(self, x, symbolic=True):\n", | |
" idx = \"[i]\" if isinstance(x, Tensor) else \"\"\n", | |
" if x.dtype == \"float\":\n", | |
" op = \"tanh\"\n", | |
" elif x.dtype == \"half\":\n", | |
" op = \"tanh\"\n", | |
" if symbolic:\n", | |
" x = x.name\n", | |
" else:\n", | |
" x = x.value\n", | |
" return \"{}({}{})\".format(op, x, idx)\n", | |
" \n", | |
"class ReLU(UnaryOp):\n", | |
" def __init__(self):\n", | |
" super(ReLU, self).__init__()\n", | |
" \n", | |
" def forward(self, x, symbolic=True):\n", | |
" idx = \"[i]\" if isinstance(x, Tensor) else \"\"\n", | |
" if x.dtype == \"float\":\n", | |
" op = \"max\"\n", | |
" elif x.dtype == \"half\":\n", | |
" op = \"max\"\n", | |
" if symbolic:\n", | |
" x = \"{}\".format(x.name)\n", | |
" else:\n", | |
" x = x.value\n", | |
" return \"{}(0.0f, {}{});\".format(op, x, idx)\n", | |
" \n", | |
" def backward(self, x, symbolic=True):\n", | |
" idx = \"[i]\" if isinstance(x, Tensor) else \"\" \n", | |
" if symbolic:\n", | |
" x = x.name\n", | |
" else:\n", | |
" x = x.value\n", | |
" return \"unimplemented({}{})\".format(x, idx)\n", | |
" \n", | |
"class Neg(UnaryOp):\n", | |
" def __init__(self):\n", | |
" super(Neg, self).__init__()\n", | |
" \n", | |
" def forward(self, x, symbolic=True):\n", | |
" idx = \"[i]\" if isinstance(x, Tensor) else \"\"\n", | |
" if symbolic:\n", | |
" x = x.name\n", | |
" else:\n", | |
" x = x.value\n", | |
" return \"-{}{};\".format(x, idx)\n", | |
" \n", | |
"class Cast(UnaryOp):\n", | |
" def __init__(self, dtype):\n", | |
" super(Cast, self).__init__()\n", | |
" self.dtype = dtype\n", | |
" \n", | |
" def forward(self, x, symbolic=True):\n", | |
" if symbolic:\n", | |
" x = x.name\n", | |
" else:\n", | |
" x = x.value\n", | |
" return \"Cast({}, {})\".format(x, self.dtype)\n", | |
" \n", | |
" def backward(self, x, symbolic=True):\n", | |
" if symbolic:\n", | |
" x = x.name\n", | |
" else:\n", | |
" x = x.value\n", | |
" return \"-{}\".format(x)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Stack:\n", | |
" def __init__(self):\n", | |
" self.__storage = []\n", | |
" \n", | |
" def empty(self):\n", | |
" return len(self.__storage) == 0\n", | |
" \n", | |
" def push(self,p):\n", | |
" self.__storage.append(p)\n", | |
" \n", | |
" def pop(self):\n", | |
" return self.__storage.pop()\n", | |
" \n", | |
" def __len__(self):\n", | |
" return len(self.__storage)\n", | |
" \n", | |
" def __repr__(self):\n", | |
" return repr(self.__storage)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def generate(postfix, fun_name, out_type, forward=True, intermediate=True):\n", | |
" \n", | |
" result = \"\"\n", | |
" operands = []\n", | |
" ctr = -1\n", | |
" stack = Stack()\n", | |
" \n", | |
" inputs = []\n", | |
" last_input_var = None\n", | |
" result = \"\"\n", | |
" \n", | |
" is_fp16 = False\n", | |
" \n", | |
" for x in postfix:\n", | |
" \n", | |
" if isinstance(x, Tensor):\n", | |
" inputs.append(\"{}* __restrict__ {}\".format(x.dtype, x))\n", | |
" stack.push(x)\n", | |
" if x.dtype == \"half\":\n", | |
" is_fp16 = True\n", | |
" \n", | |
" if isinstance(x, Assign):\n", | |
" stack.push(x)\n", | |
" if x.dtype == \"half\":\n", | |
" is_fp16 = True \n", | |
" \n", | |
" if isinstance(x, UnaryOp):\n", | |
" ctr += 1\n", | |
" a = stack.pop()\n", | |
" res = Assign(x.forward(a), \"_{}\".format(ctr), a.dtype)\n", | |
" result += \"{}\\n\".format(res if intermediate else res.value)\n", | |
" stack.push(res)\n", | |
" \n", | |
" if isinstance(x, BinaryOp):\n", | |
" ctr += 1\n", | |
" a = stack.pop()\n", | |
" b = stack.pop()\n", | |
" \n", | |
" assert a.dtype == b.dtype, \"operand types don't mstch: (a: {}, b: {})\".format(a.dtype, b.dtype)\n", | |
" \n", | |
" if forward:\n", | |
" res = Assign(x.forward(b, a), \"_{}\".format(ctr), a.dtype)\n", | |
" else:\n", | |
" res = Assign(x.backward(b, a), \"_{}\".format(ctr), a.dtype)\n", | |
" \n", | |
" result += \"{}\\n\".format(res if intermediate else res.value)\n", | |
" stack.push(res)\n", | |
" \n", | |
" result += \" out[i] = _{};\".format(ctr)\n", | |
" \n", | |
" if is_fp16:\n", | |
" out = \"cuda_fp16.h\\n\"\n", | |
" \n", | |
" out = \"__global__ void {}(\\n\".format(fun_name)\n", | |
" \n", | |
" for i in inputs:\n", | |
" out += \" {},\\n\".format(i)\n", | |
" \n", | |
" out += \" {}* __restrict__ out,\\n\".format(out_type.__name__)\n", | |
" out += \" int numel) {\\n\"\n", | |
" \n", | |
" out += \"\\n // grid-stride loop\"\n", | |
" out += \"\\n for (int i = blockIdx.x * blockDim.x + threadIdx.x;\\n\"\n", | |
" out += \" i < numel;\\n\"\n", | |
" out += \" i += blockDim.x * gridDim.x) {\\n\\n\"\n", | |
" \n", | |
" out += \"{}\\n\".format(result)\n", | |
" out += \" }\\n\"\n", | |
" out += \"}\\n\"\n", | |
" \n", | |
" return out" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"__global__ void fused_foo(\n", | |
" float* __restrict__ x,\n", | |
" float* __restrict__ y,\n", | |
" float* __restrict__ out,\n", | |
" int numel) {\n", | |
"\n", | |
" // grid-stride loop\n", | |
" for (int i = blockIdx.x * blockDim.x + threadIdx.x;\n", | |
" i < numel;\n", | |
" i += blockDim.x * gridDim.x) {\n", | |
"\n", | |
" float _0 = x[i] + y[i];\n", | |
" float _1 = -_0;\n", | |
" out[i] = _1;\n", | |
" }\n", | |
"}\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"rpn_from_graph = [\n", | |
" Tensor(\"x\", float),\n", | |
" Tensor(\"y\", float),\n", | |
" Plus(),\n", | |
" Neg()\n", | |
"]\n", | |
"\n", | |
"print(generate(rpn_from_graph, \"fused_foo\", float))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"__global__ void fused_foo(\n", | |
" float* __restrict__ x,\n", | |
" float* __restrict__ y,\n", | |
" float* __restrict__ out,\n", | |
" int numel) {\n", | |
"\n", | |
" // grid-stride loop\n", | |
" for (int i = blockIdx.x * blockDim.x + threadIdx.x;\n", | |
" i < numel;\n", | |
" i += blockDim.x * gridDim.x) {\n", | |
"\n", | |
" float _0 = x[i] + y[i];\n", | |
" float _1 = -_0;\n", | |
" out[i] = _1;\n", | |
" }\n", | |
"}\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"# For more information about PyCUDA, see:\n", | |
"# https://documen.tician.de/pycuda/\n", | |
"\n", | |
"kernel_name = \"fused_foo\"\n", | |
"source = generate(rpn_from_graph, kernel_name, float)\n", | |
"print(source)\n", | |
"mod = SourceModule(source)\n", | |
"\n", | |
"kernel = mod.get_function(kernel_name)\n", | |
"\n", | |
"numel = 128\n", | |
"a = np.random.randn(numel).astype(np.float32)\n", | |
"b = np.random.randn(numel).astype(np.float32)\n", | |
"dest = np.zeros_like(a)\n", | |
"numel_np = np.int32(numel)\n", | |
"\n", | |
"kernel(\n", | |
" drv.In(a), drv.In(b), drv.Out(dest), numel_np,\n", | |
" block=(numel,1,1), grid=(1,1))\n", | |
"\n", | |
"out = dest - (-(a + b))\n", | |
"assert (out == np.zeros_like(a)).all(), \"not all values agree\"" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment