-
-
Save c-rhodes/1e9f2d8fd0ca3c6539f167e08079f6ab to your computer and use it in GitHub Desktop.
MLIR function demonstrating load/store to SME ZA.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Integration test demonstrating load/store to SME ZA. | |
llvm.func @printI128(i128) | |
llvm.func @printI64(i64) | |
llvm.func @printF64(f64) | |
llvm.func @printOpen() | |
llvm.func @printClose() | |
llvm.func @printComma() | |
llvm.func @printNewline() | |
func.func @za0b() { | |
%c0 = arith.constant 0 : index | |
%c0_i8 = arith.constant 0 : i8 | |
%c1_i8 = arith.constant 1 : i8 | |
%c1_index = arith.constant 1 : index | |
%c16 = arith.constant 16 : index | |
%vscale = vector.vscale | |
// "svl" refers to the Streaming Vector Length and "svl_b" the number of | |
// 8-bit elements in a vector of SVL bits. | |
%svl_b = arith.muli %c16, %vscale : index | |
// Allocate memory and fill each "row" with row number. | |
%size = arith.muli %svl_b, %svl_b : index | |
%mem1 = memref.alloca(%size) : memref<?xi8> | |
%init_0 = arith.constant 0 : i8 | |
scf.for %i = %c0 to %size step %svl_b iter_args(%val = %init_0) -> (i8) { | |
%splat_val = vector.broadcast %val : i8 to vector<[16]xi8> | |
vector.store %splat_val, %mem1[%i] : memref<?xi8>, vector<[16]xi8> | |
%val_next = arith.addi %val, %c1_i8 : i8 | |
%av = vector.load %mem1[%i] : memref<?xi8>, vector<[16]xi8> | |
llvm.call @printOpen() : () -> () | |
scf.for %i2 = %c0 to %svl_b step %c1_index { | |
%elem = vector.extractelement %av[%i2 : index] : vector<[16]xi8> | |
%elem_i64 = llvm.zext %elem : i8 to i64 | |
llvm.call @printI64(%elem_i64) : (i64) -> () | |
%last_i = arith.subi %svl_b, %c1_index : index | |
%isNotLastIter = arith.cmpi ult, %i2, %last_i : index | |
scf.if %isNotLastIter { | |
llvm.call @printComma() : () -> () | |
} | |
} | |
llvm.call @printClose() : () -> () | |
llvm.call @printNewline() : () -> () | |
scf.yield %val_next : i8 | |
} | |
// Load ZA from memory | |
%za_b = vector.load %mem1[%c0] : memref<?xi8>, vector<[16]x[16]xi8> | |
// Allocate new memory to store ZA to | |
%mem2 = memref.alloca(%size) : memref<?xi8> | |
// Zero new memory | |
scf.for %i = %c0 to %size step %c1_index { | |
memref.store %c0_i8, %mem2[%i] : memref<?xi8> | |
} | |
// Dump zeroed memory | |
scf.for %i = %c0 to %size step %svl_b { | |
%av = vector.load %mem2[%i] : memref<?xi8>, vector<[16]xi8> | |
llvm.call @printOpen() : () -> () | |
scf.for %i2 = %c0 to %svl_b step %c1_index { | |
%elem = vector.extractelement %av[%i2 : index] : vector<[16]xi8> | |
%elem_i64 = llvm.zext %elem : i8 to i64 | |
llvm.call @printI64(%elem_i64) : (i64) -> () | |
%last_i = arith.subi %svl_b, %c1_index : index | |
%isNotLastIter = arith.cmpi ult, %i2, %last_i : index | |
scf.if %isNotLastIter { | |
llvm.call @printComma() : () -> () | |
} | |
} | |
llvm.call @printClose() : () -> () | |
llvm.call @printNewline() : () -> () | |
} | |
// Store ZA to memory | |
vector.store %za_b, %mem2[%c0] : memref<?xi8>, vector<[16]x[16]xi8> | |
// Dump memory after storing ZA | |
scf.for %i = %c0 to %size step %svl_b { | |
%av = vector.load %mem2[%i] : memref<?xi8>, vector<[16]xi8> | |
llvm.call @printOpen() : () -> () | |
scf.for %i2 = %c0 to %svl_b step %c1_index { | |
%elem = vector.extractelement %av[%i2 : index] : vector<[16]xi8> | |
%elem_i64 = llvm.zext %elem : i8 to i64 | |
llvm.call @printI64(%elem_i64) : (i64) -> () | |
%last_i = arith.subi %svl_b, %c1_index : index | |
%isNotLastIter = arith.cmpi ult, %i2, %last_i : index | |
scf.if %isNotLastIter { | |
llvm.call @printComma() : () -> () | |
} | |
} | |
llvm.call @printClose() : () -> () | |
llvm.call @printNewline() : () -> () | |
} | |
return | |
} | |
func.func @za0h() { | |
%c0 = arith.constant 0 : index | |
%c0_i16 = arith.constant 0 : i16 | |
%c1_i16 = arith.constant 1 : i16 | |
%c1_index = arith.constant 1 : index | |
%c8 = arith.constant 8 : index | |
%vscale = vector.vscale | |
// "svl" refers to the Streaming Vector Length and "svl_h" the number of | |
// 16-bit elements in a vector of SVL bits. | |
%svl_h = arith.muli %c8, %vscale : index | |
// Allocate memory and fill each "row" with row number. | |
%size = arith.muli %svl_h, %svl_h : index | |
%mem1 = memref.alloca(%size) : memref<?xi16> | |
%init_0 = arith.constant 0 : i16 | |
scf.for %i = %c0 to %size step %svl_h iter_args(%val = %init_0) -> (i16) { | |
%splat_val = vector.broadcast %val : i16 to vector<[8]xi16> | |
vector.store %splat_val, %mem1[%i] : memref<?xi16>, vector<[8]xi16> | |
%val_next = arith.addi %val, %c1_i16 : i16 | |
%av = vector.load %mem1[%i] : memref<?xi16>, vector<[8]xi16> | |
llvm.call @printOpen() : () -> () | |
scf.for %i2 = %c0 to %svl_h step %c1_index { | |
%elem = vector.extractelement %av[%i2 : index] : vector<[8]xi16> | |
%elem_i64 = llvm.zext %elem : i16 to i64 | |
llvm.call @printI64(%elem_i64) : (i64) -> () | |
%last_i = arith.subi %svl_h, %c1_index : index | |
%isNotLastIter = arith.cmpi ult, %i2, %last_i : index | |
scf.if %isNotLastIter { | |
llvm.call @printComma() : () -> () | |
} | |
} | |
llvm.call @printClose() : () -> () | |
llvm.call @printNewline() : () -> () | |
scf.yield %val_next : i16 | |
} | |
// Load ZA0.H from memory | |
%za0_h = vector.load %mem1[%c0] : memref<?xi16>, vector<[8]x[8]xi16> | |
// Allocate new memory to store ZA0.H to | |
%mem2 = memref.alloca(%size) : memref<?xi16> | |
// Zero new memory | |
scf.for %i = %c0 to %size step %c1_index { | |
memref.store %c0_i16, %mem2[%i] : memref<?xi16> | |
} | |
// Dump zeroed memory | |
scf.for %i = %c0 to %size step %svl_h { | |
%av = vector.load %mem2[%i] : memref<?xi16>, vector<[8]xi16> | |
llvm.call @printOpen() : () -> () | |
scf.for %i2 = %c0 to %svl_h step %c1_index { | |
%elem = vector.extractelement %av[%i2 : index] : vector<[8]xi16> | |
%elem_i64 = llvm.zext %elem : i16 to i64 | |
llvm.call @printI64(%elem_i64) : (i64) -> () | |
%last_i = arith.subi %svl_h, %c1_index : index | |
%isNotLastIter = arith.cmpi ult, %i2, %last_i : index | |
scf.if %isNotLastIter { | |
llvm.call @printComma() : () -> () | |
} | |
} | |
llvm.call @printClose() : () -> () | |
llvm.call @printNewline() : () -> () | |
} | |
// Store ZA0.H to memory | |
vector.store %za0_h, %mem2[%c0] : memref<?xi16>, vector<[8]x[8]xi16> | |
// Dump memory after storing ZA0.H | |
scf.for %i = %c0 to %size step %svl_h { | |
%av = vector.load %mem2[%i] : memref<?xi16>, vector<[8]xi16> | |
llvm.call @printOpen() : () -> () | |
scf.for %i2 = %c0 to %svl_h step %c1_index { | |
%elem = vector.extractelement %av[%i2 : index] : vector<[8]xi16> | |
%elem_i64 = llvm.zext %elem : i16 to i64 | |
llvm.call @printI64(%elem_i64) : (i64) -> () | |
%last_i = arith.subi %svl_h, %c1_index : index | |
%isNotLastIter = arith.cmpi ult, %i2, %last_i : index | |
scf.if %isNotLastIter { | |
llvm.call @printComma() : () -> () | |
} | |
} | |
llvm.call @printClose() : () -> () | |
llvm.call @printNewline() : () -> () | |
} | |
return | |
} | |
func.func @za0h_f16() { | |
%c0 = arith.constant 0 : index | |
%c0_f16 = arith.constant 0.0 : f16 | |
%c1_f16 = arith.constant 1.0 : f16 | |
%c1_index = arith.constant 1 : index | |
%c8 = arith.constant 8 : index | |
%vscale = vector.vscale | |
// "svl" refers to the Streaming Vector Length and "svl_h" the number of | |
// 32-bit elements in a vector of SVL bits. | |
%svl_h = arith.muli %c8, %vscale : index | |
// Allocate memory and fill each "row" with row number. | |
%size = arith.muli %svl_h, %svl_h : index | |
%mem1 = memref.alloca(%size) : memref<?xf16> | |
%init_0 = arith.constant 0.1 : f16 | |
scf.for %i = %c0 to %size step %svl_h iter_args(%val = %init_0) -> (f16) { | |
%splat_val = vector.broadcast %val : f16 to vector<[8]xf16> | |
vector.store %splat_val, %mem1[%i] : memref<?xf16>, vector<[8]xf16> | |
%val_next = arith.addf %val, %c1_f16 : f16 | |
%av = vector.load %mem1[%i] : memref<?xf16>, vector<[8]xf16> | |
llvm.call @printOpen() : () -> () | |
scf.for %i2 = %c0 to %svl_h step %c1_index { | |
%elem = vector.extractelement %av[%i2 : index] : vector<[8]xf16> | |
%elem_f64 = llvm.fpext %elem : f16 to f64 | |
llvm.call @printF64(%elem_f64) : (f64) -> () | |
%last_i = arith.subi %svl_h, %c1_index : index | |
%isNotLastIter = arith.cmpi ult, %i2, %last_i : index | |
scf.if %isNotLastIter { | |
llvm.call @printComma() : () -> () | |
} | |
} | |
llvm.call @printClose() : () -> () | |
llvm.call @printNewline() : () -> () | |
scf.yield %val_next : f16 | |
} | |
// Load ZA0.H from memory | |
%za0_h = vector.load %mem1[%c0] : memref<?xf16>, vector<[8]x[8]xf16> | |
// Allocate new memory to store ZA0.H to | |
%mem2 = memref.alloca(%size) : memref<?xf16> | |
// Zero new memory | |
scf.for %i = %c0 to %size step %c1_index { | |
memref.store %c0_f16, %mem2[%i] : memref<?xf16> | |
} | |
// Dump zeroed memory | |
scf.for %i = %c0 to %size step %svl_h { | |
%av = vector.load %mem2[%i] : memref<?xf16>, vector<[8]xf16> | |
llvm.call @printOpen() : () -> () | |
scf.for %i2 = %c0 to %svl_h step %c1_index { | |
%elem = vector.extractelement %av[%i2 : index] : vector<[8]xf16> | |
%elem_f64 = llvm.fpext %elem : f16 to f64 | |
llvm.call @printF64(%elem_f64) : (f64) -> () | |
%last_i = arith.subi %svl_h, %c1_index : index | |
%isNotLastIter = arith.cmpi ult, %i2, %last_i : index | |
scf.if %isNotLastIter { | |
llvm.call @printComma() : () -> () | |
} | |
} | |
llvm.call @printClose() : () -> () | |
llvm.call @printNewline() : () -> () | |
} | |
// Store ZA0.H to memory | |
vector.store %za0_h, %mem2[%c0] : memref<?xf16>, vector<[8]x[8]xf16> | |
// Dump memory after storing ZA0.H | |
scf.for %i = %c0 to %size step %svl_h { | |
%av = vector.load %mem2[%i] : memref<?xf16>, vector<[8]xf16> | |
llvm.call @printOpen() : () -> () | |
scf.for %i2 = %c0 to %svl_h step %c1_index { | |
%elem = vector.extractelement %av[%i2 : index] : vector<[8]xf16> | |
%elem_f64 = llvm.fpext %elem : f16 to f64 | |
llvm.call @printF64(%elem_f64) : (f64) -> () | |
%last_i = arith.subi %svl_h, %c1_index : index | |
%isNotLastIter = arith.cmpi ult, %i2, %last_i : index | |
scf.if %isNotLastIter { | |
llvm.call @printComma() : () -> () | |
} | |
} | |
llvm.call @printClose() : () -> () | |
llvm.call @printNewline() : () -> () | |
} | |
return | |
} | |
func.func @za0s() { | |
%c0 = arith.constant 0 : index | |
%c0_i32 = arith.constant 0 : i32 | |
%c1_i32 = arith.constant 1 : i32 | |
%c1_index = arith.constant 1 : index | |
%c4 = arith.constant 4 : index | |
%vscale = vector.vscale | |
// "svl" refers to the Streaming Vector Length and "svl_h" the number of | |
// 32-bit elements in a vector of SVL bits. | |
%svl_s = arith.muli %c4, %vscale : index | |
// Allocate memory and fill each "row" with row number. | |
%size = arith.muli %svl_s, %svl_s : index | |
%mem1 = memref.alloca(%size) : memref<?xi32> | |
%init_0 = arith.constant 0 : i32 | |
scf.for %i = %c0 to %size step %svl_s iter_args(%val = %init_0) -> (i32) { | |
%splat_val = vector.broadcast %val : i32 to vector<[4]xi32> | |
vector.store %splat_val, %mem1[%i] : memref<?xi32>, vector<[4]xi32> | |
%val_next = arith.addi %val, %c1_i32 : i32 | |
%av = vector.load %mem1[%i] : memref<?xi32>, vector<[4]xi32> | |
llvm.call @printOpen() : () -> () | |
scf.for %i2 = %c0 to %svl_s step %c1_index { | |
%elem = vector.extractelement %av[%i2 : index] : vector<[4]xi32> | |
%elem_i64 = llvm.zext %elem : i32 to i64 | |
llvm.call @printI64(%elem_i64) : (i64) -> () | |
%last_i = arith.subi %svl_s, %c1_index : index | |
%isNotLastIter = arith.cmpi ult, %i2, %last_i : index | |
scf.if %isNotLastIter { | |
llvm.call @printComma() : () -> () | |
} | |
} | |
llvm.call @printClose() : () -> () | |
llvm.call @printNewline() : () -> () | |
scf.yield %val_next : i32 | |
} | |
// Load ZA0.S from memory | |
%za0_h = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32> | |
// Allocate new memory to store ZA0.S to | |
%mem2 = memref.alloca(%size) : memref<?xi32> | |
// Zero new memory | |
scf.for %i = %c0 to %size step %c1_index { | |
memref.store %c0_i32, %mem2[%i] : memref<?xi32> | |
} | |
// Dump zeroed memory | |
scf.for %i = %c0 to %size step %svl_s { | |
%av = vector.load %mem2[%i] : memref<?xi32>, vector<[4]xi32> | |
llvm.call @printOpen() : () -> () | |
scf.for %i2 = %c0 to %svl_s step %c1_index { | |
%elem = vector.extractelement %av[%i2 : index] : vector<[4]xi32> | |
%elem_i64 = llvm.zext %elem : i32 to i64 | |
llvm.call @printI64(%elem_i64) : (i64) -> () | |
%last_i = arith.subi %svl_s, %c1_index : index | |
%isNotLastIter = arith.cmpi ult, %i2, %last_i : index | |
scf.if %isNotLastIter { | |
llvm.call @printComma() : () -> () | |
} | |
} | |
llvm.call @printClose() : () -> () | |
llvm.call @printNewline() : () -> () | |
} | |
// Store ZA0.S to memory | |
vector.store %za0_h, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32> | |
// Dump memory after storing ZA0.S | |
scf.for %i = %c0 to %size step %svl_s { | |
%av = vector.load %mem2[%i] : memref<?xi32>, vector<[4]xi32> | |
llvm.call @printOpen() : () -> () | |
scf.for %i2 = %c0 to %svl_s step %c1_index { | |
%elem = vector.extractelement %av[%i2 : index] : vector<[4]xi32> | |
%elem_i64 = llvm.zext %elem : i32 to i64 | |
llvm.call @printI64(%elem_i64) : (i64) -> () | |
%last_i = arith.subi %svl_s, %c1_index : index | |
%isNotLastIter = arith.cmpi ult, %i2, %last_i : index | |
scf.if %isNotLastIter { | |
llvm.call @printComma() : () -> () | |
} | |
} | |
llvm.call @printClose() : () -> () | |
llvm.call @printNewline() : () -> () | |
} | |
return | |
} | |
func.func @za0s_f32() { | |
%c0 = arith.constant 0 : index | |
%c0_f32 = arith.constant 0.0 : f32 | |
%c1_f32 = arith.constant 1.0 : f32 | |
%c1_index = arith.constant 1 : index | |
%c4 = arith.constant 4 : index | |
%vscale = vector.vscale | |
// "svl" refers to the Streaming Vector Length and "svl_h" the number of | |
// 32-bit elements in a vector of SVL bits. | |
%svl_s = arith.muli %c4, %vscale : index | |
// Allocate memory and fill each "row" with row number. | |
%size = arith.muli %svl_s, %svl_s : index | |
%mem1 = memref.alloca(%size) : memref<?xf32> | |
%init_0 = arith.constant 0.1 : f32 | |
scf.for %i = %c0 to %size step %svl_s iter_args(%val = %init_0) -> (f32) { | |
%splat_val = vector.broadcast %val : f32 to vector<[4]xf32> | |
vector.store %splat_val, %mem1[%i] : memref<?xf32>, vector<[4]xf32> | |
%val_next = arith.addf %val, %c1_f32 : f32 | |
%av = vector.load %mem1[%i] : memref<?xf32>, vector<[4]xf32> | |
llvm.call @printOpen() : () -> () | |
scf.for %i2 = %c0 to %svl_s step %c1_index { | |
%elem = vector.extractelement %av[%i2 : index] : vector<[4]xf32> | |
%elem_f64 = llvm.fpext %elem : f32 to f64 | |
llvm.call @printF64(%elem_f64) : (f64) -> () | |
%last_i = arith.subi %svl_s, %c1_index : index | |
%isNotLastIter = arith.cmpi ult, %i2, %last_i : index | |
scf.if %isNotLastIter { | |
llvm.call @printComma() : () -> () | |
} | |
} | |
llvm.call @printClose() : () -> () | |
llvm.call @printNewline() : () -> () | |
scf.yield %val_next : f32 | |
} | |
// Load ZA0.S from memory | |
%za0_h = vector.load %mem1[%c0] : memref<?xf32>, vector<[4]x[4]xf32> | |
// Allocate new memory to store ZA0.S to | |
%mem2 = memref.alloca(%size) : memref<?xf32> | |
// Zero new memory | |
scf.for %i = %c0 to %size step %c1_index { | |
memref.store %c0_f32, %mem2[%i] : memref<?xf32> | |
} | |
// Dump zeroed memory | |
scf.for %i = %c0 to %size step %svl_s { | |
%av = vector.load %mem2[%i] : memref<?xf32>, vector<[4]xf32> | |
llvm.call @printOpen() : () -> () | |
scf.for %i2 = %c0 to %svl_s step %c1_index { | |
%elem = vector.extractelement %av[%i2 : index] : vector<[4]xf32> | |
%elem_f64 = llvm.fpext %elem : f32 to f64 | |
llvm.call @printF64(%elem_f64) : (f64) -> () | |
%last_i = arith.subi %svl_s, %c1_index : index | |
%isNotLastIter = arith.cmpi ult, %i2, %last_i : index | |
scf.if %isNotLastIter { | |
llvm.call @printComma() : () -> () | |
} | |
} | |
llvm.call @printClose() : () -> () | |
llvm.call @printNewline() : () -> () | |
} | |
// Store ZA0.S to memory | |
vector.store %za0_h, %mem2[%c0] : memref<?xf32>, vector<[4]x[4]xf32> | |
// Dump memory after storing ZA0.S | |
scf.for %i = %c0 to %size step %svl_s { | |
%av = vector.load %mem2[%i] : memref<?xf32>, vector<[4]xf32> | |
llvm.call @printOpen() : () -> () | |
scf.for %i2 = %c0 to %svl_s step %c1_index { | |
%elem = vector.extractelement %av[%i2 : index] : vector<[4]xf32> | |
%elem_f64 = llvm.fpext %elem : f32 to f64 | |
llvm.call @printF64(%elem_f64) : (f64) -> () | |
%last_i = arith.subi %svl_s, %c1_index : index | |
%isNotLastIter = arith.cmpi ult, %i2, %last_i : index | |
scf.if %isNotLastIter { | |
llvm.call @printComma() : () -> () | |
} | |
} | |
llvm.call @printClose() : () -> () | |
llvm.call @printNewline() : () -> () | |
} | |
return | |
} | |
func.func @za0d() { | |
%c0 = arith.constant 0 : index | |
%c0_i64 = arith.constant 0 : i64 | |
%c1_i64 = arith.constant 1 : i64 | |
%c1_index = arith.constant 1 : index | |
%c2 = arith.constant 2 : index | |
%vscale = vector.vscale | |
// "svl" refers to the Streaming Vector Length and "svl_d" the number of | |
// 64-bit elements in a vector of SVL bits. | |
%svl_d = arith.muli %c2, %vscale : index | |
// Allocate memory and fill each "row" with row number. | |
%size = arith.muli %svl_d, %svl_d : index | |
%mem1 = memref.alloca(%size) : memref<?xi64> | |
%init_0 = arith.constant 0 : i64 | |
scf.for %i = %c0 to %size step %svl_d iter_args(%val = %init_0) -> (i64) { | |
%splat_val = vector.broadcast %val : i64 to vector<[2]xi64> | |
vector.store %splat_val, %mem1[%i] : memref<?xi64>, vector<[2]xi64> | |
%val_next = arith.addi %val, %c1_i64 : i64 | |
%av = vector.load %mem1[%i] : memref<?xi64>, vector<[2]xi64> | |
llvm.call @printOpen() : () -> () | |
scf.for %i2 = %c0 to %svl_d step %c1_index { | |
%elem = vector.extractelement %av[%i2 : index] : vector<[2]xi64> | |
llvm.call @printI64(%elem) : (i64) -> () | |
%last_i = arith.subi %svl_d, %c1_index : index | |
%isNotLastIter = arith.cmpi ult, %i2, %last_i : index | |
scf.if %isNotLastIter { | |
llvm.call @printComma() : () -> () | |
} | |
} | |
llvm.call @printClose() : () -> () | |
llvm.call @printNewline() : () -> () | |
scf.yield %val_next : i64 | |
} | |
// Load ZA0.D from memory | |
%za0_d = vector.load %mem1[%c0] : memref<?xi64>, vector<[2]x[2]xi64> | |
// Allocate new memory to store ZA0.D to | |
%mem2 = memref.alloca(%size) : memref<?xi64> | |
// Zero new memory | |
scf.for %i = %c0 to %size step %c1_index { | |
memref.store %c0_i64, %mem2[%i] : memref<?xi64> | |
} | |
// Dump zeroed memory | |
scf.for %i = %c0 to %size step %svl_d { | |
%av = vector.load %mem2[%i] : memref<?xi64>, vector<[2]xi64> | |
llvm.call @printOpen() : () -> () | |
scf.for %i2 = %c0 to %svl_d step %c1_index { | |
%elem = vector.extractelement %av[%i2 : index] : vector<[2]xi64> | |
llvm.call @printI64(%elem) : (i64) -> () | |
%last_i = arith.subi %svl_d, %c1_index : index | |
%isNotLastIter = arith.cmpi ult, %i2, %last_i : index | |
scf.if %isNotLastIter { | |
llvm.call @printComma() : () -> () | |
} | |
} | |
llvm.call @printClose() : () -> () | |
llvm.call @printNewline() : () -> () | |
} | |
// Store ZA0.D to memory | |
vector.store %za0_d, %mem2[%c0] : memref<?xi64>, vector<[2]x[2]xi64> | |
// Dump memory after storing ZA0.D | |
scf.for %i = %c0 to %size step %svl_d { | |
%av = vector.load %mem2[%i] : memref<?xi64>, vector<[2]xi64> | |
llvm.call @printOpen() : () -> () | |
scf.for %i2 = %c0 to %svl_d step %c1_index { | |
%elem = vector.extractelement %av[%i2 : index] : vector<[2]xi64> | |
llvm.call @printI64(%elem) : (i64) -> () | |
%last_i = arith.subi %svl_d, %c1_index : index | |
%isNotLastIter = arith.cmpi ult, %i2, %last_i : index | |
scf.if %isNotLastIter { | |
llvm.call @printComma() : () -> () | |
} | |
} | |
llvm.call @printClose() : () -> () | |
llvm.call @printNewline() : () -> () | |
} | |
return | |
} | |
func.func @za0d_f64() { | |
%c0 = arith.constant 0 : index | |
%c0_f64 = arith.constant 0.0 : f64 | |
%c1_f64 = arith.constant 1.0 : f64 | |
%c1_index = arith.constant 1 : index | |
%c2 = arith.constant 2 : index | |
%vscale = vector.vscale | |
// "svl" refers to the Streaming Vector Length and "svl_d" the number of | |
// 64-bit elements in a vector of SVL bits. | |
%svl_d = arith.muli %c2, %vscale : index | |
// Allocate memory and fill each "row" with row number. | |
%size = arith.muli %svl_d, %svl_d : index | |
%mem1 = memref.alloca(%size) : memref<?xf64> | |
%init_0 = arith.constant 0.0 : f64 | |
scf.for %i = %c0 to %size step %svl_d iter_args(%val = %init_0) -> (f64) { | |
%splat_val = vector.broadcast %val : f64 to vector<[2]xf64> | |
vector.store %splat_val, %mem1[%i] : memref<?xf64>, vector<[2]xf64> | |
%val_next = arith.addf %val, %c1_f64 : f64 | |
%av = vector.load %mem1[%i] : memref<?xf64>, vector<[2]xf64> | |
llvm.call @printOpen() : () -> () | |
scf.for %i2 = %c0 to %svl_d step %c1_index { | |
%elem = vector.extractelement %av[%i2 : index] : vector<[2]xf64> | |
llvm.call @printF64(%elem) : (f64) -> () | |
%last_i = arith.subi %svl_d, %c1_index : index | |
%isNotLastIter = arith.cmpi ult, %i2, %last_i : index | |
scf.if %isNotLastIter { | |
llvm.call @printComma() : () -> () | |
} | |
} | |
llvm.call @printClose() : () -> () | |
llvm.call @printNewline() : () -> () | |
scf.yield %val_next : f64 | |
} | |
// Load ZA0.D from memory | |
%za0_d = vector.load %mem1[%c0] : memref<?xf64>, vector<[2]x[2]xf64> | |
// Allocate new memory to store ZA0.D to | |
%mem2 = memref.alloca(%size) : memref<?xf64> | |
// Zero new memory | |
scf.for %i = %c0 to %size step %c1_index { | |
memref.store %c0_f64, %mem2[%i] : memref<?xf64> | |
} | |
// Dump zeroed memory | |
scf.for %i = %c0 to %size step %svl_d { | |
%av = vector.load %mem2[%i] : memref<?xf64>, vector<[2]xf64> | |
llvm.call @printOpen() : () -> () | |
scf.for %i2 = %c0 to %svl_d step %c1_index { | |
%elem = vector.extractelement %av[%i2 : index] : vector<[2]xf64> | |
llvm.call @printF64(%elem) : (f64) -> () | |
%last_i = arith.subi %svl_d, %c1_index : index | |
%isNotLastIter = arith.cmpi ult, %i2, %last_i : index | |
scf.if %isNotLastIter { | |
llvm.call @printComma() : () -> () | |
} | |
} | |
llvm.call @printClose() : () -> () | |
llvm.call @printNewline() : () -> () | |
} | |
// Store ZA0.D to memory | |
vector.store %za0_d, %mem2[%c0] : memref<?xf64>, vector<[2]x[2]xf64> | |
// Dump memory after storing ZA0.D | |
scf.for %i = %c0 to %size step %svl_d { | |
%av = vector.load %mem2[%i] : memref<?xf64>, vector<[2]xf64> | |
llvm.call @printOpen() : () -> () | |
scf.for %i2 = %c0 to %svl_d step %c1_index { | |
%elem = vector.extractelement %av[%i2 : index] : vector<[2]xf64> | |
llvm.call @printF64(%elem) : (f64) -> () | |
%last_i = arith.subi %svl_d, %c1_index : index | |
%isNotLastIter = arith.cmpi ult, %i2, %last_i : index | |
scf.if %isNotLastIter { | |
llvm.call @printComma() : () -> () | |
} | |
} | |
llvm.call @printClose() : () -> () | |
llvm.call @printNewline() : () -> () | |
} | |
return | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment