Skip to content

Instantly share code, notes, and snippets.

@philipturner
Last active May 28, 2024 00:00
Show Gist options
  • Save philipturner/fe0621333ab6a827ed0d35c5fa226be6 to your computer and use it in GitHub Desktop.
Save philipturner/fe0621333ab6a827ed0d35c5fa226be6 to your computer and use it in GitHub Desktop.
Hacking AIR to open up the Apple GPU to general-purpose linear algebra
//
// main.swift
// HackingAIR
//
// Created by Philip Turner on 5/27/24.
//
import Metal
// Hacking AIR to open up the Apple GPU to general-purpose linear algebra.
//
// ========================================================================== //
// Methods
// ========================================================================== //
//
// The shader code 'metalSimdgroupEvent' makes all five async copy instructions
// accessible to the Metal JIT compiler. I cannot stress how important this is
// for M1 and M2 users.
// - Asynchronously copy from device -> threadgroup, 1D.
// - Asynchronously copy from threadgroup -> device, 1D.
// - Asynchronously copy from device -> threadgroup, 2D.
// - Asynchronously copy from threadgroup -> device, 2D.
// - Wait on a simdgroup event.
//
// I have not tested the correctness of the exposed 1D async copy functions.
// Only the 2D async copies, which are required for matrix multiplication.
//
// Insights about invoking the generation of LLVM bitcode.
// - To add 'nounwind' to a function, prepend the declaration with:
// __attribute__((__nothrow__))
// - To add 'nocapture' to an argument, add the following after the argument
// type and name:
// __attribute__((__noescape__))
// - Not yet figured out the 'argmemonly' attribute, but my intuition tells me
// there's a hidden way to activate it with plain C.
// - AIR seems to accept the functions without these annotations being added.
// Therefore, I omitted them from the final product to make it more legible.
//
// ========================================================================== //
// Results (Correctness)
// ========================================================================== //
//
// Checked against the analytical results for a Laplacian matrix multiplication.
// Agrees with analytical results to within Float32 rounding error.
//
// ========================================================================== //
// Results (Performance)
// ========================================================================== //
//
// M1 Max, macOS Sonoma, Metal JIT Compiler
// - problemSize = 64 | 45 GFLOPS (32x32x32, 128 threads/group)
// - problemSize = 96 | 112 GFLOPS (32x32x32, 128 threads/group)
// - problemSize = 128 | 217 GFLOPS (32x32x32, 128 threads/group)
// - problemSize = 192 | 503 GFLOPS (32x32x32, 128 threads/group)
// - problemSize = 256 | 886 GFLOPS (32x32x32, 128 threads/group)
// - problemSize = 384 | 3448 GFLOPS (32x32x32, 128 threads/group)
// - problemSize = 512 | 5314 GFLOPS (32x32x32, 128 threads/group)
// - problemSize = 640 | 6533 GFLOPS (32x32x32, 128 threads/group)
// - problemSize = 768 | 7030 GFLOPS (48x48x24, 128 threads/group)
// - problemSize = 896 | 7154 GFLOPS (48x48x24, 128 threads/group)
// - problemSize = 1024 | 7001 GFLOPS (48x48x24, 128 threads/group)
// - problemSize = 1152 | 8145 GFLOPS (48x48x24, 128 threads/group)
// - problemSize = 1280 | 7810 GFLOPS (48x48x24, 128 threads/group)
//
// M4, iOS 17, Metal JIT Compiler
// - problemSize = 64 | 14 GFLOPS (32x32x8, 32 threads/group)
// - problemSize = 96 | 80 GFLOPS (32x32x8, 32 threads/group)
// - problemSize = 128 | 162 GFLOPS (32x32x8, 32 threads/group)
// - problemSize = 192 | 369 GFLOPS (32x32x8, 32 threads/group)
// - problemSize = 256 | 620 GFLOPS (32x32x8, 32 threads/group)
// - problemSize = 384 | 1256 GFLOPS (32x32x8, 32 threads/group)
// - problemSize = 512 | 1746 GFLOPS (32x32x8, 32 threads/group)
// - problemSize = 640 | 1960 GFLOPS (32x32x8, 32 threads/group)
// - problemSize = 768 | 1955 GFLOPS (32x32x8, 32 threads/group)
// - problemSize = 896 | 2040 GFLOPS (32x32x8, 32 threads/group)
// - problemSize = 1024 | 2076 GFLOPS (32x32x8, 32 threads/group)
// - problemSize = 1152 | 2118 GFLOPS (32x32x8, 32 threads/group)
// - problemSize = 1280 | 2130 GFLOPS (32x32x8, 32 threads/group)
//
// Compare to reference data for a pre-compiled shader with async copy.
// https://gist.github.com/philipturner/3bda14e876a635e73745c42f2eb240c8
//
// M1 Max
// - problemSize = 64 | 45 vs 44 GFLOPS
// - problemSize = 96 | 112 vs 109 GFLOPS
// - problemSize = 128 | 217 vs 214 GFLOPS
// - problemSize = 192 | 503 vs 495 GFLOPS
// - problemSize = 256 | 886 vs 883 GFLOPS
// - problemSize = 384 | 3448 vs 2912 GFLOPS
// - problemSize = 512 | 5314 vs 4091 GFLOPS
// - problemSize = 640 | 6533 vs 6440 GFLOPS
// - problemSize = 768 | 7030 vs 7017 GFLOPS
// - problemSize = 896 | 7154 vs 7136 GFLOPS
// - problemSize = 1024 | 7001 vs 6966 GFLOPS
// - problemSize = 1152 | 8145 vs 8144 GFLOPS
// - problemSize = 1280 | 7810 vs 7813 GFLOPS
//
// M4
// - problemSize = 64 | 14 vs 39 GFLOPS
// - problemSize = 96 | 80 vs 32 GFLOPS
// - problemSize = 128 | 162 vs 94 GFLOPS
// - problemSize = 192 | 369 vs 364 GFLOPS
// - problemSize = 256 | 620 vs 654 GFLOPS
// - problemSize = 384 | 1256 vs 1270 GFLOPS
// - problemSize = 512 | 1746 vs 1626 GFLOPS
// - problemSize = 640 | 1960 vs 1947 GFLOPS
// - problemSize = 768 | 1955 vs 1955 GFLOPS
// - problemSize = 896 | 2040 vs 2034 GFLOPS
// - problemSize = 1024 | 2076 vs 2078 GFLOPS
// - problemSize = 1152 | 2118 vs 2119 GFLOPS
// - problemSize = 1280 | 2130 vs 2129 GFLOPS
//
// For M1 Max, performance is the same. This means we are triggering the same
// AIR codegen as Xcode 14.2.
//
// For M4, the first few numbers appeared to match MFA without async copy.
// However, the remaining numbers match MFA with async copy. The numbers for
// large matrices are more important, as they reveal the maximum performance
// attainable with a specific design choice. ALU utilization is ~61%,
// dismal compared to ~91% achieved without async copy. This result means
// async copy instructions are correctly generated on iOS. We will need them
// for the older A14, A15, and A16 chips.
//
// ========================================================================== //
// Discussion
// ========================================================================== //
//
// I now have hope that I can access hardware-accelerated matrix multiplication
// from OpenCL. Therefore, I could target both Apple GPUs and the AMD 7900 XTX
// from the same shader source. It will require a bit of extra time to adapt
// this bitcode injection to the cl2Metal compiler. I'm deferring that
// investigation to a later date.
//
// By itself, this workaround does not solve the problem of people needing to
// rely on MPS for matrix multiplications. There needs to be an accompanying
// reference implementation, of a near-optimal matrix multiplication kernel for
// all hardware (M1/M2 and M3/M4). In addition, this code only covers Float32.
// It is the hardest data type to run at max ALU% because it requires
// so much L1 bandwidth. If this data type can run at maximum performance,
// extending the code to Float16 should not be difficult.
//
// For BFloat16, there should be graphs and statistics about the inferior
// performance of MSL's built-in type / MLX's decoding algorithm on M1. It
// should provide an alternative algorithm for BF16 -> FP32 decoding, and
// strongly encourage users to activate it on M1. In addition, users should be
// conscious about how BF16 numbers are stored in registers on M1. BF16 numbers
// will most likely be stored in FP32 registers to avoid redundantly decoding
// every time they are accessed.
// MARK: - Shader Sources
let metalSimdgroupEvent: String = """
// -*- Metal -*-
//===-- metal_simdgroup_event ---------------------------------------------===//
// Copyright (c) 2024 Philip Turner. See MIT LICENSE
//===----------------------------------------------------------------------===//
#ifndef __METAL_SIMDGROUP_EVENT
#define __METAL_SIMDGROUP_EVENT
#if !defined(__HAVE_SIMDGROUP_FUTURE__)
// Invoking the generation of LLVM bitcode for async copies.
//
// %struct._simdgroup_event_t = type opaque
//
struct _simdgroup_event_t;
// Invoking the generation of LLVM bitcode for async copies.
//
thread _simdgroup_event_t*
__metal_simdgroup_async_copy_1d(
ulong, ulong, threadgroup void *, const device void *, ulong)
__asm("air.simdgroup_async_copy_1d.p3i8.p1i8");
// Invoking the generation of LLVM bitcode for async copies.
//
thread _simdgroup_event_t*
__metal_simdgroup_async_copy_1d(
ulong, ulong, device void *, const threadgroup void *, ulong)
__asm("air.simdgroup_async_copy_1d.p1i8.p3i8");
// Invoking the generation of LLVM bitcode for async copies.
//
// ; Function Attrs: argmemonly convergent nounwind
// declare %struct._simdgroup_event_t*
// @air.simdgroup_async_copy_2d.p3i8.p1i8(
// i64, i64, i8 addrspace(3)* nocapture writeonly,
// i64, i64, <2 x i64>, i8 addrspace(1)* nocapture readonly,
// i64, i64, <2 x i64>, <2 x i64>, i32)
// local_unnamed_addr #4
//
thread _simdgroup_event_t*
__metal_simdgroup_async_copy_2d(
ulong, ulong, threadgroup void *,
ulong, ulong, ulong2, const device void *,
ulong, ulong, ulong2, long2, int)
__asm("air.simdgroup_async_copy_2d.p3i8.p1i8");
// Invoking the generation of LLVM bitcode for async copies.
//
// ; Function Attrs: argmemonly convergent nounwind
// declare %struct._simdgroup_event_t*
// @air.simdgroup_async_copy_2d.p1i8.p3i8(
// i64, i64, i8 addrspace(1)* nocapture writeonly,
// i64, i64, <2 x i64>, i8 addrspace(3)* nocapture readonly,
// i64, i64, <2 x i64>, <2 x i64>, i32)
// local_unnamed_addr #4
//
thread _simdgroup_event_t*
__metal_simdgroup_async_copy_2d(
ulong, ulong, device void *,
ulong, ulong, ulong2, const threadgroup void *,
ulong, ulong, ulong2, long2, int)
__asm("air.simdgroup_async_copy_2d.p1i8.p3i8");
// Invoking the generation of LLVM bitcode for async copies.
//
// ; Function Attrs: convergent nounwind
// declare void
// @air.wait_simdgroup_events(i32, %struct._simdgroup_event_t** nocapture)
// local_unnamed_addr #3
//
void __metal_wait_simdgroup_events(
int, thread _simdgroup_event_t**)
__asm("air.wait_simdgroup_events");
#endif
#pragma METAL internals : enable
namespace metal
{
#if !defined(__HAVE_SIMDGROUP_FUTURE__)
enum class simdgroup_async_copy_clamp_mode {
clamp_to_zero = 0,
clamp_to_edge = 1
};
#endif
struct simdgroup_event {
METAL_FUNC simdgroup_event() thread {}
template <typename T>
METAL_FUNC void async_copy(threadgroup T *dst, const device T *src, ulong n_elements) thread {
event = __metal_simdgroup_async_copy_1d(sizeof(T), alignof(T), reinterpret_cast<threadgroup void *>(dst), reinterpret_cast<const device void *>(src), n_elements);
}
template <typename T>
METAL_FUNC void async_copy(device T *dst, const threadgroup T *src, ulong n_elements) thread {
event = __metal_simdgroup_async_copy_1d(sizeof(T), alignof(T), reinterpret_cast<device void *>(dst), reinterpret_cast<const threadgroup void *>(src), n_elements);
}
template <typename T>
METAL_FUNC void async_copy(threadgroup T *dst, ushort dst_elements_per_row, ushort2 dst_tile_dimensions, const device T *src, uint src_elements_per_row, ushort2 src_tile_dimensions, bool transpose_matrix = false, simdgroup_async_copy_clamp_mode clamp_mode = simdgroup_async_copy_clamp_mode::clamp_to_zero) thread {
if (transpose_matrix) {
src_tile_dimensions = src_tile_dimensions.yx;
dst_tile_dimensions = dst_tile_dimensions.yx;
}
event = __metal_simdgroup_async_copy_2d(sizeof(T), alignof(T), reinterpret_cast<threadgroup void *>(dst), ushort(dst_elements_per_row), 1, ulong2(dst_tile_dimensions), reinterpret_cast<const device void *>(src), uint(src_elements_per_row), 1, ulong2(src_tile_dimensions), long2(0), static_cast<int>(clamp_mode));
}
template <typename T>
METAL_FUNC void async_copy(device T *dst, uint dst_elements_per_row, ushort2 dst_tile_dimensions, const threadgroup T *src, ushort src_elements_per_row, ushort2 src_tile_dimensions, bool transpose_matrix = false) thread {
if (transpose_matrix) {
src_tile_dimensions = src_tile_dimensions.yx;
dst_tile_dimensions = dst_tile_dimensions.yx;
}
event = __metal_simdgroup_async_copy_2d(sizeof(T), alignof(T), reinterpret_cast<device void *>(dst), uint(dst_elements_per_row), 1, ulong2(dst_tile_dimensions), reinterpret_cast<const threadgroup void *>(src), ushort(src_elements_per_row), 1, ulong2(src_tile_dimensions), long2(0), 0);
}
METAL_FUNC static void wait(int count, thread simdgroup_event *events) {
#if defined(__HAVE_SIMDGROUP_FUTURE__)
__metal_wait_simdgroup_events(count, reinterpret_cast<thread __metal_simdgroup_event_t*>(events));
#else
__metal_wait_simdgroup_events(count, reinterpret_cast<thread _simdgroup_event_t**>(events));
#endif
}
private:
// Invoking the generation of LLVM bitcode for async copies.
//
// %"struct.metal::simdgroup_event" = type { %struct._simdgroup_event_t* }
//
#if defined(__HAVE_SIMDGROUP_FUTURE__)
__metal_simdgroup_event_t event;
#else
thread _simdgroup_event_t* event;
#endif
};
} // namespace metal
#pragma METAL internals : disable
#endif // __METAL_SIMDGROUP_EVENT
"""
let metalSimdgroupMatrixStorage: String = """
// -*- Metal -*-
//===-- metal_simdgroup_matrix_storage ------------------------------------===//
// Copyright (c) 2024 Philip Turner. See MIT LICENSE
//===----------------------------------------------------------------------===//
#ifndef __METAL_SIMDGROUP_MATRIX_STORAGE
#define __METAL_SIMDGROUP_MATRIX_STORAGE
// Contains C++ symbols accessible to a developer through automatic code
// completion in Xcode 14.2. Formatted with the same style as the Metal Standard
// Library for consistency with other Metal code.
#if defined(__HAVE_SIMDGROUP_MATRIX__)
#pragma METAL internals : enable
namespace metal
{
template <typename T>
struct simdgroup_matrix_storage {
typedef vec<T, 64> storage_type;
storage_type t;
METAL_FUNC thread vec<T, 2>* thread_elements() thread {
return reinterpret_cast<thread vec<T, 2>*>(&t);
}
METAL_FUNC simdgroup_matrix_storage() thread = default;
METAL_FUNC simdgroup_matrix_storage(vec<T, 2> thread_elements) thread {
*(this->thread_elements()) = thread_elements;
}
METAL_FUNC static ushort2 offset(ushort thread_index_in_simdgroup) {
// https://patents.google.com/patent/US11256518B2
ushort lane_id = thread_index_in_simdgroup;
ushort quad_id = lane_id / 4;
constexpr ushort QUADRANT_SPAN_M = 4;
constexpr ushort THREADS_PER_QUADRANT = 8;
ushort M_floor_of_quadrant = (quad_id / 4) * QUADRANT_SPAN_M;
ushort M_in_quadrant = (lane_id / 2) % (THREADS_PER_QUADRANT / 2);
ushort M_in_simd = M_floor_of_quadrant + M_in_quadrant;
ushort N_floor_of_quadrant = (quad_id & 2) * 2; // 0 or 4
ushort N_in_quadrant = (lane_id % 2) * 2; // 0 or 2
ushort N_in_simd = N_floor_of_quadrant + N_in_quadrant;
return ushort2(N_in_simd, M_in_simd);
}
METAL_FUNC static device T* apply_offset(device T *src, uint elements_per_row, uint2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
return src + ulong(matrix_origin.x * elements_per_row) + matrix_origin.y;
} else {
return src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x;
}
}
METAL_FUNC static threadgroup T* apply_offset(threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
return src + matrix_origin.x * elements_per_row + matrix_origin.y;
} else {
return src + matrix_origin.y * elements_per_row + matrix_origin.x;
}
}
// WARNING: All load and store functions assume the X dimension is divisible by 2.
METAL_FUNC void load(const device T *src, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
*(thread_elements()) = vec<T, 2>(src[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y], src[ulong((matrix_origin.x + 1) * elements_per_row) + matrix_origin.y]);
} else {
*(thread_elements()) = *reinterpret_cast<const device vec<T, 2>*>(src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x);
}
}
METAL_FUNC void load(const threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
*(thread_elements()) = vec<T, 2>(src[matrix_origin.x * elements_per_row + matrix_origin.y], src[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y]);
} else {
*(thread_elements()) = *reinterpret_cast<const threadgroup vec<T, 2>*>(src + matrix_origin.y * elements_per_row + matrix_origin.x);
}
}
METAL_FUNC void load_first(const device T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
thread_elements()[0][0] = src[matrix_origin.x * elements_per_row + matrix_origin.y];
} else {
thread_elements()[0][0] = src[matrix_origin.y * elements_per_row + matrix_origin.x];
}
}
METAL_FUNC void load_second(const device T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
thread_elements()[0][1] = src[matrix_origin.x * elements_per_row + matrix_origin.y];
} else {
thread_elements()[0][1] = src[matrix_origin.y * elements_per_row + matrix_origin.x];
}
}
METAL_FUNC void store(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][0];
dst[ulong((matrix_origin.x + 1) * elements_per_row) + matrix_origin.y] = thread_elements()[0][1];
} else {
*reinterpret_cast<device vec<T, 2>*>(dst + matrix_origin.y * elements_per_row + matrix_origin.x) = *(thread_elements());
}
}
METAL_FUNC void store_first(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][0];
} else {
dst[matrix_origin.y * elements_per_row + matrix_origin.x] = thread_elements()[0][0];
}
}
METAL_FUNC void store_second(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][1];
} else {
dst[matrix_origin.y * elements_per_row + matrix_origin.x] = thread_elements()[0][1];
}
}
METAL_FUNC void store(threadgroup T *dst, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
dst[matrix_origin.x * elements_per_row + matrix_origin.y] = thread_elements()[0][0];
dst[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y] = thread_elements()[0][1];
} else {
*reinterpret_cast<threadgroup vec<T, 2>*>(dst + matrix_origin.y * elements_per_row + matrix_origin.x) = *(thread_elements());
}
}
template <typename U, typename V>
METAL_FUNC void multiply(simdgroup_matrix_storage<U> a, simdgroup_matrix_storage<V> b, bool accumulate = true) {
if (!accumulate) {
*(thread_elements()) = vec<T, 2>(0);
}
t = __metal_simdgroup_matrix_8x8_multiply_accumulate(a.t, b.t, t, typename simdgroup_matrix_storage<T>::storage_type());
}
};
} // namespace metal
#pragma METAL internals : disable
#endif
#endif // __METAL_SIMDGROUP_MATRIX_STORAGE
"""
let GEMM = """
//
// GEMM.metal
// MetalFlashAttention
//
// Created by Philip Turner on 6/23/23.
//
#include <metal_stdlib>
\(metalSimdgroupEvent)
\(metalSimdgroupMatrixStorage)
using namespace metal;
// MARK: - Function Constants
// Dimensions of each matrix.
constant uint M [[function_constant(0)]];
constant uint N [[function_constant(1)]];
constant uint K [[function_constant(2)]];
// Whether each matrix is transposed.
constant bool A_trans [[function_constant(10)]];
constant bool B_trans [[function_constant(11)]];
constant bool D_trans [[function_constant(13)]];
constant uint A_leading_dim = A_trans ? M : K;
constant uint B_leading_dim = B_trans ? K : N;
// Alpha and beta constants from BLAS.
constant float alpha [[function_constant(20)]];
constant float beta [[function_constant(21)]];
constant ushort M_simd [[function_constant(200)]];
constant ushort N_simd [[function_constant(201)]];
constant ushort K_simd [[function_constant(202)]];
// Elide work on the edge when matrix dimension < SRAM block dimension.
constant ushort M_modulo = (M % M_simd == 0) ? M_simd : (M % M_simd);
constant ushort N_modulo = (N % N_simd == 0) ? N_simd : (N % N_simd);
constant ushort M_padded = (M < M_simd) ? (M_modulo + 7) / 8 * 8 : M_simd;
constant ushort N_padded = (N < N_simd) ? (N_modulo + 7) / 8 * 8 : N_simd;
constant ushort M_splits [[function_constant(210)]];
constant ushort N_splits [[function_constant(211)]];
constant ushort M_group = M_simd * M_splits;
constant ushort N_group = N_simd * N_splits;
constant ushort A_block_leading_dim = (A_trans ? M_group : K_simd);
constant ushort B_block_leading_dim = (B_trans ? K_simd : N_group);
// There is no padding for M reads/writes.
// There is no padding for N reads/writes.
constant ushort K_simd_unpadded = (K % K_simd == 0) ? K_simd : (K % K_simd);
constant ushort K_simd_padded = (K_simd_unpadded + 7) / 8 * 8;
constant ushort A_sram_length = (M_simd / 8) * 1;
constant ushort B_sram_length = 1 * (N_simd / 8);
constant ushort A_block_length = M_group * K_simd;
// Threadgroup block must fit entire C accumulator and partial sums.
constant ushort A_sram_offset = 0;
constant ushort B_sram_offset = A_sram_offset + A_sram_length;
constant ushort C_sram_offset = B_sram_offset + B_sram_length;
constant ushort A_block_offset = 0;
constant ushort B_block_offset = A_block_offset + A_block_length;
// MARK: - Utilities
template <typename T>
METAL_FUNC thread simdgroup_matrix_storage<T>* A_sram(
thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin
) {
// A_sram[M_simd][8]
return sram + A_sram_offset + (matrix_origin.y / 8) * (8 / 8) + (matrix_origin.x / 8);
}
template <typename T>
METAL_FUNC thread simdgroup_matrix_storage<T>* B_sram(
thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin
) {
// A_sram[8][N_simd]
return sram + B_sram_offset + (matrix_origin.y / 8) * (N_simd / 8) + (matrix_origin.x / 8);
}
template <typename T>
METAL_FUNC thread simdgroup_matrix_storage<T>* C_sram(
thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin
) {
// C_sram[M_simd][N_simd]
return sram + C_sram_offset + (matrix_origin.y / 8) * (N_simd / 8) + (matrix_origin.x / 8);
}
template <typename T>
METAL_FUNC void prefetch(threadgroup T *A_block, device T *A,
ushort2 A_tile_src, uint2 A_offset,
threadgroup T *B_block, device T *B,
ushort2 B_tile_src, uint2 B_offset, uint k)
{
A_tile_src.x = min(uint(K_simd), K - k);
B_tile_src.y = min(uint(K_simd), K - k);
auto A_src = simdgroup_matrix_storage<T>::apply_offset(
A, A_leading_dim, A_offset, A_trans);
auto B_src = simdgroup_matrix_storage<T>::apply_offset(
B, B_leading_dim, B_offset, B_trans);
// Rounded-up ceiling for the threadgroup block.
const uint K_edge_floor = K - K_simd_unpadded;
const uint K_edge_ceil = K_edge_floor + K_simd_padded;
ushort K_padded;
if (K_edge_floor == K_simd) {
K_padded = K_simd;
} else {
K_padded = min(uint(K_simd), K_edge_ceil - k);
}
ushort2 A_tile_dst(K_padded, A_tile_src.y);
ushort2 B_tile_dst(B_tile_src.x, K_padded);
simdgroup_event events[2];
events[0].async_copy(A_block, A_block_leading_dim, A_tile_dst, A_src,
A_leading_dim, A_tile_src, A_trans);
events[1].async_copy(B_block, B_block_leading_dim, B_tile_dst, B_src,
B_leading_dim, B_tile_src, B_trans);
simdgroup_event::wait(2, events);
}
// One iteration of the MACC loop, effectively k=8 iterations.
template <typename T>
METAL_FUNC void multiply_accumulate(thread simdgroup_matrix_storage<T> *sram,
const threadgroup T *A_block,
const threadgroup T *B_block,
bool accumulate = true)
{
#pragma clang loop unroll(full)
for (ushort m = 0; m < M_padded; m += 8) {
ushort2 origin(0, m);
A_sram(sram, origin)->load(A_block, A_block_leading_dim, origin, A_trans);
}
#pragma clang loop unroll(full)
for (ushort n = 0; n < N_padded; n += 8) {
ushort2 origin(n, 0);
B_sram(sram, origin)->load(B_block, B_block_leading_dim, origin, B_trans);
}
#pragma clang loop unroll(full)
for (ushort m = 0; m < M_padded; m += 8) {
auto A = A_sram(sram, ushort2(0, m));
#pragma clang loop unroll(full)
for (ushort n = 0; n < N_padded; n += 8) {
auto B = B_sram(sram, ushort2(n, 0));
auto C = C_sram(sram, ushort2(n, m));
C->multiply(*A, *B, accumulate);
}
}
}
template <typename T>
METAL_FUNC void partial_store(thread simdgroup_matrix_storage<T> *sram,
threadgroup T *C_block, bool is_k_summation)
{
#pragma clang loop unroll(full)
for (ushort m = 0; m < M_padded; m += 8) {
#pragma clang loop unroll(full)
for (ushort n = 0; n < N_padded; n += 8) {
ushort2 origin(n, m);
if (is_k_summation) {
C_sram(sram, origin)->store(C_block, N_simd, origin);
} else {
C_sram(sram, origin)->store(C_block, N_group, origin);
}
}
}
}
template <typename T>
METAL_FUNC void partial_accumulate(thread simdgroup_matrix_storage<T> *sram,
threadgroup T *C_block, bool is_k_summation)
{
#pragma clang loop unroll(full)
for (ushort m = 0; m < M_padded; m += 8) {
#pragma clang loop unroll(full)
for (ushort n = 0; n < N_padded; n += 8) {
ushort2 origin(n, m);
auto B = B_sram(sram, ushort2(n, 0));
if (is_k_summation) {
B->load(C_block, N_simd, origin);
} else {
B->load(C_block, N_group, origin);
}
}
#pragma clang loop unroll(full)
for (ushort n = 0; n < N_padded; n += 8) {
ushort2 origin(n, m);
auto B = B_sram(sram, ushort2(n, 0));
auto C = C_sram(sram, origin);
if (is_k_summation) {
C->thread_elements()[0] += B->thread_elements()[0];
} else {
float2 C_old = float2(B->thread_elements()[0]);
float2 C_new = float2(C->thread_elements()[0]);
C->thread_elements()[0] = vec<T, 2>(fast::fma(C_old, beta, C_new));
}
}
}
}
template <typename T>
METAL_FUNC void async_access_accumulator(threadgroup T *C_block, device T *C,
uint2 C_offset, bool is_store)
{
ushort2 C_tile(min(uint(N_group), N - C_offset.x),
min(uint(M_group), M - C_offset.y));
auto C_src = simdgroup_matrix_storage<T>::apply_offset(C, N, C_offset);
simdgroup_event event;
if (is_store) {
event.async_copy(C_src, N, C_tile, C_block, N_group, C_tile);
} else {
event.async_copy(C_block, N_group, C_tile, C_src, N, C_tile);
simdgroup_event::wait(1, &event);
}
}
template <typename T>
METAL_FUNC void store_accumulator(thread simdgroup_matrix_storage<T> *sram,
device T *C, bool m_is_edge, bool n_is_edge)
{
const ushort m_start = (m_is_edge) ? M_modulo : 0;
const ushort n_start = (n_is_edge) ? N_modulo : 0;
const ushort m_end = (m_is_edge) ? M_simd : M_modulo;
const ushort n_end = (n_is_edge) ? N_simd : N_modulo;
#pragma clang loop unroll(full)
for (ushort m = m_start; m < m_end; m += 8) {
#pragma clang loop unroll(full)
for (ushort n = n_start; n < n_end; n += 8) {
ushort2 origin(n, m);
C_sram(sram, origin)->store(C, N, origin);
}
}
}
// MARK: - Kernels
kernel void sgemm(device float *A [[buffer(0)]],
device float *B [[buffer(1)]],
device float *C [[buffer(2)]],
threadgroup float *threadgroup_block [[threadgroup(0)]],
uint3 gid [[threadgroup_position_in_grid]],
ushort sidx [[simdgroup_index_in_threadgroup]],
ushort lane_id [[thread_index_in_simdgroup]])
{
simdgroup_matrix_storage<float> sram[1024];
auto A_block = threadgroup_block + A_block_offset;
auto B_block = threadgroup_block + B_block_offset;
ushort2 sid(sidx % N_splits, sidx / N_splits);
ushort2 offset_in_simd = simdgroup_matrix_storage<float>::offset(lane_id);
uint2 A_offset(0, gid.y * M_group);
uint2 B_offset(gid.x * N_group, 0);
{
uint C_base_offset_x = B_offset.x + sid.x * N_simd;
uint C_base_offset_y = A_offset.y + sid.y * M_simd;
if (C_base_offset_x >= N || C_base_offset_y >= M) {
return;
}
}
ushort2 offset_in_group(sid.x * N_simd + offset_in_simd.x,
sid.y * M_simd + offset_in_simd.y);
ushort2 A_tile_src;
ushort2 B_tile_src;
if (sidx == 0) {
A_tile_src.y = min(uint(M_group), M - A_offset.y);
B_tile_src.x = min(uint(N_group), N - B_offset.x);
prefetch(A_block, A, A_tile_src, A_offset,
B_block, B, B_tile_src, B_offset, 0);
}
if (K > K_simd) {
#pragma clang loop unroll(full)
for (ushort m = 0; m < M_padded; m += 8) {
#pragma clang loop unroll(full)
for (ushort n = 0; n < N_padded; n += 8) {
*C_sram(sram, ushort2(n, m)) = simdgroup_matrix_storage<float>(0);
}
}
}
for (uint K_floor = 0; K_floor < K; K_floor += K_simd) {
ushort2 A_block_offset(offset_in_simd.x, offset_in_group.y);
ushort2 B_block_offset(offset_in_group.x, offset_in_simd.y);
auto A_block_src = simdgroup_matrix_storage<float>::apply_offset(
A_block, A_block_leading_dim, A_block_offset, A_trans);
auto B_block_src = simdgroup_matrix_storage<float>::apply_offset(
B_block, B_block_leading_dim, B_block_offset, B_trans);
threadgroup_barrier(mem_flags::mem_threadgroup);
#pragma clang loop unroll(full)
for (ushort k = 0; k < K_simd_padded; k += 8) {
bool accumulate = !(K <= K_simd && k == 0);
multiply_accumulate(sram, A_block_src, B_block_src, accumulate);
A_block_src += A_trans ? 8 * A_block_leading_dim : 8;
B_block_src += B_trans ? 8 : 8 * B_block_leading_dim;
}
if (K_floor + K_simd < K) {
#pragma clang loop unroll(full)
for (ushort k = K_simd_padded; k < K_simd; k += 8) {
multiply_accumulate(sram, A_block_src, B_block_src);
A_block_src += A_trans ? 8 * A_block_leading_dim : 8;
B_block_src += B_trans ? 8 : 8 * B_block_leading_dim;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sidx == 0) {
uint K_next = K_floor + K_simd;
A_offset.x = K_next;
B_offset.y = K_next;
prefetch(A_block, A, A_tile_src, A_offset, B_block, B, B_tile_src, B_offset, K_next);
}
}
}
if (alpha != 1) {
#pragma clang loop unroll(full)
for (int m = 0; m < M_padded; m += 8) {
#pragma clang loop unroll(full)
for (int n = 0; n < N_padded; n += 8) {
C_sram(sram, ushort2(n, m))->thread_elements()[0] *= alpha;
}
}
}
uint2 C_offset(B_offset.x, A_offset.y);
ushort2 C_block_offset = offset_in_group.xy;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (beta != 0) {
if (sidx == 0) {
async_access_accumulator(threadgroup_block, C, C_offset, false);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
auto C_block = simdgroup_matrix_storage<float>::apply_offset(
threadgroup_block, N_group, C_block_offset);
partial_accumulate(sram, C_block, false);
threadgroup_barrier(mem_flags::mem_threadgroup);
}
if ((M % 8 != 0) || (N % 8 != 0)) {
auto C_block = simdgroup_matrix_storage<float>::apply_offset(
threadgroup_block, N_group, C_block_offset);
partial_store(sram, C_block, false);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sidx == 0) {
async_access_accumulator(threadgroup_block, C, C_offset, true);
}
} else {
uint2 matrix_origin = C_offset + uint2(C_block_offset);
auto C_src = simdgroup_matrix_storage<float>::apply_offset(
C, N, matrix_origin);
store_accumulator(sram, C_src, false, false);
const uint M_edge_floor = M - M % M_simd;
const uint N_edge_floor = N - N % N_simd;
if (matrix_origin.y < M_edge_floor) {
store_accumulator(sram, C_src, true, false);
}
if (matrix_origin.x < N_edge_floor) {
store_accumulator(sram, C_src, false, true);
if (matrix_origin.y < M_edge_floor) {
store_accumulator(sram, C_src, true, true);
}
}
}
}
"""
// MARK: - Utilities
struct MTLContext {
var device: MTLDevice
var commandQueue: MTLCommandQueue
init() {
device = MTLCreateSystemDefaultDevice()!
commandQueue = device.makeCommandQueue()!
}
}
#if os(macOS)
public class MTLCompiler {
public internal(set) var buildProductsDirectory: URL
var xcodePath: String = "/Applications/Xcode.app"
public init() {
// let temporaryDirectory = FileManager.default.temporaryDirectory
let temporaryDirectory = URL(filePath: "/Users/philipturner/Desktop")
buildProductsDirectory = temporaryDirectory
.appendingPathComponent("metal-compiler")
}
public func setXcodePath(_ path: String) {
self.xcodePath = path
}
// This has about 200 ms latency, so invoke it sparingly. In the final
// version of this library, we'll compile all GPU shaders beforehand and cache
// the Metal resource objects at runtime.
//
// Near-term, create a key-value cache of the shader sources to the
// compiled objects. This could speed up some workflows during development of
// the library.
public func compile(_ source: String) -> MTLLibrary {
//
// build.swift
// MetalFlashAttention
//
// Created by Philip Turner on 6/26/23.
//
// MARK: - Parse Arguments
struct BuildSettings {
var externalMetallibPath: String? = nil
var platform: Platform? = nil
var verbose: Bool = false
var xcodePath: String = ""
enum Platform {
case iOS
case macOS
var metalToolsPath: String {
switch self {
case .iOS:
return "ios"
case .macOS:
return "macos"
}
}
var deploymentVersionArgument: String {
switch self {
case .iOS:
return "-mios-version-min=16.0.0"
case .macOS:
return "-mmacosx-version-min=13.0.0"
}
}
var xcrunSDK: String {
switch self {
case .iOS:
return "iphoneos"
case .macOS:
return "macosx"
}
}
}
func metalToolPath(executable: String) -> String {
guard let metalToolsPath = platform?.metalToolsPath else {
fatalError("Must specify platform before locating Metal tools.")
}
var output = xcodePath
output += "/Contents/Developer/Toolchains/XcodeDefault.xctoolchain"
output += "/usr/metal/\(metalToolsPath)/bin/\(executable)"
return output
}
func xcrunMetalArguments(executable: String) -> [String] {
guard let xcrunSDK = platform?.xcrunSDK else {
fatalError("Must specify platform before locating Metal tools.")
}
return ["-sdk", xcrunSDK, executable]
}
}
var settings = BuildSettings()
settings.verbose = Bool.random() ? true : true
settings.platform = .macOS
settings.xcodePath = self.xcodePath
// MARK: - Prepare File Directories
func directoryExists(url: URL) -> Bool {
var isDirectory: ObjCBool = false
let succeeded = FileManager.default.fileExists(
atPath: url.path, isDirectory: &isDirectory)
return succeeded && isDirectory.boolValue
}
func fileExists(url: URL) -> Bool {
var isDirectory: ObjCBool = false
let succeeded = FileManager.default.fileExists(
atPath: url.path, isDirectory: &isDirectory)
return succeeded && !isDirectory.boolValue
}
func assertDirectoryExists(url: URL, line: UInt = #line) {
guard directoryExists(url: url) else {
fatalError("""
Line \(line):
Directory not found at '\(url.path)'.
""")
}
}
func assertFileExists(url: URL, line: UInt = #line) {
guard fileExists(url: url) else {
fatalError("""
Line \(line):
File not found at '\(url.path)'.
""")
}
}
func touchDirectory(url: URL) {
if !directoryExists(url: url) {
try! FileManager.default.createDirectory(
at: url, withIntermediateDirectories: false)
}
assertDirectoryExists(url: url)
}
let workDir = self.buildProductsDirectory
touchDirectory(url: workDir)
let buildDir = workDir.appending(component: "build")
let libDir = buildDir.appending(component: "lib")
let srcDir = buildDir.appending(component: "src")
touchDirectory(url: buildDir)
touchDirectory(url: libDir)
touchDirectory(url: srcDir)
func createFile(source: String, name: String) {
guard let data = source.data(using: .utf8) else {
fatalError("Could not encode source string as UTF-8.")
}
let destinationURL = srcDir.appendingPathComponent(name)
try! data.write(to: destinationURL)
}
createFile(
source: metalSimdgroupEvent,
name: "metal_simdgroup_event")
createFile(
source: metalSimdgroupMatrixStorage,
name: "metal_simdgroup_matrix_storage")
createFile(
source: source,
name: "File.metal")
let metalURL = srcDir.appendingPathComponent("File.metal")
let airURL = buildDir.appendingPathComponent("File.air")
let airPath = airURL.relativePath
try? FileManager.default.removeItem(atPath: airPath)
guard FileManager.default.createFile(
atPath: airPath, contents: nil) else {
fatalError("Could not create destination path '\(airPath)'.")
}
// MARK: - Compile AIR File
// Arguments to invoke the Metal compiler with.
var arguments: [String] = []
arguments.append(settings.platform!.deploymentVersionArgument)
arguments.append("-c")
// Suppress compiler warnings unless the user enters '--verbose'.
if settings.verbose {
arguments.append("-Wno-unused-function")
arguments.append("-Wno-unused-variable")
} else {
arguments.append("-w")
}
let process = Process()
let toolPath = settings.metalToolPath(executable: "metal")
process.executableURL = URL(filePath: toolPath)
process.arguments = arguments + [metalURL.path, "-o", airPath]
try! process.run()
process.waitUntilExit()
if process.terminationStatus != 0 {
fatalError("Could not compile source.")
}
// MARK: - Compile Metal Library
arguments = []
// Package the metallib using the up-to-date Xcode version.
func runProcess() {
let process = try! Process.run(
URL(fileURLWithPath: "/usr/bin/xcrun"),
arguments: arguments)
process.waitUntilExit()
}
let metallibName = "File.metallib"
let metallibURL = libDir.appending(component: metallibName)
arguments = []
arguments += settings.xcrunMetalArguments(executable: "metal")
arguments.append(settings.platform!.deploymentVersionArgument)
arguments.append(airPath)
arguments.append("-o")
arguments.append(metallibURL.path)
runProcess()
// MARK: - Instantiate Resource Object
let device = MTLCreateSystemDefaultDevice()!
let library = try! device.makeLibrary(URL: metallibURL)
//try! FileManager.default.removeItem(at: workDir)
return library
}
}
#endif
// MARK: - Script
func runApplication() {
print("Hello, console.")
let problemSize: Int = 64
var A = [Float](repeating: .zero, count: problemSize * problemSize)
var B = [Float](repeating: .zero, count: problemSize * problemSize)
var C = [Float](repeating: .zero, count: problemSize * problemSize)
// Initialize A as the 2nd-order periodic Laplacian.
for diagonalID in 0..<problemSize {
let diagonalAddress = diagonalID * problemSize + diagonalID
A[diagonalAddress] = -2
let leftColumnID = (diagonalID + problemSize - 1) % problemSize
let leftSubDiagonalAddress = diagonalID * problemSize + leftColumnID
A[leftSubDiagonalAddress] = 1
let rightColumnID = (diagonalID + problemSize + 1) % problemSize
let rightSubDiagonalAddress = diagonalID * problemSize + rightColumnID
A[rightSubDiagonalAddress] = 1
}
// Initialize B to random numbers.
for rowID in 0..<problemSize {
for columnID in 0..<problemSize {
let address = rowID * problemSize + columnID
let entry = Float.random(in: 0..<1)
B[address] = entry
}
}
// Multiply A with B.
#if false
for m in 0..<problemSize {
for n in 0..<problemSize {
var dotProduct: Float = .zero
for k in 0..<problemSize {
// Read from the input matrices.
let lhsAddress = m * problemSize + k
let rhsAddress = k * problemSize + n
let lhsValue = A[lhsAddress]
let rhsValue = B[rhsAddress]
dotProduct += lhsValue * rhsValue
}
// Write to the output matrix.
let address = m * problemSize + n
C[address] = dotProduct
}
}
#else
do {
// Initialize the context.
let context = MTLContext()
// let compiler = MTLCompiler()
// let library = compiler.compile(GEMM)
let library = try! context.device.makeLibrary(source: GEMM, options: nil)
// Set the function constants.
let constants = MTLFunctionConstantValues()
var M: Int = problemSize
var N: Int = problemSize
var K: Int = problemSize
var transpose: Bool = false
var alpha: Float = 1
var beta: Float = 0
constants.setConstantValue(&M, type: .uint, index: 0)
constants.setConstantValue(&N, type: .uint, index: 1)
constants.setConstantValue(&K, type: .uint, index: 2)
constants.setConstantValue(&transpose, type: .bool, index: 10)
constants.setConstantValue(&transpose, type: .bool, index: 11)
constants.setConstantValue(&alpha, type: .float, index: 20)
constants.setConstantValue(&beta, type: .float, index: 21)
var M_simd: UInt16 = 32
var N_simd: UInt16 = 32
var K_simd: UInt16 = 8
var M_splits: UInt16 = 1
var N_splits: UInt16 = 1
constants.setConstantValue(&M_simd, type: .ushort, index: 200)
constants.setConstantValue(&N_simd, type: .ushort, index: 201)
constants.setConstantValue(&K_simd, type: .ushort, index: 202)
constants.setConstantValue(&M_splits, type: .ushort, index: 210)
constants.setConstantValue(&N_splits, type: .ushort, index: 211)
let function = try! library.makeFunction(
name: "sgemm", constantValues: constants)
let pipeline = try! context.device
.makeComputePipelineState(function: function)
// Allocate threadgroup blocks.
let M_group = M_simd * M_splits
let N_group = N_simd * N_splits
let A_block_length = M_group * K_simd
let B_block_length = K_simd * N_group
var blockElements = A_block_length + B_block_length;
if (M % 8 != 0) && (N % 8 != 0) {
let C_block_length = M_group * N_group;
blockElements = max(C_block_length, blockElements)
}
let blockBytes = blockElements * UInt16(4)
func ceilDivide(target: Int, granularity: UInt16) -> Int {
(target + Int(granularity) - 1) / Int(granularity)
}
let gridSize = MTLSize(
width: ceilDivide(target: N, granularity: N_group),
height: ceilDivide(target: M, granularity: M_group),
depth: 1)
let groupSize = MTLSize(
width: Int(32 * M_splits * N_splits),
height: 1,
depth: 1)
// Create the buffers.
func createBuffer(data: [Float]) -> MTLBuffer {
// Allocate enough memory to store everything in Float32.
let bufferSize = problemSize * problemSize * 4
let buffer = context.device.makeBuffer(length: bufferSize)!
// Copy the data into the buffer.
let pointer = buffer.contents().assumingMemoryBound(to: Float.self)
pointer.initialize(from: data, count: problemSize * problemSize)
// Return the buffer object.
return buffer
}
let bufferA = createBuffer(data: A)
let bufferB = createBuffer(data: B)
let bufferC = createBuffer(data: C)
// Profile the latency of matrix multiplication.
for _ in 0..<15 {
let duplicatedCommandCount: Int = 20
// Encode the GPU command.
let commandBuffer = context.commandQueue.makeCommandBuffer()!
let encoder = commandBuffer.makeComputeCommandEncoder()!
encoder.setComputePipelineState(pipeline)
encoder.setThreadgroupMemoryLength(Int(blockBytes), index: 0)
encoder.setBuffer(bufferA, offset: 0, index: 0)
encoder.setBuffer(bufferB, offset: 0, index: 1)
encoder.setBuffer(bufferC, offset: 0, index: 2)
for _ in 0..<duplicatedCommandCount {
encoder.dispatchThreadgroups(
gridSize, threadsPerThreadgroup: groupSize)
}
encoder.endEncoding()
commandBuffer.commit()
commandBuffer.waitUntilCompleted()
// Determine the time taken.
let start = commandBuffer.gpuStartTime
let end = commandBuffer.gpuEndTime
let latency = end - start
let latencyMicroseconds = Int(latency / 1e-6)
// Determine the amount of work done.
var operations = 2 * problemSize * problemSize * problemSize
operations = operations * duplicatedCommandCount
let gflops = Int(Double(operations) / Double(latency) / 1e9)
// Report the results.
print(latencyMicroseconds, "μs", gflops, "GFLOPS")
}
// Copy the results to C.
do {
let rawPointer = bufferC.contents()
let castedPointer = rawPointer.assumingMemoryBound(to: Float.self)
for rowID in 0..<problemSize {
for columnID in 0..<problemSize {
let address = rowID * problemSize + columnID
let entry = castedPointer[address]
C[address] = entry
}
}
}
}
#endif
// Check the results.
for m in 0..<problemSize {
for n in 0..<problemSize {
// Find the source row IDs.
let leftRowID = (m + problemSize - 1) % problemSize
let centerRowID = m
let rightRowID = (m + problemSize + 1) % problemSize
// Find the source values.
let leftSource = B[leftRowID * problemSize + n]
let centerSource = B[centerRowID * problemSize + n]
let rightSource = B[rightRowID * problemSize + n]
// Find the expected and actual values.
let expected = leftSource - 2 * centerSource + rightSource
let actual = C[m * problemSize + n]
// Report the results.
let error = (expected - actual).magnitude
if error > 1e-5 {
print("error: \(error) / ~1.000")
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment