Skip to content

Instantly share code, notes, and snippets.

@makslevental
Last active March 30, 2024 21:42
Show Gist options
  • Save makslevental/15fa1b55effee9a84b28a9d7a7e2c3ae to your computer and use it in GitHub Desktop.
Save makslevental/15fa1b55effee9a84b28a9d7a7e2c3ae to your computer and use it in GitHub Desktop.
@aie.device(AIEDevice.ipu)
def ipu():
tile_0_0 = aie.tile(0, 0)
tile_0_1 = aie.tile(0, 1)
tile_0_2 = aie.tile(0, 2)
tile_1_0 = aie.tile(1, 0)
tile_2_0 = aie.tile(2, 0)
tile_3_0 = aie.tile(3, 0)
buffer_weight = aie.buffer(tile_0_2, (N_CONSUMERS * N_WRITE_OUTS * K,), T.i32())
lock_read_weight = aie.lock(tile_0_2, init=1)
lock_send_weight = aie.lock(tile_0_2, init=0)
core_to_mem_fl = aie.flow(tile_0_2, DMA, 0, tile_0_1, DMA, 2)
@aie.core(tile_0_2)
def core():
for i in range_(iters):
with aiex.hold_lock(lock_read_weight, lock_send_weight):
for j in range(N_CONSUMERS):
ii = i + 1
jj = j + 1
linalg.fill(ii + jj + ii * jj, buffer_weight[j * N_WRITE_OUTS * K : (j + 1) * N_WRITE_OUTS * K])
yield_([])
@aie.mem(tile_0_2)
def mem_0_2():
@aie.dma(MM2S, core_to_mem_fl.source_channel, repeat_count=iters - 1)
def dma3():
aie.use_lock(lock_send_weight, AcquireGreaterEqual)
aie.dma_bd(buffer_weight)
aie.use_lock(lock_read_weight, Release)
aie.end()
# try to use different channels as much as possible to prevent false positives
mem_to_shim_tile_0_0_fl = aie.flow(tile_0_1, DMA, 1, tile_0_0, DMA, 0)
col_shim_channel_index[0] = int(mem_to_shim_tile_0_0_fl.dest_channel)
mem_to_shim_tile_1_0_fl = aie.flow(tile_0_1, DMA, 2, tile_1_0, DMA, 1)
col_shim_channel_index[1] = int(mem_to_shim_tile_1_0_fl.dest_channel)
mem_to_shim_tile_2_0_fl = aie.flow(tile_0_1, DMA, 3, tile_2_0, DMA, 0)
col_shim_channel_index[2] = int(mem_to_shim_tile_2_0_fl.dest_channel)
mem_to_shim_tile_3_0_fl = aie.flow(tile_0_1, DMA, 4, tile_3_0, DMA, 0)
col_shim_channel_index[3] = int(mem_to_shim_tile_3_0_fl.dest_channel)
consumer_flows = [mem_to_shim_tile_0_0_fl, mem_to_shim_tile_1_0_fl, mem_to_shim_tile_2_0_fl, mem_to_shim_tile_3_0_fl]
@aie.memtile_dma(tile_0_1)
def memtile_dma_0_1():
buffer_0_1_c = aie.buffer(tile_0_1, (N_CONSUMERS * N_WRITE_OUTS * K,), dtype)
read_in_locks = [aie.lock(tile_0_1, init=N_WRITE_OUTS if i % 2 else 2 * N_WRITE_OUTS) for i in range(N_CONSUMERS)]
write_out_locks = [aie.lock(tile_0_1, init=0) for i in range(N_CONSUMERS)]
# read in
@aie.dma(S2MM, core_to_mem_fl.dest_channel, repeat_count=iters - 1, num_bds=N_CONSUMERS)
def dma5():
aie.use_lock(read_in_locks[0], AcquireGreaterEqual, value=int(read_in_locks[0].owner.opview.init))
aie.dma_bd(buffer_0_1_c)
aie.use_lock(write_out_locks[0], Release, value=int(read_in_locks[0].owner.opview.init))
for i in range(1, N_CONSUMERS):
write_outs = int(read_in_locks[i].owner.opview.init)
@aie.another_bd(dma5)
def _():
aie.use_lock(read_in_locks[i], AcquireGreaterEqual, value=write_outs)
aie.dma_bd(buffer_0_1_c, len=0)
aie.use_lock(write_out_locks[i], Release, value=write_outs)
# write out
for i, fl in enumerate(consumer_flows):
write_outs = int(read_in_locks[i].owner.opview.init)
k = K if i % 2 else K // 2
@aie.dma(MM2S, fl.source_channel, repeat_count=(iters * write_outs) - 1)
def dma6():
aie.use_lock(write_out_locks[i], AcquireGreaterEqual, value=1)
aie.dma_bd(buffer_0_1_c, len=k, offset=i * write_outs * k, iteration=(write_outs, k))
aie.use_lock(read_in_locks[i], Release, value=1)
aie.end()
module {
aie.device(ipu) {
%shim_tile_0_0 = tile(0, 0)
%mem_tile_0_1 = tile(0, 1)
%core_tile_0_2 = tile(0, 2)
%shim_tile_1_0 = tile(1, 0)
%shim_tile_2_0 = tile(2, 0)
%shim_tile_3_0 = tile(3, 0)
%buffer_0_2 = buffer(%core_tile_0_2) : memref<64xi32>
%lock_0_2 = lock(%core_tile_0_2) {init = 1 : i32}
%lock_0_2_0 = lock(%core_tile_0_2) {init = 0 : i32}
flow(%core_tile_0_2, DMA : 0, %mem_tile_0_1, DMA : 2)
%core_0_2 = core(%core_tile_0_2) {
%c0 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
scf.for %arg0 = %c0 to %c8 step %c1 {
aie.use_lock(%lock_0_2, AcquireGreaterEqual)
%c1_1 = arith.constant 1 : index
%0 = arith.addi %arg0, %c1_1 : index
%c1_2 = arith.constant 1 : index
%1 = arith.addi %0, %c1_2 : index
%c1_3 = arith.constant 1 : index
%2 = arith.muli %0, %c1_3 : index
%3 = arith.addi %1, %2 : index
%subview = memref.subview %buffer_0_2[0] [16] [1] : memref<64xi32> to memref<16xi32>
linalg.fill ins(%3 : index) outs(%subview : memref<16xi32>)
%c1_4 = arith.constant 1 : index
%4 = arith.addi %arg0, %c1_4 : index
%c2 = arith.constant 2 : index
%5 = arith.addi %4, %c2 : index
%c2_5 = arith.constant 2 : index
%6 = arith.muli %4, %c2_5 : index
%7 = arith.addi %5, %6 : index
%subview_6 = memref.subview %buffer_0_2[16] [16] [1] : memref<64xi32> to memref<16xi32, strided<[1], offset: 16>>
linalg.fill ins(%7 : index) outs(%subview_6 : memref<16xi32, strided<[1], offset: 16>>)
%c1_7 = arith.constant 1 : index
%8 = arith.addi %arg0, %c1_7 : index
%c3 = arith.constant 3 : index
%9 = arith.addi %8, %c3 : index
%c3_8 = arith.constant 3 : index
%10 = arith.muli %8, %c3_8 : index
%11 = arith.addi %9, %10 : index
%subview_9 = memref.subview %buffer_0_2[32] [16] [1] : memref<64xi32> to memref<16xi32, strided<[1], offset: 32>>
linalg.fill ins(%11 : index) outs(%subview_9 : memref<16xi32, strided<[1], offset: 32>>)
%c1_10 = arith.constant 1 : index
%12 = arith.addi %arg0, %c1_10 : index
%c4 = arith.constant 4 : index
%13 = arith.addi %12, %c4 : index
%c4_11 = arith.constant 4 : index
%14 = arith.muli %12, %c4_11 : index
%15 = arith.addi %13, %14 : index
%subview_12 = memref.subview %buffer_0_2[48] [16] [1] : memref<64xi32> to memref<16xi32, strided<[1], offset: 48>>
linalg.fill ins(%15 : index) outs(%subview_12 : memref<16xi32, strided<[1], offset: 48>>)
aie.use_lock(%lock_0_2_0, Release)
}
aie.end
}
%mem_0_2 = mem(%core_tile_0_2) {
%0 = aie.dma(MM2S, 0) {loop = false, repeat_count = 7 : i32} [{
aie.use_lock(%lock_0_2_0, AcquireGreaterEqual)
aie.dma_bd(%buffer_0_2 : memref<64xi32>)
aie.use_lock(%lock_0_2, Release)
}]
aie.end
}
flow(%mem_tile_0_1, DMA : 1, %shim_tile_0_0, DMA : 0)
flow(%mem_tile_0_1, DMA : 2, %shim_tile_1_0, DMA : 1)
flow(%mem_tile_0_1, DMA : 3, %shim_tile_2_0, DMA : 0)
flow(%mem_tile_0_1, DMA : 4, %shim_tile_3_0, DMA : 0)
%memtile_dma_0_1 = memtile_dma(%mem_tile_0_1) {
%buffer_0_1 = aie.buffer(%mem_tile_0_1) : memref<64xi32>
%read_in_0 = aie.lock(%mem_tile_0_1) {init = 8 : i32, sym_name = "read_in_0"}
%read_in_1 = aie.lock(%mem_tile_0_1) {init = 4 : i32, sym_name = "read_in_1"}
%read_in_2 = aie.lock(%mem_tile_0_1) {init = 8 : i32, sym_name = "read_in_2"}
%read_in_3 = aie.lock(%mem_tile_0_1) {init = 4 : i32, sym_name = "read_in_3"}
%write_out_0 = aie.lock(%mem_tile_0_1) {init = 0 : i32, sym_name = "write_out_0"}
%write_out_1 = aie.lock(%mem_tile_0_1) {init = 0 : i32, sym_name = "write_out_1"}
%write_out_2 = aie.lock(%mem_tile_0_1) {init = 0 : i32, sym_name = "write_out_2"}
%write_out_3 = aie.lock(%mem_tile_0_1) {init = 0 : i32, sym_name = "write_out_3"}
%0 = aie.dma(S2MM, 2) {loop = false, repeat_count = 7 : i32} [{
aie.use_lock(%read_in_0, AcquireGreaterEqual) {value = 8 : i32}
aie.dma_bd(%buffer_0_1 : memref<64xi32>)
aie.use_lock(%write_out_0, Release) {value = 8 : i32}
}, {
aie.use_lock(%read_in_1, AcquireGreaterEqual) {value = 4 : i32}
aie.dma_bd(%buffer_0_1 : memref<64xi32>) {len = 0 : i32}
aie.use_lock(%write_out_1, Release) {value = 4 : i32}
}, {
aie.use_lock(%read_in_2, AcquireGreaterEqual) {value = 8 : i32}
aie.dma_bd(%buffer_0_1 : memref<64xi32>) {len = 0 : i32}
aie.use_lock(%write_out_2, Release) {value = 8 : i32}
}, {
aie.use_lock(%read_in_3, AcquireGreaterEqual) {value = 4 : i32}
aie.dma_bd(%buffer_0_1 : memref<64xi32>) {len = 0 : i32}
aie.use_lock(%write_out_3, Release) {value = 4 : i32}
}]
%1 = aie.dma(MM2S, 1) {loop = false, repeat_count = 63 : i32} [{
aie.use_lock(%write_out_0, AcquireGreaterEqual) {value = 1 : i32}
aie.dma_bd(%buffer_0_1 : memref<64xi32>) {iteration = #aie.bd_dim_layout<size = 8, stride = 2>, len = 2 : i32}
aie.use_lock(%read_in_0, Release) {value = 1 : i32}
}]
%2 = aie.dma(MM2S, 2) {loop = false, repeat_count = 31 : i32} [{
aie.use_lock(%write_out_1, AcquireGreaterEqual) {value = 1 : i32}
aie.dma_bd(%buffer_0_1 : memref<64xi32>) {iteration = #aie.bd_dim_layout<size = 4, stride = 4>, len = 4 : i32, offset = 16 : i32}
aie.use_lock(%read_in_1, Release) {value = 1 : i32}
}]
%3 = aie.dma(MM2S, 3) {loop = false, repeat_count = 63 : i32} [{
aie.use_lock(%write_out_2, AcquireGreaterEqual) {value = 1 : i32}
aie.dma_bd(%buffer_0_1 : memref<64xi32>) {iteration = #aie.bd_dim_layout<size = 8, stride = 2>, len = 2 : i32, offset = 32 : i32}
aie.use_lock(%read_in_2, Release) {value = 1 : i32}
}]
%4 = aie.dma(MM2S, 4) {loop = false, repeat_count = 31 : i32} [{
aie.use_lock(%write_out_3, AcquireGreaterEqual) {value = 1 : i32}
aie.dma_bd(%buffer_0_1 : memref<64xi32>) {iteration = #aie.bd_dim_layout<size = 4, stride = 4>, len = 4 : i32, offset = 48 : i32}
aie.use_lock(%read_in_3, Release) {value = 1 : i32}
}]
aie.end
}
}
}
compiling core 0 2
Running kernel
time=477.524us
[[[ 3 3 3 3]
[ 3 3 3 3]
[ 3 3 3 3]
[ 3 3 3 3]]
[[ 5 5 5 5]
[ 5 5 5 5]
[ 5 5 5 5]
[ 5 5 5 5]]
[[ 7 7 7 7]
[ 7 7 7 7]
[ 7 7 7 7]
[ 7 7 7 7]]
[[ 9 9 9 9]
[ 9 9 9 9]
[ 9 9 9 9]
[ 9 9 9 9]]
[[11 11 11 11]
[11 11 11 11]
[11 11 11 11]
[11 11 11 11]]
[[13 13 13 13]
[13 13 13 13]
[13 13 13 13]
[13 13 13 13]]
[[15 15 15 15]
[15 15 15 15]
[15 15 15 15]
[15 15 15 15]]
[[17 17 17 17]
[17 17 17 17]
[17 17 17 17]
[17 17 17 17]]]
[[[ 5 5 5 5]
[ 5 5 5 5]
[ 5 5 5 5]
[ 5 5 5 5]]
[[ 8 8 8 8]
[ 8 8 8 8]
[ 8 8 8 8]
[ 8 8 8 8]]
[[11 11 11 11]
[11 11 11 11]
[11 11 11 11]
[11 11 11 11]]
[[14 14 14 14]
[14 14 14 14]
[14 14 14 14]
[14 14 14 14]]
[[17 17 17 17]
[17 17 17 17]
[17 17 17 17]
[17 17 17 17]]
[[20 20 20 20]
[20 20 20 20]
[20 20 20 20]
[20 20 20 20]]
[[23 23 23 23]
[23 23 23 23]
[23 23 23 23]
[23 23 23 23]]
[[26 26 26 26]
[26 26 26 26]
[26 26 26 26]
[26 26 26 26]]]
[[[ 7 7 7 7]
[ 7 7 7 7]
[ 7 7 7 7]
[ 7 7 7 7]]
[[11 11 11 11]
[11 11 11 11]
[11 11 11 11]
[11 11 11 11]]
[[15 15 15 15]
[15 15 15 15]
[15 15 15 15]
[15 15 19 19]]
[[19 19 19 19]
[19 19 19 19]
[19 19 19 19]
[19 19 19 19]]
[[27 27 27 27]
[27 27 27 27]
[27 27 27 27]
[27 27 27 27]]
[[27 27 27 27]
[27 27 27 27]
[27 27 27 27]
[27 27 27 27]]
[[35 35 35 35]
[35 35 35 35]
[35 35 35 35]
[35 35 35 35]]
[[35 35 35 35]
[35 35 35 35]
[35 35 35 35]
[35 35 35 35]]]
[[[ 9 9 9 9]
[ 9 9 9 9]
[ 9 9 9 9]
[ 9 9 9 9]]
[[14 14 14 14]
[14 14 14 14]
[14 14 14 14]
[14 14 14 14]]
[[19 19 19 19]
[19 19 19 19]
[19 19 19 19]
[19 19 19 19]]
[[24 24 24 24]
[24 24 24 24]
[24 24 24 24]
[24 24 24 24]]
[[29 29 29 29]
[34 34 34 34]
[34 34 34 34]
[34 34 34 34]]
[[34 34 34 34]
[34 34 34 34]
[34 34 34 34]
[34 34 34 34]]
[[39 39 39 39]
[39 39 39 39]
[44 44 44 44]
[44 44 44 44]]
[[44 44 44 44]
[44 44 44 44]
[44 44 44 44]
[44 44 44 44]]]
PASSED
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment