Skip to content

Instantly share code, notes, and snippets.

gm is class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "bf16[32768, 50304]"):
# File: /home/shunting/ws/pytorch/test/inductor/test_online_softmax.py:51 in f, code: return torch.softmax(x, dim=-1)
convert_element_type: "f32[32768, 50304]" = torch.ops.prims.convert_element_type.default(arg0_1, torch.float32); arg0_1 = None
amax: "f32[32768, 1]" = torch.ops.aten.amax.default(convert_element_type, [-1], True)
sub: "f32[32768, 50304]" = torch.ops.aten.sub.Tensor(convert_element_type, amax); convert_element_type = amax = None
exp: "f32[32768, 50304]" = torch.ops.aten.exp.default(sub); sub = None
sum_1: "f32[32768, 1]" = torch.ops.aten.sum.dim_IntList(exp, [-1], True)
div: "f32[32768, 50304]" = torch.ops.aten.div.Tensor(exp, sum_1); exp = sum_1 = None
convert_element_type_1: "bf16[32768, 50304]" = torch.ops.prims.convert_element_type.default(div, torch.bfloat16); div = None
diff --git a/profile_gpt2.cu b/profile_gpt2.cu
index fa5e528..b0f08ad 100644
--- a/profile_gpt2.cu
+++ b/profile_gpt2.cu
@@ -42,7 +42,7 @@ int main(int argc, char *argv[]) {
gpt2_init_common(&model);
gpt2_build_from_checkpoint(&model, "gpt2_124M_bf16.bin");
- int B = 24; // if program OOMs decrease this number, e.g. all the way down to 4 or etc
+ int B = 32; // if program OOMs decrease this number, e.g. all the way down to 4 or etc
import sys
sys.path.append("/home/shunting/ws/pytorch/test/inductor")
import torch
from test_torchinductor import check_model_gpu
from torch.testing._internal.common_utils import TestCase, run_tests
class MyTest(TestCase):
def test_fft(self):
args = [
import sys
sys.path.append("/home/shunting/ws/pytorch/test/inductor")
import torch
from test_torchinductor import check_model_gpu
from torch.testing._internal.common_utils import TestCase, run_tests
class MyTest(TestCase):
def test_fft(self):
args = [
/home/shunting/.conda/envs/pytorch/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
warnings.warn(
Compiled module path: /tmp/torchinductor_shunting/wr/cwriulytre5lncfhnfi5h6tcsd5hxyyidu27nsa4bntav5ig2bki.py
Compiled module path: /tmp/torchinductor_shunting/ka/ckaklm6mmv4t62ufhwpchtw3uneqaouzqg5v7hb2nym2pstnjmks.py
Running pytorch 2.7.0a0+git02dd7a7
using device: cuda
total desired batch size: 32768
=> calculated gradient accumulation steps: 1
loading weights from pretrained gpt: gpt2
compiling the model...
Multi-GPU support is disabled. Using a single GPU.
+-----------------------+----------------------------------------------------+
| Parameter | Value |
+-----------------------+----------------------------------------------------+
| train data pattern | dev/data/tinyshakespeare/tiny_shakespeare_train.bin |
| val data pattern | dev/data/tinyshakespeare/tiny_shakespeare_train.bin |
| output log dir | NULL |
| checkpoint_every | 0 |
| resume | 0 |
| micro batch size B | 32 |
# AOT ID: ['0_backward']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
from triton.testing import do_bench
import torch
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers
from torch._inductor.runtime.triton_helpers import math as tl_math
def eager(inp):
inp32 = inp.to(torch.float32)
ref_max = inp32.amax(dim=-1, keepdim=True)
Fatbin elf code:
================
arch = sm_52
code version = [1,7]
host = linux
compile_size = 64bit
Fatbin elf code:
================
diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py
index 8ca1f8b32..3d88b1e5a 100644
--- a/python/triton/compiler/compiler.py
+++ b/python/triton/compiler/compiler.py
@@ -348,6 +348,7 @@ class CompiledKernel:
launch_exit_hook = None
def __init__(self, src, metadata_group, hash):
+ print(metadata_group)
from collections import namedtuple