Skip to content

Instantly share code, notes, and snippets.

@HDCharles
HDCharles / gist:c6413717039002c2c20b6cd669edba3e
Created May 9, 2023 22:35
triton graph for safe_int_mm
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch import empty_strided, as_strided, device
@HDCharles
HDCharles / gist:17300b0c0e2cd2e7a3e49d546dc9e19a
Created May 10, 2023 00:25
dynamically_quantize_per_tensor triton graph
===== __compiled_fn_21 =====
<eval_with_key>.144 class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
# File: /fsx/users/cdhernandez/protoquant/ao_experimental/quant_primitives.py:29, code: min_val = torch.min(x)
min_1 = torch.min(l_x_)
# File: /fsx/users/cdhernandez/protoquant/ao_experimental/quant_primitives.py:30, code: max_val = torch.max(x)
max_1 = torch.max(l_x_)
@HDCharles
HDCharles / gist:62ecd38aa852a8bc2b658277cf816307
Created May 10, 2023 00:57
dynamically_quantize_per_channel triton graph
===== __compiled_fn_12 =====
<eval_with_key>.199 class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
# File: /fsx/users/cdhernandez/protoquant/ao_experimental/quant_primitives.py:76, code: x2 = x.permute(new_axis_list)
permute = l_x_.permute([3, 1, 2, 0])
# File: /fsx/users/cdhernandez/protoquant/ao_experimental/quant_primitives.py:77, code: x2 = torch.flatten(x2, start_dim = 1)
flatten = torch.flatten(permute, start_dim = 1); permute = None
@HDCharles
HDCharles / gist:43cbcb07f873c89988ec0c020fa764ee
Created May 10, 2023 01:10
dequantize_per_channel triton graph
===== __compiled_fn_12 =====
<eval_with_key>.135 class GraphModule(torch.nn.Module):
def forward(self, L_int_repr_ : torch.Tensor, L_scales_ : torch.Tensor, L_zero_points_ : torch.Tensor):
l_int_repr_ = L_int_repr_
l_scales_ = L_scales_
l_zero_points_ = L_zero_points_
# File: /fsx/users/cdhernandez/protoquant/ao_experimental/quant_primitives.py:134, code: y = int_repr.transpose(-1, axis)
transpose = l_int_repr_.transpose(-1, 2); l_int_repr_ = None
@HDCharles
HDCharles / gist:47449e0dc2512256eb01205dfb8ad144
Created May 10, 2023 01:12
dequantize_per_tensor triton graph
===== __compiled_fn_11 =====
<eval_with_key>.128 class GraphModule(torch.nn.Module):
def forward(self, L_int_repr_ : torch.Tensor):
l_int_repr_ = L_int_repr_
# File: /fsx/users/cdhernandez/protoquant/ao_experimental/quant_primitives.py:118, code: return (int_repr.to(out_dtype) - zero_point) * scale
to = l_int_repr_.to(torch.float64); l_int_repr_ = None
sub = to - 9; to = None
mul = sub * 34.638118489583334; sub = None
return (mul,)
@HDCharles
HDCharles / gist:ad3fc0be203a52cd440ec70ae5e4925a
Last active May 15, 2023 18:17
benchmark_linear with integer math, 32 bit scale/zp and 32 bit sums
shape_x shape_w lin_ms qlin_ms trit_qlin_ms qlin_speedup trit_qlin_speedup matmul_ms trit_matmul_ms trit_matmul_speedup
-------------- -------------- ---------- ---------- -------------- -------------- ------------------- ----------- ---------------- ---------------------
(512, 512) (512, 512) 21.9914 404.657 195.39 0.0543458 0.112551 81.2566 77.3025 1.05115
(512, 512) (512, 2048) 21.9638 390.717 195.653 0.0562142 0.112259 81.7465 77.3369 1.05702
(512, 512) (512, 16384) 50.5481 655.018 218.912 0.0771705 0.230906 391.546 123.41 3.17272
(512, 2048) (2048, 512) 22.2229 388.179 195.835 0.0572492 0.113478 137.091 77.5579 1.7676
(512, 2048) (2048, 2048) 37.0526 408.55
shape_x shape_w lin_ms qlin_ms trit_qlin_ms qlin_speedup trit_qlin_speedup matmul_ms trit_matmul_ms trit_matmul_speedup
-------------- -------------- ---------- ---------- -------------- -------------- ------------------- ----------- ---------------- ---------------------
(512, 512) (512, 512) 29.4059 474.733 293.763 0.061942 0.100101 112.424 138.208 0.813438
(512, 512) (512, 2048) 28.5993 461.46 290.286 0.0619758 0.0985211 111.883 139.265 0.80338
(512, 512) (512, 16384) 50.4579 781.971 292.295 0.0645266 0.172627 496.533 139.407 3.56174
(512, 2048) (2048, 512) 28.4453 472.304 294.685 0.0602266 0.0965278 145.547 138.449 1.05127
(512, 2048) (2048, 2048) 37.0158 476.439
shape_x q_per_tensor_int8_ms trit_q_per_tensor_int8_ms dq_per_tensor_int8_ms trit_dq_per_tensor_int8_ms trit_q_per_tensor_int8_speedup trit_dq_per_tensor_int8_speedup
-------------- ---------------------- --------------------------- ----------------------- ---------------------------- -------------------------------- ---------------------------------
(512, 512) 249.966 96.5174 25.0979 54.0861 2.58986 0.464037
(512, 2048) 258.848 98.001 24.1691 54.1419 2.64128 0.446402
(512, 16384) 666.822 149.243 73.2224 53.4983 4.46801 1.36869
(2048, 512) 260.73 98.24
shape_x shape_w mm_half_ms int_mm_ms safe_int_mm_ms trit_int_mm_ms trit_int_t_mm_ms int_mm_speedup safe_int_mm_speedup int_mm_trit_speedup int_mm_trit_t_speedup
-------------- -------------- ------------ ----------- ---------------- ---------------- ------------------ ---------------- --------------------- --------------------- -----------------------
(512, 512) (512, 512) 13.9634 30.6949 30.7042 59.4251 59.263 0.454909 0.454773 0.234975 0.235618
(512, 512) (512, 2048) 18.0451 30.7888 30.7703 59.1577 59.7134 0.586093 0.586445 0.305034 0.302195
(512, 512) (512, 16384) 48.1365 130.105 130.148 174.705 78.7401 0.369981 0.369859 0.275531 0.611334
(512, 204
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile