Skip to content

Instantly share code, notes, and snippets.

@yaoyaoding
Created October 16, 2023 17:39
Show Gist options
  • Save yaoyaoding/ab4ffe3edd3a66d68780572f7eb8b0e3 to your computer and use it in GitHub Desktop.
Save yaoyaoding/ab4ffe3edd3a66d68780572f7eb8b0e3 to your computer and use it in GitHub Desktop.
Batch update of blocks
from typing import List
from functools import lru_cache
import torch
import hidet
hidet.option.cache_dir('./outs/cache')
@lru_cache(maxsize=None)
def build_kernel():
from hidet.lang import attrs, printf
from hidet.lang.types import int32, float16, tensor_pointer
from hidet.lang.cuda import threadIdx, blockIdx, syncthreads
from hidet.lang.mapping import spatial, repeat
block_size = 64
num_threads = 256
with hidet.script_module() as script_module:
@hidet.script
def kernel(
weights: ~(~float16),
weight_rows: ~int32,
weight_columns: ~int32,
num_blocks: int,
blocks: ~(~float16),
weight_indices: ~int32,
block_x: ~int32,
block_y: ~int32,
):
attrs.func_kind = 'cuda_kernel'
attrs.cuda.block_dim = num_threads
attrs.cuda.grid_dim = num_blocks
block_idx = blockIdx.x
weight_idx = weight_indices[block_idx]
weight_row = weight_rows[weight_idx]
weight_column = weight_columns[weight_idx]
weight_ptr = weights[weight_idx]
weight = tensor_pointer(dtype=float16, shape=[weight_row, weight_column], init=weight_ptr)
block_ptr = blocks[block_idx]
bi = block_x[block_idx]
bj = block_y[block_idx]
block = tensor_pointer(dtype=float16, shape=[block_size, block_size], init=block_ptr)
for i, j in repeat(16, 1).spatial(4, 64).on(threadIdx.x):
weight[bi + i, bj + j] += block[i, j]
return script_module.build()
def block_update(
weights: List[torch.Tensor],
blocks: List[torch.Tensor],
weight_indices: List[int],
block_x: List[int],
block_y: List[int]
):
kernel_func = build_kernel()
# prepare arguments
weight_ptrs = torch.asarray([w.data_ptr() for w in weights], dtype=torch.int64, device='cuda')
weight_rows = torch.asarray([w.size(0) for w in weights], dtype=torch.int32, device='cuda')
weight_columns = torch.asarray([w.size(1) for w in weights], dtype=torch.int32, device='cuda')
num_blocks = len(blocks)
block_ptrs = torch.asarray([b.data_ptr() for b in blocks], dtype=torch.int64, device='cuda')
weight_indices = torch.asarray(weight_indices, dtype=torch.int32, device='cuda')
block_x = torch.asarray(block_x, dtype=torch.int32, device='cuda')
block_y = torch.asarray(block_y, dtype=torch.int32, device='cuda')
# call kernel
kernel_func(weight_ptrs, weight_rows, weight_columns, num_blocks, block_ptrs, weight_indices, block_x, block_y)
def block_update_ref(
weights: List[torch.Tensor],
blocks: List[torch.Tensor],
weight_indices: List[int],
block_x: List[int],
block_y: List[int]
):
for b, wi, bx, by in zip(blocks, weight_indices, block_x, block_y):
weights[wi][bx:bx + 64, by:by + 64] += b
def run_test(
weights,
blocks,
weight_indices,
block_x,
block_y
):
weights_ref = [w.clone() for w in weights]
blocks_ref = [b.clone() for b in blocks]
block_update(weights, blocks, weight_indices, block_x, block_y)
torch.cuda.synchronize()
block_update_ref(weights_ref, blocks_ref, weight_indices, block_x, block_y)
torch.cuda.synchronize()
for w, w_ref in zip(weights, weights_ref):
torch.testing.assert_close(actual=w, expected=w_ref, rtol=1e-3, atol=1e-3)
def demo_usage():
w1 = torch.randn(256, 256, dtype=torch.float16, device='cuda')
w2 = torch.randn(1024, 1024, dtype=torch.float16, device='cuda')
b1 = torch.randn(64, 64, dtype=torch.float16, device='cuda')
b2 = torch.randn(64, 64, dtype=torch.float16, device='cuda')
b3 = torch.randn(64, 64, dtype=torch.float16, device='cuda')
weights = [w1, w2]
blocks = [b1, b2, b3]
weight_indices = [0, 1, 0] # the index of the weight to be updated for each block
block_x = [0, 512, 64] # the x coordinate of the block in the weight (left-top corner)
block_y = [0, 64, 128] # the y coordinate of the block in the weight (left-top corner)
run_test(weights, blocks, weight_indices, block_x, block_y)
if __name__ == '__main__':
demo_usage()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment