-
-
Save dougallj/7cba721da1a94da725ee37c1e9cd1f21 to your computer and use it in GitHub Desktop.
amx simulator and hardware tests
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#include <stdint.h> | |
// TODO: is it possible to not go via x0? I'm guessing not without messing with | |
// the compile process, but still, kinda ugly. might at least be possible to | |
// force the compiler to get the value in x0 itself | |
// TODO: do I need memory as an input? | |
#define AMX_LDX(V) \ | |
__asm__ volatile( \ | |
"mov x0, %0 \r\n .word (0x201000 | (0 << 5) | 0)" ::"r"((uint64_t)V) \ | |
: "x0", "memory") | |
#define AMX_LDY(V) \ | |
__asm__ volatile( \ | |
"mov x0, %0 \r\n .word (0x201000 | (1 << 5) | 0)" ::"r"((uint64_t)V) \ | |
: "x0", "memory") | |
#define AMX_STX(V) \ | |
__asm__ volatile( \ | |
"mov x0, %0 \r\n .word (0x201000 | (2 << 5) | 0)" ::"r"((uint64_t)V) \ | |
: "x0", "memory") | |
#define AMX_STY(V) \ | |
__asm__ volatile( \ | |
"mov x0, %0 \r\n .word (0x201000 | (3 << 5) | 0)" ::"r"((uint64_t)V) \ | |
: "x0", "memory") | |
#define AMX_LDZ(V) \ | |
__asm__ volatile( \ | |
"mov x0, %0 \r\n .word (0x201000 | (4 << 5) | 0)" ::"r"((uint64_t)V) \ | |
: "x0", "memory") | |
#define AMX_STZ(V) \ | |
__asm__ volatile( \ | |
"mov x0, %0 \r\n .word (0x201000 | (5 << 5) | 0)" ::"r"((uint64_t)V) \ | |
: "x0", "memory") | |
#define AMX_LDZI(V) \ | |
__asm__ volatile( \ | |
"mov x0, %0 \r\n .word (0x201000 | (6 << 5) | 0)" ::"r"((uint64_t)V) \ | |
: "x0", "memory") | |
#define AMX_STZI(V) \ | |
__asm__ volatile( \ | |
"mov x0, %0 \r\n .word (0x201000 | (7 << 5) | 0)" ::"r"((uint64_t)V) \ | |
: "x0", "memory") | |
// TODO: probably shouldn't say these clobber memory? | |
#define AMX_EXTRX(V) \ | |
__asm__ volatile( \ | |
"mov x0, %0 \r\n .word (0x201000 | (8 << 5) | 0)" ::"r"((uint64_t)V) \ | |
: "x0", "memory") | |
#define AMX_EXTRY(V) \ | |
__asm__ volatile( \ | |
"mov x0, %0 \r\n .word (0x201000 | (9 << 5) | 0)" ::"r"((uint64_t)V) \ | |
: "x0", "memory") | |
#define AMX_FMA64(V) \ | |
__asm__ volatile( \ | |
"mov x0, %0 \r\n .word (0x201000 | (10 << 5) | 0)" ::"r"((uint64_t)V) \ | |
: "x0", "memory") | |
#define AMX_FMS64(V) \ | |
__asm__ volatile( \ | |
"mov x0, %0 \r\n .word (0x201000 | (11 << 5) | 0)" ::"r"((uint64_t)V) \ | |
: "x0", "memory") | |
#define AMX_FMA32(V) \ | |
__asm__ volatile( \ | |
"mov x0, %0 \r\n .word (0x201000 | (12 << 5) | 0)" ::"r"((uint64_t)V) \ | |
: "x0", "memory") | |
#define AMX_FMS32(V) \ | |
__asm__ volatile( \ | |
"mov x0, %0 \r\n .word (0x201000 | (13 << 5) | 0)" ::"r"((uint64_t)V) \ | |
: "x0", "memory") | |
#define AMX_MAC16(V) \ | |
__asm__ volatile( \ | |
"mov x0, %0 \r\n .word (0x201000 | (14 << 5) | 0)" ::"r"((uint64_t)V) \ | |
: "x0", "memory") | |
#define AMX_FMA16(V) \ | |
__asm__ volatile( \ | |
"mov x0, %0 \r\n .word (0x201000 | (15 << 5) | 0)" ::"r"((uint64_t)V) \ | |
: "x0", "memory") | |
#define AMX_FMS16(V) \ | |
__asm__ volatile( \ | |
"mov x0, %0 \r\n .word (0x201000 | (16 << 5) | 0)" ::"r"((uint64_t)V) \ | |
: "x0", "memory") | |
#define AMX_START() \ | |
__asm__ volatile( \ | |
"nop \r\n nop \r\n nop \r\n .word (0x201000 | (17 << 5) | 0)" :: \ | |
: "memory") | |
#define AMX_STOP() \ | |
__asm__ volatile( \ | |
"nop \r\n nop \r\n nop \r\n .word (0x201000 | (17 << 5) | 1)" :: \ | |
: "memory") | |
// horizontal multiply uint16_ts? (doesn't mac16 have a flag for this?) | |
// z0[i] += x0[i] + y0[i] | |
#define AMX_VECINT(V) \ | |
__asm__ volatile( \ | |
"mov x0, %0 \r\n .word (0x201000 | (18 << 5) | 0)" ::"r"((uint64_t)V) \ | |
: "x0", "memory") | |
// horizontal multiply float16_ts? (doesn't fma16 have a flag for this?) | |
// z0[i] += x0[i] + y0[i] | |
#define AMX_VECFP(V) \ | |
__asm__ volatile( \ | |
"mov x0, %0 \r\n .word (0x201000 | (19 << 5) | 0)" ::"r"((uint64_t)V) \ | |
: "x0", "memory") | |
// uint16_t matrix multiply? (doesn't mac16 do this?) | |
#define AMX_MATINT(V) \ | |
__asm__ volatile( \ | |
"mov x0, %0 \r\n .word (0x201000 | (20 << 5) | 0)" ::"r"((uint64_t)V) \ | |
: "x0", "memory") | |
// float16_t matrix multiply? (doesn't fma16 do this?) | |
#define AMX_MATFP(V) \ | |
__asm__ volatile( \ | |
"mov x0, %0 \r\n .word (0x201000 | (21 << 5) | 0)" ::"r"((uint64_t)V) \ | |
: "x0", "memory") | |
// looks only at z0, clears it, and generates a 64-bit value in x0[0]: | |
// [15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0] -> 0xffffffffffffffff | |
// [0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1] -> 0xf0 | |
// [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] -> 0xfedcba9876543210 | |
// [0x0, 0x10000000, 0x20000000, 0x30000000, 0x40000000, 0x50000000, 0x60000000, | |
// 0x70000000, 0x80000000, 0x90000000, 0xA0000000, 0xB0000000, 0xC0000000, | |
// 0xD0000000, 0xE0000000, 0xF0000000] -> fffffff0f6543210 | |
#define AMX_GENLUT(V) \ | |
__asm__ volatile( \ | |
"mov x0, %0 \r\n .word (0x201000 | (22 << 5) | 0)" ::"r"((uint64_t)V) \ | |
: "x0", "memory") | |
typedef _Float16 float16; | |
union amx_row { | |
// not all supported types, just useful ones | |
uint8_t u8[64]; | |
uint16_t u16[32]; | |
uint32_t u32[16]; | |
uint64_t u64[8]; | |
float16 f16[32]; | |
float f32[16]; | |
double f64[8]; | |
}; | |
struct amx_state { | |
union amx_row x[8]; | |
union amx_row y[8]; | |
union amx_row z[64]; | |
}; | |
void store_amx_state(struct amx_state *state) { | |
memset(state, 0xAA, sizeof *state); | |
for (uint64_t i = 0; i < 8; i++) { | |
AMX_STX((i << 56) | (uint64_t)&state->x[i]); | |
} | |
for (uint64_t i = 0; i < 8; i++) { | |
AMX_STY((i << 56) | (uint64_t)&state->y[i]); | |
} | |
for (uint64_t i = 0; i < 64; i++) { | |
AMX_STZ((i << 56) | (uint64_t)&state->z[i]); | |
} | |
} | |
void load_amx_state(struct amx_state *state) { | |
for (uint64_t i = 0; i < 8; i++) { | |
AMX_LDX((i << 56) | (uint64_t)&state->x[i]); | |
} | |
for (uint64_t i = 0; i < 8; i++) { | |
AMX_LDY((i << 56) | (uint64_t)&state->y[i]); | |
} | |
for (uint64_t i = 0; i < 64; i++) { | |
AMX_LDZ((i << 56) | (uint64_t)&state->z[i]); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <assert.h> | |
#include <stdbool.h> | |
#include <stdint.h> | |
#include <stdio.h> | |
#include <string.h> | |
#include "amx.h" | |
#include "simulator.h" | |
static int check_state(struct amx_state *sim_state, int flags) { | |
__attribute__((aligned(0x80))) static struct amx_state state; | |
store_amx_state(&state); | |
if (memcmp(sim_state, &state, sizeof state)) { | |
printf("state mismatch!\n"); | |
printf("real state:\n"); | |
diff_amx_state(&state, sim_state, flags | PF_SKIP_B); | |
printf("simulated state:\n"); | |
diff_amx_state(&state, sim_state, flags | PF_SKIP_A); | |
return 0; | |
} | |
return 1; | |
} | |
bool is_prime(uint64_t v) { | |
if (v < 2) { | |
return false; | |
} | |
for (uint64_t i = 2; i * i <= v; i++) { | |
if (v % i == 0) { | |
return false; | |
} | |
} | |
return true; | |
} | |
static void init_test_f32s(float *data) { | |
for (uint64_t v = 1000, o = 0; o < (8 + 8) * 16; v++) { | |
if (is_prime(v)) { | |
data[o++] = v; | |
} | |
} | |
int i = 1; | |
for (uint64_t o = (8 + 8) * 16; o < (8 + 8 + 64) * 16; o++) { | |
data[o] = i++; | |
} | |
} | |
static void test_start(struct amx_state *sim_state) { | |
amx_state_zero(sim_state); | |
AMX_START(); | |
} | |
static void test_stop(struct amx_state *sim_state) { AMX_STOP(); } | |
// LDX/LDY/LDZ/LDZI | |
static void test_ldx(struct amx_state *sim_state, uint64_t op) { | |
amx_state_ldx(sim_state, op); | |
AMX_LDX(op); | |
check_state(sim_state, PF_F32); | |
} | |
static void test_ldy(struct amx_state *sim_state, uint64_t op) { | |
amx_state_ldy(sim_state, op); | |
AMX_LDY(op); | |
check_state(sim_state, PF_F32); | |
} | |
static void test_ldz(struct amx_state *sim_state, uint64_t op) { | |
amx_state_ldz(sim_state, op); | |
AMX_LDZ(op); | |
check_state(sim_state, PF_F32); | |
} | |
static void test_ldzi(struct amx_state *sim_state, uint64_t op) { | |
amx_state_ldzi(sim_state, op); | |
AMX_LDZI(op); | |
check_state(sim_state, PF_F32); | |
} | |
static void test_loads(void) { | |
__attribute__((aligned(0x80))) static float test_data[(8 + 8 + 64) * 16]; | |
init_test_f32s(test_data); | |
struct amx_state sim_state; | |
// TODO: test alignment checks somehow | |
for (int test_bit = 56 - 1; test_bit < 64; test_bit++) { | |
uint64_t extra_bit = (test_bit == 55 ? 0 : (1ull << test_bit)); | |
for (uint64_t reg = 0; reg < 8; reg++) { | |
test_start(&sim_state); | |
test_ldx(&sim_state, (uint64_t)test_data | (reg << 56) | extra_bit); | |
test_stop(&sim_state); | |
} | |
for (uint64_t reg = 0; reg < 8; reg++) { | |
test_start(&sim_state); | |
test_ldy(&sim_state, (uint64_t)test_data | (reg << 56) | extra_bit); | |
test_stop(&sim_state); | |
} | |
for (uint64_t reg = 0; reg < 64; reg++) { | |
test_start(&sim_state); | |
test_ldz(&sim_state, (uint64_t)test_data | (reg << 56) | extra_bit); | |
test_stop(&sim_state); | |
} | |
for (uint64_t reg = 0; reg < 64; reg++) { | |
test_start(&sim_state); | |
test_ldzi(&sim_state, (uint64_t)test_data | (reg << 56) | extra_bit); | |
test_stop(&sim_state); | |
} | |
} | |
} | |
// STX/STY/STZ/STZI | |
static void test_stores(void) { | |
struct amx_state sim_state; | |
static float test_data[(8 + 8 + 64) * 16]; | |
init_test_f32s(test_data); | |
__attribute__((aligned(0x80))) static float store_buffer1[0x100]; | |
__attribute__((aligned(0x80))) static float store_buffer2[0x100]; | |
test_start(&sim_state); | |
memcpy(&sim_state, test_data, sizeof sim_state); | |
load_amx_state(&sim_state); | |
for (int test_bit = 56 - 1; test_bit < 64; test_bit++) { | |
uint64_t extra_bit = (test_bit == 55 ? 0 : (1ull << test_bit)); | |
for (uint64_t reg = 0; reg < 8; reg++) { | |
memset(store_buffer1, 0, sizeof store_buffer1); | |
memset(store_buffer2, 0, sizeof store_buffer2); | |
AMX_STX((uint64_t)store_buffer1 | (reg << 56) | extra_bit); | |
amx_state_stx(&sim_state, | |
(uint64_t)store_buffer2 | (reg << 56) | extra_bit); | |
if (memcmp(store_buffer1, store_buffer2, sizeof store_buffer1)) { | |
printf("store test failed!\n"); | |
} | |
} | |
for (uint64_t reg = 0; reg < 8; reg++) { | |
memset(store_buffer1, 0, sizeof store_buffer1); | |
memset(store_buffer2, 0, sizeof store_buffer2); | |
AMX_STY((uint64_t)store_buffer1 | (reg << 56) | extra_bit); | |
amx_state_sty(&sim_state, | |
(uint64_t)store_buffer2 | (reg << 56) | extra_bit); | |
if (memcmp(store_buffer1, store_buffer2, sizeof store_buffer1)) { | |
printf("store test failed!\n"); | |
} | |
} | |
for (uint64_t reg = 0; reg < 64; reg++) { | |
memset(store_buffer1, 0, sizeof store_buffer1); | |
memset(store_buffer2, 0, sizeof store_buffer2); | |
AMX_STZ((uint64_t)store_buffer1 | (reg << 56) | extra_bit); | |
amx_state_stz(&sim_state, | |
(uint64_t)store_buffer2 | (reg << 56) | extra_bit); | |
if (memcmp(store_buffer1, store_buffer2, sizeof store_buffer1)) { | |
printf("store test failed!\n"); | |
} | |
} | |
for (uint64_t reg = 0; reg < 64; reg++) { | |
memset(store_buffer1, 0, sizeof store_buffer1); | |
memset(store_buffer2, 0, sizeof store_buffer2); | |
AMX_STZI((uint64_t)store_buffer1 | (reg << 56) | extra_bit); | |
amx_state_stzi(&sim_state, | |
(uint64_t)store_buffer2 | (reg << 56) | extra_bit); | |
if (memcmp(store_buffer1, store_buffer2, sizeof store_buffer1)) { | |
printf("store test failed! 1\n"); | |
} | |
} | |
} | |
test_stop(&sim_state); | |
} | |
// EXTRX | |
static void extrx_test(void *initial_state, uint64_t operand) { | |
struct amx_state sim_state; | |
test_start(&sim_state); | |
memcpy(&sim_state, initial_state, sizeof sim_state); | |
load_amx_state(&sim_state); | |
check_state(&sim_state, PF_F32); | |
AMX_EXTRX(operand); | |
amx_state_extrx(&sim_state, operand); | |
check_state(&sim_state, PF_U32); | |
test_stop(&sim_state); | |
} | |
void test_extrx(void) { | |
uint32_t initial_state[(8 + 8 + 64) * 16]; | |
for (int i = 0; i < 8 * 16; i++) { | |
initial_state[i] = 0xAA0000 | i; | |
} | |
for (int i = 8 * 16; i < (8 + 8) * 16; i++) { | |
initial_state[i] = 0xBB00000 | i; | |
} | |
for (int i = (8 + 8) * 16; i < (8 + 8 + 64) * 16; i++) { | |
initial_state[i] = 0xCC000000 | i; | |
} | |
// TODO: bit 26 | |
uint64_t mask = 0x800000001ff7fdc0 & ~(1 << 26); | |
// for perf, skip some offsets | |
mask &= ~(0x1FFull << 10); | |
mask |= 0x147ull << 10; | |
uint64_t operand = 0; | |
do { | |
extrx_test(initial_state, operand); | |
// ryg's texture tiling and swizzling loop | |
operand = (operand - mask) & mask; | |
} while (operand); | |
#define __amx_op8_to_x(o) extrx_test(initial_state, o) | |
// bit 26 test cases: | |
__amx_op8_to_x(0x4000000); | |
__amx_op8_to_x(0x4800000); | |
__amx_op8_to_x(0x5000000); | |
__amx_op8_to_x(0x5800000); | |
__amx_op8_to_x(0x6000000); | |
__amx_op8_to_x(0x6800000); | |
__amx_op8_to_x(0x7000000); | |
__amx_op8_to_x(0x7800000); | |
// some failing bit 26 test cases: | |
//__amx_op8_to_x(0x4110000); | |
//__amx_op8_to_x(0x4910000); | |
//__amx_op8_to_x(0x5110000); | |
//__amx_op8_to_x(0x5910000); | |
//__amx_op8_to_x(0x6110000); | |
//__amx_op8_to_x(0x6910000); | |
//__amx_op8_to_x(0x7110000); | |
//__amx_op8_to_x(0x7910000); | |
//__amx_op8_to_x(0x8000000004004500); | |
//__amx_op8_to_x(0x8000000004204000); | |
//__amx_op8_to_x(0x8000000004304000); | |
//__amx_op8_to_x(0x8000000004604000); | |
//__amx_op8_to_x(0x8000000004704000); | |
//__amx_op8_to_x(0x8000000004804000); | |
//__amx_op8_to_x(0x8000000004a04000); | |
//__amx_op8_to_x(0x8000000004b04000); | |
//__amx_op8_to_x(0x8000000004e04000); | |
//__amx_op8_to_x(0x8000000004f04000); | |
//__amx_op8_to_x(0x8000000005004040); | |
//__amx_op8_to_x(0x8000000005204000); | |
//__amx_op8_to_x(0x8000000005304000); | |
//__amx_op8_to_x(0x8000000005604000); | |
//__amx_op8_to_x(0x8000000005704000); | |
//__amx_op8_to_x(0x8000000005804080); | |
//__amx_op8_to_x(0x8000000005a04000); | |
//__amx_op8_to_x(0x8000000005b04000); | |
//__amx_op8_to_x(0x8000000005e04000); | |
//__amx_op8_to_x(0x8000000005f04000); | |
//__amx_op8_to_x(0x8000000006004540); | |
//__amx_op8_to_x(0x8000000006204040); | |
//__amx_op8_to_x(0x8000000006304040); | |
//__amx_op8_to_x(0x8000000006604040); | |
//__amx_op8_to_x(0x8000000006704040); | |
//__amx_op8_to_x(0x8000000006804040); | |
//__amx_op8_to_x(0x8000000006a04040); | |
//__amx_op8_to_x(0x8000000006b04040); | |
//__amx_op8_to_x(0x8000000006e04040); | |
//__amx_op8_to_x(0x8000000006f04040); | |
//__amx_op8_to_x(0x8000000007204040); | |
//__amx_op8_to_x(0x8000000007304040); | |
//__amx_op8_to_x(0x8000000007604040); | |
//__amx_op8_to_x(0x8000000007704040); | |
//__amx_op8_to_x(0x8000000007a04040); | |
//__amx_op8_to_x(0x8000000007b04040); | |
//__amx_op8_to_x(0x8000000007e04040); | |
//__amx_op8_to_x(0x8000000007f04000); | |
//__amx_op8_to_x(0x8000000007f04040); | |
// TODO: commented out things (and the rest, as some probably interact with | |
// other bits) | |
for (int i = 0; i < 41; i++) { | |
extrx_test(initial_state, (1ull << i)); | |
} | |
// extrx_test(initial_state, (1ull << 41)); | |
// extrx_test(initial_state, (1ull << 42)); | |
// extrx_test(initial_state, (1ull << 43)); | |
// extrx_test(initial_state, (1ull << 44)); | |
// extrx_test(initial_state, (1ull << 45)); | |
// extrx_test(initial_state, (1ull << 46)); | |
for (int i = 47; i < 64; i++) { | |
extrx_test(initial_state, (1ull << i)); | |
} | |
} | |
// EXTRY | |
static void extry_test(void *initial_state, uint64_t operand) { | |
struct amx_state sim_state; | |
test_start(&sim_state); | |
memcpy(&sim_state, initial_state, sizeof sim_state); | |
load_amx_state(&sim_state); | |
check_state(&sim_state, PF_F32); | |
AMX_EXTRY(operand); | |
amx_state_extry(&sim_state, operand); | |
check_state(&sim_state, PF_U32); | |
test_stop(&sim_state); | |
} | |
void test_extry(void) { | |
uint32_t initial_state[(8 + 8 + 64) * 16]; | |
for (int i = 0; i < 8 * 16; i++) { | |
initial_state[i] = 0xAA0000 | i; | |
} | |
for (int i = 8 * 16; i < (8 + 8) * 16; i++) { | |
initial_state[i] = 0xBB00000 | i; | |
} | |
for (int i = (8 + 8) * 16; i < (8 + 8 + 64) * 16; i++) { | |
initial_state[i] = 0xCC000000 | i; | |
} | |
// still can't handle (1ull << 14) and some higher bits | |
// uint64_t mask = (1ull << 63) | (1ull << 29) | (1ull << 28) | (1ull << 27) | | |
// (1ull << 26) | (1ull << 13) | (1ull << 12) | (1ull << 11) | (1ull << 10) | | |
// 0x40 | 0x20 | 1 | (63<<20); | |
uint64_t mask = (0x800000003ff04dc0 & ~(1ull << 14)) | 1; | |
uint64_t operand = 0; | |
do { | |
extry_test(initial_state, operand); | |
operand = (operand - mask) & mask; | |
} while (operand); | |
// bit 14 tests that already pass: | |
extry_test(initial_state, 0x8000000004004000); | |
extry_test(initial_state, 0x8000000004104000); | |
extry_test(initial_state, 0x8000000004104040); | |
extry_test(initial_state, 0x8000000004204000); | |
extry_test(initial_state, 0x8000000004304000); | |
extry_test(initial_state, 0x8000000004404040); | |
extry_test(initial_state, 0x8000000004404100); | |
extry_test(initial_state, 0x8000000004504000); | |
extry_test(initial_state, 0x8000000004604000); | |
extry_test(initial_state, 0x8000000004704040); | |
extry_test(initial_state, 0x8000000004804080); | |
extry_test(initial_state, 0x8000000004804400); | |
extry_test(initial_state, 0x8000000004c040c0); | |
extry_test(initial_state, 0x8000000004c04500); | |
extry_test(initial_state, 0x8000000005004040); | |
extry_test(initial_state, 0x8000000005004100); | |
extry_test(initial_state, 0x8000000005404140); | |
extry_test(initial_state, 0x8000000005804180); | |
extry_test(initial_state, 0x8000000005804440); | |
extry_test(initial_state, 0x8000000005c041c0); | |
extry_test(initial_state, 0x8000000005c04540); | |
extry_test(initial_state, 0x8000000006004080); | |
extry_test(initial_state, 0x8000000006004400); | |
extry_test(initial_state, 0x8000000006404180); | |
extry_test(initial_state, 0x8000000006404440); | |
extry_test(initial_state, 0x8000000006804480); | |
extry_test(initial_state, 0x8000000006c044c0); | |
extry_test(initial_state, 0x8000000006c04580); | |
extry_test(initial_state, 0x80000000070040c0); | |
extry_test(initial_state, 0x8000000007004500); | |
extry_test(initial_state, 0x80000000074041c0); | |
extry_test(initial_state, 0x8000000007404540); | |
extry_test(initial_state, 0x80000000078044c0); | |
extry_test(initial_state, 0x8000000007804580); | |
extry_test(initial_state, 0x8000000007c045c0); | |
} | |
// FMA32 | |
static void fma32_test(void *initial_state, uint64_t operand) { | |
struct amx_state sim_state; | |
test_start(&sim_state); | |
memcpy(&sim_state, initial_state, sizeof sim_state); | |
load_amx_state(&sim_state); | |
// check_state(&sim_state, PF_F32); | |
AMX_FMA32(operand); | |
amx_state_fma32(&sim_state, operand); | |
check_state(&sim_state, PF_F32); | |
test_stop(&sim_state); | |
} | |
static void fms32_test(void *initial_state, uint64_t operand) { | |
struct amx_state sim_state; | |
test_start(&sim_state); | |
memcpy(&sim_state, initial_state, sizeof sim_state); | |
load_amx_state(&sim_state); | |
// check_state(&sim_state, PF_F32); | |
AMX_FMS32(operand); | |
amx_state_fms32(&sim_state, operand); | |
check_state(&sim_state, PF_F32); | |
test_stop(&sim_state); | |
} | |
static void test_fma32_fms32(void) { | |
static float test_data[(8 + 8 + 64) * 16]; | |
init_test_f32s(test_data); | |
// slow, but passes | |
// TODO: fields at bit 32/41/60 | |
// uint64_t mask = 0xa000da4f3bf709f8 & ~((0x3Full << 32) | (0x3Full << 41) | | |
// (0x3ull << 60)); | |
uint64_t offset_mask = 0x101; // optimised - thorough is 0x1FF | |
uint64_t mask = (1ull << 63) | (1ull << 27) | (1ull << 28) | (1ull << 29) | | |
(63 << 20) | (offset_mask << 10) | offset_mask; | |
uint64_t operand = 0; | |
do { | |
fma32_test(test_data, operand); | |
// TODO: this is almost working, but when we clear we have problems with | |
// negative zero | |
fms32_test(test_data, operand & ~(1ull << 27)); | |
operand = (operand - mask) & mask; | |
} while (operand); | |
// TODO: commented out things (and the rest, as some probably interact with | |
// other bits) | |
for (int i = 0; i < 32; i++) { | |
fma32_test(test_data, (1ull << i)); | |
} | |
// fma32_test(test_data, (1ull << 32)); | |
// fma32_test(test_data, (1ull << 33)); | |
// fma32_test(test_data, (1ull << 34)); | |
// fma32_test(test_data, (1ull << 35)); | |
// fma32_test(test_data, (1ull << 36)); | |
// fma32_test(test_data, (1ull << 37)); | |
fma32_test(test_data, (1ull << 38)); | |
fma32_test(test_data, (1ull << 39)); | |
fma32_test(test_data, (1ull << 40)); | |
// fma32_test(test_data, (1ull << 41)); | |
// fma32_test(test_data, (1ull << 42)); | |
// fma32_test(test_data, (1ull << 43)); | |
// fma32_test(test_data, (1ull << 44)); | |
// fma32_test(test_data, (1ull << 45)); | |
// fma32_test(test_data, (1ull << 46)); | |
fma32_test(test_data, (1ull << 47)); | |
fma32_test(test_data, (1ull << 48)); | |
fma32_test(test_data, (1ull << 49)); | |
fma32_test(test_data, (1ull << 50)); | |
fma32_test(test_data, (1ull << 51)); | |
fma32_test(test_data, (1ull << 52)); | |
fma32_test(test_data, (1ull << 53)); | |
fma32_test(test_data, (1ull << 54)); | |
fma32_test(test_data, (1ull << 55)); | |
fma32_test(test_data, (1ull << 56)); | |
fma32_test(test_data, (1ull << 57)); | |
fma32_test(test_data, (1ull << 58)); | |
fma32_test(test_data, (1ull << 59)); | |
// fma32_test(test_data, (1ull << 60)); | |
// fma32_test(test_data, (1ull << 61)); | |
fma32_test(test_data, (1ull << 62)); | |
fma32_test(test_data, (1ull << 63)); | |
} | |
// FMA64/FMS64 | |
static void init_test_f64s(double *test) { | |
for (uint64_t v = 1000, o = 0; o < (8 + 8) * 8; v++) { | |
if (is_prime(v)) { | |
test[o++] = v; | |
} | |
} | |
for (uint64_t o = (8 + 8) * 8; o < (8 + 8 + 64) * 8; o++) { | |
test[o] = 1.0f; | |
} | |
} | |
static void fma64_test(void *initial_state, uint64_t operand) { | |
struct amx_state sim_state; | |
test_start(&sim_state); | |
memcpy(&sim_state, initial_state, sizeof sim_state); | |
load_amx_state(&sim_state); | |
AMX_FMA64(operand); | |
amx_state_fma64(&sim_state, operand); | |
check_state(&sim_state, PF_F64); | |
test_stop(&sim_state); | |
} | |
static void fms64_test(void *initial_state, uint64_t operand) { | |
struct amx_state sim_state; | |
test_start(&sim_state); | |
memcpy(&sim_state, initial_state, sizeof sim_state); | |
load_amx_state(&sim_state); | |
AMX_FMS64(operand); | |
amx_state_fms64(&sim_state, operand); | |
check_state(&sim_state, PF_F64); | |
test_stop(&sim_state); | |
} | |
static void test_fma64_fms64(void) { | |
static double test_data[(8 + 8 + 64) * 8]; | |
init_test_f64s(test_data); | |
fma64_test(test_data, 0); | |
// slow, but passes | |
// uint64_t mask = 0xa000da4f3bf709f8 & ~((0x3Full << 32) | (0x3Full << 41)); | |
uint64_t offset_mask = 0x101; // optimised - thorough is 0x1FF | |
uint64_t mask = (1ull << 63) | (1ull << 27) | (1ull << 28) | (1ull << 29) | | |
(63 << 20) | (offset_mask << 10) | offset_mask; | |
uint64_t operand = 0; | |
do { | |
fma64_test(test_data, operand); | |
// TODO: this is almost working, but when we clear we have problems with | |
// negative zero | |
fms64_test(test_data, operand & ~(1ull << 27)); | |
// ryg's texture tiling and swizzling loop | |
operand = (operand - mask) & mask; | |
} while (operand); | |
} | |
// FMA16/FMS16 | |
static void fma16_test(void *initial_state, uint64_t operand) { | |
struct amx_state sim_state; | |
test_start(&sim_state); | |
memcpy(&sim_state, initial_state, sizeof sim_state); | |
load_amx_state(&sim_state); | |
AMX_FMA16(operand); | |
amx_state_fma16(&sim_state, operand); | |
if (!check_state(&sim_state, PF_U16)) { | |
// check_state(&sim_state, PF_F16); | |
check_state(&sim_state, PF_F32); | |
printf("^ %llx\n\n", operand); | |
} | |
test_stop(&sim_state); | |
} | |
static void fms16_test(void *initial_state, uint64_t operand) { | |
struct amx_state sim_state; | |
test_start(&sim_state); | |
memcpy(&sim_state, initial_state, sizeof sim_state); | |
load_amx_state(&sim_state); | |
AMX_FMS16(operand); | |
amx_state_fms16(&sim_state, operand); | |
if (!check_state(&sim_state, PF_U16)) { | |
// check_state(&sim_state, PF_F16); | |
check_state(&sim_state, PF_F32); | |
printf("^ %llx\n\n", operand); | |
} | |
test_stop(&sim_state); | |
} | |
static void init_test_f16s(float16 *test) { | |
for (uint64_t v = 5, o = 0; o < (8 + 8) * 32; v++) { | |
if (is_prime(v)) { | |
test[o++] = v; | |
} | |
} | |
for (uint64_t o = (8 + 8) * 32; o < (8 + 8 + 64) * 32; o++) { | |
test[o] = 1.0f; | |
} | |
} | |
static void test_fma16_fms16(void) { | |
// TODO: There's something wrong with NaN accuracy - probably in the others as | |
// well but it shows up here if we let offset_mask be odd. | |
static float16 test_data[(8 + 8 + 64) * 32]; | |
init_test_f16s(test_data); | |
uint64_t offset_mask = 0x102; // optimised - thorough is 0x1FF | |
uint64_t mask = (1ull << 63) | (1ull << 62) | (1ull << 27) | (1ull << 28) | | |
(1ull << 29) | (63 << 20) | (offset_mask << 10) | offset_mask; | |
uint64_t operand = 0; | |
do { | |
fma16_test(test_data, operand); | |
// TODO: this is almost working, but when we clear we have problems with | |
// negative zero | |
fms16_test(test_data, operand & ~(1ull << 27)); | |
// ryg's texture tiling and swizzling loop | |
operand = (operand - mask) & mask; | |
} while (operand); | |
} | |
// MAC16 | |
static void mac16_test(void *initial_state, uint64_t operand) { | |
struct amx_state sim_state; | |
test_start(&sim_state); | |
memcpy(&sim_state, initial_state, sizeof sim_state); | |
load_amx_state(&sim_state); | |
AMX_MAC16(operand); | |
amx_state_mac16(&sim_state, operand); | |
if (!check_state(&sim_state, PF_U16)) { | |
printf("^ %llx\n\n", operand); | |
} | |
test_stop(&sim_state); | |
} | |
static void init_test_u16s(uint16_t *test) { | |
for (uint64_t v = 0x8000 - 100, o = 0; o < (8 + 8) * 32; v++) { | |
if (is_prime(v)) { | |
test[o++] = v; | |
} | |
} | |
for (uint64_t o = (8 + 8) * 32; o < (8 + 8 + 64) * 32; o++) { | |
test[o] = 1; | |
} | |
} | |
static void test_mac16(void) { | |
static uint16_t test_data[(8 + 8 + 64) * 32]; | |
init_test_u16s(test_data); | |
uint64_t offset_mask = 0x101; // optimised - thorough is 0x1FF | |
uint64_t mask = (1ull << 63) | (1ull << 62) | (1ull << 27) | (1ull << 28) | | |
(1ull << 29) | (63 << 20) | (offset_mask << 10) | offset_mask; | |
uint64_t operand = 0; | |
do { | |
mac16_test(test_data, operand); | |
operand = (operand - mask) & mask; | |
} while (operand); | |
} | |
int main(int argc, char **argv) { | |
// TODO: test stores | |
test_loads(); | |
test_stores(); | |
test_mac16(); | |
test_fma16_fms16(); | |
test_fma32_fms32(); | |
test_fma64_fms64(); | |
test_extrx(); | |
test_extry(); | |
return 0; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#include <assert.h> | |
#include <stdbool.h> | |
#include <stdint.h> | |
#include <stdio.h> | |
#include <string.h> | |
#include "amx.h" | |
#ifdef __aarch64__ | |
double fma64(double a, double b, double c) { | |
double out; | |
asm("fmadd %d[out], %d[a], %d[b], %d[c]" | |
: [out] "=w"(out) | |
: [a] "w"(a), [b] "w"(b), [c] "w"(c)); | |
return out; | |
} | |
float fma32(float a, float b, float c) { | |
float out; | |
asm("fmadd %s[out], %s[a], %s[b], %s[c]" | |
: [out] "=w"(out) | |
: [a] "w"(a), [b] "w"(b), [c] "w"(c)); | |
return out; | |
} | |
float fms32(float a, float b, float c) { | |
float out; | |
asm("fmsub %s[out], %s[a], %s[b], %s[c]" | |
: [out] "=w"(out) | |
: [a] "w"(a), [b] "w"(b), [c] "w"(c)); | |
return out; | |
} | |
float16 fma16(float16 a, float16 b, float16 c) { | |
float16 out; | |
asm("fmadd %h[out], %h[a], %h[b], %h[c]" | |
: [out] "=w"(out) | |
: [a] "w"(a), [b] "w"(b), [c] "w"(c)); | |
return out; | |
} | |
#else | |
#error "TODO: portable fma64/fma32/fma16 implementations" | |
float fma32(float a, float b, float c) { | |
#pragma clang fp contract(fast) | |
return a * b + c; | |
} | |
double fma64(double a, double b, double c) { | |
#pragma clang fp contract(fast) | |
return a * b + c; | |
} | |
#endif | |
uint16_t mac16(uint16_t a, uint16_t b, uint16_t c) { | |
return (uint16_t)(((int64_t)a * (int64_t)b + (int64_t)c) & 0xFFFF); | |
} | |
void amx_state_zero(struct amx_state *state) { | |
memset(state, 0, sizeof *state); | |
} | |
#define LDST_ADDRESS_MASK ((1ull << 56) - 1) | |
#define LDST_DOUBLE_WIDTH (1ull << 62) | |
static void amx_state_load_impl(union amx_row *rows, int mask, | |
uint64_t operand) { | |
uint64_t double_width = (operand & LDST_DOUBLE_WIDTH); | |
if (double_width && (operand & 0x7F) != 0) { | |
// TODO: some way to test exceptions | |
printf("error: bad alignment\n"); | |
} | |
char *addr = (char *)(operand & LDST_ADDRESS_MASK); | |
uint64_t reg = (operand >> 56) & mask; | |
memcpy(&rows[reg], addr, 0x40); | |
if (double_width) { | |
memcpy(&rows[(reg + 1) & mask], addr + 0x40, 0x40); | |
} | |
} | |
static void amx_state_store_impl(union amx_row *rows, int mask, | |
uint64_t operand) { | |
uint64_t double_width = (operand & LDST_DOUBLE_WIDTH); | |
if (double_width && (operand & 0x7F) != 0) { | |
// TODO: some way to test exceptions | |
printf("error: bad alignment\n"); | |
} | |
char *addr = (char *)(operand & LDST_ADDRESS_MASK); | |
uint64_t reg = (operand >> 56) & mask; | |
memcpy(addr, &rows[reg], 0x40); | |
if (double_width) { | |
memcpy(addr + 0x40, &rows[(reg + 1) & mask], 0x40); | |
} | |
} | |
void amx_state_ldx(struct amx_state *state, uint64_t operand) { | |
amx_state_load_impl(state->x, 7, operand); | |
} | |
void amx_state_ldy(struct amx_state *state, uint64_t operand) { | |
amx_state_load_impl(state->y, 7, operand); | |
} | |
void amx_state_ldz(struct amx_state *state, uint64_t operand) { | |
amx_state_load_impl(state->z, 0x3F, operand); | |
} | |
void amx_state_stx(struct amx_state *state, uint64_t operand) { | |
amx_state_store_impl(state->x, 7, operand); | |
} | |
void amx_state_sty(struct amx_state *state, uint64_t operand) { | |
amx_state_store_impl(state->y, 7, operand); | |
} | |
void amx_state_stz(struct amx_state *state, uint64_t operand) { | |
amx_state_store_impl(state->z, 0x3F, operand); | |
} | |
void amx_state_ldzi(struct amx_state *state, uint64_t operand) { | |
char *addr = (char *)(operand & LDST_ADDRESS_MASK); | |
uint32_t row[16]; | |
memcpy(row, addr, sizeof row); | |
uint64_t reg = (operand >> 56) & 0x3F; | |
for (int i = 0; i < 16; i++) { | |
state->z[(reg & ~1) + (i & 1)].u32[((reg & 1) << 3) + (i >> 1)] = row[i]; | |
} | |
} | |
void amx_state_stzi(struct amx_state *state, uint64_t operand) { | |
uint64_t reg = (operand >> 56) & 0x3F; | |
char *addr = (char *)(operand & LDST_ADDRESS_MASK); | |
uint32_t row[16]; | |
for (int i = 0; i < 16; i++) { | |
row[i] = state->z[(reg & ~1) + (i & 1)].u32[((reg & 1) << 3) + (i >> 1)]; | |
} | |
memcpy(addr, row, sizeof row); | |
} | |
static void load_from_x(void *output, struct amx_state *state, size_t offset, | |
size_t size) { | |
char *p = (char *)output; | |
for (size_t i = 0; i < size; i++) { | |
memcpy(p++, ((char *)&state->x) + ((offset + i) & 0x1FF), 1); | |
} | |
} | |
static void load_from_y(void *output, struct amx_state *state, size_t offset, | |
size_t size) { | |
char *p = (char *)output; | |
for (size_t i = 0; i < size; i++) { | |
memcpy(p++, ((char *)&state->y) + ((offset + i) & 0x1FF), 1); | |
} | |
} | |
static void load_from_z(void *output, struct amx_state *state, size_t offset, | |
size_t size) { | |
char *p = (char *)output; | |
for (size_t i = 0; i < size; i++) { | |
memcpy(p++, ((char *)&state->z) + ((offset + i) & 0xFFF), 1); | |
} | |
} | |
static void store_to_x(void *output, struct amx_state *state, size_t offset, | |
size_t size) { | |
char *p = (char *)output; | |
for (size_t i = 0; i < size; i++) { | |
memcpy(((char *)&state->x) + ((offset + i) & 0x1FF), p++, 1); | |
} | |
} | |
static void store_to_y(void *output, struct amx_state *state, size_t offset, | |
size_t size) { | |
char *p = (char *)output; | |
for (size_t i = 0; i < size; i++) { | |
memcpy(((char *)&state->y) + ((offset + i) & 0x1FF), p++, 1); | |
} | |
} | |
#define FMA_SKIP_Z_INPUT (1ull << 27) | |
#define FMA_SKIP_Y_INPUT (1ull << 28) | |
#define FMA_SKIP_X_INPUT (1ull << 29) | |
static void amx_state_fmas32_impl(struct amx_state *state, uint64_t operand, | |
bool sub) { | |
float x[16]; | |
float y[16]; | |
// TODO: need to wrap around byte offsets | |
uint64_t y_offset = operand & 0x1FF; | |
uint64_t x_offset = (operand >> 10) & 0x1FF; | |
uint64_t z_offset = (operand >> 20) & 63; | |
assert((sizeof x) == 0x40); | |
if ((operand & FMA_SKIP_Y_INPUT) && (operand & FMA_SKIP_X_INPUT)) { | |
memset(&x, 0, sizeof x); | |
memset(&y, 0, sizeof y); | |
} else { | |
if (operand & FMA_SKIP_X_INPUT) { | |
for (int i = 0; i < 16; i++) { | |
x[i] = 1.0; | |
} | |
} else { | |
load_from_x(x, state, x_offset, sizeof x); | |
} | |
if (operand & FMA_SKIP_Y_INPUT) { | |
for (int i = 0; i < 16; i++) { | |
y[i] = 1.0; | |
} | |
} else { | |
load_from_y(y, state, y_offset, sizeof y); | |
} | |
} | |
float sub_mul = sub ? -1.0 : 1.0; | |
if (operand & (1ull << 63)) { | |
for (int i = 0; i < 16; i++) { | |
float *z = &state->z[z_offset].f32[i]; | |
*z = | |
fma32(sub_mul * x[i], y[i], (operand & FMA_SKIP_Z_INPUT) ? 0.0f : *z); | |
} | |
} else { | |
z_offset &= 3; | |
for (int i = 0; i < 16; i++) { | |
for (int j = 0; j < 16; j++) { | |
float *z = &state->z[(j * 4) + z_offset].f32[i]; | |
*z = fma32(sub_mul * x[i], y[j], | |
(operand & FMA_SKIP_Z_INPUT) ? 0.0f : *z); | |
} | |
} | |
} | |
} | |
void amx_state_fma32(struct amx_state *state, uint64_t operand) { | |
amx_state_fmas32_impl(state, operand, false); | |
} | |
void amx_state_fms32(struct amx_state *state, uint64_t operand) { | |
amx_state_fmas32_impl(state, operand, true); | |
} | |
static void amx_state_fmas64_impl(struct amx_state *state, uint64_t operand, | |
bool sub) { | |
double x[8]; | |
double y[8]; | |
// TODO: need to wrap around byte offsets | |
uint64_t y_offset = operand & 0x1FF; | |
uint64_t x_offset = (operand >> 10) & 0x1FF; | |
uint64_t z_offset = (operand >> 20) & 63; | |
assert((sizeof x) == 0x40); | |
if ((operand & FMA_SKIP_Y_INPUT) && (operand & FMA_SKIP_X_INPUT)) { | |
memset(&x, 0, sizeof x); | |
memset(&y, 0, sizeof y); | |
} else { | |
if (operand & FMA_SKIP_X_INPUT) { | |
for (int i = 0; i < 8; i++) { | |
x[i] = 1.0; | |
} | |
} else { | |
load_from_x(x, state, x_offset, sizeof x); | |
} | |
if (operand & FMA_SKIP_Y_INPUT) { | |
for (int i = 0; i < 8; i++) { | |
y[i] = 1.0; | |
} | |
} else { | |
load_from_y(y, state, y_offset, sizeof y); | |
} | |
} | |
double sub_mul = sub ? -1.0 : 1.0; | |
if (operand & (1ull << 63)) { | |
for (int i = 0; i < 8; i++) { | |
double *z = &state->z[z_offset].f64[i]; | |
*z = | |
fma64(sub_mul * x[i], y[i], (operand & FMA_SKIP_Z_INPUT) ? 0.0f : *z); | |
} | |
} else { | |
z_offset &= 7; | |
for (int i = 0; i < 8; i++) { | |
for (int j = 0; j < 8; j++) { | |
double *z = &state->z[(j * 8) + z_offset].f64[i]; | |
*z = fma64(sub_mul * x[i], y[j], | |
(operand & FMA_SKIP_Z_INPUT) ? 0.0f : *z); | |
} | |
} | |
} | |
} | |
void amx_state_fma64(struct amx_state *state, uint64_t operand) { | |
amx_state_fmas64_impl(state, operand, false); | |
} | |
void amx_state_fms64(struct amx_state *state, uint64_t operand) { | |
amx_state_fmas64_impl(state, operand, true); | |
} | |
static void amx_state_fmas16_impl(struct amx_state *state, uint64_t operand, | |
bool sub) { | |
float16 x[32]; | |
float16 y[32]; | |
// TODO: need to wrap around byte offsets | |
uint64_t y_offset = operand & 0x1FF; | |
uint64_t x_offset = (operand >> 10) & 0x1FF; | |
uint64_t z_offset = (operand >> 20) & 63; | |
assert((sizeof x) == 0x40); | |
if ((operand & FMA_SKIP_Y_INPUT) && (operand & FMA_SKIP_X_INPUT)) { | |
memset(&x, 0, sizeof x); | |
memset(&y, 0, sizeof y); | |
} else { | |
if (operand & FMA_SKIP_X_INPUT) { | |
for (int i = 0; i < 32; i++) { | |
x[i] = (float16)1.0; | |
} | |
} else { | |
load_from_x(x, state, x_offset, sizeof x); | |
} | |
if (operand & FMA_SKIP_Y_INPUT) { | |
for (int i = 0; i < 32; i++) { | |
y[i] = (float16)1.0; | |
} | |
} else { | |
load_from_y(y, state, y_offset, sizeof y); | |
} | |
} | |
float16 sub_mul = sub ? -1.0 : 1.0; | |
if (operand & (1ull << 63)) { | |
for (int i = 0; i < 32; i++) { | |
float16 *z = &state->z[z_offset].f16[i]; | |
*z = fma16(sub_mul * x[i], y[i], | |
(operand & FMA_SKIP_Z_INPUT) ? (float16)0.0f : *z); | |
} | |
} else { | |
z_offset &= 1; | |
for (int i = 0; i < 32; i++) { | |
for (int j = 0; j < 32; j++) { | |
if (operand & (1ull << 62)) { | |
float *z = &state->z[(j * 2) + (i & 1)].f32[i >> 1]; | |
float acc = (operand & FMA_SKIP_Z_INPUT) ? 0 : *z; | |
acc = fma32((float)sub_mul * (float)x[i], (float)y[j], acc); | |
*z = acc; | |
} else { | |
float16 *z = &state->z[(j * 2) + z_offset].f16[i]; | |
*z = fma16(sub_mul * x[i], y[j], | |
(operand & FMA_SKIP_Z_INPUT) ? (float16)0.0f : *z); | |
} | |
} | |
} | |
} | |
} | |
void amx_state_fma16(struct amx_state *state, uint64_t operand) { | |
amx_state_fmas16_impl(state, operand, false); | |
} | |
void amx_state_fms16(struct amx_state *state, uint64_t operand) { | |
amx_state_fmas16_impl(state, operand, true); | |
} | |
void amx_state_mac16(struct amx_state *state, uint64_t operand) { | |
uint16_t x[32]; | |
uint16_t y[32]; | |
// TODO: need to wrap around byte offsets | |
uint64_t y_offset = operand & 0x1FF; | |
uint64_t x_offset = (operand >> 10) & 0x1FF; | |
uint64_t z_offset = (operand >> 20) & 63; | |
assert((sizeof x) == 0x40); | |
if ((operand & FMA_SKIP_Y_INPUT) && (operand & FMA_SKIP_X_INPUT)) { | |
memset(&x, 0, sizeof x); | |
memset(&y, 0, sizeof y); | |
} else { | |
if (operand & FMA_SKIP_X_INPUT) { | |
for (int i = 0; i < 32; i++) { | |
x[i] = (uint16_t)1; | |
} | |
} else { | |
load_from_x(x, state, x_offset, sizeof x); | |
} | |
if (operand & FMA_SKIP_Y_INPUT) { | |
for (int i = 0; i < 32; i++) { | |
y[i] = (uint16_t)1; | |
} | |
} else { | |
load_from_y(y, state, y_offset, sizeof y); | |
} | |
} | |
if (operand & (1ull << 63)) { | |
for (int i = 0; i < 32; i++) { | |
uint16_t *z = &state->z[z_offset].u16[i]; | |
*z = mac16(x[i], y[i], (operand & FMA_SKIP_Z_INPUT) ? (uint16_t)0 : *z); | |
} | |
} else { | |
z_offset &= 1; | |
for (int i = 0; i < 32; i++) { | |
for (int j = 0; j < 32; j++) { | |
if (operand & (1ull << 62)) { | |
uint32_t *z = &state->z[(j * 2) + (i & 1)].u32[i >> 1]; | |
uint32_t acc = (operand & FMA_SKIP_Z_INPUT) ? 0 : *z; | |
acc += (uint32_t)((int64_t)(int16_t)x[i] * (int64_t)(int16_t)y[j]); | |
*z = acc; | |
} else { | |
uint16_t *z = &state->z[(j * 2) + z_offset].u16[i]; | |
*z = mac16(x[i], y[j], | |
(operand & FMA_SKIP_Z_INPUT) ? (uint16_t)0 : *z); | |
} | |
} | |
} | |
} | |
} | |
void amx_state_extrx(struct amx_state *state, uint64_t operand) { | |
// uint64_t y_offset = operand & 0x1FF; | |
uint64_t x_offset = (operand >> 10) & 0x1FF; | |
uint64_t z_offset = (operand >> 20) & 63; | |
uint32_t buffer[16]; | |
if ((operand & (1ull << 27))) { | |
x_offset &= ~0x3F; | |
load_from_y(buffer, state, z_offset * 0x40, 0x40); | |
} else { | |
load_from_z(buffer, state, z_offset * 0x40, 0x40); | |
} | |
store_to_x(buffer, state, x_offset, 0x40); | |
} | |
void amx_state_extry(struct amx_state *state, uint64_t operand) { | |
// I have misgivings about calling this "extry" as it does sometimes move to | |
// "x" | |
uint64_t y_offset = operand & 0x1FF; | |
uint64_t z_offset = (operand >> 20) & 63; | |
if (operand & (1ull << 26)) { | |
// TODO: might be a good place to use the "union amx_row" type | |
uint8_t buffer[64]; | |
uint64_t operation = ((operand >> 11) & 0xF) | ((operand >> 63) << 4); | |
switch (operation) { | |
case 0x00: { | |
for (int i = 0; i < 64; i++) { | |
buffer[i] = state->z[i].u8[z_offset & 63]; | |
} | |
} break; | |
case 0x0B: { | |
// TODO: mishandles z_offset | |
for (int i = 0; i < 64; i++) { | |
buffer[i] = state->z[i].u8[z_offset & 63]; | |
} | |
} break; | |
case 0x0D: { | |
// TODO: mishandles z_offset | |
for (int i = 0; i < 64; i++) { | |
buffer[i] = state->z[i].u8[z_offset & 63]; | |
} | |
} break; | |
case 0x09: { | |
// TODO: mishandles z_offset | |
uint16_t buffer1[32]; | |
for (int i = 0; i < 32; i++) { | |
buffer1[i] = | |
state->z[((i & ~1) << 1) + (z_offset & ((1 << 1) - 1)) + (i & 1)] | |
.u16[(z_offset >> 1) & 31]; | |
} | |
memcpy(buffer, buffer1, sizeof buffer); | |
} break; | |
case 0x0A: { | |
// TODO: still mishandles z_offset | |
uint16_t buffer1[32]; | |
for (int i = 0; i < 32; i++) { | |
buffer1[i] = | |
state->z[((i << 1) + (z_offset & ((1 << 1) - 1))) ^ (z_offset & 2)] | |
.u16[((z_offset >> 2) & 31)]; | |
} | |
memcpy(buffer, buffer1, sizeof buffer); | |
} break; | |
case 0x11: { | |
uint64_t buffer1[8]; | |
for (int i = 0; i < 8; i++) { | |
buffer1[i] = state->z[(i << 3) + (z_offset & ((1 << 3) - 1))] | |
.u64[(z_offset >> 3) & 7]; | |
} | |
memcpy(buffer, buffer1, sizeof buffer); | |
} break; | |
// (does this have an undiscovered difference?) | |
case 0x08: | |
case 0x18: { | |
uint32_t buffer1[16]; | |
for (int i = 0; i < 16; i++) { | |
buffer1[i] = state->z[(i << 2) + (z_offset & ((1 << 2) - 1))] | |
.u32[(z_offset >> 2) & 15]; | |
} | |
memcpy(buffer, buffer1, sizeof buffer); | |
} break; | |
default: { | |
uint16_t buffer1[32]; | |
for (int i = 0; i < 32; i++) { | |
buffer1[i] = state->z[(i << 1) + (z_offset & ((1 << 1) - 1))] | |
.u16[(z_offset >> 1) & 31]; | |
} | |
memcpy(buffer, buffer1, sizeof buffer); | |
} break; | |
} | |
if (operand & (1ull << 10)) { | |
store_to_y(buffer, state, y_offset, 0x40); | |
} else { | |
store_to_x(buffer, state, y_offset, 0x40); | |
} | |
} else if ((operand & (1ull << 27))) { | |
y_offset &= ~0x3F; | |
uint32_t buffer[16]; | |
load_from_x(buffer, state, z_offset * 0x40, 0x40); | |
store_to_y(buffer, state, y_offset, 0x40); | |
} else { | |
// TODO: rewrite as a switch on (operand >> 28) & 3? | |
if ((operand & (1ull << 29)) && (operand & (1ull << 28))) { | |
uint16_t buffer[32]; | |
load_from_y(buffer, state, y_offset, 0x40); | |
for (int i = 0; i < 32; i++) { | |
buffer[i] &= 0xFF00; | |
buffer[i] |= state->z[(i << 1) + (z_offset & ((1 << 1) - 1))] | |
.u16[(z_offset >> 1) & 31] & | |
0xFF; | |
} | |
store_to_y(buffer, state, y_offset, 0x40); | |
} else if (operand & (1ull << 29)) { | |
uint16_t buffer[32]; | |
for (int i = 0; i < 32; i++) { | |
buffer[i] = state->z[(i << 1) + (z_offset & ((1 << 1) - 1))] | |
.u16[(z_offset >> 1) & 31]; | |
} | |
store_to_y(buffer, state, y_offset, 0x40); | |
} else if (operand & (1ull << 28)) { | |
uint32_t buffer[16]; | |
for (int i = 0; i < 16; i++) { | |
buffer[i] = state->z[(i << 2) + (z_offset & ((1 << 2) - 1))] | |
.u32[(z_offset >> 2) & 15]; | |
} | |
store_to_y(buffer, state, y_offset, 0x40); | |
} else { | |
uint64_t buffer[8]; | |
for (int i = 0; i < 8; i++) { | |
buffer[i] = state->z[(i << 3) + (z_offset & ((1 << 3) - 1))] | |
.u64[(z_offset >> 3) & 7]; | |
} | |
store_to_y(buffer, state, y_offset, 0x40); | |
} | |
} | |
} | |
// print flags | |
enum print_flags { | |
PF_TYPE_MASK = 0x0F, | |
PF_U32 = 0x00, | |
PF_U16 = 0x01, | |
PF_U8 = 0x02, | |
PF_F16 = 0x03, | |
PF_F32 = 0x04, | |
PF_F64 = 0x05, | |
PF_U64 = 0x06, | |
PF_SKIP_X = 0x10, | |
PF_SKIP_Y = 0x20, | |
PF_SKIP_Z = 0x40, | |
PF_SKIP_ZERO_ROWS = 0x80, | |
// for diffs | |
PF_SKIP_A = 0x10000, | |
PF_SKIP_B = 0x20000, | |
}; | |
void print_amx_row(char reg, int num, union amx_row *r, int flags) { | |
bool should_print = true; | |
if (flags & PF_SKIP_ZERO_ROWS) { | |
should_print = false; | |
for (uint64_t o = 0; o < 16; o++) { | |
if (r->u32[o] != 0) { | |
should_print = true; | |
break; | |
} | |
} | |
} | |
if (should_print) { | |
if (num < 10) | |
printf(" "); | |
printf(" %c%d:", reg, num); | |
switch (flags & PF_TYPE_MASK) { | |
case PF_U32: | |
for (uint64_t o = 0; o < 16; o++) { | |
printf(" %8x", r->u32[o]); | |
} | |
break; | |
case PF_U16: | |
for (uint64_t o = 0; o < 32; o++) { | |
printf(" %4x", r->u16[o]); | |
} | |
break; | |
case PF_U8: | |
for (uint64_t o = 0; o < 64; o++) { | |
printf(" %2x", r->u8[o]); | |
} | |
break; | |
case PF_F16: | |
for (uint64_t o = 0; o < 32; o++) { | |
printf(" %16f", (float)r->f16[o]); | |
} | |
break; | |
case PF_F32: | |
for (uint64_t o = 0; o < 16; o++) { | |
printf(" %16f", r->f32[o]); | |
} | |
break; | |
case PF_F64: | |
for (uint64_t o = 0; o < 8; o++) { | |
printf(" %16lf", r->f64[o]); | |
} | |
break; | |
case PF_U64: | |
for (uint64_t o = 0; o < 8; o++) { | |
printf(" %16llx", r->u64[o]); | |
} | |
break; | |
default: | |
assert(0 && "invalid print flags"); | |
} | |
printf("\n"); | |
} | |
} | |
void print_amx_state(struct amx_state *state, int flags) { | |
if (!(flags & PF_SKIP_X)) { | |
for (int i = 0; i < 8; i++) { | |
print_amx_row('x', i, &state->x[i], flags); | |
} | |
} | |
if (!(flags & PF_SKIP_Y)) { | |
for (int i = 0; i < 8; i++) { | |
print_amx_row('y', i, &state->y[i], flags); | |
} | |
} | |
if (!(flags & PF_SKIP_Z)) { | |
for (int i = 0; i < 64; i++) { | |
print_amx_row('z', i, &state->z[i], flags); | |
} | |
} | |
} | |
int diff_amx_state(struct amx_state *a, struct amx_state *b, int flags) { | |
int same = 1; | |
if (!(flags & PF_SKIP_X)) { | |
for (int i = 0; i < 8; i++) { | |
if (memcmp(&a->x[i], &b->x[i], sizeof(union amx_row))) { | |
if (!(flags & PF_SKIP_A)) | |
print_amx_row('x', i, &a->x[i], flags); | |
if (!(flags & PF_SKIP_B)) | |
print_amx_row('x', i, &b->x[i], flags); | |
same = 0; | |
} | |
} | |
} | |
if (!(flags & PF_SKIP_Y)) { | |
for (int i = 0; i < 8; i++) { | |
if (memcmp(&a->y[i], &b->y[i], sizeof(union amx_row))) { | |
if (!(flags & PF_SKIP_A)) | |
print_amx_row('y', i, &a->y[i], flags); | |
if (!(flags & PF_SKIP_B)) | |
print_amx_row('y', i, &b->y[i], flags); | |
same = 0; | |
} | |
} | |
} | |
if (!(flags & PF_SKIP_Z)) { | |
for (int i = 0; i < 64; i++) { | |
if (memcmp(&a->z[i], &b->z[i], sizeof(union amx_row))) { | |
if (!(flags & PF_SKIP_A)) | |
print_amx_row('z', i, &a->z[i], flags); | |
if (!(flags & PF_SKIP_B)) | |
print_amx_row('z', i, &b->z[i], flags); | |
same = 0; | |
} | |
} | |
} | |
return same; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment