Skip to content

Instantly share code, notes, and snippets.

@jjsjann123
Created June 3, 2024 22:08
Show Gist options
  • Save jjsjann123/2c4db9f6659cfe2cc8aa9503cb8a806c to your computer and use it in GitHub Desktop.
Save jjsjann123/2c4db9f6659cfe2cc8aa9503cb8a806c to your computer and use it in GitHub Desktop.
rope_nvfuser_prototype
import torch
from nvfuser import FusionDefinition, DataType
bsz = 2
block_size = 1024
n_head = 16
head_size = 32
rope_n_elem = 8
def rope_fusion(fd: FusionDefinition) -> None:
q = fd.define_tensor(
shape=[bsz, n_head, block_size, head_size],
contiguity=[True, True, True, True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[3, 2, 1, 0],
)
cos = fd.define_tensor(
shape=[block_size, rope_n_elem],
contiguity=[True, True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[1, 0],
)
sin = fd.define_tensor(
shape=[block_size, rope_n_elem],
contiguity=[True, True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[1, 0],
)
offset_0 = rope_n_elem // 2
q_rope = fd.ops.slice(q, start_indices=[0, 0, 0, 0], end_indices=[bsz, n_head, block_size, rope_n_elem], strides=[1, 1, 1, 1])
q_remainder = fd.ops.slice(q, start_indices=[0, 0, 0, rope_n_elem], end_indices=[bsz, n_head, block_size, head_size], strides=[1, 1, 1, 1])
q_remainder = fd.ops.pad(q_remainder, list(reversed([0, 0, 0, 0, 0, 0, 0, rope_n_elem])))
q_left = fd.ops.slice(q_rope, start_indices=[0, 0, 0, 0], end_indices=[bsz, n_head, block_size, offset_0], strides=[1, 1, 1, 1])
q_left = fd.ops.pad(q_left, list(reversed([0, 0, 0, 0, 0, 0, head_size - rope_n_elem, rope_n_elem - offset_0])))
q_right = fd.ops.slice(q_rope, start_indices=[0, 0, 0, offset_0], end_indices=[bsz, n_head, block_size, rope_n_elem], strides=[1, 1, 1, 1])
q_right = fd.ops.pad(q_right, list(reversed([0, 0, 0, 0, 0, 0, head_size - rope_n_elem + offset_0, 0])))
# note that this is identical to q_left and q_right. We should be able to merge it back.
q_left_cos = fd.ops.slice(q_rope, start_indices=[0, 0, 0, 0], end_indices=[bsz, n_head, block_size, offset_0], strides=[1, 1, 1, 1])
q_left_cos = fd.ops.pad(q_left_cos, list(reversed([0, 0, 0, 0, 0, 0, head_size - rope_n_elem + offset_0, 0])))
q_right_cos = fd.ops.slice(q_rope, start_indices=[0, 0, 0, offset_0], end_indices=[bsz, n_head, block_size, rope_n_elem], strides=[1, 1, 1, 1])
q_right_cos = fd.ops.pad(q_right_cos, list(reversed([0, 0, 0, 0, 0, 0, head_size - rope_n_elem, rope_n_elem - offset_0])))
# slice cos/sin
cos_left = fd.ops.slice(cos, start_indices=[0, 0], end_indices=[block_size, offset_0], strides=[1, 1])
cos_left = fd.ops.pad(cos_left, list(reversed([0, 0, head_size - offset_0, 0])))
cos_left = fd.ops.broadcast_in_dim(cos_left, shape=[1, 1, block_size, head_size], broadcast_dims=[2, 3])
cos_right = fd.ops.slice(cos, start_indices=[0, offset_0], end_indices=[block_size, rope_n_elem], strides=[1, 1])
cos_right = fd.ops.pad(cos_right, list(reversed([0, 0, head_size - rope_n_elem, offset_0])))
cos_right = fd.ops.broadcast_in_dim(cos_right, shape=[1, 1, block_size, head_size], broadcast_dims=[2, 3])
sin_left = fd.ops.slice(sin, start_indices=[0, 0], end_indices=[block_size, offset_0], strides=[1, 1])
sin_left = fd.ops.pad(sin_left, list(reversed([0, 0, head_size - offset_0, 0])))
sin_left = fd.ops.broadcast_in_dim(sin_left, shape=[1, 1, block_size, head_size], broadcast_dims=[2, 3])
sin_right = fd.ops.slice(sin, start_indices=[0, offset_0], end_indices=[block_size, rope_n_elem], strides=[1, 1])
sin_right = fd.ops.pad(sin_right, list(reversed([0, 0, head_size - rope_n_elem, offset_0])))
sin_right = fd.ops.broadcast_in_dim(sin_right, shape=[1, 1, block_size, head_size], broadcast_dims=[2, 3])
q0 = (-q_right) * sin_left + cos_left * q_left_cos
q1 = q_left * sin_right + cos_right * q_right_cos
q_out = q0 + q1 + q_remainder
q_out = fd.ops.cast(q_out, dtype=DataType.BFloat16)
q0 = fd.ops.cast(q0, dtype=DataType.BFloat16)
fd.add_output(q_out)
with FusionDefinition() as fd:
rope_fusion(fd)
inputs = [
torch.randn((bsz, n_head, block_size, head_size), dtype=torch.bfloat16, device="cuda:0"),
torch.randn((block_size, rope_n_elem), dtype=torch.bfloat16, device="cuda:0"),
torch.randn((block_size, rope_n_elem), dtype=torch.bfloat16, device="cuda:0"),
]
o = fd.execute(inputs)[0]
def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
head_size = x.size(-1)
x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
roped = (x * cos) + (rotated * sin)
return roped.to(dtype=x.dtype)
def rope_one_entry(x, cos, sin, rope_n_elem):
x_roped = apply_rope(x[..., : rope_n_elem], cos, sin)
return torch.cat((x_roped, x[..., rope_n_elem :]), dim=-1)
#import thunder
#thunder_rope_one = thunder.jit(rope_one_entry, nv_enable_bookend=False)
#o_ref = thunder_rope_one(*inputs, rope_n_elem)
@jjsjann123
Copy link
Author

Note: all slice are segmented out as no-op kernel and the entire operation runs as a single kernel afterwards.

egmented_Fusion Dump: -- fusion segments:
Segmented_Fusion{
groups:
g{35}

g{23}

g{41}

g{29}

g{3}

g{1, 7, 11, 15, 19}

g{5, 9, 13, 17, 21, 25, 26, 27, 31, 32, 33, 37, 38, 39, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64}

edges:
e{ g{35}
 -> g{5, 9, 13, 17, 21, 25, 26, 27, 31, 32, 33, 37, 38, 39, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64}
(T22_g[ iS91{1024}rf, iS93{4}rf ]) }

e{ g{23}
 -> g{5, 9, 13, 17, 21, 25, 26, 27, 31, 32, 33, 37, 38, 39, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64}
(T14_g[ iS63{1024}rf, iS65{4}rf ]) }

e{ g{41}
 -> g{5, 9, 13, 17, 21, 25, 26, 27, 31, 32, 33, 37, 38, 39, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64}
(T26_g[ iS105{1024}rf, iS107{4}rf ]) }

e{ g{29}
 -> g{5, 9, 13, 17, 21, 25, 26, 27, 31, 32, 33, 37, 38, 39, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64}
(T18_g[ iS77{1024}rf, iS79{4}rf ]) }

e{ g{3}
 -> g{5, 9, 13, 17, 21, 25, 26, 27, 31, 32, 33, 37, 38, 39, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64}
(T4_g[ iS13{2}rf, iS14{16}rf, iS15{1024}rf, iS17{24}rf ]) }

e{ g{1, 7, 11, 15, 19}
 -> g{5, 9, 13, 17, 21, 25, 26, 27, 31, 32, 33, 37, 38, 39, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64}
(T10_g[ iS43{2}rf, iS44{16}rf, iS45{1024}rf, iS47{4}rf ]) }

e{ g{1, 7, 11, 15, 19}
 -> g{5, 9, 13, 17, 21, 25, 26, 27, 31, 32, 33, 37, 38, 39, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64}
(T8_g[ iS33{2}rf, iS34{16}rf, iS35{1024}rf, iS37{4}rf ]) }

e{ g{1, 7, 11, 15, 19}
 -> g{5, 9, 13, 17, 21, 25, 26, 27, 31, 32, 33, 37, 38, 39, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64}
(T12_g[ iS53{2}rf, iS54{16}rf, iS55{1024}rf, iS57{4}rf ]) }

e{ g{1, 7, 11, 15, 19}
 -> g{5, 9, 13, 17, 21, 25, 26, 27, 31, 32, 33, 37, 38, 39, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64}
(T6_g[ iS23{2}rf, iS24{16}rf, iS25{1024}rf, iS27{4}rf ]) }


group details:
g{(no_op)
inputs:
T2_g[ iS6{1024}, iS7{8} ] __bfloat
outputs:
T22_g[ iS91{1024}rf, iS93{4}rf ] __bfloat


T22_g[ iS91{1024}rf, iS93{4}rf ]
   = slice( T2_g[ iS6{1024}, iS7{8} ], { {0, 1024, 1} {0, 4, 1} } )
(35)
}

g{(no_op)
inputs:
T1_g[ iS4{1024}, iS5{8} ] __bfloat
outputs:
T14_g[ iS63{1024}rf, iS65{4}rf ] __bfloat


T14_g[ iS63{1024}rf, iS65{4}rf ]
   = slice( T1_g[ iS4{1024}, iS5{8} ], { {0, 1024, 1} {0, 4, 1} } )
(23)
}

g{(no_op)
inputs:
T2_g[ iS6{1024}, iS7{8} ] __bfloat
outputs:
T26_g[ iS105{1024}rf, iS107{4}rf ] __bfloat


T26_g[ iS105{1024}rf, iS107{4}rf ]
   = slice( T2_g[ iS6{1024}, iS7{8} ], { {0, 1024, 1} {4, 8, 1} } )
(41)

g{(no_op)
inputs:
T1_g[ iS4{1024}, iS5{8} ] __bfloat
outputs:
T18_g[ iS77{1024}rf, iS79{4}rf ] __bfloat


T18_g[ iS77{1024}rf, iS79{4}rf ]
   = slice( T1_g[ iS4{1024}, iS5{8} ], { {0, 1024, 1} {4, 8, 1} } )
(29)
}

g{(no_op)
inputs:
T0_g[ iS0{2}, iS1{16}, iS2{1024}, iS3{32} ] __bfloat
outputs:
T4_g[ iS13{2}rf, iS14{16}rf, iS15{1024}rf, iS17{24}rf ] __bfloat


T4_g[ iS13{2}rf, iS14{16}rf, iS15{1024}rf, iS17{24}rf ]
   = slice( T0_g[ iS0{2}, iS1{16}, iS2{1024}, iS3{32} ], { {0, 2, 1} {0, 16, 1} {0, 1024, 1} {8, 32, 1} } )
(3)
}

g{(no_op)
inputs:
T0_g[ iS0{2}, iS1{16}, iS2{1024}, iS3{32} ] __bfloat
outputs:
T6_g[ iS23{2}rf, iS24{16}rf, iS25{1024}rf, iS27{4}rf ] __bfloat
T8_g[ iS33{2}rf, iS34{16}rf, iS35{1024}rf, iS37{4}rf ] __bfloat
T10_g[ iS43{2}rf, iS44{16}rf, iS45{1024}rf, iS47{4}rf ] __bfloat
T12_g[ iS53{2}rf, iS54{16}rf, iS55{1024}rf, iS57{4}rf ] __bfloat


T3_g[ iS8{2}rf, iS9{16}rf, iS10{1024}rf, iS12{8}rf ]
   = slice( T0_g[ iS0{2}, iS1{16}, iS2{1024}, iS3{32} ], { {0, 2, 1} {0, 16, 1} {0, 1024, 1} {0, 8, 1} } )
(1)
T12_g[ iS53{2}rf, iS54{16}rf, iS55{1024}rf, iS57{4}rf ]
   = slice( T3_g[ iS8{2}rf, iS9{16}rf, iS10{1024}rf, iS12{8}rf ], { {0, 2, 1} {0, 16, 1} {0, 1024, 1} {4, 8, 1} } )
(19)
T6_g[ iS23{2}rf, iS24{16}rf, iS25{1024}rf, iS27{4}rf ]
   = slice( T3_g[ iS8{2}rf, iS9{16}rf, iS10{1024}rf, iS12{8}rf ], { {0, 2, 1} {0, 16, 1} {0, 1024, 1} {0, 4, 1} } )
(7)
T10_g[ iS43{2}rf, iS44{16}rf, iS45{1024}rf, iS47{4}rf ]
   = slice( T3_g[ iS8{2}rf, iS9{16}rf, iS10{1024}rf, iS12{8}rf ], { {0, 2, 1} {0, 16, 1} {0, 1024, 1} {0, 4, 1} } )
(15)
T8_g[ iS33{2}rf, iS34{16}rf, iS35{1024}rf, iS37{4}rf ]
   = slice( T3_g[ iS8{2}rf, iS9{16}rf, iS10{1024}rf, iS12{8}rf ], { {0, 2, 1} {0, 16, 1} {0, 1024, 1} {4, 8, 1} } )
(11)
}

g{(pointwise)
inputs:
T4_g[ iS13{2}rf, iS14{16}rf, iS15{1024}rf, iS17{24}rf ] __bfloat
T6_g[ iS23{2}rf, iS24{16}rf, iS25{1024}rf, iS27{4}rf ] __bfloat
T8_g[ iS33{2}rf, iS34{16}rf, iS35{1024}rf, iS37{4}rf ] __bfloat
T10_g[ iS43{2}rf, iS44{16}rf, iS45{1024}rf, iS47{4}rf ] __bfloat
T12_g[ iS53{2}rf, iS54{16}rf, iS55{1024}rf, iS57{4}rf ] __bfloat
T14_g[ iS63{1024}rf, iS65{4}rf ] __bfloat
T18_g[ iS77{1024}rf, iS79{4}rf ] __bfloat
T22_g[ iS91{1024}rf, iS93{4}rf ] __bfloat
T26_g[ iS105{1024}rf, iS107{4}rf ] __bfloat
outputs:
T48_g[ iS191{2}, iS192{16}, iS193{1024}, iS194{32} ] __bfloat


T11_g[ iS48{2}, iS49{16}, iS50{1024}, iS52{32}rf ]
   = pad( T10_g[ iS43{2}rf, iS44{16}rf, iS45{1024}rf, iS47{4}rf ], {0, 0, 0, 0, 0, 0, 0, 28} )
(17)
T35_g[ iS139{2}, iS140{16}, iS141{1024}, iS142{32} ]
   = __bfloat2float(T11_g[ iS48{2}, iS49{16}, iS50{1024}, iS52{32}rf ]);
(51)
T9_g[ iS38{2}, iS39{16}, iS40{1024}, iS42{32}rf ]
   = pad( T8_g[ iS33{2}rf, iS34{16}rf, iS35{1024}rf, iS37{4}rf ], {0, 0, 0, 0, 0, 0, 0, 28} )
(13)
T30_l[ iS119{2}, iS120{16}, iS121{1024}, iS122{32} ]
   = __bfloat2float(T9_g[ iS38{2}, iS39{16}, iS40{1024}, iS42{32}rf ]);
(46)
T31_g[ iS123{2}, iS124{16}, iS125{1024}, iS126{32} ]
   = -T30_l[ iS119{2}, iS120{16}, iS121{1024}, iS122{32} ];
(47)
T23_g[ iS94{1024}, iS96{32}rf ]
   = pad( T22_g[ iS91{1024}rf, iS93{4}rf ], {0, 0, 0, 28} )
(37)
T24_g[ bS97{1}, bS98{1}, iS99{1024}, iS100{32} ]
   = broadcast( T23_g[ iS94{1024}, iS96{32}rf ] )
(38)
T25_g[ bS101{1}, bS102{1}, iS103{1024}, iS104{32} ]
   = Set( T24_g[ bS97{1}, bS98{1}, iS99{1024}, iS100{32} ], cache_op=Streaming )
(39)
T32_g[ bS127{1}, bS128{1}, iS129{1024}, iS130{32} ]
   = __bfloat2float(T25_g[ bS101{1}, bS102{1}, iS103{1024}, iS104{32} ]);
(48)
T15_g[ iS66{1024}, iS68{32}rf ]
   = pad( T14_g[ iS63{1024}rf, iS65{4}rf ], {0, 0, 0, 28} )
(25)
T16_g[ bS69{1}, bS70{1}, iS71{1024}, iS72{32} ]
   = broadcast( T15_g[ iS66{1024}, iS68{32}rf ] )
(26)
T17_g[ bS73{1}, bS74{1}, iS75{1024}, iS76{32} ]
   = Set( T16_g[ bS69{1}, bS70{1}, iS71{1024}, iS72{32} ], cache_op=Streaming )
(27)
T33_l[ iS131{2}, iS132{16}, iS133{1024}, iS134{32} ]
   = T31_g[ iS123{2}, iS124{16}, iS125{1024}, iS126{32} ]
   * T32_g[ bS127{1}, bS128{1}, iS129{1024}, iS130{32} ];
(49)
T34_l[ bS135{1}, bS136{1}, iS137{1024}, iS138{32} ]
   = __bfloat2float(T17_g[ bS73{1}, bS74{1}, iS75{1024}, iS76{32} ]);
(50)
T36_g[ iS143{2}, iS144{16}, iS145{1024}, iS146{32} ]
   = T34_l[ bS135{1}, bS136{1}, iS137{1024}, iS138{32} ]
   * T35_g[ iS139{2}, iS140{16}, iS141{1024}, iS142{32} ];
(52)
T37_g[ iS147{2}, iS148{16}, iS149{1024}, iS150{32} ]
   = T33_l[ iS131{2}, iS132{16}, iS133{1024}, iS134{32} ]
   + T36_g[ iS143{2}, iS144{16}, iS145{1024}, iS146{32} ];
(53)
T19_g[ iS80{1024}, iS82{32}rf ]
   = pad( T18_g[ iS77{1024}rf, iS79{4}rf ], {0, 0, 4, 24} )
(31)
T20_g[ bS83{1}, bS84{1}, iS85{1024}, iS86{32} ]
   = broadcast( T19_g[ iS80{1024}, iS82{32}rf ] )
(32)
T21_g[ bS87{1}, bS88{1}, iS89{1024}, iS90{32} ]
   = Set( T20_g[ bS83{1}, bS84{1}, iS85{1024}, iS86{32} ], cache_op=Streaming )
(33)
T41_l[ bS163{1}, bS164{1}, iS165{1024}, iS166{32} ]
   = __bfloat2float(T21_g[ bS87{1}, bS88{1}, iS89{1024}, iS90{32} ]);
(57)
T13_g[ iS58{2}, iS59{16}, iS60{1024}, iS62{32}rf ]
   = pad( T12_g[ iS53{2}rf, iS54{16}rf, iS55{1024}rf, iS57{4}rf ], {0, 0, 0, 0, 0, 0, 4, 24} )
(21)
T42_g[ iS167{2}, iS168{16}, iS169{1024}, iS170{32} ]
   = __bfloat2float(T13_g[ iS58{2}, iS59{16}, iS60{1024}, iS62{32}rf ]);
(58)
T43_g[ iS171{2}, iS172{16}, iS173{1024}, iS174{32} ]
   = T41_l[ bS163{1}, bS164{1}, iS165{1024}, iS166{32} ]
   * T42_g[ iS167{2}, iS168{16}, iS169{1024}, iS170{32} ];
(59)
T27_g[ iS108{1024}, iS110{32}rf ]
   = pad( T26_g[ iS105{1024}rf, iS107{4}rf ], {0, 0, 4, 24} )
(43)
T28_g[ bS111{1}, bS112{1}, iS113{1024}, iS114{32} ]
   = broadcast( T27_g[ iS108{1024}, iS110{32}rf ] )
(44)
T29_g[ bS115{1}, bS116{1}, iS117{1024}, iS118{32} ]
   = Set( T28_g[ bS111{1}, bS112{1}, iS113{1024}, iS114{32} ], cache_op=Streaming )
(45)
T5_g[ iS18{2}, iS19{16}, iS20{1024}, iS22{32}rf ]
   = pad( T4_g[ iS13{2}rf, iS14{16}rf, iS15{1024}rf, iS17{24}rf ], {0, 0, 0, 0, 0, 0, 8, 0} )
(5)
T46_g[ iS183{2}, iS184{16}, iS185{1024}, iS186{32} ]
   = __bfloat2float(T5_g[ iS18{2}, iS19{16}, iS20{1024}, iS22{32}rf ]);
(62)
T7_l[ iS28{2}, iS29{16}, iS30{1024}, iS32{32}rf ]
   = pad( T6_g[ iS23{2}rf, iS24{16}rf, iS25{1024}rf, iS27{4}rf ], {0, 0, 0, 0, 0, 0, 4, 24} )
(9)
T38_g[ iS151{2}, iS152{16}, iS153{1024}, iS154{32} ]
   = __bfloat2float(T7_l[ iS28{2}, iS29{16}, iS30{1024}, iS32{32}rf ]);
(54)
T39_l[ bS155{1}, bS156{1}, iS157{1024}, iS158{32} ]
   = __bfloat2float(T29_g[ bS115{1}, bS116{1}, iS117{1024}, iS118{32} ]);
(55)
T40_g[ iS159{2}, iS160{16}, iS161{1024}, iS162{32} ]
   = T38_g[ iS151{2}, iS152{16}, iS153{1024}, iS154{32} ]
   * T39_l[ bS155{1}, bS156{1}, iS157{1024}, iS158{32} ];
(56)
T44_g[ iS175{2}, iS176{16}, iS177{1024}, iS178{32} ]
   = T40_g[ iS159{2}, iS160{16}, iS161{1024}, iS162{32} ]
   + T43_g[ iS171{2}, iS172{16}, iS173{1024}, iS174{32} ];
(60)
T45_l[ iS179{2}, iS180{16}, iS181{1024}, iS182{32} ]
   = T37_g[ iS147{2}, iS148{16}, iS149{1024}, iS150{32} ]
   + T44_g[ iS175{2}, iS176{16}, iS177{1024}, iS178{32} ];
(61)
T47_g[ iS187{2}, iS188{16}, iS189{1024}, iS190{32} ]
   = T45_l[ iS179{2}, iS180{16}, iS181{1024}, iS182{32} ]
   + T46_g[ iS183{2}, iS184{16}, iS185{1024}, iS186{32} ];
(63)
T48_g[ iS191{2}, iS192{16}, iS193{1024}, iS194{32} ]
   = __float2bfloat(T47_g[ iS187{2}, iS188{16}, iS189{1024}, iS190{32} ]);
(64)
}

} //Segmented_Fusion

@jjsjann123
Copy link
Author

Generated cuda kernel

======= Codegen output for kernel: nvfuser_pointwise_f0_c1_r0_g6 =======

__global__ void nvfuser_pointwise_f0_c1_r0_g6(Tensor<__bfloat, 4, 4> T10, Tensor<__bfloat, 4, 4> T8, Tensor<__bfloat, 2, 2> T22, Tensor<__bfloat, 2, 2> T14, Tensor<__bfloat, 2, 2> T18, Tensor<__bfloat, 4, 4> T12, Tensor<__bfloat, 2, 2> T26, Tensor<__bfloat, 4, 4> T4, Tensor<__bfloat, 4, 4> T6, Tensor<__bfloat, 4, 4> T48) {
  NVFUSER_DEFINE_MAGIC_ZERO;
  nvfuser_index_t i0;
  i0 = ((nvfuser_index_t)blockIdx.y) / 32;
  nvfuser_index_t i1;
  i1 = 8LL * ((nvfuser_index_t)blockDim.x);
  nvfuser_index_t i2;
  i2 = ((-4 + ((8LL * T26.alloc_stride[0LL]) * ((nvfuser_index_t)threadIdx.x))) + i0) + ((i1 * T26.alloc_stride[0LL]) * ((nvfuser_index_t)blockIdx.x));
  nvfuser_index_t i3;
  i3 = ((nvfuser_index_t)blockIdx.y) % 32;
  nvfuser_index_t i4;
  i4 = (((-4 + ((8LL * T6.alloc_stride[2LL]) * ((nvfuser_index_t)threadIdx.x))) + i0) + ((1024LL * T6.alloc_stride[2LL]) * i3)) + ((i1 * T6.alloc_stride[2LL]) * ((nvfuser_index_t)blockIdx.x));
  nvfuser_index_t i5;
  i5 = (((8LL * T14.alloc_stride[0LL]) * ((nvfuser_index_t)threadIdx.x)) + i0) + ((i1 * T14.alloc_stride[0LL]) * ((nvfuser_index_t)blockIdx.x));
  nvfuser_index_t i6;
  i6 = ((((8LL * T10.alloc_stride[2LL]) * ((nvfuser_index_t)threadIdx.x)) + i0) + ((1024LL * T10.alloc_stride[2LL]) * i3)) + ((i1 * T10.alloc_stride[2LL]) * ((nvfuser_index_t)blockIdx.x));
  nvfuser_index_t i7;
  i7 = ((((8LL * T8.alloc_stride[2LL]) * ((nvfuser_index_t)threadIdx.x)) + i0) + ((1024LL * T8.alloc_stride[2LL]) * i3)) + ((i1 * T8.alloc_stride[2LL]) * ((nvfuser_index_t)blockIdx.x));
  nvfuser_index_t i8;
  i8 = (((8LL * T22.alloc_stride[0LL]) * ((nvfuser_index_t)threadIdx.x)) + i0) + ((i1 * T22.alloc_stride[0LL]) * ((nvfuser_index_t)blockIdx.x));
  nvfuser_index_t i9;
  i9 = (((-8 + ((8LL * T4.alloc_stride[2LL]) * ((nvfuser_index_t)threadIdx.x))) + i0) + ((1024LL * T4.alloc_stride[2LL]) * i3)) + ((i1 * T4.alloc_stride[2LL]) * ((nvfuser_index_t)blockIdx.x));
  nvfuser_index_t i10;
  i10 = ((-4 + ((8LL * T18.alloc_stride[0LL]) * ((nvfuser_index_t)threadIdx.x))) + i0) + ((i1 * T18.alloc_stride[0LL]) * ((nvfuser_index_t)blockIdx.x));
  nvfuser_index_t i11;
  i11 = (((-4 + ((8LL * T12.alloc_stride[2LL]) * ((nvfuser_index_t)threadIdx.x))) + i0) + ((1024LL * T12.alloc_stride[2LL]) * i3)) + ((i1 * T12.alloc_stride[2LL]) * ((nvfuser_index_t)blockIdx.x));
  nvfuser_index_t i12;
  i12 = 8LL * ((nvfuser_index_t)threadIdx.x);
  nvfuser_index_t i13;
  i13 = i1 * ((nvfuser_index_t)blockIdx.x);
  nvfuser_index_t i14;
  i14 = (i12 + (1024LL * ((nvfuser_index_t)blockIdx.y))) + i13;
  bool b15;
  b15 = ((7 + i12) + i13) < 1024;
  if ((((i12 + 7) + i13) < 1024)) {
    Array<__bfloat, 8, 8> T50;
    #pragma unroll
    for(nvfuser_index_t i16 = 0; i16 < 8; ++i16) {
      nvfuser_index_t i17;
      i17 = i16 + nvfuser_zero;
      __bfloat T27[1];
      T27[0] = 0;
      T27[0]
         = ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) < 4)) ? T26[(i2 + (T26.alloc_stride[0LL] * i17))] : 0.0000e+00;
      __bfloat T28[1];
      T28[0]
         = T27[0];
      __bfloat T29[1];
      T29[0]
         = T28[0];
      __bfloat T7[1];
      T7[0] = 0;
      T7[0]
         = ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) < 4)) ? T6[(i4 + (T6.alloc_stride[2LL] * i17))] : 0.0000e+00;
      float T38[1];
      T38[0]
         = __bfloat2float(T7[0]);
      float T39[1];
      T39[0]
         = __bfloat2float(T29[0]);
      float T40[1];
      T40[0]
        = T38[0]
        * T39[0];
      __bfloat T15[1];
      T15[0] = 0;
      T15[0]
         = ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) < 4)) ? T14[(i5 + (T14.alloc_stride[0LL] * i17))] : 0.0000e+00;
      __bfloat T16[1];
      T16[0]
         = T15[0];
      __bfloat T17[1];
      T17[0]
         = T16[0];
      __bfloat T11[1];
      T11[0] = 0;
      T11[0]
         = ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) < 4)) ? T10[(i6 + (T10.alloc_stride[2LL] * i17))] : 0.0000e+00;
      float T35[1];
      T35[0]
         = __bfloat2float(T11[0]);
      float T34[1];
      T34[0]
         = __bfloat2float(T17[0]);
      float T36[1];
      T36[0]
        = T34[0]
        * T35[0];
      __bfloat T9[1];
      T9[0] = 0;
      T9[0]
         = ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) < 4)) ? T8[(i7 + (T8.alloc_stride[2LL] * i17))] : 0.0000e+00;
      float T30[1];
      T30[0]
         = __bfloat2float(T9[0]);
      float T31[1];
      T31[0]
         = -T30[0];
      __bfloat T23[1];
      T23[0] = 0;
      T23[0]
         = ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) < 4)) ? T22[(i8 + (T22.alloc_stride[0LL] * i17))] : 0.0000e+00;
      __bfloat T24[1];
      T24[0]
         = T23[0];
      __bfloat T25[1];
      T25[0]
         = T24[0];
      float T32[1];
      T32[0]
         = __bfloat2float(T25[0]);
      float T33[1];
      T33[0]
        = T31[0]
        * T32[0];
      __bfloat T5[1];
      T5[0] = 0;
      T5[0]
         = ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 8) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 8) < 24)) ? T4[(i9 + (T4.alloc_stride[2LL] * i17))] : 0.0000e+00;
      float T46[1];
      T46[0]
         = __bfloat2float(T5[0]);
      __bfloat T19[1];
      T19[0] = 0;
      T19[0]
         = ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) < 4)) ? T18[(i10 + (T18.alloc_stride[0LL] * i17))] : 0.0000e+00;
      __bfloat T20[1];
      T20[0]
         = T19[0];
      __bfloat T21[1];
      T21[0]
         = T20[0];
      __bfloat T13[1];
      T13[0] = 0;
      T13[0]
         = ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) < 4)) ? T12[(i11 + (T12.alloc_stride[2LL] * i17))] : 0.0000e+00;
      float T42[1];
      T42[0]
         = __bfloat2float(T13[0]);
      float T41[1];
      T41[0]
         = __bfloat2float(T21[0]);
      float T43[1];
      T43[0]
        = T41[0]
        * T42[0];
      float T44[1];
      T44[0]
        = T40[0]
        + T43[0];
      float T37[1];
      T37[0]
        = T33[0]
        + T36[0];
      float T45[1];
      T45[0]
        = T37[0]
        + T44[0];
      float T47[1];
      T47[0]
        = T45[0]
        + T46[0];
      T50[i16]
         = __float2bfloat(T47[0]);
    }
    NVFUSER_UPDATE_MAGIC_ZERO;
    loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i14], &T50[0]);
  } else {
    Array<__bfloat, 8, 8> T50;
    #pragma unroll
    for(nvfuser_index_t i16 = 0; i16 < 8; ++i16) {
      nvfuser_index_t i18;
      i18 = i16 + nvfuser_zero;
      __bfloat T27[1];
      T27[0] = 0;
      if (b15) {
        T27[0]
           = ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) < 4)) ? T26[(i2 + (T26.alloc_stride[0LL] * i18))] : 0.0000e+00;
      }
      __bfloat T28[1];
      T28[0]
         = T27[0];
      __bfloat T29[1];
      T29[0]
         = T28[0];
      __bfloat T7[1];
      T7[0] = 0;
      if (b15) {
        T7[0]
           = ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) < 4)) ? T6[(i4 + (T6.alloc_stride[2LL] * i18))] : 0.0000e+00;
      }
      float T38[1];
      T38[0]
         = __bfloat2float(T7[0]);
      float T39[1];
      T39[0]
         = __bfloat2float(T29[0]);
      float T40[1];
      T40[0]
        = T38[0]
        * T39[0];
      __bfloat T15[1];
      T15[0] = 0;
      if (b15) {
        T15[0]
           = ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) < 4)) ? T14[(i5 + (T14.alloc_stride[0LL] * i18))] : 0.0000e+00;
      }
      __bfloat T16[1];
      T16[0]
         = T15[0];
      __bfloat T17[1];
      T17[0]
         = T16[0];
      __bfloat T11[1];
      T11[0] = 0;
      if (b15) {
        T11[0]
           = ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) < 4)) ? T10[(i6 + (T10.alloc_stride[2LL] * i18))] : 0.0000e+00;
      }
      float T35[1];
      T35[0]
         = __bfloat2float(T11[0]);
      float T34[1];
      T34[0]
         = __bfloat2float(T17[0]);
      float T36[1];
      T36[0]
        = T34[0]
        * T35[0];
      __bfloat T9[1];
      T9[0] = 0;
      if (b15) {
        T9[0]
           = ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) < 4)) ? T8[(i7 + (T8.alloc_stride[2LL] * i18))] : 0.0000e+00;
      }
      float T30[1];
      T30[0]
         = __bfloat2float(T9[0]);
      float T31[1];
      T31[0]
         = -T30[0];
      __bfloat T23[1];
      T23[0] = 0;
      if (b15) {                                         
        T23[0]
           = ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) < 4)) ? T22[(i8 + (T22.alloc_stride[0LL] * i18))] : 0.0000e+00;
      }
      __bfloat T24[1];
      T24[0]
         = T23[0]; 
      __bfloat T25[1];
      T25[0]
         = T24[0];
      float T32[1];
      T32[0]
         = __bfloat2float(T25[0]);
      float T33[1];
      T33[0]
        = T31[0]
        * T32[0];
      __bfloat T5[1];
      T5[0] = 0;
      if (b15) {
        T5[0]
           = ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 8) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 8) < 24)) ? T4[(i9 + (T4.alloc_stride[2LL] * i18))] : 0.0000e+00;
      }
      float T46[1];
      T46[0]
         = __bfloat2float(T5[0]);
      __bfloat T19[1];
      T19[0] = 0;
      if (b15) {
        T19[0]
           = ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) < 4)) ? T18[(i10 + (T18.alloc_stride[0LL] * i18))] : 0.0000e+00;
      }
      __bfloat T20[1];
      T20[0]
         = T19[0];
      __bfloat T21[1];
      T21[0]
         = T20[0];
      __bfloat T13[1];
      T13[0] = 0;
      if (b15) {
        T13[0]
           = ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) < 4)) ? T12[(i11 + (T12.alloc_stride[2LL] * i18))] : 0.0000e+00;
      }
      float T42[1];
      T42[0]
         = __bfloat2float(T13[0]);
      float T41[1];
      T41[0]
         = __bfloat2float(T21[0]);
      float T43[1];
      T43[0]
        = T41[0]
        * T42[0];
      float T44[1];
      T44[0]
        = T40[0]
        + T43[0];
      float T37[1];
      T37[0]
        = T33[0]
        + T36[0];
      float T45[1];
      T45[0]
        = T37[0]
        + T44[0];
      float T47[1];
      T47[0]
        = T45[0]
        + T46[0];
      T50[i16]
         = __float2bfloat(T47[0]);
    }
    NVFUSER_UPDATE_MAGIC_ZERO;
    if (b15) {
      loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i14], &T50[0]);
    }
  }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment