Created
June 25, 2024 07:04
-
-
Save shunting314/a1ff2b88a54bda5644effc0c216c0b88 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.persistent_reduction( | |
size_hints=[131072, 64], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*i64', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp16', 7: '*fp32', 8: '*fp32', 9: 'i32', 10: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__softmax__softmax_backward_data__to_copy_151', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 6, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': True, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.041178776} | |
) | |
@triton.jit | |
def triton_per_fused__softmax__softmax_backward_data__to_copy_151(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr): | |
rnumel = 49 | |
RBLOCK: tl.constexpr = 64 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rindex = tl.arange(0, RBLOCK)[None, :] | |
roffset = 0 | |
rmask = rindex < rnumel | |
r3 = rindex | |
x4 = xindex | |
x0 = xindex % 49 | |
x1 = (xindex // 49) % 4 | |
x2 = (xindex // 196) | |
tmp0 = tl.load(in_ptr0 + (r3 + (49*x4)), rmask & xmask, other=0.0).to(tl.float32) | |
tmp2 = tl.load(in_ptr1 + (r3 + (49*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0) | |
tmp10 = tl.load(in_ptr3 + (r3 + (49*x0) + (2401*x2)), rmask & xmask, eviction_policy='evict_last', other=0.0) | |
tmp12 = tl.load(in_ptr4 + (x4), xmask, eviction_policy='evict_last') | |
tmp15 = tl.load(in_ptr5 + (x4), xmask, eviction_policy='evict_last') | |
tmp17 = tl.load(in_ptr6 + (r3 + (49*x4)), rmask & xmask, other=0.0).to(tl.float32) | |
tmp1 = tmp0.to(tl.float32) | |
tmp3 = tl.full([XBLOCK, RBLOCK], 169, tl.int32) | |
tmp4 = tmp2 + tmp3 | |
tmp5 = tmp2 < 0 | |
tmp6 = tl.where(tmp5, tmp4, tmp2) | |
tl.device_assert(((0 <= tmp6) & (tmp6 < 169)) | ~(rmask & xmask), "index out of bounds: 0 <= tmp6 < 169") | |
tmp8 = tl.load(in_ptr2 + (x1 + (4*tmp6)), rmask & xmask, eviction_policy='evict_last') | |
tmp9 = tmp1 + tmp8 | |
tmp11 = tmp9 + tmp10 | |
tmp13 = tmp11 - tmp12 | |
tmp14 = tl_math.exp(tmp13) | |
tmp16 = tmp14 / tmp15 | |
tmp18 = tmp17.to(tl.float32) | |
tmp19 = tmp18 * tmp16 | |
tmp20 = tl.broadcast_to(tmp19, [XBLOCK, RBLOCK]) | |
tmp22 = tl.where(rmask & xmask, tmp20, 0) | |
tmp23 = tl.sum(tmp22, 1)[:, None] | |
tl.store(out_ptr0 + (r3 + (49*x4)), tmp16, rmask & xmask) | |
tl.store(out_ptr1 + (x4), tmp23, xmask) | |
def get_args(): | |
arg_0 = rand_strided((2048, 49, 49), (2401, 49, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((2401,), (1,), device='cuda:0', dtype=torch.int64) | |
arg_2 = rand_strided((169, 4), (4, 1), device='cuda:0', dtype=torch.float32) | |
arg_3 = rand_strided((64, 49, 49), (2401, 49, 1), device='cuda:0', dtype=torch.float32) | |
arg_4 = rand_strided((512, 4, 49, 1), (196, 49, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_5 = rand_strided((512, 4, 49, 1), (196, 49, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_6 = rand_strided((2048, 49, 49), (2401, 49, 1), device='cuda:0', dtype=torch.float16) | |
arg_7 = rand_strided((512, 4, 49, 49), (9604, 2401, 49, 1), device='cuda:0', dtype=torch.float32) | |
arg_8 = rand_strided((512, 4, 49, 1), (196, 49, 1, 100352), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_per_fused__softmax__softmax_backward_data__to_copy_151.run(*args, 100352, 49, grid=grid(100352), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_per_fused__softmax__softmax_backward_data__to_copy_151.benchmark_all_configs(*args, 100352, 49, grid=grid(100352)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.041178776 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment