Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save shunting314/a1ff2b88a54bda5644effc0c216c0b88 to your computer and use it in GitHub Desktop.
Save shunting314/a1ff2b88a54bda5644effc0c216c0b88 to your computer and use it in GitHub Desktop.
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.persistent_reduction(
size_hints=[131072, 64],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*i64', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp16', 7: '*fp32', 8: '*fp32', 9: 'i32', 10: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__softmax__softmax_backward_data__to_copy_151', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 6, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': True, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.041178776}
)
@triton.jit
def triton_per_fused__softmax__softmax_backward_data__to_copy_151(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr):
rnumel = 49
RBLOCK: tl.constexpr = 64
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[None, :]
roffset = 0
rmask = rindex < rnumel
r3 = rindex
x4 = xindex
x0 = xindex % 49
x1 = (xindex // 49) % 4
x2 = (xindex // 196)
tmp0 = tl.load(in_ptr0 + (r3 + (49*x4)), rmask & xmask, other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr1 + (r3 + (49*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0)
tmp10 = tl.load(in_ptr3 + (r3 + (49*x0) + (2401*x2)), rmask & xmask, eviction_policy='evict_last', other=0.0)
tmp12 = tl.load(in_ptr4 + (x4), xmask, eviction_policy='evict_last')
tmp15 = tl.load(in_ptr5 + (x4), xmask, eviction_policy='evict_last')
tmp17 = tl.load(in_ptr6 + (r3 + (49*x4)), rmask & xmask, other=0.0).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp3 = tl.full([XBLOCK, RBLOCK], 169, tl.int32)
tmp4 = tmp2 + tmp3
tmp5 = tmp2 < 0
tmp6 = tl.where(tmp5, tmp4, tmp2)
tl.device_assert(((0 <= tmp6) & (tmp6 < 169)) | ~(rmask & xmask), "index out of bounds: 0 <= tmp6 < 169")
tmp8 = tl.load(in_ptr2 + (x1 + (4*tmp6)), rmask & xmask, eviction_policy='evict_last')
tmp9 = tmp1 + tmp8
tmp11 = tmp9 + tmp10
tmp13 = tmp11 - tmp12
tmp14 = tl_math.exp(tmp13)
tmp16 = tmp14 / tmp15
tmp18 = tmp17.to(tl.float32)
tmp19 = tmp18 * tmp16
tmp20 = tl.broadcast_to(tmp19, [XBLOCK, RBLOCK])
tmp22 = tl.where(rmask & xmask, tmp20, 0)
tmp23 = tl.sum(tmp22, 1)[:, None]
tl.store(out_ptr0 + (r3 + (49*x4)), tmp16, rmask & xmask)
tl.store(out_ptr1 + (x4), tmp23, xmask)
def get_args():
arg_0 = rand_strided((2048, 49, 49), (2401, 49, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((2401,), (1,), device='cuda:0', dtype=torch.int64)
arg_2 = rand_strided((169, 4), (4, 1), device='cuda:0', dtype=torch.float32)
arg_3 = rand_strided((64, 49, 49), (2401, 49, 1), device='cuda:0', dtype=torch.float32)
arg_4 = rand_strided((512, 4, 49, 1), (196, 49, 1, 1), device='cuda:0', dtype=torch.float32)
arg_5 = rand_strided((512, 4, 49, 1), (196, 49, 1, 1), device='cuda:0', dtype=torch.float32)
arg_6 = rand_strided((2048, 49, 49), (2401, 49, 1), device='cuda:0', dtype=torch.float16)
arg_7 = rand_strided((512, 4, 49, 49), (9604, 2401, 49, 1), device='cuda:0', dtype=torch.float32)
arg_8 = rand_strided((512, 4, 49, 1), (196, 49, 1, 100352), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_per_fused__softmax__softmax_backward_data__to_copy_151.run(*args, 100352, 49, grid=grid(100352), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_per_fused__softmax__softmax_backward_data__to_copy_151.benchmark_all_configs(*args, 100352, 49, grid=grid(100352))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.041178776
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment