Skip to content

Instantly share code, notes, and snippets.

View joydddd's full-sized avatar

Joy Juechu Dong joydddd

View GitHub Profile
@joydddd
joydddd / iterator_index.py
Last active July 15, 2025 19:14
Tuple in Helion
@helion.jit(config=helion.Config(static_ranges=[True]))
def kernel_8(
a_shared_tuple: tuple[torch.Tensor, ...],
):
out = torch.empty_like(a_shared_tuple[0])
N = out.size(0)
for tile_n in hl.tile(N):
acc = torch.zeros([tile_n], dtype=torch.float32, device=out.device)
@joydddd
joydddd / graph.py
Created June 13, 2025 21:37
Graph for tile.block_size + atomic_add
While executing %u8 : [num_users=1] = call_function[target=helion.language._tracing_ops._get_symnode](args = (u8,), kwargs = {})
GraphModule: class GraphModule(torch.nn.Module):
def forward(self):
# File: /home/joydong/helion/test/test_signal_wait.py:67 in test_block_size_access, code: block_size = tile.block_size
block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0')
tile_block_size: "Sym(u10)" = helion_language_tiles_tile_block_size(block_size_0)
# File: /home/joydong/helion/test/test_signal_wait.py:68 in test_block_size_access, code: out[tile] = block_size
out: "i32[64][1]" = helion_language__tracing_ops__host_tensor('out')
store = helion_language_memory_ops_store(out, [block_size_0], tile_block_size, None); block_size_0 = store = None
@joydddd
joydddd / barrier.ptx
Last active June 9, 2025 21:21
Vectorized Triton
//
// Generated by LLVM NVPTX Back-End
//
.version 8.4
.target sm_90a
.address_size 64
// .globl wait_kernel // -- Begin function wait_kernel
// @wait_kernel
@joydddd
joydddd / test_atomic.py
Last active June 4, 2025 19:22
MLIR Error for Triton tensor `atomic_cas`
import torch
import triton
import triton.language as tl
@triton.jit
def atomic_cas_test(signal_board_ptr, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr):
index = tl.arange(0, BLOCK_SIZE)
addrs = signal_board_ptr + BLOCK_SIZE * tl.program_id(0) + index
@joydddd
joydddd / bitcode.patch
Last active June 2, 2025 19:38
nvshmem patch
--- a/nvshmem_src/src/include/non_abi/device/common/nvshmemi_common_device.cuh
+++ b/nvshmem_src/src/include/non_abi/device/common/nvshmemi_common_device.cuh
@@ -284,8 +284,13 @@
if (0 == len) return;
+#ifdef __clang_llvm_bitcode_lib__
+ dst = (void *)(dst_p + nelems * 4);
+ src = (void *)(src_p + nelems * 4);
+#else
@joydddd
joydddd / hit.c
Last active March 13, 2025 20:16
Annotated chaining backtracking
/**
* @brief start from end index of the chain, find the location of min score on
* the chain until anchor has no predecessor OR anchor is in another chain
*
* @param max_drop
* @param z [in] {sc, anchor idx}, sorted by sc
* @param f [in] score
* @param p [in] predecessor
* @param k [in] chain end index
import torch
from torch._inductor.runtime.benchmarking import benchmarker
from typing import Callable, Dict, List, Optional, Tuple, Union
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
# warmup
for _ in range(5):
func(*args, **kwargs)
return benchmarker.benchmark_gpu(lambda: func(*args, **kwargs)) * 1e3
@joydddd
joydddd / paged_attn_bench.py
Last active October 29, 2024 14:10
PagedFlexAttn Benchmark
import torch
from torch.nn.attention.experimental._paged_attention import PagedAttention
from torch.nn.attention.flex_attention import (
_identity,
create_block_mask,
flex_attention,
noop_mask,
)
from functools import partial
from typing import Callable