Skip to content

Instantly share code, notes, and snippets.

@leslie-fang-intel
Created May 14, 2024 04:52
Show Gist options
  • Save leslie-fang-intel/37d81441237b5139c8295f5e6c4cd31a to your computer and use it in GitHub Desktop.
Save leslie-fang-intel/37d81441237b5139c8295f5e6c4cd31a to your computer and use it in GitHub Desktop.
# TORCHINDUCTOR_FREEZING=1 TORCH_LOGS="+output_code" numactl -C 56-111 -m 1 python test_softmax.py
import torch
import time
import random
import numpy as np
from torch._inductor import config as inductor_config
# inductor_config.cpp_wrapper = True
local_seed= 2024
torch.manual_seed(local_seed) # Set PyTorch seed
np.random.seed(seed=local_seed) # Set Numpy seed
random.seed(local_seed) # Set the Python seed
class M(torch.nn.Module):
def __init__(self,):
super().__init__()
self.attn_dropout = torch.nn.Dropout(0.1)
def forward(self, attn_weights):
# attn_weights:
# size(4, 12, 1024, 1024)
# stride(12582912, 1048576, 1024, 1)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
return attn_weights
dynamic = True
if __name__ == "__main__":
with torch.no_grad():
m = M().eval()
input = torch.randn(4, 12, 1025, 1024).to(torch.bfloat16)
m(input)
warmup_steps = 100
steps = 1000
# Refer path
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
ref_res = m(input)
for _ in range(warmup_steps):
m(input)
ref_start = time.time()
for _ in range(steps):
m(input)
ref_end = time.time()
# Compiler Path
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
c_m = torch.compile(m, dynamic=dynamic)
inductor_res = c_m(input)
for _ in range(warmup_steps):
c_m(input)
inductor_start = time.time()
for _ in range(steps):
c_m(input)
inductor_end = time.time()
print("ref time is: {}".format(ref_end - ref_start), flush=True)
print("inductor time is: {}".format(inductor_end - inductor_start), flush=True)
print(torch.allclose(ref_res[0], inductor_res[0], atol=0.01, rtol=0.01), flush=True)
print(torch.allclose(ref_res[1], inductor_res[1], atol=0.01, rtol=0.01), flush=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment