Skip to content

Instantly share code, notes, and snippets.

@protonu
Created February 7, 2025 01:08
Show Gist options
  • Save protonu/12c61763c53cc691cb032a613871f8c3 to your computer and use it in GitHub Desktop.
Save protonu/12c61763c53cc691cb032a613871f8c3 to your computer and use it in GitHub Desktop.
C++ test
TEST_F(Tutorial, HF_SEG) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());
// T0_g_float[bS0{1}, iS1{32}, bS2{1}]
// T12_g___bfloat[bS243{1}, bS244{1 ex 6}, iS245{i17}]
auto t0 = TensorViewBuilder().shape({1, 32, 1}).build();
auto t12 = TensorViewBuilder()
.shape({1, 6, -1})
.expanded({false, true, false})
.dtype(DataType::BFloat16)
.build();
fusion->addInput(t0);
fusion->addInput(t12);
// T16_l_float[bS50{1}, bS52{1}, iS51{32}]
// = Set.Permute( T0_g_float[bS0{1}, iS1{32}, bS2{1}], cache_op=Streaming )
auto t16 = permute(t0, {0, 2, 1});
t16->setMemoryType(MemoryType::Local);
// T17_g_float[bS53{1}, bS54{1}, bS55{1}, iS56{32}]
// = broadcast( T16_l_float[bS50{1}, bS52{1}, iS51{32}], flags = {false,
// false, true, false} )
auto t17 = broadcast(t16, {false, false, true, false});
t17->setMemoryType(MemoryType::Global);
// T18_g_float[bS57{1}, bS58{1}, bS59{1 ex 2}, iS60{32}] = expand(
// T17_g_float[bS53{1}, bS54{1}, bS55{1}, iS56{32}], {1, 1, 2, 32} )
auto t18 = expand(
t17,
{IrBuilder::create<Val>(1),
IrBuilder::create<Val>(1),
IrBuilder::create<Val>(2),
IrBuilder::create<Val>(32)});
t18->setMemoryType(MemoryType::Global);
// T19_g_float[bS61{1}, bS62{1}, iS67{64}rf] = view( T18_g_float[bS57{1},
// bS58{1}, bS59{1 ex 2}, iS60{32}] )
auto t19 = reshape(
t18,
{IrBuilder::create<Val>(1),
IrBuilder::create<Val>(1),
IrBuilder::create<Val>(64)});
t19->setMemoryType(MemoryType::Global);
// T44_g_float[bS236{1}, bS237{1}, bS238{1}, iS239{64}]
// = broadcast( T19_g_float[bS61{1}, bS62{1}, iS67{64}rf], flags = {false,
// true, false, false} )
auto t44 = broadcast(t19, {false, true, false, false});
t44->setMemoryType(MemoryType::Global);
// T33_l_float[bS123{1}, bS124{1}, bS125{1 ex 6}, iS126{64}] = expand(
// T44_g_float[bS236{1}, bS237{1}, bS238{1}, iS239{64}], {1, 1, 6, 64} )
auto t33 = expand(
t44,
{IrBuilder::create<Val>(1),
IrBuilder::create<Val>(1),
IrBuilder::create<Val>(6),
IrBuilder::create<Val>(64)});
t33->setMemoryType(MemoryType::Local);
// T34_g_float[bS127{1}, bS128{1}, bS129{1 ex 6}, iS130{64}]
// = Set( T33_l_float[bS123{1}, bS124{1}, bS125{1 ex 6}, iS126{64}],
// cache_op=Streaming )
auto t34 = set(t33);
t34->setMemoryType(MemoryType::Global);
// auto nz = SimplifyingIrBuilder::divExpr(
// t12->getLogicalDomain().back()->extent(), IrBuilder::create<Val>(8));
// T42_l___bfloat[bS165{1}, bS166{1 ex 6}, iS169{8}rf, iS170{( ceilDiv(i17, 8)
// )}rf] = view( T12_g___bfloat[bS243{1}, bS244{1 ex 6}, iS245{i17}] )
// T14_g___bfloat[bS179{1}, iS181{8}, bS180{1 ex 6}, iS182{64}]
auto t42 = reshape(
t12,
{t12->getLogicalDomain()[0]->getMaybeExpandedExtent(),
t12->getLogicalDomain()[1]->getMaybeExpandedExtent(),
IrBuilder::create<Val>(8),
IrBuilder::create<Val>(64),
/*nz*/});
t42->setMemoryType(MemoryType::Local);
// T14_g___bfloat[bS179{1}, iS181{8}, bS180{1 ex 6}, iS182{64}]
// = Set.Permute( T42_l___bfloat[bS165{1}, bS166{1 ex 6}, iS169{8}rf,
// iS170{( ceilDiv(i17, 8) )}rf], cache_op=Streaming )
auto t14 = permute(t42, {0, 2, 1, 3});
t14->setMemoryType(MemoryType::Global);
// T15_g_float[bS183{1}, iS184{8}, bS185{1 ex 6}, iS186{64}]
// = __bfloat2float(T14_g___bfloat[bS179{1}, iS181{8}, bS180{1 ex 6},
// iS182{64}]);
auto t15 = castOp(DataType::Float, t14);
t15->setMemoryType(MemoryType::Global);
// T20_g_float[bS68{1}, bS69{1}, bS70{1}, iS71{64}]
// = broadcast( T19_g_float[bS61{1}, bS62{1}, iS67{64}rf], flags = {false,
// true, false, false} )
auto t20 = broadcast(t19, {false, true, false, false});
t20->setMemoryType(MemoryType::Global);
// T21_g_float[bS72{1}, bS73{1}, bS74{1 ex 6}, iS75{64}] = expand(
// T20_g_float[bS68{1}, bS69{1}, bS70{1}, iS71{64}], {1, 1, 6, 64} )
auto t21 = expand(
t20,
{IrBuilder::create<Val>(1),
IrBuilder::create<Val>(1),
IrBuilder::create<Val>(6),
IrBuilder::create<Val>(64)});
t21->setMemoryType(MemoryType::Global);
// T22_l_float[bS76{1}, bS77{1}, bS78{1 ex 6}, iS79{64}]
// = Set( T21_g_float[bS72{1}, bS73{1}, bS74{1 ex 6}, iS75{64}],
// cache_op=Streaming )
auto t22 = set(t21);
t22->setMemoryType(MemoryType::Local);
// T23_g_float[bS80{1}, bS81{1 ex 8}, bS82{1 ex 6}, iS83{64}] = expand(
// T22_l_float[bS76{1}, bS77{1}, bS78{1 ex 6}, iS79{64}], {1, 8, 6, 64} )
auto t23 = expand(
t22,
{IrBuilder::create<Val>(1),
IrBuilder::create<Val>(8),
IrBuilder::create<Val>(6),
IrBuilder::create<Val>(64)});
t23->setMemoryType(MemoryType::Global);
// T24_g_float[bS187{1}, iS188{8}, bS189{1 ex 6}, iS87{64}]
// = T15_g_float[bS183{1}, iS184{8}, bS185{1 ex 6}, iS186{64}]
// * T23_g_float[bS80{1}, bS81{1 ex 8}, bS82{1 ex 6}, iS83{64}];
auto t24 = mul(t15, t23);
t24->setMemoryType(MemoryType::Global);
fusion->addOutput(t14);
fusion->addOutput(t24);
fusion->addOutput(t34);
fusion->printMath();
auto options1 = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto options2 = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0);
auto i0 = at::randn({1, 32, 6}, options1);
auto i1 = at::randn({1, 1, 512}, options2).expand({1, 6, 512});
FusionExecutorCache executor_cache(std::move(fusion));
executor_cache.runFusionWithInputs({i0, i1});
ASSERT_TRUE(1 == 1);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment