Skip to content

Instantly share code, notes, and snippets.

@dougallj

dougallj/amx.h Secret

Created December 26, 2020 02:05
Show Gist options
  • Save dougallj/7cba721da1a94da725ee37c1e9cd1f21 to your computer and use it in GitHub Desktop.
Save dougallj/7cba721da1a94da725ee37c1e9cd1f21 to your computer and use it in GitHub Desktop.
amx simulator and hardware tests
#pragma once
#include <stdint.h>
// TODO: is it possible to not go via x0? I'm guessing not without messing with
// the compile process, but still, kinda ugly. might at least be possible to
// force the compiler to get the value in x0 itself
// TODO: do I need memory as an input?
#define AMX_LDX(V) \
__asm__ volatile( \
"mov x0, %0 \r\n .word (0x201000 | (0 << 5) | 0)" ::"r"((uint64_t)V) \
: "x0", "memory")
#define AMX_LDY(V) \
__asm__ volatile( \
"mov x0, %0 \r\n .word (0x201000 | (1 << 5) | 0)" ::"r"((uint64_t)V) \
: "x0", "memory")
#define AMX_STX(V) \
__asm__ volatile( \
"mov x0, %0 \r\n .word (0x201000 | (2 << 5) | 0)" ::"r"((uint64_t)V) \
: "x0", "memory")
#define AMX_STY(V) \
__asm__ volatile( \
"mov x0, %0 \r\n .word (0x201000 | (3 << 5) | 0)" ::"r"((uint64_t)V) \
: "x0", "memory")
#define AMX_LDZ(V) \
__asm__ volatile( \
"mov x0, %0 \r\n .word (0x201000 | (4 << 5) | 0)" ::"r"((uint64_t)V) \
: "x0", "memory")
#define AMX_STZ(V) \
__asm__ volatile( \
"mov x0, %0 \r\n .word (0x201000 | (5 << 5) | 0)" ::"r"((uint64_t)V) \
: "x0", "memory")
#define AMX_LDZI(V) \
__asm__ volatile( \
"mov x0, %0 \r\n .word (0x201000 | (6 << 5) | 0)" ::"r"((uint64_t)V) \
: "x0", "memory")
#define AMX_STZI(V) \
__asm__ volatile( \
"mov x0, %0 \r\n .word (0x201000 | (7 << 5) | 0)" ::"r"((uint64_t)V) \
: "x0", "memory")
// TODO: probably shouldn't say these clobber memory?
#define AMX_EXTRX(V) \
__asm__ volatile( \
"mov x0, %0 \r\n .word (0x201000 | (8 << 5) | 0)" ::"r"((uint64_t)V) \
: "x0", "memory")
#define AMX_EXTRY(V) \
__asm__ volatile( \
"mov x0, %0 \r\n .word (0x201000 | (9 << 5) | 0)" ::"r"((uint64_t)V) \
: "x0", "memory")
#define AMX_FMA64(V) \
__asm__ volatile( \
"mov x0, %0 \r\n .word (0x201000 | (10 << 5) | 0)" ::"r"((uint64_t)V) \
: "x0", "memory")
#define AMX_FMS64(V) \
__asm__ volatile( \
"mov x0, %0 \r\n .word (0x201000 | (11 << 5) | 0)" ::"r"((uint64_t)V) \
: "x0", "memory")
#define AMX_FMA32(V) \
__asm__ volatile( \
"mov x0, %0 \r\n .word (0x201000 | (12 << 5) | 0)" ::"r"((uint64_t)V) \
: "x0", "memory")
#define AMX_FMS32(V) \
__asm__ volatile( \
"mov x0, %0 \r\n .word (0x201000 | (13 << 5) | 0)" ::"r"((uint64_t)V) \
: "x0", "memory")
#define AMX_MAC16(V) \
__asm__ volatile( \
"mov x0, %0 \r\n .word (0x201000 | (14 << 5) | 0)" ::"r"((uint64_t)V) \
: "x0", "memory")
#define AMX_FMA16(V) \
__asm__ volatile( \
"mov x0, %0 \r\n .word (0x201000 | (15 << 5) | 0)" ::"r"((uint64_t)V) \
: "x0", "memory")
#define AMX_FMS16(V) \
__asm__ volatile( \
"mov x0, %0 \r\n .word (0x201000 | (16 << 5) | 0)" ::"r"((uint64_t)V) \
: "x0", "memory")
#define AMX_START() \
__asm__ volatile( \
"nop \r\n nop \r\n nop \r\n .word (0x201000 | (17 << 5) | 0)" :: \
: "memory")
#define AMX_STOP() \
__asm__ volatile( \
"nop \r\n nop \r\n nop \r\n .word (0x201000 | (17 << 5) | 1)" :: \
: "memory")
// horizontal multiply uint16_ts? (doesn't mac16 have a flag for this?)
// z0[i] += x0[i] + y0[i]
#define AMX_VECINT(V) \
__asm__ volatile( \
"mov x0, %0 \r\n .word (0x201000 | (18 << 5) | 0)" ::"r"((uint64_t)V) \
: "x0", "memory")
// horizontal multiply float16_ts? (doesn't fma16 have a flag for this?)
// z0[i] += x0[i] + y0[i]
#define AMX_VECFP(V) \
__asm__ volatile( \
"mov x0, %0 \r\n .word (0x201000 | (19 << 5) | 0)" ::"r"((uint64_t)V) \
: "x0", "memory")
// uint16_t matrix multiply? (doesn't mac16 do this?)
#define AMX_MATINT(V) \
__asm__ volatile( \
"mov x0, %0 \r\n .word (0x201000 | (20 << 5) | 0)" ::"r"((uint64_t)V) \
: "x0", "memory")
// float16_t matrix multiply? (doesn't fma16 do this?)
#define AMX_MATFP(V) \
__asm__ volatile( \
"mov x0, %0 \r\n .word (0x201000 | (21 << 5) | 0)" ::"r"((uint64_t)V) \
: "x0", "memory")
// looks only at z0, clears it, and generates a 64-bit value in x0[0]:
// [15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0] -> 0xffffffffffffffff
// [0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1] -> 0xf0
// [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] -> 0xfedcba9876543210
// [0x0, 0x10000000, 0x20000000, 0x30000000, 0x40000000, 0x50000000, 0x60000000,
// 0x70000000, 0x80000000, 0x90000000, 0xA0000000, 0xB0000000, 0xC0000000,
// 0xD0000000, 0xE0000000, 0xF0000000] -> fffffff0f6543210
#define AMX_GENLUT(V) \
__asm__ volatile( \
"mov x0, %0 \r\n .word (0x201000 | (22 << 5) | 0)" ::"r"((uint64_t)V) \
: "x0", "memory")
typedef _Float16 float16;
union amx_row {
// not all supported types, just useful ones
uint8_t u8[64];
uint16_t u16[32];
uint32_t u32[16];
uint64_t u64[8];
float16 f16[32];
float f32[16];
double f64[8];
};
struct amx_state {
union amx_row x[8];
union amx_row y[8];
union amx_row z[64];
};
void store_amx_state(struct amx_state *state) {
memset(state, 0xAA, sizeof *state);
for (uint64_t i = 0; i < 8; i++) {
AMX_STX((i << 56) | (uint64_t)&state->x[i]);
}
for (uint64_t i = 0; i < 8; i++) {
AMX_STY((i << 56) | (uint64_t)&state->y[i]);
}
for (uint64_t i = 0; i < 64; i++) {
AMX_STZ((i << 56) | (uint64_t)&state->z[i]);
}
}
void load_amx_state(struct amx_state *state) {
for (uint64_t i = 0; i < 8; i++) {
AMX_LDX((i << 56) | (uint64_t)&state->x[i]);
}
for (uint64_t i = 0; i < 8; i++) {
AMX_LDY((i << 56) | (uint64_t)&state->y[i]);
}
for (uint64_t i = 0; i < 64; i++) {
AMX_LDZ((i << 56) | (uint64_t)&state->z[i]);
}
}
#include <assert.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <string.h>
#include "amx.h"
#include "simulator.h"
static int check_state(struct amx_state *sim_state, int flags) {
__attribute__((aligned(0x80))) static struct amx_state state;
store_amx_state(&state);
if (memcmp(sim_state, &state, sizeof state)) {
printf("state mismatch!\n");
printf("real state:\n");
diff_amx_state(&state, sim_state, flags | PF_SKIP_B);
printf("simulated state:\n");
diff_amx_state(&state, sim_state, flags | PF_SKIP_A);
return 0;
}
return 1;
}
bool is_prime(uint64_t v) {
if (v < 2) {
return false;
}
for (uint64_t i = 2; i * i <= v; i++) {
if (v % i == 0) {
return false;
}
}
return true;
}
static void init_test_f32s(float *data) {
for (uint64_t v = 1000, o = 0; o < (8 + 8) * 16; v++) {
if (is_prime(v)) {
data[o++] = v;
}
}
int i = 1;
for (uint64_t o = (8 + 8) * 16; o < (8 + 8 + 64) * 16; o++) {
data[o] = i++;
}
}
static void test_start(struct amx_state *sim_state) {
amx_state_zero(sim_state);
AMX_START();
}
static void test_stop(struct amx_state *sim_state) { AMX_STOP(); }
// LDX/LDY/LDZ/LDZI
static void test_ldx(struct amx_state *sim_state, uint64_t op) {
amx_state_ldx(sim_state, op);
AMX_LDX(op);
check_state(sim_state, PF_F32);
}
static void test_ldy(struct amx_state *sim_state, uint64_t op) {
amx_state_ldy(sim_state, op);
AMX_LDY(op);
check_state(sim_state, PF_F32);
}
static void test_ldz(struct amx_state *sim_state, uint64_t op) {
amx_state_ldz(sim_state, op);
AMX_LDZ(op);
check_state(sim_state, PF_F32);
}
static void test_ldzi(struct amx_state *sim_state, uint64_t op) {
amx_state_ldzi(sim_state, op);
AMX_LDZI(op);
check_state(sim_state, PF_F32);
}
static void test_loads(void) {
__attribute__((aligned(0x80))) static float test_data[(8 + 8 + 64) * 16];
init_test_f32s(test_data);
struct amx_state sim_state;
// TODO: test alignment checks somehow
for (int test_bit = 56 - 1; test_bit < 64; test_bit++) {
uint64_t extra_bit = (test_bit == 55 ? 0 : (1ull << test_bit));
for (uint64_t reg = 0; reg < 8; reg++) {
test_start(&sim_state);
test_ldx(&sim_state, (uint64_t)test_data | (reg << 56) | extra_bit);
test_stop(&sim_state);
}
for (uint64_t reg = 0; reg < 8; reg++) {
test_start(&sim_state);
test_ldy(&sim_state, (uint64_t)test_data | (reg << 56) | extra_bit);
test_stop(&sim_state);
}
for (uint64_t reg = 0; reg < 64; reg++) {
test_start(&sim_state);
test_ldz(&sim_state, (uint64_t)test_data | (reg << 56) | extra_bit);
test_stop(&sim_state);
}
for (uint64_t reg = 0; reg < 64; reg++) {
test_start(&sim_state);
test_ldzi(&sim_state, (uint64_t)test_data | (reg << 56) | extra_bit);
test_stop(&sim_state);
}
}
}
// STX/STY/STZ/STZI
static void test_stores(void) {
struct amx_state sim_state;
static float test_data[(8 + 8 + 64) * 16];
init_test_f32s(test_data);
__attribute__((aligned(0x80))) static float store_buffer1[0x100];
__attribute__((aligned(0x80))) static float store_buffer2[0x100];
test_start(&sim_state);
memcpy(&sim_state, test_data, sizeof sim_state);
load_amx_state(&sim_state);
for (int test_bit = 56 - 1; test_bit < 64; test_bit++) {
uint64_t extra_bit = (test_bit == 55 ? 0 : (1ull << test_bit));
for (uint64_t reg = 0; reg < 8; reg++) {
memset(store_buffer1, 0, sizeof store_buffer1);
memset(store_buffer2, 0, sizeof store_buffer2);
AMX_STX((uint64_t)store_buffer1 | (reg << 56) | extra_bit);
amx_state_stx(&sim_state,
(uint64_t)store_buffer2 | (reg << 56) | extra_bit);
if (memcmp(store_buffer1, store_buffer2, sizeof store_buffer1)) {
printf("store test failed!\n");
}
}
for (uint64_t reg = 0; reg < 8; reg++) {
memset(store_buffer1, 0, sizeof store_buffer1);
memset(store_buffer2, 0, sizeof store_buffer2);
AMX_STY((uint64_t)store_buffer1 | (reg << 56) | extra_bit);
amx_state_sty(&sim_state,
(uint64_t)store_buffer2 | (reg << 56) | extra_bit);
if (memcmp(store_buffer1, store_buffer2, sizeof store_buffer1)) {
printf("store test failed!\n");
}
}
for (uint64_t reg = 0; reg < 64; reg++) {
memset(store_buffer1, 0, sizeof store_buffer1);
memset(store_buffer2, 0, sizeof store_buffer2);
AMX_STZ((uint64_t)store_buffer1 | (reg << 56) | extra_bit);
amx_state_stz(&sim_state,
(uint64_t)store_buffer2 | (reg << 56) | extra_bit);
if (memcmp(store_buffer1, store_buffer2, sizeof store_buffer1)) {
printf("store test failed!\n");
}
}
for (uint64_t reg = 0; reg < 64; reg++) {
memset(store_buffer1, 0, sizeof store_buffer1);
memset(store_buffer2, 0, sizeof store_buffer2);
AMX_STZI((uint64_t)store_buffer1 | (reg << 56) | extra_bit);
amx_state_stzi(&sim_state,
(uint64_t)store_buffer2 | (reg << 56) | extra_bit);
if (memcmp(store_buffer1, store_buffer2, sizeof store_buffer1)) {
printf("store test failed! 1\n");
}
}
}
test_stop(&sim_state);
}
// EXTRX
static void extrx_test(void *initial_state, uint64_t operand) {
struct amx_state sim_state;
test_start(&sim_state);
memcpy(&sim_state, initial_state, sizeof sim_state);
load_amx_state(&sim_state);
check_state(&sim_state, PF_F32);
AMX_EXTRX(operand);
amx_state_extrx(&sim_state, operand);
check_state(&sim_state, PF_U32);
test_stop(&sim_state);
}
void test_extrx(void) {
uint32_t initial_state[(8 + 8 + 64) * 16];
for (int i = 0; i < 8 * 16; i++) {
initial_state[i] = 0xAA0000 | i;
}
for (int i = 8 * 16; i < (8 + 8) * 16; i++) {
initial_state[i] = 0xBB00000 | i;
}
for (int i = (8 + 8) * 16; i < (8 + 8 + 64) * 16; i++) {
initial_state[i] = 0xCC000000 | i;
}
// TODO: bit 26
uint64_t mask = 0x800000001ff7fdc0 & ~(1 << 26);
// for perf, skip some offsets
mask &= ~(0x1FFull << 10);
mask |= 0x147ull << 10;
uint64_t operand = 0;
do {
extrx_test(initial_state, operand);
// ryg's texture tiling and swizzling loop
operand = (operand - mask) & mask;
} while (operand);
#define __amx_op8_to_x(o) extrx_test(initial_state, o)
// bit 26 test cases:
__amx_op8_to_x(0x4000000);
__amx_op8_to_x(0x4800000);
__amx_op8_to_x(0x5000000);
__amx_op8_to_x(0x5800000);
__amx_op8_to_x(0x6000000);
__amx_op8_to_x(0x6800000);
__amx_op8_to_x(0x7000000);
__amx_op8_to_x(0x7800000);
// some failing bit 26 test cases:
//__amx_op8_to_x(0x4110000);
//__amx_op8_to_x(0x4910000);
//__amx_op8_to_x(0x5110000);
//__amx_op8_to_x(0x5910000);
//__amx_op8_to_x(0x6110000);
//__amx_op8_to_x(0x6910000);
//__amx_op8_to_x(0x7110000);
//__amx_op8_to_x(0x7910000);
//__amx_op8_to_x(0x8000000004004500);
//__amx_op8_to_x(0x8000000004204000);
//__amx_op8_to_x(0x8000000004304000);
//__amx_op8_to_x(0x8000000004604000);
//__amx_op8_to_x(0x8000000004704000);
//__amx_op8_to_x(0x8000000004804000);
//__amx_op8_to_x(0x8000000004a04000);
//__amx_op8_to_x(0x8000000004b04000);
//__amx_op8_to_x(0x8000000004e04000);
//__amx_op8_to_x(0x8000000004f04000);
//__amx_op8_to_x(0x8000000005004040);
//__amx_op8_to_x(0x8000000005204000);
//__amx_op8_to_x(0x8000000005304000);
//__amx_op8_to_x(0x8000000005604000);
//__amx_op8_to_x(0x8000000005704000);
//__amx_op8_to_x(0x8000000005804080);
//__amx_op8_to_x(0x8000000005a04000);
//__amx_op8_to_x(0x8000000005b04000);
//__amx_op8_to_x(0x8000000005e04000);
//__amx_op8_to_x(0x8000000005f04000);
//__amx_op8_to_x(0x8000000006004540);
//__amx_op8_to_x(0x8000000006204040);
//__amx_op8_to_x(0x8000000006304040);
//__amx_op8_to_x(0x8000000006604040);
//__amx_op8_to_x(0x8000000006704040);
//__amx_op8_to_x(0x8000000006804040);
//__amx_op8_to_x(0x8000000006a04040);
//__amx_op8_to_x(0x8000000006b04040);
//__amx_op8_to_x(0x8000000006e04040);
//__amx_op8_to_x(0x8000000006f04040);
//__amx_op8_to_x(0x8000000007204040);
//__amx_op8_to_x(0x8000000007304040);
//__amx_op8_to_x(0x8000000007604040);
//__amx_op8_to_x(0x8000000007704040);
//__amx_op8_to_x(0x8000000007a04040);
//__amx_op8_to_x(0x8000000007b04040);
//__amx_op8_to_x(0x8000000007e04040);
//__amx_op8_to_x(0x8000000007f04000);
//__amx_op8_to_x(0x8000000007f04040);
// TODO: commented out things (and the rest, as some probably interact with
// other bits)
for (int i = 0; i < 41; i++) {
extrx_test(initial_state, (1ull << i));
}
// extrx_test(initial_state, (1ull << 41));
// extrx_test(initial_state, (1ull << 42));
// extrx_test(initial_state, (1ull << 43));
// extrx_test(initial_state, (1ull << 44));
// extrx_test(initial_state, (1ull << 45));
// extrx_test(initial_state, (1ull << 46));
for (int i = 47; i < 64; i++) {
extrx_test(initial_state, (1ull << i));
}
}
// EXTRY
static void extry_test(void *initial_state, uint64_t operand) {
struct amx_state sim_state;
test_start(&sim_state);
memcpy(&sim_state, initial_state, sizeof sim_state);
load_amx_state(&sim_state);
check_state(&sim_state, PF_F32);
AMX_EXTRY(operand);
amx_state_extry(&sim_state, operand);
check_state(&sim_state, PF_U32);
test_stop(&sim_state);
}
void test_extry(void) {
uint32_t initial_state[(8 + 8 + 64) * 16];
for (int i = 0; i < 8 * 16; i++) {
initial_state[i] = 0xAA0000 | i;
}
for (int i = 8 * 16; i < (8 + 8) * 16; i++) {
initial_state[i] = 0xBB00000 | i;
}
for (int i = (8 + 8) * 16; i < (8 + 8 + 64) * 16; i++) {
initial_state[i] = 0xCC000000 | i;
}
// still can't handle (1ull << 14) and some higher bits
// uint64_t mask = (1ull << 63) | (1ull << 29) | (1ull << 28) | (1ull << 27) |
// (1ull << 26) | (1ull << 13) | (1ull << 12) | (1ull << 11) | (1ull << 10) |
// 0x40 | 0x20 | 1 | (63<<20);
uint64_t mask = (0x800000003ff04dc0 & ~(1ull << 14)) | 1;
uint64_t operand = 0;
do {
extry_test(initial_state, operand);
operand = (operand - mask) & mask;
} while (operand);
// bit 14 tests that already pass:
extry_test(initial_state, 0x8000000004004000);
extry_test(initial_state, 0x8000000004104000);
extry_test(initial_state, 0x8000000004104040);
extry_test(initial_state, 0x8000000004204000);
extry_test(initial_state, 0x8000000004304000);
extry_test(initial_state, 0x8000000004404040);
extry_test(initial_state, 0x8000000004404100);
extry_test(initial_state, 0x8000000004504000);
extry_test(initial_state, 0x8000000004604000);
extry_test(initial_state, 0x8000000004704040);
extry_test(initial_state, 0x8000000004804080);
extry_test(initial_state, 0x8000000004804400);
extry_test(initial_state, 0x8000000004c040c0);
extry_test(initial_state, 0x8000000004c04500);
extry_test(initial_state, 0x8000000005004040);
extry_test(initial_state, 0x8000000005004100);
extry_test(initial_state, 0x8000000005404140);
extry_test(initial_state, 0x8000000005804180);
extry_test(initial_state, 0x8000000005804440);
extry_test(initial_state, 0x8000000005c041c0);
extry_test(initial_state, 0x8000000005c04540);
extry_test(initial_state, 0x8000000006004080);
extry_test(initial_state, 0x8000000006004400);
extry_test(initial_state, 0x8000000006404180);
extry_test(initial_state, 0x8000000006404440);
extry_test(initial_state, 0x8000000006804480);
extry_test(initial_state, 0x8000000006c044c0);
extry_test(initial_state, 0x8000000006c04580);
extry_test(initial_state, 0x80000000070040c0);
extry_test(initial_state, 0x8000000007004500);
extry_test(initial_state, 0x80000000074041c0);
extry_test(initial_state, 0x8000000007404540);
extry_test(initial_state, 0x80000000078044c0);
extry_test(initial_state, 0x8000000007804580);
extry_test(initial_state, 0x8000000007c045c0);
}
// FMA32
static void fma32_test(void *initial_state, uint64_t operand) {
struct amx_state sim_state;
test_start(&sim_state);
memcpy(&sim_state, initial_state, sizeof sim_state);
load_amx_state(&sim_state);
// check_state(&sim_state, PF_F32);
AMX_FMA32(operand);
amx_state_fma32(&sim_state, operand);
check_state(&sim_state, PF_F32);
test_stop(&sim_state);
}
static void fms32_test(void *initial_state, uint64_t operand) {
struct amx_state sim_state;
test_start(&sim_state);
memcpy(&sim_state, initial_state, sizeof sim_state);
load_amx_state(&sim_state);
// check_state(&sim_state, PF_F32);
AMX_FMS32(operand);
amx_state_fms32(&sim_state, operand);
check_state(&sim_state, PF_F32);
test_stop(&sim_state);
}
static void test_fma32_fms32(void) {
static float test_data[(8 + 8 + 64) * 16];
init_test_f32s(test_data);
// slow, but passes
// TODO: fields at bit 32/41/60
// uint64_t mask = 0xa000da4f3bf709f8 & ~((0x3Full << 32) | (0x3Full << 41) |
// (0x3ull << 60));
uint64_t offset_mask = 0x101; // optimised - thorough is 0x1FF
uint64_t mask = (1ull << 63) | (1ull << 27) | (1ull << 28) | (1ull << 29) |
(63 << 20) | (offset_mask << 10) | offset_mask;
uint64_t operand = 0;
do {
fma32_test(test_data, operand);
// TODO: this is almost working, but when we clear we have problems with
// negative zero
fms32_test(test_data, operand & ~(1ull << 27));
operand = (operand - mask) & mask;
} while (operand);
// TODO: commented out things (and the rest, as some probably interact with
// other bits)
for (int i = 0; i < 32; i++) {
fma32_test(test_data, (1ull << i));
}
// fma32_test(test_data, (1ull << 32));
// fma32_test(test_data, (1ull << 33));
// fma32_test(test_data, (1ull << 34));
// fma32_test(test_data, (1ull << 35));
// fma32_test(test_data, (1ull << 36));
// fma32_test(test_data, (1ull << 37));
fma32_test(test_data, (1ull << 38));
fma32_test(test_data, (1ull << 39));
fma32_test(test_data, (1ull << 40));
// fma32_test(test_data, (1ull << 41));
// fma32_test(test_data, (1ull << 42));
// fma32_test(test_data, (1ull << 43));
// fma32_test(test_data, (1ull << 44));
// fma32_test(test_data, (1ull << 45));
// fma32_test(test_data, (1ull << 46));
fma32_test(test_data, (1ull << 47));
fma32_test(test_data, (1ull << 48));
fma32_test(test_data, (1ull << 49));
fma32_test(test_data, (1ull << 50));
fma32_test(test_data, (1ull << 51));
fma32_test(test_data, (1ull << 52));
fma32_test(test_data, (1ull << 53));
fma32_test(test_data, (1ull << 54));
fma32_test(test_data, (1ull << 55));
fma32_test(test_data, (1ull << 56));
fma32_test(test_data, (1ull << 57));
fma32_test(test_data, (1ull << 58));
fma32_test(test_data, (1ull << 59));
// fma32_test(test_data, (1ull << 60));
// fma32_test(test_data, (1ull << 61));
fma32_test(test_data, (1ull << 62));
fma32_test(test_data, (1ull << 63));
}
// FMA64/FMS64
static void init_test_f64s(double *test) {
for (uint64_t v = 1000, o = 0; o < (8 + 8) * 8; v++) {
if (is_prime(v)) {
test[o++] = v;
}
}
for (uint64_t o = (8 + 8) * 8; o < (8 + 8 + 64) * 8; o++) {
test[o] = 1.0f;
}
}
static void fma64_test(void *initial_state, uint64_t operand) {
struct amx_state sim_state;
test_start(&sim_state);
memcpy(&sim_state, initial_state, sizeof sim_state);
load_amx_state(&sim_state);
AMX_FMA64(operand);
amx_state_fma64(&sim_state, operand);
check_state(&sim_state, PF_F64);
test_stop(&sim_state);
}
static void fms64_test(void *initial_state, uint64_t operand) {
struct amx_state sim_state;
test_start(&sim_state);
memcpy(&sim_state, initial_state, sizeof sim_state);
load_amx_state(&sim_state);
AMX_FMS64(operand);
amx_state_fms64(&sim_state, operand);
check_state(&sim_state, PF_F64);
test_stop(&sim_state);
}
static void test_fma64_fms64(void) {
static double test_data[(8 + 8 + 64) * 8];
init_test_f64s(test_data);
fma64_test(test_data, 0);
// slow, but passes
// uint64_t mask = 0xa000da4f3bf709f8 & ~((0x3Full << 32) | (0x3Full << 41));
uint64_t offset_mask = 0x101; // optimised - thorough is 0x1FF
uint64_t mask = (1ull << 63) | (1ull << 27) | (1ull << 28) | (1ull << 29) |
(63 << 20) | (offset_mask << 10) | offset_mask;
uint64_t operand = 0;
do {
fma64_test(test_data, operand);
// TODO: this is almost working, but when we clear we have problems with
// negative zero
fms64_test(test_data, operand & ~(1ull << 27));
// ryg's texture tiling and swizzling loop
operand = (operand - mask) & mask;
} while (operand);
}
// FMA16/FMS16
static void fma16_test(void *initial_state, uint64_t operand) {
struct amx_state sim_state;
test_start(&sim_state);
memcpy(&sim_state, initial_state, sizeof sim_state);
load_amx_state(&sim_state);
AMX_FMA16(operand);
amx_state_fma16(&sim_state, operand);
if (!check_state(&sim_state, PF_U16)) {
// check_state(&sim_state, PF_F16);
check_state(&sim_state, PF_F32);
printf("^ %llx\n\n", operand);
}
test_stop(&sim_state);
}
static void fms16_test(void *initial_state, uint64_t operand) {
struct amx_state sim_state;
test_start(&sim_state);
memcpy(&sim_state, initial_state, sizeof sim_state);
load_amx_state(&sim_state);
AMX_FMS16(operand);
amx_state_fms16(&sim_state, operand);
if (!check_state(&sim_state, PF_U16)) {
// check_state(&sim_state, PF_F16);
check_state(&sim_state, PF_F32);
printf("^ %llx\n\n", operand);
}
test_stop(&sim_state);
}
static void init_test_f16s(float16 *test) {
for (uint64_t v = 5, o = 0; o < (8 + 8) * 32; v++) {
if (is_prime(v)) {
test[o++] = v;
}
}
for (uint64_t o = (8 + 8) * 32; o < (8 + 8 + 64) * 32; o++) {
test[o] = 1.0f;
}
}
static void test_fma16_fms16(void) {
// TODO: There's something wrong with NaN accuracy - probably in the others as
// well but it shows up here if we let offset_mask be odd.
static float16 test_data[(8 + 8 + 64) * 32];
init_test_f16s(test_data);
uint64_t offset_mask = 0x102; // optimised - thorough is 0x1FF
uint64_t mask = (1ull << 63) | (1ull << 62) | (1ull << 27) | (1ull << 28) |
(1ull << 29) | (63 << 20) | (offset_mask << 10) | offset_mask;
uint64_t operand = 0;
do {
fma16_test(test_data, operand);
// TODO: this is almost working, but when we clear we have problems with
// negative zero
fms16_test(test_data, operand & ~(1ull << 27));
// ryg's texture tiling and swizzling loop
operand = (operand - mask) & mask;
} while (operand);
}
// MAC16
static void mac16_test(void *initial_state, uint64_t operand) {
struct amx_state sim_state;
test_start(&sim_state);
memcpy(&sim_state, initial_state, sizeof sim_state);
load_amx_state(&sim_state);
AMX_MAC16(operand);
amx_state_mac16(&sim_state, operand);
if (!check_state(&sim_state, PF_U16)) {
printf("^ %llx\n\n", operand);
}
test_stop(&sim_state);
}
static void init_test_u16s(uint16_t *test) {
for (uint64_t v = 0x8000 - 100, o = 0; o < (8 + 8) * 32; v++) {
if (is_prime(v)) {
test[o++] = v;
}
}
for (uint64_t o = (8 + 8) * 32; o < (8 + 8 + 64) * 32; o++) {
test[o] = 1;
}
}
static void test_mac16(void) {
static uint16_t test_data[(8 + 8 + 64) * 32];
init_test_u16s(test_data);
uint64_t offset_mask = 0x101; // optimised - thorough is 0x1FF
uint64_t mask = (1ull << 63) | (1ull << 62) | (1ull << 27) | (1ull << 28) |
(1ull << 29) | (63 << 20) | (offset_mask << 10) | offset_mask;
uint64_t operand = 0;
do {
mac16_test(test_data, operand);
operand = (operand - mask) & mask;
} while (operand);
}
int main(int argc, char **argv) {
// TODO: test stores
test_loads();
test_stores();
test_mac16();
test_fma16_fms16();
test_fma32_fms32();
test_fma64_fms64();
test_extrx();
test_extry();
return 0;
}
#pragma once
#include <assert.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <string.h>
#include "amx.h"
#ifdef __aarch64__
double fma64(double a, double b, double c) {
double out;
asm("fmadd %d[out], %d[a], %d[b], %d[c]"
: [out] "=w"(out)
: [a] "w"(a), [b] "w"(b), [c] "w"(c));
return out;
}
float fma32(float a, float b, float c) {
float out;
asm("fmadd %s[out], %s[a], %s[b], %s[c]"
: [out] "=w"(out)
: [a] "w"(a), [b] "w"(b), [c] "w"(c));
return out;
}
float fms32(float a, float b, float c) {
float out;
asm("fmsub %s[out], %s[a], %s[b], %s[c]"
: [out] "=w"(out)
: [a] "w"(a), [b] "w"(b), [c] "w"(c));
return out;
}
float16 fma16(float16 a, float16 b, float16 c) {
float16 out;
asm("fmadd %h[out], %h[a], %h[b], %h[c]"
: [out] "=w"(out)
: [a] "w"(a), [b] "w"(b), [c] "w"(c));
return out;
}
#else
#error "TODO: portable fma64/fma32/fma16 implementations"
float fma32(float a, float b, float c) {
#pragma clang fp contract(fast)
return a * b + c;
}
double fma64(double a, double b, double c) {
#pragma clang fp contract(fast)
return a * b + c;
}
#endif
uint16_t mac16(uint16_t a, uint16_t b, uint16_t c) {
return (uint16_t)(((int64_t)a * (int64_t)b + (int64_t)c) & 0xFFFF);
}
void amx_state_zero(struct amx_state *state) {
memset(state, 0, sizeof *state);
}
#define LDST_ADDRESS_MASK ((1ull << 56) - 1)
#define LDST_DOUBLE_WIDTH (1ull << 62)
static void amx_state_load_impl(union amx_row *rows, int mask,
uint64_t operand) {
uint64_t double_width = (operand & LDST_DOUBLE_WIDTH);
if (double_width && (operand & 0x7F) != 0) {
// TODO: some way to test exceptions
printf("error: bad alignment\n");
}
char *addr = (char *)(operand & LDST_ADDRESS_MASK);
uint64_t reg = (operand >> 56) & mask;
memcpy(&rows[reg], addr, 0x40);
if (double_width) {
memcpy(&rows[(reg + 1) & mask], addr + 0x40, 0x40);
}
}
static void amx_state_store_impl(union amx_row *rows, int mask,
uint64_t operand) {
uint64_t double_width = (operand & LDST_DOUBLE_WIDTH);
if (double_width && (operand & 0x7F) != 0) {
// TODO: some way to test exceptions
printf("error: bad alignment\n");
}
char *addr = (char *)(operand & LDST_ADDRESS_MASK);
uint64_t reg = (operand >> 56) & mask;
memcpy(addr, &rows[reg], 0x40);
if (double_width) {
memcpy(addr + 0x40, &rows[(reg + 1) & mask], 0x40);
}
}
void amx_state_ldx(struct amx_state *state, uint64_t operand) {
amx_state_load_impl(state->x, 7, operand);
}
void amx_state_ldy(struct amx_state *state, uint64_t operand) {
amx_state_load_impl(state->y, 7, operand);
}
void amx_state_ldz(struct amx_state *state, uint64_t operand) {
amx_state_load_impl(state->z, 0x3F, operand);
}
void amx_state_stx(struct amx_state *state, uint64_t operand) {
amx_state_store_impl(state->x, 7, operand);
}
void amx_state_sty(struct amx_state *state, uint64_t operand) {
amx_state_store_impl(state->y, 7, operand);
}
void amx_state_stz(struct amx_state *state, uint64_t operand) {
amx_state_store_impl(state->z, 0x3F, operand);
}
void amx_state_ldzi(struct amx_state *state, uint64_t operand) {
char *addr = (char *)(operand & LDST_ADDRESS_MASK);
uint32_t row[16];
memcpy(row, addr, sizeof row);
uint64_t reg = (operand >> 56) & 0x3F;
for (int i = 0; i < 16; i++) {
state->z[(reg & ~1) + (i & 1)].u32[((reg & 1) << 3) + (i >> 1)] = row[i];
}
}
void amx_state_stzi(struct amx_state *state, uint64_t operand) {
uint64_t reg = (operand >> 56) & 0x3F;
char *addr = (char *)(operand & LDST_ADDRESS_MASK);
uint32_t row[16];
for (int i = 0; i < 16; i++) {
row[i] = state->z[(reg & ~1) + (i & 1)].u32[((reg & 1) << 3) + (i >> 1)];
}
memcpy(addr, row, sizeof row);
}
static void load_from_x(void *output, struct amx_state *state, size_t offset,
size_t size) {
char *p = (char *)output;
for (size_t i = 0; i < size; i++) {
memcpy(p++, ((char *)&state->x) + ((offset + i) & 0x1FF), 1);
}
}
static void load_from_y(void *output, struct amx_state *state, size_t offset,
size_t size) {
char *p = (char *)output;
for (size_t i = 0; i < size; i++) {
memcpy(p++, ((char *)&state->y) + ((offset + i) & 0x1FF), 1);
}
}
static void load_from_z(void *output, struct amx_state *state, size_t offset,
size_t size) {
char *p = (char *)output;
for (size_t i = 0; i < size; i++) {
memcpy(p++, ((char *)&state->z) + ((offset + i) & 0xFFF), 1);
}
}
static void store_to_x(void *output, struct amx_state *state, size_t offset,
size_t size) {
char *p = (char *)output;
for (size_t i = 0; i < size; i++) {
memcpy(((char *)&state->x) + ((offset + i) & 0x1FF), p++, 1);
}
}
static void store_to_y(void *output, struct amx_state *state, size_t offset,
size_t size) {
char *p = (char *)output;
for (size_t i = 0; i < size; i++) {
memcpy(((char *)&state->y) + ((offset + i) & 0x1FF), p++, 1);
}
}
#define FMA_SKIP_Z_INPUT (1ull << 27)
#define FMA_SKIP_Y_INPUT (1ull << 28)
#define FMA_SKIP_X_INPUT (1ull << 29)
static void amx_state_fmas32_impl(struct amx_state *state, uint64_t operand,
bool sub) {
float x[16];
float y[16];
// TODO: need to wrap around byte offsets
uint64_t y_offset = operand & 0x1FF;
uint64_t x_offset = (operand >> 10) & 0x1FF;
uint64_t z_offset = (operand >> 20) & 63;
assert((sizeof x) == 0x40);
if ((operand & FMA_SKIP_Y_INPUT) && (operand & FMA_SKIP_X_INPUT)) {
memset(&x, 0, sizeof x);
memset(&y, 0, sizeof y);
} else {
if (operand & FMA_SKIP_X_INPUT) {
for (int i = 0; i < 16; i++) {
x[i] = 1.0;
}
} else {
load_from_x(x, state, x_offset, sizeof x);
}
if (operand & FMA_SKIP_Y_INPUT) {
for (int i = 0; i < 16; i++) {
y[i] = 1.0;
}
} else {
load_from_y(y, state, y_offset, sizeof y);
}
}
float sub_mul = sub ? -1.0 : 1.0;
if (operand & (1ull << 63)) {
for (int i = 0; i < 16; i++) {
float *z = &state->z[z_offset].f32[i];
*z =
fma32(sub_mul * x[i], y[i], (operand & FMA_SKIP_Z_INPUT) ? 0.0f : *z);
}
} else {
z_offset &= 3;
for (int i = 0; i < 16; i++) {
for (int j = 0; j < 16; j++) {
float *z = &state->z[(j * 4) + z_offset].f32[i];
*z = fma32(sub_mul * x[i], y[j],
(operand & FMA_SKIP_Z_INPUT) ? 0.0f : *z);
}
}
}
}
void amx_state_fma32(struct amx_state *state, uint64_t operand) {
amx_state_fmas32_impl(state, operand, false);
}
void amx_state_fms32(struct amx_state *state, uint64_t operand) {
amx_state_fmas32_impl(state, operand, true);
}
static void amx_state_fmas64_impl(struct amx_state *state, uint64_t operand,
bool sub) {
double x[8];
double y[8];
// TODO: need to wrap around byte offsets
uint64_t y_offset = operand & 0x1FF;
uint64_t x_offset = (operand >> 10) & 0x1FF;
uint64_t z_offset = (operand >> 20) & 63;
assert((sizeof x) == 0x40);
if ((operand & FMA_SKIP_Y_INPUT) && (operand & FMA_SKIP_X_INPUT)) {
memset(&x, 0, sizeof x);
memset(&y, 0, sizeof y);
} else {
if (operand & FMA_SKIP_X_INPUT) {
for (int i = 0; i < 8; i++) {
x[i] = 1.0;
}
} else {
load_from_x(x, state, x_offset, sizeof x);
}
if (operand & FMA_SKIP_Y_INPUT) {
for (int i = 0; i < 8; i++) {
y[i] = 1.0;
}
} else {
load_from_y(y, state, y_offset, sizeof y);
}
}
double sub_mul = sub ? -1.0 : 1.0;
if (operand & (1ull << 63)) {
for (int i = 0; i < 8; i++) {
double *z = &state->z[z_offset].f64[i];
*z =
fma64(sub_mul * x[i], y[i], (operand & FMA_SKIP_Z_INPUT) ? 0.0f : *z);
}
} else {
z_offset &= 7;
for (int i = 0; i < 8; i++) {
for (int j = 0; j < 8; j++) {
double *z = &state->z[(j * 8) + z_offset].f64[i];
*z = fma64(sub_mul * x[i], y[j],
(operand & FMA_SKIP_Z_INPUT) ? 0.0f : *z);
}
}
}
}
void amx_state_fma64(struct amx_state *state, uint64_t operand) {
amx_state_fmas64_impl(state, operand, false);
}
void amx_state_fms64(struct amx_state *state, uint64_t operand) {
amx_state_fmas64_impl(state, operand, true);
}
static void amx_state_fmas16_impl(struct amx_state *state, uint64_t operand,
bool sub) {
float16 x[32];
float16 y[32];
// TODO: need to wrap around byte offsets
uint64_t y_offset = operand & 0x1FF;
uint64_t x_offset = (operand >> 10) & 0x1FF;
uint64_t z_offset = (operand >> 20) & 63;
assert((sizeof x) == 0x40);
if ((operand & FMA_SKIP_Y_INPUT) && (operand & FMA_SKIP_X_INPUT)) {
memset(&x, 0, sizeof x);
memset(&y, 0, sizeof y);
} else {
if (operand & FMA_SKIP_X_INPUT) {
for (int i = 0; i < 32; i++) {
x[i] = (float16)1.0;
}
} else {
load_from_x(x, state, x_offset, sizeof x);
}
if (operand & FMA_SKIP_Y_INPUT) {
for (int i = 0; i < 32; i++) {
y[i] = (float16)1.0;
}
} else {
load_from_y(y, state, y_offset, sizeof y);
}
}
float16 sub_mul = sub ? -1.0 : 1.0;
if (operand & (1ull << 63)) {
for (int i = 0; i < 32; i++) {
float16 *z = &state->z[z_offset].f16[i];
*z = fma16(sub_mul * x[i], y[i],
(operand & FMA_SKIP_Z_INPUT) ? (float16)0.0f : *z);
}
} else {
z_offset &= 1;
for (int i = 0; i < 32; i++) {
for (int j = 0; j < 32; j++) {
if (operand & (1ull << 62)) {
float *z = &state->z[(j * 2) + (i & 1)].f32[i >> 1];
float acc = (operand & FMA_SKIP_Z_INPUT) ? 0 : *z;
acc = fma32((float)sub_mul * (float)x[i], (float)y[j], acc);
*z = acc;
} else {
float16 *z = &state->z[(j * 2) + z_offset].f16[i];
*z = fma16(sub_mul * x[i], y[j],
(operand & FMA_SKIP_Z_INPUT) ? (float16)0.0f : *z);
}
}
}
}
}
void amx_state_fma16(struct amx_state *state, uint64_t operand) {
amx_state_fmas16_impl(state, operand, false);
}
void amx_state_fms16(struct amx_state *state, uint64_t operand) {
amx_state_fmas16_impl(state, operand, true);
}
void amx_state_mac16(struct amx_state *state, uint64_t operand) {
uint16_t x[32];
uint16_t y[32];
// TODO: need to wrap around byte offsets
uint64_t y_offset = operand & 0x1FF;
uint64_t x_offset = (operand >> 10) & 0x1FF;
uint64_t z_offset = (operand >> 20) & 63;
assert((sizeof x) == 0x40);
if ((operand & FMA_SKIP_Y_INPUT) && (operand & FMA_SKIP_X_INPUT)) {
memset(&x, 0, sizeof x);
memset(&y, 0, sizeof y);
} else {
if (operand & FMA_SKIP_X_INPUT) {
for (int i = 0; i < 32; i++) {
x[i] = (uint16_t)1;
}
} else {
load_from_x(x, state, x_offset, sizeof x);
}
if (operand & FMA_SKIP_Y_INPUT) {
for (int i = 0; i < 32; i++) {
y[i] = (uint16_t)1;
}
} else {
load_from_y(y, state, y_offset, sizeof y);
}
}
if (operand & (1ull << 63)) {
for (int i = 0; i < 32; i++) {
uint16_t *z = &state->z[z_offset].u16[i];
*z = mac16(x[i], y[i], (operand & FMA_SKIP_Z_INPUT) ? (uint16_t)0 : *z);
}
} else {
z_offset &= 1;
for (int i = 0; i < 32; i++) {
for (int j = 0; j < 32; j++) {
if (operand & (1ull << 62)) {
uint32_t *z = &state->z[(j * 2) + (i & 1)].u32[i >> 1];
uint32_t acc = (operand & FMA_SKIP_Z_INPUT) ? 0 : *z;
acc += (uint32_t)((int64_t)(int16_t)x[i] * (int64_t)(int16_t)y[j]);
*z = acc;
} else {
uint16_t *z = &state->z[(j * 2) + z_offset].u16[i];
*z = mac16(x[i], y[j],
(operand & FMA_SKIP_Z_INPUT) ? (uint16_t)0 : *z);
}
}
}
}
}
void amx_state_extrx(struct amx_state *state, uint64_t operand) {
// uint64_t y_offset = operand & 0x1FF;
uint64_t x_offset = (operand >> 10) & 0x1FF;
uint64_t z_offset = (operand >> 20) & 63;
uint32_t buffer[16];
if ((operand & (1ull << 27))) {
x_offset &= ~0x3F;
load_from_y(buffer, state, z_offset * 0x40, 0x40);
} else {
load_from_z(buffer, state, z_offset * 0x40, 0x40);
}
store_to_x(buffer, state, x_offset, 0x40);
}
void amx_state_extry(struct amx_state *state, uint64_t operand) {
// I have misgivings about calling this "extry" as it does sometimes move to
// "x"
uint64_t y_offset = operand & 0x1FF;
uint64_t z_offset = (operand >> 20) & 63;
if (operand & (1ull << 26)) {
// TODO: might be a good place to use the "union amx_row" type
uint8_t buffer[64];
uint64_t operation = ((operand >> 11) & 0xF) | ((operand >> 63) << 4);
switch (operation) {
case 0x00: {
for (int i = 0; i < 64; i++) {
buffer[i] = state->z[i].u8[z_offset & 63];
}
} break;
case 0x0B: {
// TODO: mishandles z_offset
for (int i = 0; i < 64; i++) {
buffer[i] = state->z[i].u8[z_offset & 63];
}
} break;
case 0x0D: {
// TODO: mishandles z_offset
for (int i = 0; i < 64; i++) {
buffer[i] = state->z[i].u8[z_offset & 63];
}
} break;
case 0x09: {
// TODO: mishandles z_offset
uint16_t buffer1[32];
for (int i = 0; i < 32; i++) {
buffer1[i] =
state->z[((i & ~1) << 1) + (z_offset & ((1 << 1) - 1)) + (i & 1)]
.u16[(z_offset >> 1) & 31];
}
memcpy(buffer, buffer1, sizeof buffer);
} break;
case 0x0A: {
// TODO: still mishandles z_offset
uint16_t buffer1[32];
for (int i = 0; i < 32; i++) {
buffer1[i] =
state->z[((i << 1) + (z_offset & ((1 << 1) - 1))) ^ (z_offset & 2)]
.u16[((z_offset >> 2) & 31)];
}
memcpy(buffer, buffer1, sizeof buffer);
} break;
case 0x11: {
uint64_t buffer1[8];
for (int i = 0; i < 8; i++) {
buffer1[i] = state->z[(i << 3) + (z_offset & ((1 << 3) - 1))]
.u64[(z_offset >> 3) & 7];
}
memcpy(buffer, buffer1, sizeof buffer);
} break;
// (does this have an undiscovered difference?)
case 0x08:
case 0x18: {
uint32_t buffer1[16];
for (int i = 0; i < 16; i++) {
buffer1[i] = state->z[(i << 2) + (z_offset & ((1 << 2) - 1))]
.u32[(z_offset >> 2) & 15];
}
memcpy(buffer, buffer1, sizeof buffer);
} break;
default: {
uint16_t buffer1[32];
for (int i = 0; i < 32; i++) {
buffer1[i] = state->z[(i << 1) + (z_offset & ((1 << 1) - 1))]
.u16[(z_offset >> 1) & 31];
}
memcpy(buffer, buffer1, sizeof buffer);
} break;
}
if (operand & (1ull << 10)) {
store_to_y(buffer, state, y_offset, 0x40);
} else {
store_to_x(buffer, state, y_offset, 0x40);
}
} else if ((operand & (1ull << 27))) {
y_offset &= ~0x3F;
uint32_t buffer[16];
load_from_x(buffer, state, z_offset * 0x40, 0x40);
store_to_y(buffer, state, y_offset, 0x40);
} else {
// TODO: rewrite as a switch on (operand >> 28) & 3?
if ((operand & (1ull << 29)) && (operand & (1ull << 28))) {
uint16_t buffer[32];
load_from_y(buffer, state, y_offset, 0x40);
for (int i = 0; i < 32; i++) {
buffer[i] &= 0xFF00;
buffer[i] |= state->z[(i << 1) + (z_offset & ((1 << 1) - 1))]
.u16[(z_offset >> 1) & 31] &
0xFF;
}
store_to_y(buffer, state, y_offset, 0x40);
} else if (operand & (1ull << 29)) {
uint16_t buffer[32];
for (int i = 0; i < 32; i++) {
buffer[i] = state->z[(i << 1) + (z_offset & ((1 << 1) - 1))]
.u16[(z_offset >> 1) & 31];
}
store_to_y(buffer, state, y_offset, 0x40);
} else if (operand & (1ull << 28)) {
uint32_t buffer[16];
for (int i = 0; i < 16; i++) {
buffer[i] = state->z[(i << 2) + (z_offset & ((1 << 2) - 1))]
.u32[(z_offset >> 2) & 15];
}
store_to_y(buffer, state, y_offset, 0x40);
} else {
uint64_t buffer[8];
for (int i = 0; i < 8; i++) {
buffer[i] = state->z[(i << 3) + (z_offset & ((1 << 3) - 1))]
.u64[(z_offset >> 3) & 7];
}
store_to_y(buffer, state, y_offset, 0x40);
}
}
}
// print flags
enum print_flags {
PF_TYPE_MASK = 0x0F,
PF_U32 = 0x00,
PF_U16 = 0x01,
PF_U8 = 0x02,
PF_F16 = 0x03,
PF_F32 = 0x04,
PF_F64 = 0x05,
PF_U64 = 0x06,
PF_SKIP_X = 0x10,
PF_SKIP_Y = 0x20,
PF_SKIP_Z = 0x40,
PF_SKIP_ZERO_ROWS = 0x80,
// for diffs
PF_SKIP_A = 0x10000,
PF_SKIP_B = 0x20000,
};
void print_amx_row(char reg, int num, union amx_row *r, int flags) {
bool should_print = true;
if (flags & PF_SKIP_ZERO_ROWS) {
should_print = false;
for (uint64_t o = 0; o < 16; o++) {
if (r->u32[o] != 0) {
should_print = true;
break;
}
}
}
if (should_print) {
if (num < 10)
printf(" ");
printf(" %c%d:", reg, num);
switch (flags & PF_TYPE_MASK) {
case PF_U32:
for (uint64_t o = 0; o < 16; o++) {
printf(" %8x", r->u32[o]);
}
break;
case PF_U16:
for (uint64_t o = 0; o < 32; o++) {
printf(" %4x", r->u16[o]);
}
break;
case PF_U8:
for (uint64_t o = 0; o < 64; o++) {
printf(" %2x", r->u8[o]);
}
break;
case PF_F16:
for (uint64_t o = 0; o < 32; o++) {
printf(" %16f", (float)r->f16[o]);
}
break;
case PF_F32:
for (uint64_t o = 0; o < 16; o++) {
printf(" %16f", r->f32[o]);
}
break;
case PF_F64:
for (uint64_t o = 0; o < 8; o++) {
printf(" %16lf", r->f64[o]);
}
break;
case PF_U64:
for (uint64_t o = 0; o < 8; o++) {
printf(" %16llx", r->u64[o]);
}
break;
default:
assert(0 && "invalid print flags");
}
printf("\n");
}
}
void print_amx_state(struct amx_state *state, int flags) {
if (!(flags & PF_SKIP_X)) {
for (int i = 0; i < 8; i++) {
print_amx_row('x', i, &state->x[i], flags);
}
}
if (!(flags & PF_SKIP_Y)) {
for (int i = 0; i < 8; i++) {
print_amx_row('y', i, &state->y[i], flags);
}
}
if (!(flags & PF_SKIP_Z)) {
for (int i = 0; i < 64; i++) {
print_amx_row('z', i, &state->z[i], flags);
}
}
}
int diff_amx_state(struct amx_state *a, struct amx_state *b, int flags) {
int same = 1;
if (!(flags & PF_SKIP_X)) {
for (int i = 0; i < 8; i++) {
if (memcmp(&a->x[i], &b->x[i], sizeof(union amx_row))) {
if (!(flags & PF_SKIP_A))
print_amx_row('x', i, &a->x[i], flags);
if (!(flags & PF_SKIP_B))
print_amx_row('x', i, &b->x[i], flags);
same = 0;
}
}
}
if (!(flags & PF_SKIP_Y)) {
for (int i = 0; i < 8; i++) {
if (memcmp(&a->y[i], &b->y[i], sizeof(union amx_row))) {
if (!(flags & PF_SKIP_A))
print_amx_row('y', i, &a->y[i], flags);
if (!(flags & PF_SKIP_B))
print_amx_row('y', i, &b->y[i], flags);
same = 0;
}
}
}
if (!(flags & PF_SKIP_Z)) {
for (int i = 0; i < 64; i++) {
if (memcmp(&a->z[i], &b->z[i], sizeof(union amx_row))) {
if (!(flags & PF_SKIP_A))
print_amx_row('z', i, &a->z[i], flags);
if (!(flags & PF_SKIP_B))
print_amx_row('z', i, &b->z[i], flags);
same = 0;
}
}
}
return same;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment