Last active
March 30, 2024 21:42
-
-
Save makslevental/15fa1b55effee9a84b28a9d7a7e2c3ae to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@aie.device(AIEDevice.ipu) | |
def ipu(): | |
tile_0_0 = aie.tile(0, 0) | |
tile_0_1 = aie.tile(0, 1) | |
tile_0_2 = aie.tile(0, 2) | |
tile_1_0 = aie.tile(1, 0) | |
tile_2_0 = aie.tile(2, 0) | |
tile_3_0 = aie.tile(3, 0) | |
buffer_weight = aie.buffer(tile_0_2, (N_CONSUMERS * N_WRITE_OUTS * K,), T.i32()) | |
lock_read_weight = aie.lock(tile_0_2, init=1) | |
lock_send_weight = aie.lock(tile_0_2, init=0) | |
core_to_mem_fl = aie.flow(tile_0_2, DMA, 0, tile_0_1, DMA, 2) | |
@aie.core(tile_0_2) | |
def core(): | |
for i in range_(iters): | |
with aiex.hold_lock(lock_read_weight, lock_send_weight): | |
for j in range(N_CONSUMERS): | |
ii = i + 1 | |
jj = j + 1 | |
linalg.fill(ii + jj + ii * jj, buffer_weight[j * N_WRITE_OUTS * K : (j + 1) * N_WRITE_OUTS * K]) | |
yield_([]) | |
@aie.mem(tile_0_2) | |
def mem_0_2(): | |
@aie.dma(MM2S, core_to_mem_fl.source_channel, repeat_count=iters - 1) | |
def dma3(): | |
aie.use_lock(lock_send_weight, AcquireGreaterEqual) | |
aie.dma_bd(buffer_weight) | |
aie.use_lock(lock_read_weight, Release) | |
aie.end() | |
# try to use different channels as much as possible to prevent false positives | |
mem_to_shim_tile_0_0_fl = aie.flow(tile_0_1, DMA, 1, tile_0_0, DMA, 0) | |
col_shim_channel_index[0] = int(mem_to_shim_tile_0_0_fl.dest_channel) | |
mem_to_shim_tile_1_0_fl = aie.flow(tile_0_1, DMA, 2, tile_1_0, DMA, 1) | |
col_shim_channel_index[1] = int(mem_to_shim_tile_1_0_fl.dest_channel) | |
mem_to_shim_tile_2_0_fl = aie.flow(tile_0_1, DMA, 3, tile_2_0, DMA, 0) | |
col_shim_channel_index[2] = int(mem_to_shim_tile_2_0_fl.dest_channel) | |
mem_to_shim_tile_3_0_fl = aie.flow(tile_0_1, DMA, 4, tile_3_0, DMA, 0) | |
col_shim_channel_index[3] = int(mem_to_shim_tile_3_0_fl.dest_channel) | |
consumer_flows = [mem_to_shim_tile_0_0_fl, mem_to_shim_tile_1_0_fl, mem_to_shim_tile_2_0_fl, mem_to_shim_tile_3_0_fl] | |
@aie.memtile_dma(tile_0_1) | |
def memtile_dma_0_1(): | |
buffer_0_1_c = aie.buffer(tile_0_1, (N_CONSUMERS * N_WRITE_OUTS * K,), dtype) | |
read_in_locks = [aie.lock(tile_0_1, init=N_WRITE_OUTS if i % 2 else 2 * N_WRITE_OUTS) for i in range(N_CONSUMERS)] | |
write_out_locks = [aie.lock(tile_0_1, init=0) for i in range(N_CONSUMERS)] | |
# read in | |
@aie.dma(S2MM, core_to_mem_fl.dest_channel, repeat_count=iters - 1, num_bds=N_CONSUMERS) | |
def dma5(): | |
aie.use_lock(read_in_locks[0], AcquireGreaterEqual, value=int(read_in_locks[0].owner.opview.init)) | |
aie.dma_bd(buffer_0_1_c) | |
aie.use_lock(write_out_locks[0], Release, value=int(read_in_locks[0].owner.opview.init)) | |
for i in range(1, N_CONSUMERS): | |
write_outs = int(read_in_locks[i].owner.opview.init) | |
@aie.another_bd(dma5) | |
def _(): | |
aie.use_lock(read_in_locks[i], AcquireGreaterEqual, value=write_outs) | |
aie.dma_bd(buffer_0_1_c, len=0) | |
aie.use_lock(write_out_locks[i], Release, value=write_outs) | |
# write out | |
for i, fl in enumerate(consumer_flows): | |
write_outs = int(read_in_locks[i].owner.opview.init) | |
k = K if i % 2 else K // 2 | |
@aie.dma(MM2S, fl.source_channel, repeat_count=(iters * write_outs) - 1) | |
def dma6(): | |
aie.use_lock(write_out_locks[i], AcquireGreaterEqual, value=1) | |
aie.dma_bd(buffer_0_1_c, len=k, offset=i * write_outs * k, iteration=(write_outs, k)) | |
aie.use_lock(read_in_locks[i], Release, value=1) | |
aie.end() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module { | |
aie.device(ipu) { | |
%shim_tile_0_0 = tile(0, 0) | |
%mem_tile_0_1 = tile(0, 1) | |
%core_tile_0_2 = tile(0, 2) | |
%shim_tile_1_0 = tile(1, 0) | |
%shim_tile_2_0 = tile(2, 0) | |
%shim_tile_3_0 = tile(3, 0) | |
%buffer_0_2 = buffer(%core_tile_0_2) : memref<64xi32> | |
%lock_0_2 = lock(%core_tile_0_2) {init = 1 : i32} | |
%lock_0_2_0 = lock(%core_tile_0_2) {init = 0 : i32} | |
flow(%core_tile_0_2, DMA : 0, %mem_tile_0_1, DMA : 2) | |
%core_0_2 = core(%core_tile_0_2) { | |
%c0 = arith.constant 0 : index | |
%c8 = arith.constant 8 : index | |
%c1 = arith.constant 1 : index | |
scf.for %arg0 = %c0 to %c8 step %c1 { | |
aie.use_lock(%lock_0_2, AcquireGreaterEqual) | |
%c1_1 = arith.constant 1 : index | |
%0 = arith.addi %arg0, %c1_1 : index | |
%c1_2 = arith.constant 1 : index | |
%1 = arith.addi %0, %c1_2 : index | |
%c1_3 = arith.constant 1 : index | |
%2 = arith.muli %0, %c1_3 : index | |
%3 = arith.addi %1, %2 : index | |
%subview = memref.subview %buffer_0_2[0] [16] [1] : memref<64xi32> to memref<16xi32> | |
linalg.fill ins(%3 : index) outs(%subview : memref<16xi32>) | |
%c1_4 = arith.constant 1 : index | |
%4 = arith.addi %arg0, %c1_4 : index | |
%c2 = arith.constant 2 : index | |
%5 = arith.addi %4, %c2 : index | |
%c2_5 = arith.constant 2 : index | |
%6 = arith.muli %4, %c2_5 : index | |
%7 = arith.addi %5, %6 : index | |
%subview_6 = memref.subview %buffer_0_2[16] [16] [1] : memref<64xi32> to memref<16xi32, strided<[1], offset: 16>> | |
linalg.fill ins(%7 : index) outs(%subview_6 : memref<16xi32, strided<[1], offset: 16>>) | |
%c1_7 = arith.constant 1 : index | |
%8 = arith.addi %arg0, %c1_7 : index | |
%c3 = arith.constant 3 : index | |
%9 = arith.addi %8, %c3 : index | |
%c3_8 = arith.constant 3 : index | |
%10 = arith.muli %8, %c3_8 : index | |
%11 = arith.addi %9, %10 : index | |
%subview_9 = memref.subview %buffer_0_2[32] [16] [1] : memref<64xi32> to memref<16xi32, strided<[1], offset: 32>> | |
linalg.fill ins(%11 : index) outs(%subview_9 : memref<16xi32, strided<[1], offset: 32>>) | |
%c1_10 = arith.constant 1 : index | |
%12 = arith.addi %arg0, %c1_10 : index | |
%c4 = arith.constant 4 : index | |
%13 = arith.addi %12, %c4 : index | |
%c4_11 = arith.constant 4 : index | |
%14 = arith.muli %12, %c4_11 : index | |
%15 = arith.addi %13, %14 : index | |
%subview_12 = memref.subview %buffer_0_2[48] [16] [1] : memref<64xi32> to memref<16xi32, strided<[1], offset: 48>> | |
linalg.fill ins(%15 : index) outs(%subview_12 : memref<16xi32, strided<[1], offset: 48>>) | |
aie.use_lock(%lock_0_2_0, Release) | |
} | |
aie.end | |
} | |
%mem_0_2 = mem(%core_tile_0_2) { | |
%0 = aie.dma(MM2S, 0) {loop = false, repeat_count = 7 : i32} [{ | |
aie.use_lock(%lock_0_2_0, AcquireGreaterEqual) | |
aie.dma_bd(%buffer_0_2 : memref<64xi32>) | |
aie.use_lock(%lock_0_2, Release) | |
}] | |
aie.end | |
} | |
flow(%mem_tile_0_1, DMA : 1, %shim_tile_0_0, DMA : 0) | |
flow(%mem_tile_0_1, DMA : 2, %shim_tile_1_0, DMA : 1) | |
flow(%mem_tile_0_1, DMA : 3, %shim_tile_2_0, DMA : 0) | |
flow(%mem_tile_0_1, DMA : 4, %shim_tile_3_0, DMA : 0) | |
%memtile_dma_0_1 = memtile_dma(%mem_tile_0_1) { | |
%buffer_0_1 = aie.buffer(%mem_tile_0_1) : memref<64xi32> | |
%read_in_0 = aie.lock(%mem_tile_0_1) {init = 8 : i32, sym_name = "read_in_0"} | |
%read_in_1 = aie.lock(%mem_tile_0_1) {init = 4 : i32, sym_name = "read_in_1"} | |
%read_in_2 = aie.lock(%mem_tile_0_1) {init = 8 : i32, sym_name = "read_in_2"} | |
%read_in_3 = aie.lock(%mem_tile_0_1) {init = 4 : i32, sym_name = "read_in_3"} | |
%write_out_0 = aie.lock(%mem_tile_0_1) {init = 0 : i32, sym_name = "write_out_0"} | |
%write_out_1 = aie.lock(%mem_tile_0_1) {init = 0 : i32, sym_name = "write_out_1"} | |
%write_out_2 = aie.lock(%mem_tile_0_1) {init = 0 : i32, sym_name = "write_out_2"} | |
%write_out_3 = aie.lock(%mem_tile_0_1) {init = 0 : i32, sym_name = "write_out_3"} | |
%0 = aie.dma(S2MM, 2) {loop = false, repeat_count = 7 : i32} [{ | |
aie.use_lock(%read_in_0, AcquireGreaterEqual) {value = 8 : i32} | |
aie.dma_bd(%buffer_0_1 : memref<64xi32>) | |
aie.use_lock(%write_out_0, Release) {value = 8 : i32} | |
}, { | |
aie.use_lock(%read_in_1, AcquireGreaterEqual) {value = 4 : i32} | |
aie.dma_bd(%buffer_0_1 : memref<64xi32>) {len = 0 : i32} | |
aie.use_lock(%write_out_1, Release) {value = 4 : i32} | |
}, { | |
aie.use_lock(%read_in_2, AcquireGreaterEqual) {value = 8 : i32} | |
aie.dma_bd(%buffer_0_1 : memref<64xi32>) {len = 0 : i32} | |
aie.use_lock(%write_out_2, Release) {value = 8 : i32} | |
}, { | |
aie.use_lock(%read_in_3, AcquireGreaterEqual) {value = 4 : i32} | |
aie.dma_bd(%buffer_0_1 : memref<64xi32>) {len = 0 : i32} | |
aie.use_lock(%write_out_3, Release) {value = 4 : i32} | |
}] | |
%1 = aie.dma(MM2S, 1) {loop = false, repeat_count = 63 : i32} [{ | |
aie.use_lock(%write_out_0, AcquireGreaterEqual) {value = 1 : i32} | |
aie.dma_bd(%buffer_0_1 : memref<64xi32>) {iteration = #aie.bd_dim_layout<size = 8, stride = 2>, len = 2 : i32} | |
aie.use_lock(%read_in_0, Release) {value = 1 : i32} | |
}] | |
%2 = aie.dma(MM2S, 2) {loop = false, repeat_count = 31 : i32} [{ | |
aie.use_lock(%write_out_1, AcquireGreaterEqual) {value = 1 : i32} | |
aie.dma_bd(%buffer_0_1 : memref<64xi32>) {iteration = #aie.bd_dim_layout<size = 4, stride = 4>, len = 4 : i32, offset = 16 : i32} | |
aie.use_lock(%read_in_1, Release) {value = 1 : i32} | |
}] | |
%3 = aie.dma(MM2S, 3) {loop = false, repeat_count = 63 : i32} [{ | |
aie.use_lock(%write_out_2, AcquireGreaterEqual) {value = 1 : i32} | |
aie.dma_bd(%buffer_0_1 : memref<64xi32>) {iteration = #aie.bd_dim_layout<size = 8, stride = 2>, len = 2 : i32, offset = 32 : i32} | |
aie.use_lock(%read_in_2, Release) {value = 1 : i32} | |
}] | |
%4 = aie.dma(MM2S, 4) {loop = false, repeat_count = 31 : i32} [{ | |
aie.use_lock(%write_out_3, AcquireGreaterEqual) {value = 1 : i32} | |
aie.dma_bd(%buffer_0_1 : memref<64xi32>) {iteration = #aie.bd_dim_layout<size = 4, stride = 4>, len = 4 : i32, offset = 48 : i32} | |
aie.use_lock(%read_in_3, Release) {value = 1 : i32} | |
}] | |
aie.end | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
compiling core 0 2 | |
Running kernel | |
time=477.524us | |
[[[ 3 3 3 3] | |
[ 3 3 3 3] | |
[ 3 3 3 3] | |
[ 3 3 3 3]] | |
[[ 5 5 5 5] | |
[ 5 5 5 5] | |
[ 5 5 5 5] | |
[ 5 5 5 5]] | |
[[ 7 7 7 7] | |
[ 7 7 7 7] | |
[ 7 7 7 7] | |
[ 7 7 7 7]] | |
[[ 9 9 9 9] | |
[ 9 9 9 9] | |
[ 9 9 9 9] | |
[ 9 9 9 9]] | |
[[11 11 11 11] | |
[11 11 11 11] | |
[11 11 11 11] | |
[11 11 11 11]] | |
[[13 13 13 13] | |
[13 13 13 13] | |
[13 13 13 13] | |
[13 13 13 13]] | |
[[15 15 15 15] | |
[15 15 15 15] | |
[15 15 15 15] | |
[15 15 15 15]] | |
[[17 17 17 17] | |
[17 17 17 17] | |
[17 17 17 17] | |
[17 17 17 17]]] | |
[[[ 5 5 5 5] | |
[ 5 5 5 5] | |
[ 5 5 5 5] | |
[ 5 5 5 5]] | |
[[ 8 8 8 8] | |
[ 8 8 8 8] | |
[ 8 8 8 8] | |
[ 8 8 8 8]] | |
[[11 11 11 11] | |
[11 11 11 11] | |
[11 11 11 11] | |
[11 11 11 11]] | |
[[14 14 14 14] | |
[14 14 14 14] | |
[14 14 14 14] | |
[14 14 14 14]] | |
[[17 17 17 17] | |
[17 17 17 17] | |
[17 17 17 17] | |
[17 17 17 17]] | |
[[20 20 20 20] | |
[20 20 20 20] | |
[20 20 20 20] | |
[20 20 20 20]] | |
[[23 23 23 23] | |
[23 23 23 23] | |
[23 23 23 23] | |
[23 23 23 23]] | |
[[26 26 26 26] | |
[26 26 26 26] | |
[26 26 26 26] | |
[26 26 26 26]]] | |
[[[ 7 7 7 7] | |
[ 7 7 7 7] | |
[ 7 7 7 7] | |
[ 7 7 7 7]] | |
[[11 11 11 11] | |
[11 11 11 11] | |
[11 11 11 11] | |
[11 11 11 11]] | |
[[15 15 15 15] | |
[15 15 15 15] | |
[15 15 15 15] | |
[15 15 19 19]] | |
[[19 19 19 19] | |
[19 19 19 19] | |
[19 19 19 19] | |
[19 19 19 19]] | |
[[27 27 27 27] | |
[27 27 27 27] | |
[27 27 27 27] | |
[27 27 27 27]] | |
[[27 27 27 27] | |
[27 27 27 27] | |
[27 27 27 27] | |
[27 27 27 27]] | |
[[35 35 35 35] | |
[35 35 35 35] | |
[35 35 35 35] | |
[35 35 35 35]] | |
[[35 35 35 35] | |
[35 35 35 35] | |
[35 35 35 35] | |
[35 35 35 35]]] | |
[[[ 9 9 9 9] | |
[ 9 9 9 9] | |
[ 9 9 9 9] | |
[ 9 9 9 9]] | |
[[14 14 14 14] | |
[14 14 14 14] | |
[14 14 14 14] | |
[14 14 14 14]] | |
[[19 19 19 19] | |
[19 19 19 19] | |
[19 19 19 19] | |
[19 19 19 19]] | |
[[24 24 24 24] | |
[24 24 24 24] | |
[24 24 24 24] | |
[24 24 24 24]] | |
[[29 29 29 29] | |
[34 34 34 34] | |
[34 34 34 34] | |
[34 34 34 34]] | |
[[34 34 34 34] | |
[34 34 34 34] | |
[34 34 34 34] | |
[34 34 34 34]] | |
[[39 39 39 39] | |
[39 39 39 39] | |
[44 44 44 44] | |
[44 44 44 44]] | |
[[44 44 44 44] | |
[44 44 44 44] | |
[44 44 44 44] | |
[44 44 44 44]]] | |
PASSED |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment