Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save HDCharles/62ecd38aa852a8bc2b658277cf816307 to your computer and use it in GitHub Desktop.
Save HDCharles/62ecd38aa852a8bc2b658277cf816307 to your computer and use it in GitHub Desktop.
dynamically_quantize_per_channel triton graph
===== __compiled_fn_12 =====
<eval_with_key>.199 class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
# File: /fsx/users/cdhernandez/protoquant/ao_experimental/quant_primitives.py:76, code: x2 = x.permute(new_axis_list)
permute = l_x_.permute([3, 1, 2, 0])
# File: /fsx/users/cdhernandez/protoquant/ao_experimental/quant_primitives.py:77, code: x2 = torch.flatten(x2, start_dim = 1)
flatten = torch.flatten(permute, start_dim = 1); permute = None
# File: /fsx/users/cdhernandez/protoquant/ao_experimental/quant_primitives.py:78, code: mins = x2.min(dim=1).values
min_1 = flatten.min(dim = 1)
getitem = min_1[0]; min_1 = None
# File: /fsx/users/cdhernandez/protoquant/ao_experimental/quant_primitives.py:79, code: maxs = x2.max(dim=1).values
max_1 = flatten.max(dim = 1); flatten = None
getitem_2 = max_1[0]; max_1 = None
# File: /fsx/users/cdhernandez/protoquant/ao_experimental/quant_primitives.py:85, code: max_val_pos = torch.max(max_val, -min_val)
neg = -getitem; getitem = None
max_2 = torch.max(getitem_2, neg); getitem_2 = neg = None
# File: /fsx/users/cdhernandez/protoquant/ao_experimental/quant_primitives.py:87, code: scales = 2*max_val_pos.to(torch.float64) / torch.tensor([quant_max - quant_min], device=x.device).to(torch.float64)
to = max_2.to(torch.float64); max_2 = None
mul = 2 * to; to = None
tensor = torch.tensor([255], device = device(type='cuda', index=0))
to_1 = tensor.to(torch.float64); tensor = None
truediv = mul / to_1; mul = to_1 = None
# File: /fsx/users/cdhernandez/protoquant/ao_experimental/quant_primitives.py:89, code: scales = torch.clamp(scales, min=eps)
clamp = torch.clamp(truediv, min = 1.1920928955078125e-07); truediv = None
# File: /fsx/users/cdhernandez/protoquant/ao_experimental/quant_primitives.py:90, code: zero_points = torch.zeros(max_val_pos.size(), dtype=torch.int64, device=x.device)+128+quant_min
zeros = torch.zeros((32,), dtype = torch.int64, device = device(type='cuda', index=0))
add = zeros + 128; zeros = None
add_1 = add + -128; add = None
# File: /fsx/users/cdhernandez/protoquant/ao_experimental/quant_primitives.py:94, code: x_div = x.transpose(axis, -1) / scales
transpose = l_x_.transpose(3, -1); l_x_ = None
truediv_1 = transpose / clamp; transpose = None
# File: /fsx/users/cdhernandez/protoquant/ao_experimental/quant_primitives.py:98, code: x_round = torch.round(x_div)
round_1 = torch.round(truediv_1); truediv_1 = None
# File: /fsx/users/cdhernandez/protoquant/ao_experimental/quant_primitives.py:99, code: x_zp = x_round + zero_points
add_2 = round_1 + add_1; round_1 = None
# File: /fsx/users/cdhernandez/protoquant/ao_experimental/quant_primitives.py:100, code: x_zp = x_zp.transpose(axis, -1)
transpose_1 = add_2.transpose(3, -1); add_2 = None
# File: /fsx/users/cdhernandez/protoquant/ao_experimental/quant_primitives.py:101, code: x_q = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
clamp_1 = torch.clamp(transpose_1, -128, 127); transpose_1 = None
to_2 = clamp_1.to(torch.int8); clamp_1 = None
return (to_2, clamp, add_1)
[2023-05-10 00:56:18,997] torch._dynamo.output_graph.__graph: [DEBUG] TRACED GRAPH
__compiled_fn_12 <eval_with_key>.199 opcode name target args kwargs
------------- ----------- ---------------------------------------------------------- ------------------------ --------------------------------------------------------------
placeholder l_x_ L_x_ () {}
call_method permute permute (l_x_, [3, 1, 2, 0]) {}
call_function flatten <built-in method flatten of type object at 0x7fcb3c9b58a0> (permute,) {'start_dim': 1}
call_method min_1 min (flatten,) {'dim': 1}
call_function getitem <built-in function getitem> (min_1, 0) {}
call_method max_1 max (flatten,) {'dim': 1}
call_function getitem_2 <built-in function getitem> (max_1, 0) {}
call_function neg <built-in function neg> (getitem,) {}
call_function max_2 <built-in method max of type object at 0x7fcb3c9b58a0> (getitem_2, neg) {}
call_method to to (max_2, torch.float64) {}
call_function mul <built-in function mul> (2, to) {}
call_function tensor <built-in method tensor of type object at 0x7fcb3c9b58a0> ([255],) {'device': device(type='cuda', index=0)}
call_method to_1 to (tensor, torch.float64) {}
call_function truediv <built-in function truediv> (mul, to_1) {}
call_function clamp <built-in method clamp of type object at 0x7fcb3c9b58a0> (truediv,) {'min': 1.1920928955078125e-07}
call_function zeros <built-in method zeros of type object at 0x7fcb3c9b58a0> ((32,),) {'dtype': torch.int64, 'device': device(type='cuda', index=0)}
call_function add <built-in function add> (zeros, 128) {}
call_function add_1 <built-in function add> (add, -128) {}
call_method transpose transpose (l_x_, 3, -1) {}
call_function truediv_1 <built-in function truediv> (transpose, clamp) {}
call_function round_1 <built-in method round of type object at 0x7fcb3c9b58a0> (truediv_1,) {}
call_function add_2 <built-in function add> (round_1, add_1) {}
call_method transpose_1 transpose (add_2, 3, -1) {}
call_function clamp_1 <built-in method clamp of type object at 0x7fcb3c9b58a0> (transpose_1, -128, 127) {}
call_method to_2 to (clamp_1, torch.int8) {}
output output output ((to_2, clamp, add_1),) {}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment