Skip to content

Instantly share code, notes, and snippets.

@zhanghb97
Created January 10, 2022 14:18
Show Gist options
  • Save zhanghb97/db87cd22d330ba6424b31c70b135b0ca to your computer and use it in GitHub Desktop.
Save zhanghb97/db87cd22d330ba6424b31c70b135b0ca to your computer and use it in GitHub Desktop.
Test RVV Strip-Mining
#include <iostream>
using namespace std;
// Define Memref Descriptor.
template<class T>
struct MemRef_descriptor_ {
T *allocated;
T *aligned;
intptr_t offset;
intptr_t sizes[1];
intptr_t strides[1];
};
// Constructor
template<class T>
struct MemRef_descriptor_<T> *MemRef_Descriptor(T *allocated, T *aligned,
intptr_t offset, intptr_t sizes[1], intptr_t strides[1]) {
struct MemRef_descriptor_<T> *n = (struct MemRef_descriptor_<T> *)malloc(sizeof(*n));
n->allocated = allocated;
n->aligned = aligned;
n->offset = offset;
for (int i = 0; i < 1; i++)
n->sizes[i] = sizes[i];
for (int j = 0; j < 1; j++)
n->strides[j] = strides[j];
return n;
}
typedef struct MemRef_descriptor_<int> *MemRef_descriptor_i32;
// Declare the interface.
extern "C" {
void _mlir_ciface_riscvv_loop_stripmining(MemRef_descriptor_i32 mem);
}
int main(int argc, char *argv[]) {
int input[20];
for (int i = 0; i < 20; i++)
input[i] = i;
intptr_t sizes[1] = {20};
intptr_t strides[1] = {20};
MemRef_descriptor_i32 mem = MemRef_Descriptor<int>(input, input, 0, sizes, strides);
_mlir_ciface_riscvv_loop_stripmining(mem);
printf("[ ");
for (int i = 0; i < 20; ++i) {
printf("%d ", mem->aligned[i]);
}
printf("]\n");
return 0;
}
func @riscvv_loop_stripmining(%m: memref<?xi32>) {
%c0_idx = arith.constant 0 : index
%init_avl = memref.dim %m, %c0_idx : memref<?xi32>
%init_idx = arith.constant 0 : index
%c2_i32 = arith.constant 2 : i32
%sew = arith.constant 32 : index
%lmul = arith.constant 2 : index
// While loop.
%a1, %a2 = scf.while (%avl = %init_avl, %idx = %init_idx) : (index, index) -> (index, index) {
// If avl greater than zero.
%cond = arith.cmpi sgt, %avl, %c0_idx : index
// Pass avl, idx to the after region.
scf.condition(%cond) %avl, %idx : index, index
} do {
^bb0(%avl : index, %idx : index):
// Perform the calculation according to the vl.
%vl = riscvv.setvl %avl, %sew, %lmul : index
%input_vector = riscvv.load %m[%idx], %vl : memref<?xi32>, vector<[8]xi32>, index
%result_vector = riscvv.add %input_vector, %c2_i32, %vl : vector<[8]xi32>, i32, index
riscvv.store %result_vector, %m[%idx], %vl : vector<[8]xi32>, memref<?xi32>, index
// Update idx and avl.
%new_idx = arith.addi %idx, %vl : index
%new_avl = arith.subi %avl, %vl : index
scf.yield %new_avl, %new_idx : index, index
}
return
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment