Created
January 10, 2022 14:18
-
-
Save zhanghb97/db87cd22d330ba6424b31c70b135b0ca to your computer and use it in GitHub Desktop.
Test RVV Strip-Mining
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
using namespace std; | |
// Define Memref Descriptor. | |
template<class T> | |
struct MemRef_descriptor_ { | |
T *allocated; | |
T *aligned; | |
intptr_t offset; | |
intptr_t sizes[1]; | |
intptr_t strides[1]; | |
}; | |
// Constructor | |
template<class T> | |
struct MemRef_descriptor_<T> *MemRef_Descriptor(T *allocated, T *aligned, | |
intptr_t offset, intptr_t sizes[1], intptr_t strides[1]) { | |
struct MemRef_descriptor_<T> *n = (struct MemRef_descriptor_<T> *)malloc(sizeof(*n)); | |
n->allocated = allocated; | |
n->aligned = aligned; | |
n->offset = offset; | |
for (int i = 0; i < 1; i++) | |
n->sizes[i] = sizes[i]; | |
for (int j = 0; j < 1; j++) | |
n->strides[j] = strides[j]; | |
return n; | |
} | |
typedef struct MemRef_descriptor_<int> *MemRef_descriptor_i32; | |
// Declare the interface. | |
extern "C" { | |
void _mlir_ciface_riscvv_loop_stripmining(MemRef_descriptor_i32 mem); | |
} | |
int main(int argc, char *argv[]) { | |
int input[20]; | |
for (int i = 0; i < 20; i++) | |
input[i] = i; | |
intptr_t sizes[1] = {20}; | |
intptr_t strides[1] = {20}; | |
MemRef_descriptor_i32 mem = MemRef_Descriptor<int>(input, input, 0, sizes, strides); | |
_mlir_ciface_riscvv_loop_stripmining(mem); | |
printf("[ "); | |
for (int i = 0; i < 20; ++i) { | |
printf("%d ", mem->aligned[i]); | |
} | |
printf("]\n"); | |
return 0; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
func @riscvv_loop_stripmining(%m: memref<?xi32>) { | |
%c0_idx = arith.constant 0 : index | |
%init_avl = memref.dim %m, %c0_idx : memref<?xi32> | |
%init_idx = arith.constant 0 : index | |
%c2_i32 = arith.constant 2 : i32 | |
%sew = arith.constant 32 : index | |
%lmul = arith.constant 2 : index | |
// While loop. | |
%a1, %a2 = scf.while (%avl = %init_avl, %idx = %init_idx) : (index, index) -> (index, index) { | |
// If avl greater than zero. | |
%cond = arith.cmpi sgt, %avl, %c0_idx : index | |
// Pass avl, idx to the after region. | |
scf.condition(%cond) %avl, %idx : index, index | |
} do { | |
^bb0(%avl : index, %idx : index): | |
// Perform the calculation according to the vl. | |
%vl = riscvv.setvl %avl, %sew, %lmul : index | |
%input_vector = riscvv.load %m[%idx], %vl : memref<?xi32>, vector<[8]xi32>, index | |
%result_vector = riscvv.add %input_vector, %c2_i32, %vl : vector<[8]xi32>, i32, index | |
riscvv.store %result_vector, %m[%idx], %vl : vector<[8]xi32>, memref<?xi32>, index | |
// Update idx and avl. | |
%new_idx = arith.addi %idx, %vl : index | |
%new_avl = arith.subi %avl, %vl : index | |
scf.yield %new_avl, %new_idx : index, index | |
} | |
return | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment