Skip to content

Instantly share code, notes, and snippets.

import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
import torch
import torch._inductor.config as inductor_config
inductor_config.optimize_scatter_upon_const_tensor = False
torch.set_default_device("cuda")
M, N = 1024, 2048
x = torch.randint(0, N, (M,), dtype=torch.int64)
# AOT ID: ['0_backward']
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
# AOT ID: ['0_backward']
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
This file has been truncated, but you can view the full file.
<!DOCTYPE html>
<html>
<head>
</head>
<body>
<script type="module">
import {add_local_files} from "https://cdn.jsdelivr.net/gh/pytorch/pytorch@main/torch/utils/viz/MemoryViz.js"
const local_files = [{"name": "snapshot.pickle", "base64": "gASVAgABAAAAAAB9lCiMCHNlZ21lbnRzlF2UKH2UKIwGZGV2aWNllEsAjAdhZGRyZXNzlIoGAAAA3Kl/jAp0b3RhbF9zaXpllIoFAADAiAGMDmFsbG9jYXRlZF9zaXpllEsAjAthY3RpdmVfc2l6ZZRLAIwOcmVxdWVzdGVkX3NpemWUSwCMBnN0cmVhbZRLAIwMc2VnbWVudF90eXBllIwFbGFyZ2WUjA9zZWdtZW50X3Bvb2xfaWSUSwBLAIaUjA1pc19leHBhbmRhYmxllImMBmZyYW1lc5RdlIwGYmxvY2tzlF2UfZQoaAWKBgAAANypf4wEc2l6ZZSKBQAAwIgBaAlKAACQAIwFc3RhdGWUjAhpbmFjdGl2ZZRoEGgRdWF1fZQoaARLAGgFigYAAGBlq39oBkoAACAAaAdLAGgISwBoCUsAaApLAGgLjAVzbWFsbJRoDUsASwCGlGgPiWgQaBFoEl2UfZQoaAWKBgAAYGWrf2gVSgAAIABoCU0AMGgWaBdoEGgRdWF1fZQoaARLAGgFigYAAAB4q39oBooFAADAiAFoB0sAaAhLAGgJSwBoCksAaAtoDGgNSwBLAIaUaA+JaBBoEWgSXZR9lChoBYoGAAAAeKt/aBWKBQAAwIgBaAmKBbD/f8QAaBZoF2gQaBF1YXV9lChoBEsAaAWKBgAAAAStf2gGigUAAMCIAWgHSwBoCEsAaAlLAGgKSwBoC2gMaA1LAEsAhpRoD4loEl2UfZQoaAWKBgAAAAStf2gVigUAAMCIAWgJi
# AOT ID: ['0_backward']
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
class GraphModule(torch.nn.Module):
def forward(self, primals_3: "f32[768]", primals_9: "f32[768]", primals_15: "f32[768]", primals_21: "f32[768]", primals_27: "f32[768]", primals_33: "f32[768]", primals_39: "f32[768]", primals_45: "f32[768]", primals_51: "f32[768]", primals_57: "f32[768]", primals_63: "f32[768]", primals_69: "f32[768]", primals_75: "f32[768]", primals_81: "f32[768]", primals_87: "f32[768]", primals_93: "f32[768]", primals_99: "f32[768]", primals_105: "f32[768]", primals_111: "f32[768]", primals_117: "f32[768]", primals_123: "f32[768]", primals_129: "f32[768]", primals_135: "f32[768]", primals_141: "f32[768]", primals_147: "f32[768]", primals_150: "i64[32, 1024]", primals_151: "i64[32, 1024]", iota: "i64[1024]", embedding: "f32[32, 1024, 768]", embedding_1: "f32[1024, 768]", getitem_1: "f32[32, 1024, 1]", rsqrt: "f32[32, 1024, 1]", view: "bf16[32768, 768]", permute_1: "bf16[32, 12, 1024, 64]", permute_2: "bf16[32, 12, 1024, 64]", permute_3: "bf16[32, 12, 1024, 64]", getitem_5: "bf16[32, 1
import torch
aten = torch.ops.aten
prims = torch.ops.prims
def fuse_scatter_upon_allzero(graph):
return # TODO
for cur_node in graph.nodes:
if cur_node.op != "call_function":
continue
import torch
from triton.testing import do_bench
import torch._inductor.config as inductor_config
from torch import nn
import copy
inductor_config.benchmark_kernel = True
inductor_config.triton.unique_kernel_names = True
torch.set_default_device("cuda")