-
-
Save msaroufim/3dfe4e7769fc9139d1489be8582659d8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# import triton | |
# @triton.jit | |
# def kernel(a, b): | |
# return a + b | |
# # kernel.compile('kernel', return_type=triton.int32) | |
# binary = kernel[(1, 1)](1, 1) | |
# print(binary.asm) | |
import torch | |
import triton | |
import triton.language as tl | |
@triton.jit | |
def add_kernel( | |
x_ptr, # *Pointer* to first input vector. | |
y_ptr, # *Pointer* to second input vector. | |
output_ptr, # *Pointer* to output vector. | |
n_elements, # Size of the vector. | |
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. | |
# NOTE: `constexpr` so it can be used as a shape value. | |
): | |
# There are multiple 'programs' processing different data. We identify which program | |
# we are here: | |
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. | |
# This program will process inputs that are offset from the initial data. | |
# For instance, if you had a vector of length 256 and block_size of 64, the programs | |
# would each access the elements [0:64, 64:128, 128:192, 192:256]. | |
# Note that offsets is a list of pointers: | |
block_start = pid * BLOCK_SIZE | |
offsets = block_start + tl.arange(0, BLOCK_SIZE) | |
# Create a mask to guard memory operations against out-of-bounds accesses. | |
mask = offsets < n_elements | |
# Load x and y from DRAM, masking out any extra elements in case the input is not a | |
# multiple of the block size. | |
x = tl.load(x_ptr + offsets, mask=mask) | |
y = tl.load(y_ptr + offsets, mask=mask) | |
output = x + y | |
# Write x + y back to DRAM. | |
tl.store(output_ptr + offsets, output, mask=mask) | |
def add(x: torch.Tensor, y: torch.Tensor): | |
# We need to preallocate the output. | |
output = torch.empty_like(x) | |
assert x.is_cuda and y.is_cuda and output.is_cuda | |
n_elements = output.numel() | |
# The SPMD launch grid denotes the number of kernel instances that run in parallel. | |
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]. | |
# In this case, we use a 1D grid where the size is the number of blocks: | |
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) | |
# NOTE: | |
# - Each torch.tensor object is implicitly converted into a pointer to its first element. | |
# - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel. | |
# - Don't forget to pass meta-parameters as keywords arguments. | |
# Instead of returning output return binary only also return output | |
binary = add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) | |
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still | |
# running asynchronously at this point. | |
return binary, output | |
torch.manual_seed(0) | |
size = 98432 | |
x = torch.rand(size, device='cuda') | |
y = torch.rand(size, device='cuda') | |
binary, output = add(x, y) | |
print(binary.asm["ptx"]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment