Created
June 3, 2024 22:08
-
-
Save jjsjann123/2c4db9f6659cfe2cc8aa9503cb8a806c to your computer and use it in GitHub Desktop.
rope_nvfuser_prototype
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from nvfuser import FusionDefinition, DataType | |
bsz = 2 | |
block_size = 1024 | |
n_head = 16 | |
head_size = 32 | |
rope_n_elem = 8 | |
def rope_fusion(fd: FusionDefinition) -> None: | |
q = fd.define_tensor( | |
shape=[bsz, n_head, block_size, head_size], | |
contiguity=[True, True, True, True], | |
dtype=DataType.BFloat16, | |
is_cpu=False, | |
stride_order=[3, 2, 1, 0], | |
) | |
cos = fd.define_tensor( | |
shape=[block_size, rope_n_elem], | |
contiguity=[True, True], | |
dtype=DataType.BFloat16, | |
is_cpu=False, | |
stride_order=[1, 0], | |
) | |
sin = fd.define_tensor( | |
shape=[block_size, rope_n_elem], | |
contiguity=[True, True], | |
dtype=DataType.BFloat16, | |
is_cpu=False, | |
stride_order=[1, 0], | |
) | |
offset_0 = rope_n_elem // 2 | |
q_rope = fd.ops.slice(q, start_indices=[0, 0, 0, 0], end_indices=[bsz, n_head, block_size, rope_n_elem], strides=[1, 1, 1, 1]) | |
q_remainder = fd.ops.slice(q, start_indices=[0, 0, 0, rope_n_elem], end_indices=[bsz, n_head, block_size, head_size], strides=[1, 1, 1, 1]) | |
q_remainder = fd.ops.pad(q_remainder, list(reversed([0, 0, 0, 0, 0, 0, 0, rope_n_elem]))) | |
q_left = fd.ops.slice(q_rope, start_indices=[0, 0, 0, 0], end_indices=[bsz, n_head, block_size, offset_0], strides=[1, 1, 1, 1]) | |
q_left = fd.ops.pad(q_left, list(reversed([0, 0, 0, 0, 0, 0, head_size - rope_n_elem, rope_n_elem - offset_0]))) | |
q_right = fd.ops.slice(q_rope, start_indices=[0, 0, 0, offset_0], end_indices=[bsz, n_head, block_size, rope_n_elem], strides=[1, 1, 1, 1]) | |
q_right = fd.ops.pad(q_right, list(reversed([0, 0, 0, 0, 0, 0, head_size - rope_n_elem + offset_0, 0]))) | |
# note that this is identical to q_left and q_right. We should be able to merge it back. | |
q_left_cos = fd.ops.slice(q_rope, start_indices=[0, 0, 0, 0], end_indices=[bsz, n_head, block_size, offset_0], strides=[1, 1, 1, 1]) | |
q_left_cos = fd.ops.pad(q_left_cos, list(reversed([0, 0, 0, 0, 0, 0, head_size - rope_n_elem + offset_0, 0]))) | |
q_right_cos = fd.ops.slice(q_rope, start_indices=[0, 0, 0, offset_0], end_indices=[bsz, n_head, block_size, rope_n_elem], strides=[1, 1, 1, 1]) | |
q_right_cos = fd.ops.pad(q_right_cos, list(reversed([0, 0, 0, 0, 0, 0, head_size - rope_n_elem, rope_n_elem - offset_0]))) | |
# slice cos/sin | |
cos_left = fd.ops.slice(cos, start_indices=[0, 0], end_indices=[block_size, offset_0], strides=[1, 1]) | |
cos_left = fd.ops.pad(cos_left, list(reversed([0, 0, head_size - offset_0, 0]))) | |
cos_left = fd.ops.broadcast_in_dim(cos_left, shape=[1, 1, block_size, head_size], broadcast_dims=[2, 3]) | |
cos_right = fd.ops.slice(cos, start_indices=[0, offset_0], end_indices=[block_size, rope_n_elem], strides=[1, 1]) | |
cos_right = fd.ops.pad(cos_right, list(reversed([0, 0, head_size - rope_n_elem, offset_0]))) | |
cos_right = fd.ops.broadcast_in_dim(cos_right, shape=[1, 1, block_size, head_size], broadcast_dims=[2, 3]) | |
sin_left = fd.ops.slice(sin, start_indices=[0, 0], end_indices=[block_size, offset_0], strides=[1, 1]) | |
sin_left = fd.ops.pad(sin_left, list(reversed([0, 0, head_size - offset_0, 0]))) | |
sin_left = fd.ops.broadcast_in_dim(sin_left, shape=[1, 1, block_size, head_size], broadcast_dims=[2, 3]) | |
sin_right = fd.ops.slice(sin, start_indices=[0, offset_0], end_indices=[block_size, rope_n_elem], strides=[1, 1]) | |
sin_right = fd.ops.pad(sin_right, list(reversed([0, 0, head_size - rope_n_elem, offset_0]))) | |
sin_right = fd.ops.broadcast_in_dim(sin_right, shape=[1, 1, block_size, head_size], broadcast_dims=[2, 3]) | |
q0 = (-q_right) * sin_left + cos_left * q_left_cos | |
q1 = q_left * sin_right + cos_right * q_right_cos | |
q_out = q0 + q1 + q_remainder | |
q_out = fd.ops.cast(q_out, dtype=DataType.BFloat16) | |
q0 = fd.ops.cast(q0, dtype=DataType.BFloat16) | |
fd.add_output(q_out) | |
with FusionDefinition() as fd: | |
rope_fusion(fd) | |
inputs = [ | |
torch.randn((bsz, n_head, block_size, head_size), dtype=torch.bfloat16, device="cuda:0"), | |
torch.randn((block_size, rope_n_elem), dtype=torch.bfloat16, device="cuda:0"), | |
torch.randn((block_size, rope_n_elem), dtype=torch.bfloat16, device="cuda:0"), | |
] | |
o = fd.execute(inputs)[0] | |
def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: | |
head_size = x.size(-1) | |
x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) | |
x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) | |
rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) | |
roped = (x * cos) + (rotated * sin) | |
return roped.to(dtype=x.dtype) | |
def rope_one_entry(x, cos, sin, rope_n_elem): | |
x_roped = apply_rope(x[..., : rope_n_elem], cos, sin) | |
return torch.cat((x_roped, x[..., rope_n_elem :]), dim=-1) | |
#import thunder | |
#thunder_rope_one = thunder.jit(rope_one_entry, nv_enable_bookend=False) | |
#o_ref = thunder_rope_one(*inputs, rope_n_elem) |
Generated cuda kernel
======= Codegen output for kernel: nvfuser_pointwise_f0_c1_r0_g6 =======
__global__ void nvfuser_pointwise_f0_c1_r0_g6(Tensor<__bfloat, 4, 4> T10, Tensor<__bfloat, 4, 4> T8, Tensor<__bfloat, 2, 2> T22, Tensor<__bfloat, 2, 2> T14, Tensor<__bfloat, 2, 2> T18, Tensor<__bfloat, 4, 4> T12, Tensor<__bfloat, 2, 2> T26, Tensor<__bfloat, 4, 4> T4, Tensor<__bfloat, 4, 4> T6, Tensor<__bfloat, 4, 4> T48) {
NVFUSER_DEFINE_MAGIC_ZERO;
nvfuser_index_t i0;
i0 = ((nvfuser_index_t)blockIdx.y) / 32;
nvfuser_index_t i1;
i1 = 8LL * ((nvfuser_index_t)blockDim.x);
nvfuser_index_t i2;
i2 = ((-4 + ((8LL * T26.alloc_stride[0LL]) * ((nvfuser_index_t)threadIdx.x))) + i0) + ((i1 * T26.alloc_stride[0LL]) * ((nvfuser_index_t)blockIdx.x));
nvfuser_index_t i3;
i3 = ((nvfuser_index_t)blockIdx.y) % 32;
nvfuser_index_t i4;
i4 = (((-4 + ((8LL * T6.alloc_stride[2LL]) * ((nvfuser_index_t)threadIdx.x))) + i0) + ((1024LL * T6.alloc_stride[2LL]) * i3)) + ((i1 * T6.alloc_stride[2LL]) * ((nvfuser_index_t)blockIdx.x));
nvfuser_index_t i5;
i5 = (((8LL * T14.alloc_stride[0LL]) * ((nvfuser_index_t)threadIdx.x)) + i0) + ((i1 * T14.alloc_stride[0LL]) * ((nvfuser_index_t)blockIdx.x));
nvfuser_index_t i6;
i6 = ((((8LL * T10.alloc_stride[2LL]) * ((nvfuser_index_t)threadIdx.x)) + i0) + ((1024LL * T10.alloc_stride[2LL]) * i3)) + ((i1 * T10.alloc_stride[2LL]) * ((nvfuser_index_t)blockIdx.x));
nvfuser_index_t i7;
i7 = ((((8LL * T8.alloc_stride[2LL]) * ((nvfuser_index_t)threadIdx.x)) + i0) + ((1024LL * T8.alloc_stride[2LL]) * i3)) + ((i1 * T8.alloc_stride[2LL]) * ((nvfuser_index_t)blockIdx.x));
nvfuser_index_t i8;
i8 = (((8LL * T22.alloc_stride[0LL]) * ((nvfuser_index_t)threadIdx.x)) + i0) + ((i1 * T22.alloc_stride[0LL]) * ((nvfuser_index_t)blockIdx.x));
nvfuser_index_t i9;
i9 = (((-8 + ((8LL * T4.alloc_stride[2LL]) * ((nvfuser_index_t)threadIdx.x))) + i0) + ((1024LL * T4.alloc_stride[2LL]) * i3)) + ((i1 * T4.alloc_stride[2LL]) * ((nvfuser_index_t)blockIdx.x));
nvfuser_index_t i10;
i10 = ((-4 + ((8LL * T18.alloc_stride[0LL]) * ((nvfuser_index_t)threadIdx.x))) + i0) + ((i1 * T18.alloc_stride[0LL]) * ((nvfuser_index_t)blockIdx.x));
nvfuser_index_t i11;
i11 = (((-4 + ((8LL * T12.alloc_stride[2LL]) * ((nvfuser_index_t)threadIdx.x))) + i0) + ((1024LL * T12.alloc_stride[2LL]) * i3)) + ((i1 * T12.alloc_stride[2LL]) * ((nvfuser_index_t)blockIdx.x));
nvfuser_index_t i12;
i12 = 8LL * ((nvfuser_index_t)threadIdx.x);
nvfuser_index_t i13;
i13 = i1 * ((nvfuser_index_t)blockIdx.x);
nvfuser_index_t i14;
i14 = (i12 + (1024LL * ((nvfuser_index_t)blockIdx.y))) + i13;
bool b15;
b15 = ((7 + i12) + i13) < 1024;
if ((((i12 + 7) + i13) < 1024)) {
Array<__bfloat, 8, 8> T50;
#pragma unroll
for(nvfuser_index_t i16 = 0; i16 < 8; ++i16) {
nvfuser_index_t i17;
i17 = i16 + nvfuser_zero;
__bfloat T27[1];
T27[0] = 0;
T27[0]
= ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) < 4)) ? T26[(i2 + (T26.alloc_stride[0LL] * i17))] : 0.0000e+00;
__bfloat T28[1];
T28[0]
= T27[0];
__bfloat T29[1];
T29[0]
= T28[0];
__bfloat T7[1];
T7[0] = 0;
T7[0]
= ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) < 4)) ? T6[(i4 + (T6.alloc_stride[2LL] * i17))] : 0.0000e+00;
float T38[1];
T38[0]
= __bfloat2float(T7[0]);
float T39[1];
T39[0]
= __bfloat2float(T29[0]);
float T40[1];
T40[0]
= T38[0]
* T39[0];
__bfloat T15[1];
T15[0] = 0;
T15[0]
= ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) < 4)) ? T14[(i5 + (T14.alloc_stride[0LL] * i17))] : 0.0000e+00;
__bfloat T16[1];
T16[0]
= T15[0];
__bfloat T17[1];
T17[0]
= T16[0];
__bfloat T11[1];
T11[0] = 0;
T11[0]
= ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) < 4)) ? T10[(i6 + (T10.alloc_stride[2LL] * i17))] : 0.0000e+00;
float T35[1];
T35[0]
= __bfloat2float(T11[0]);
float T34[1];
T34[0]
= __bfloat2float(T17[0]);
float T36[1];
T36[0]
= T34[0]
* T35[0];
__bfloat T9[1];
T9[0] = 0;
T9[0]
= ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) < 4)) ? T8[(i7 + (T8.alloc_stride[2LL] * i17))] : 0.0000e+00;
float T30[1];
T30[0]
= __bfloat2float(T9[0]);
float T31[1];
T31[0]
= -T30[0];
__bfloat T23[1];
T23[0] = 0;
T23[0]
= ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) < 4)) ? T22[(i8 + (T22.alloc_stride[0LL] * i17))] : 0.0000e+00;
__bfloat T24[1];
T24[0]
= T23[0];
__bfloat T25[1];
T25[0]
= T24[0];
float T32[1];
T32[0]
= __bfloat2float(T25[0]);
float T33[1];
T33[0]
= T31[0]
* T32[0];
__bfloat T5[1];
T5[0] = 0;
T5[0]
= ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 8) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 8) < 24)) ? T4[(i9 + (T4.alloc_stride[2LL] * i17))] : 0.0000e+00;
float T46[1];
T46[0]
= __bfloat2float(T5[0]);
__bfloat T19[1];
T19[0] = 0;
T19[0]
= ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) < 4)) ? T18[(i10 + (T18.alloc_stride[0LL] * i17))] : 0.0000e+00;
__bfloat T20[1];
T20[0]
= T19[0];
__bfloat T21[1];
T21[0]
= T20[0];
__bfloat T13[1];
T13[0] = 0;
T13[0]
= ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) < 4)) ? T12[(i11 + (T12.alloc_stride[2LL] * i17))] : 0.0000e+00;
float T42[1];
T42[0]
= __bfloat2float(T13[0]);
float T41[1];
T41[0]
= __bfloat2float(T21[0]);
float T43[1];
T43[0]
= T41[0]
* T42[0];
float T44[1];
T44[0]
= T40[0]
+ T43[0];
float T37[1];
T37[0]
= T33[0]
+ T36[0];
float T45[1];
T45[0]
= T37[0]
+ T44[0];
float T47[1];
T47[0]
= T45[0]
+ T46[0];
T50[i16]
= __float2bfloat(T47[0]);
}
NVFUSER_UPDATE_MAGIC_ZERO;
loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i14], &T50[0]);
} else {
Array<__bfloat, 8, 8> T50;
#pragma unroll
for(nvfuser_index_t i16 = 0; i16 < 8; ++i16) {
nvfuser_index_t i18;
i18 = i16 + nvfuser_zero;
__bfloat T27[1];
T27[0] = 0;
if (b15) {
T27[0]
= ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) < 4)) ? T26[(i2 + (T26.alloc_stride[0LL] * i18))] : 0.0000e+00;
}
__bfloat T28[1];
T28[0]
= T27[0];
__bfloat T29[1];
T29[0]
= T28[0];
__bfloat T7[1];
T7[0] = 0;
if (b15) {
T7[0]
= ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) < 4)) ? T6[(i4 + (T6.alloc_stride[2LL] * i18))] : 0.0000e+00;
}
float T38[1];
T38[0]
= __bfloat2float(T7[0]);
float T39[1];
T39[0]
= __bfloat2float(T29[0]);
float T40[1];
T40[0]
= T38[0]
* T39[0];
__bfloat T15[1];
T15[0] = 0;
if (b15) {
T15[0]
= ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) < 4)) ? T14[(i5 + (T14.alloc_stride[0LL] * i18))] : 0.0000e+00;
}
__bfloat T16[1];
T16[0]
= T15[0];
__bfloat T17[1];
T17[0]
= T16[0];
__bfloat T11[1];
T11[0] = 0;
if (b15) {
T11[0]
= ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) < 4)) ? T10[(i6 + (T10.alloc_stride[2LL] * i18))] : 0.0000e+00;
}
float T35[1];
T35[0]
= __bfloat2float(T11[0]);
float T34[1];
T34[0]
= __bfloat2float(T17[0]);
float T36[1];
T36[0]
= T34[0]
* T35[0];
__bfloat T9[1];
T9[0] = 0;
if (b15) {
T9[0]
= ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) < 4)) ? T8[(i7 + (T8.alloc_stride[2LL] * i18))] : 0.0000e+00;
}
float T30[1];
T30[0]
= __bfloat2float(T9[0]);
float T31[1];
T31[0]
= -T30[0];
__bfloat T23[1];
T23[0] = 0;
if (b15) {
T23[0]
= ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 0) < 4)) ? T22[(i8 + (T22.alloc_stride[0LL] * i18))] : 0.0000e+00;
}
__bfloat T24[1];
T24[0]
= T23[0];
__bfloat T25[1];
T25[0]
= T24[0];
float T32[1];
T32[0]
= __bfloat2float(T25[0]);
float T33[1];
T33[0]
= T31[0]
* T32[0];
__bfloat T5[1];
T5[0] = 0;
if (b15) {
T5[0]
= ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 8) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 8) < 24)) ? T4[(i9 + (T4.alloc_stride[2LL] * i18))] : 0.0000e+00;
}
float T46[1];
T46[0]
= __bfloat2float(T5[0]);
__bfloat T19[1];
T19[0] = 0;
if (b15) {
T19[0]
= ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) < 4)) ? T18[(i10 + (T18.alloc_stride[0LL] * i18))] : 0.0000e+00;
}
__bfloat T20[1];
T20[0]
= T19[0];
__bfloat T21[1];
T21[0]
= T20[0];
__bfloat T13[1];
T13[0] = 0;
if (b15) {
T13[0]
= ((((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) >= 0) && (((((nvfuser_index_t)blockIdx.y) / (2 * 16)) - 4) < 4)) ? T12[(i11 + (T12.alloc_stride[2LL] * i18))] : 0.0000e+00;
}
float T42[1];
T42[0]
= __bfloat2float(T13[0]);
float T41[1];
T41[0]
= __bfloat2float(T21[0]);
float T43[1];
T43[0]
= T41[0]
* T42[0];
float T44[1];
T44[0]
= T40[0]
+ T43[0];
float T37[1];
T37[0]
= T33[0]
+ T36[0];
float T45[1];
T45[0]
= T37[0]
+ T44[0];
float T47[1];
T47[0]
= T45[0]
+ T46[0];
T50[i16]
= __float2bfloat(T47[0]);
}
NVFUSER_UPDATE_MAGIC_ZERO;
if (b15) {
loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i14], &T50[0]);
}
}
}
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Note: all slice are segmented out as no-op kernel and the entire operation runs as a single kernel afterwards.