Created
February 7, 2025 01:08
-
-
Save protonu/12c61763c53cc691cb032a613871f8c3 to your computer and use it in GitHub Desktop.
C++ test
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
TEST_F(Tutorial, HF_SEG) { | |
auto fusion = std::make_unique<Fusion>(); | |
FusionGuard fg(fusion.get()); | |
// T0_g_float[bS0{1}, iS1{32}, bS2{1}] | |
// T12_g___bfloat[bS243{1}, bS244{1 ex 6}, iS245{i17}] | |
auto t0 = TensorViewBuilder().shape({1, 32, 1}).build(); | |
auto t12 = TensorViewBuilder() | |
.shape({1, 6, -1}) | |
.expanded({false, true, false}) | |
.dtype(DataType::BFloat16) | |
.build(); | |
fusion->addInput(t0); | |
fusion->addInput(t12); | |
// T16_l_float[bS50{1}, bS52{1}, iS51{32}] | |
// = Set.Permute( T0_g_float[bS0{1}, iS1{32}, bS2{1}], cache_op=Streaming ) | |
auto t16 = permute(t0, {0, 2, 1}); | |
t16->setMemoryType(MemoryType::Local); | |
// T17_g_float[bS53{1}, bS54{1}, bS55{1}, iS56{32}] | |
// = broadcast( T16_l_float[bS50{1}, bS52{1}, iS51{32}], flags = {false, | |
// false, true, false} ) | |
auto t17 = broadcast(t16, {false, false, true, false}); | |
t17->setMemoryType(MemoryType::Global); | |
// T18_g_float[bS57{1}, bS58{1}, bS59{1 ex 2}, iS60{32}] = expand( | |
// T17_g_float[bS53{1}, bS54{1}, bS55{1}, iS56{32}], {1, 1, 2, 32} ) | |
auto t18 = expand( | |
t17, | |
{IrBuilder::create<Val>(1), | |
IrBuilder::create<Val>(1), | |
IrBuilder::create<Val>(2), | |
IrBuilder::create<Val>(32)}); | |
t18->setMemoryType(MemoryType::Global); | |
// T19_g_float[bS61{1}, bS62{1}, iS67{64}rf] = view( T18_g_float[bS57{1}, | |
// bS58{1}, bS59{1 ex 2}, iS60{32}] ) | |
auto t19 = reshape( | |
t18, | |
{IrBuilder::create<Val>(1), | |
IrBuilder::create<Val>(1), | |
IrBuilder::create<Val>(64)}); | |
t19->setMemoryType(MemoryType::Global); | |
// T44_g_float[bS236{1}, bS237{1}, bS238{1}, iS239{64}] | |
// = broadcast( T19_g_float[bS61{1}, bS62{1}, iS67{64}rf], flags = {false, | |
// true, false, false} ) | |
auto t44 = broadcast(t19, {false, true, false, false}); | |
t44->setMemoryType(MemoryType::Global); | |
// T33_l_float[bS123{1}, bS124{1}, bS125{1 ex 6}, iS126{64}] = expand( | |
// T44_g_float[bS236{1}, bS237{1}, bS238{1}, iS239{64}], {1, 1, 6, 64} ) | |
auto t33 = expand( | |
t44, | |
{IrBuilder::create<Val>(1), | |
IrBuilder::create<Val>(1), | |
IrBuilder::create<Val>(6), | |
IrBuilder::create<Val>(64)}); | |
t33->setMemoryType(MemoryType::Local); | |
// T34_g_float[bS127{1}, bS128{1}, bS129{1 ex 6}, iS130{64}] | |
// = Set( T33_l_float[bS123{1}, bS124{1}, bS125{1 ex 6}, iS126{64}], | |
// cache_op=Streaming ) | |
auto t34 = set(t33); | |
t34->setMemoryType(MemoryType::Global); | |
// auto nz = SimplifyingIrBuilder::divExpr( | |
// t12->getLogicalDomain().back()->extent(), IrBuilder::create<Val>(8)); | |
// T42_l___bfloat[bS165{1}, bS166{1 ex 6}, iS169{8}rf, iS170{( ceilDiv(i17, 8) | |
// )}rf] = view( T12_g___bfloat[bS243{1}, bS244{1 ex 6}, iS245{i17}] ) | |
// T14_g___bfloat[bS179{1}, iS181{8}, bS180{1 ex 6}, iS182{64}] | |
auto t42 = reshape( | |
t12, | |
{t12->getLogicalDomain()[0]->getMaybeExpandedExtent(), | |
t12->getLogicalDomain()[1]->getMaybeExpandedExtent(), | |
IrBuilder::create<Val>(8), | |
IrBuilder::create<Val>(64), | |
/*nz*/}); | |
t42->setMemoryType(MemoryType::Local); | |
// T14_g___bfloat[bS179{1}, iS181{8}, bS180{1 ex 6}, iS182{64}] | |
// = Set.Permute( T42_l___bfloat[bS165{1}, bS166{1 ex 6}, iS169{8}rf, | |
// iS170{( ceilDiv(i17, 8) )}rf], cache_op=Streaming ) | |
auto t14 = permute(t42, {0, 2, 1, 3}); | |
t14->setMemoryType(MemoryType::Global); | |
// T15_g_float[bS183{1}, iS184{8}, bS185{1 ex 6}, iS186{64}] | |
// = __bfloat2float(T14_g___bfloat[bS179{1}, iS181{8}, bS180{1 ex 6}, | |
// iS182{64}]); | |
auto t15 = castOp(DataType::Float, t14); | |
t15->setMemoryType(MemoryType::Global); | |
// T20_g_float[bS68{1}, bS69{1}, bS70{1}, iS71{64}] | |
// = broadcast( T19_g_float[bS61{1}, bS62{1}, iS67{64}rf], flags = {false, | |
// true, false, false} ) | |
auto t20 = broadcast(t19, {false, true, false, false}); | |
t20->setMemoryType(MemoryType::Global); | |
// T21_g_float[bS72{1}, bS73{1}, bS74{1 ex 6}, iS75{64}] = expand( | |
// T20_g_float[bS68{1}, bS69{1}, bS70{1}, iS71{64}], {1, 1, 6, 64} ) | |
auto t21 = expand( | |
t20, | |
{IrBuilder::create<Val>(1), | |
IrBuilder::create<Val>(1), | |
IrBuilder::create<Val>(6), | |
IrBuilder::create<Val>(64)}); | |
t21->setMemoryType(MemoryType::Global); | |
// T22_l_float[bS76{1}, bS77{1}, bS78{1 ex 6}, iS79{64}] | |
// = Set( T21_g_float[bS72{1}, bS73{1}, bS74{1 ex 6}, iS75{64}], | |
// cache_op=Streaming ) | |
auto t22 = set(t21); | |
t22->setMemoryType(MemoryType::Local); | |
// T23_g_float[bS80{1}, bS81{1 ex 8}, bS82{1 ex 6}, iS83{64}] = expand( | |
// T22_l_float[bS76{1}, bS77{1}, bS78{1 ex 6}, iS79{64}], {1, 8, 6, 64} ) | |
auto t23 = expand( | |
t22, | |
{IrBuilder::create<Val>(1), | |
IrBuilder::create<Val>(8), | |
IrBuilder::create<Val>(6), | |
IrBuilder::create<Val>(64)}); | |
t23->setMemoryType(MemoryType::Global); | |
// T24_g_float[bS187{1}, iS188{8}, bS189{1 ex 6}, iS87{64}] | |
// = T15_g_float[bS183{1}, iS184{8}, bS185{1 ex 6}, iS186{64}] | |
// * T23_g_float[bS80{1}, bS81{1 ex 8}, bS82{1 ex 6}, iS83{64}]; | |
auto t24 = mul(t15, t23); | |
t24->setMemoryType(MemoryType::Global); | |
fusion->addOutput(t14); | |
fusion->addOutput(t24); | |
fusion->addOutput(t34); | |
fusion->printMath(); | |
auto options1 = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); | |
auto options2 = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0); | |
auto i0 = at::randn({1, 32, 6}, options1); | |
auto i1 = at::randn({1, 1, 512}, options2).expand({1, 6, 512}); | |
FusionExecutorCache executor_cache(std::move(fusion)); | |
executor_cache.runFusionWithInputs({i0, i1}); | |
ASSERT_TRUE(1 == 1); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment