Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save shunting314/75c161368a833a535bd0d240b8099d7e to your computer and use it in GitHub Desktop.
Save shunting314/75c161368a833a535bd0d240b8099d7e to your computer and use it in GitHub Desktop.
buf1: SchedulerNode(ComputedBuffer)
buf1.writes = [MemoryDep('buf1', c0, {c0: 2097152})]
buf1.unmet_dependencies = []
buf1.met_dependencies = [MemoryDep('arg1_1', c0, {c0: 2097152})]
buf1.users = [NodeUser(node=SchedulerNode(name='buf2'), can_inplace=False, is_weak=False)]
buf1.group.device = cuda:0
buf1.group.iteration = (2097152, 1)
buf1.sizes = ([2097152], [])
arg1_1_layout = FixedLayout('cuda', torch.float32, size=[2048, 1024], stride=[1024, 1])
buf1_layout = FixedLayout('cuda', torch.float32, size=[2048, 1024], stride=[1024, 1])
class buf1_loop_body:
var_ranges = {z0: 2097152}
index0 = z0
def body(self, ops):
get_index = self.get_index('index0')
load = ops.load('arg1_1', get_index)
constant = ops.constant(3.0, torch.float32)
mul = ops.mul(load, constant)
get_index_1 = self.get_index('index0')
store = ops.store('buf1', get_index_1, mul, None)
return store
buf1 Triton code:
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor
@triton_heuristics.pointwise(
size_hints=[2097152],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '918f7a18d5a3802eb6e9bef806b53d18de2f20e5e789828bf19d422fc9d53f3f', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 2097152
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None)
tmp1 = 3.0
tmp2 = tmp0 * tmp1
tl.store(out_ptr0 + (x0), tmp2, None)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment