This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@helion.jit(config=helion.Config(static_ranges=[True])) | |
def kernel_8( | |
a_shared_tuple: tuple[torch.Tensor, ...], | |
): | |
out = torch.empty_like(a_shared_tuple[0]) | |
N = out.size(0) | |
for tile_n in hl.tile(N): | |
acc = torch.zeros([tile_n], dtype=torch.float32, device=out.device) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
While executing %u8 : [num_users=1] = call_function[target=helion.language._tracing_ops._get_symnode](args = (u8,), kwargs = {}) | |
GraphModule: class GraphModule(torch.nn.Module): | |
def forward(self): | |
# File: /home/joydong/helion/test/test_signal_wait.py:67 in test_block_size_access, code: block_size = tile.block_size | |
block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0') | |
tile_block_size: "Sym(u10)" = helion_language_tiles_tile_block_size(block_size_0) | |
# File: /home/joydong/helion/test/test_signal_wait.py:68 in test_block_size_access, code: out[tile] = block_size | |
out: "i32[64][1]" = helion_language__tracing_ops__host_tensor('out') | |
store = helion_language_memory_ops_store(out, [block_size_0], tile_block_size, None); block_size_0 = store = None |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// | |
// Generated by LLVM NVPTX Back-End | |
// | |
.version 8.4 | |
.target sm_90a | |
.address_size 64 | |
// .globl wait_kernel // -- Begin function wait_kernel | |
// @wait_kernel |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import triton | |
import triton.language as tl | |
@triton.jit | |
def atomic_cas_test(signal_board_ptr, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr): | |
index = tl.arange(0, BLOCK_SIZE) | |
addrs = signal_board_ptr + BLOCK_SIZE * tl.program_id(0) + index |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
--- a/nvshmem_src/src/include/non_abi/device/common/nvshmemi_common_device.cuh | |
+++ b/nvshmem_src/src/include/non_abi/device/common/nvshmemi_common_device.cuh | |
@@ -284,8 +284,13 @@ | |
if (0 == len) return; | |
+#ifdef __clang_llvm_bitcode_lib__ | |
+ dst = (void *)(dst_p + nelems * 4); | |
+ src = (void *)(src_p + nelems * 4); | |
+#else |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* @brief start from end index of the chain, find the location of min score on | |
* the chain until anchor has no predecessor OR anchor is in another chain | |
* | |
* @param max_drop | |
* @param z [in] {sc, anchor idx}, sorted by sc | |
* @param f [in] score | |
* @param p [in] predecessor | |
* @param k [in] chain end index |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch._inductor.runtime.benchmarking import benchmarker | |
from typing import Callable, Dict, List, Optional, Tuple, Union | |
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float: | |
# warmup | |
for _ in range(5): | |
func(*args, **kwargs) | |
return benchmarker.benchmark_gpu(lambda: func(*args, **kwargs)) * 1e3 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.nn.attention.experimental._paged_attention import PagedAttention | |
from torch.nn.attention.flex_attention import ( | |
_identity, | |
create_block_mask, | |
flex_attention, | |
noop_mask, | |
) | |
from functools import partial | |
from typing import Callable |