Created
October 16, 2023 17:39
-
-
Save yaoyaoding/ab4ffe3edd3a66d68780572f7eb8b0e3 to your computer and use it in GitHub Desktop.
Batch update of blocks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List | |
from functools import lru_cache | |
import torch | |
import hidet | |
hidet.option.cache_dir('./outs/cache') | |
@lru_cache(maxsize=None) | |
def build_kernel(): | |
from hidet.lang import attrs, printf | |
from hidet.lang.types import int32, float16, tensor_pointer | |
from hidet.lang.cuda import threadIdx, blockIdx, syncthreads | |
from hidet.lang.mapping import spatial, repeat | |
block_size = 64 | |
num_threads = 256 | |
with hidet.script_module() as script_module: | |
@hidet.script | |
def kernel( | |
weights: ~(~float16), | |
weight_rows: ~int32, | |
weight_columns: ~int32, | |
num_blocks: int, | |
blocks: ~(~float16), | |
weight_indices: ~int32, | |
block_x: ~int32, | |
block_y: ~int32, | |
): | |
attrs.func_kind = 'cuda_kernel' | |
attrs.cuda.block_dim = num_threads | |
attrs.cuda.grid_dim = num_blocks | |
block_idx = blockIdx.x | |
weight_idx = weight_indices[block_idx] | |
weight_row = weight_rows[weight_idx] | |
weight_column = weight_columns[weight_idx] | |
weight_ptr = weights[weight_idx] | |
weight = tensor_pointer(dtype=float16, shape=[weight_row, weight_column], init=weight_ptr) | |
block_ptr = blocks[block_idx] | |
bi = block_x[block_idx] | |
bj = block_y[block_idx] | |
block = tensor_pointer(dtype=float16, shape=[block_size, block_size], init=block_ptr) | |
for i, j in repeat(16, 1).spatial(4, 64).on(threadIdx.x): | |
weight[bi + i, bj + j] += block[i, j] | |
return script_module.build() | |
def block_update( | |
weights: List[torch.Tensor], | |
blocks: List[torch.Tensor], | |
weight_indices: List[int], | |
block_x: List[int], | |
block_y: List[int] | |
): | |
kernel_func = build_kernel() | |
# prepare arguments | |
weight_ptrs = torch.asarray([w.data_ptr() for w in weights], dtype=torch.int64, device='cuda') | |
weight_rows = torch.asarray([w.size(0) for w in weights], dtype=torch.int32, device='cuda') | |
weight_columns = torch.asarray([w.size(1) for w in weights], dtype=torch.int32, device='cuda') | |
num_blocks = len(blocks) | |
block_ptrs = torch.asarray([b.data_ptr() for b in blocks], dtype=torch.int64, device='cuda') | |
weight_indices = torch.asarray(weight_indices, dtype=torch.int32, device='cuda') | |
block_x = torch.asarray(block_x, dtype=torch.int32, device='cuda') | |
block_y = torch.asarray(block_y, dtype=torch.int32, device='cuda') | |
# call kernel | |
kernel_func(weight_ptrs, weight_rows, weight_columns, num_blocks, block_ptrs, weight_indices, block_x, block_y) | |
def block_update_ref( | |
weights: List[torch.Tensor], | |
blocks: List[torch.Tensor], | |
weight_indices: List[int], | |
block_x: List[int], | |
block_y: List[int] | |
): | |
for b, wi, bx, by in zip(blocks, weight_indices, block_x, block_y): | |
weights[wi][bx:bx + 64, by:by + 64] += b | |
def run_test( | |
weights, | |
blocks, | |
weight_indices, | |
block_x, | |
block_y | |
): | |
weights_ref = [w.clone() for w in weights] | |
blocks_ref = [b.clone() for b in blocks] | |
block_update(weights, blocks, weight_indices, block_x, block_y) | |
torch.cuda.synchronize() | |
block_update_ref(weights_ref, blocks_ref, weight_indices, block_x, block_y) | |
torch.cuda.synchronize() | |
for w, w_ref in zip(weights, weights_ref): | |
torch.testing.assert_close(actual=w, expected=w_ref, rtol=1e-3, atol=1e-3) | |
def demo_usage(): | |
w1 = torch.randn(256, 256, dtype=torch.float16, device='cuda') | |
w2 = torch.randn(1024, 1024, dtype=torch.float16, device='cuda') | |
b1 = torch.randn(64, 64, dtype=torch.float16, device='cuda') | |
b2 = torch.randn(64, 64, dtype=torch.float16, device='cuda') | |
b3 = torch.randn(64, 64, dtype=torch.float16, device='cuda') | |
weights = [w1, w2] | |
blocks = [b1, b2, b3] | |
weight_indices = [0, 1, 0] # the index of the weight to be updated for each block | |
block_x = [0, 512, 64] # the x coordinate of the block in the weight (left-top corner) | |
block_y = [0, 64, 128] # the y coordinate of the block in the weight (left-top corner) | |
run_test(weights, blocks, weight_indices, block_x, block_y) | |
if __name__ == '__main__': | |
demo_usage() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment