Skip to content

Instantly share code, notes, and snippets.

@malfet
malfet / dyld.py
Created June 5, 2024 21:04
Print shared libraries loaded by PyTorch on MacOS
from ctypes import cdll, c_char_p, c_uint32
libdyld = cdll.LoadLibrary("libSystem.dylib")
libdyld._dyld_image_count.restype = c_uint32
libdyld._dyld_get_image_name.restype = c_char_p
libdyld._dyld_get_image_name.argtypes = [c_uint32]
before_torch = {libdyld._dyld_get_image_name(i).decode("ascii") for i in range(libdyld._dyld_image_count())}
import torch
after_torch = {libdyld._dyld_get_image_name(i).decode("ascii") for i in range(libdyld._dyld_image_count())}
import Metal
import MetalPerformanceShadersGraph
func calculateExpMetal(device: MTLDevice, ibuf: MTLBuffer, obuf: MTLBuffer, nelem: Int, fastMathEnabled: Bool = false) {
let shader_source = """
#include <metal_stdlib>
using namespace metal;
kernel void do_exp(constant float *input [[buffer(0)]],
@malfet
malfet / subnormals_metal.swift
Created May 7, 2024 00:47
Check if `nextafter(0.0, 1.0)` is greater than zero on Metal device
import Metal
let shader_source = """
#include <metal_stdlib>
using namespace metal;
kernel void nextafter_pred(device float *data [[buffer(0)]],
device bool *pred [[buffer(1)]],
uint thread_index [[thread_position_in_grid]]) {
data[thread_index] = nextafter(float(thread_index) - 8.0, 1e4);
@malfet
malfet / mps_matmul.swift
Created January 9, 2024 02:29
Swift example that runs matrix multiplicaiton on MPS
import MetalPerformanceShadersGraph
let graph = MPSGraph()
let x = graph.constant(1, shape: [32, 4096, 40], dataType: .float32)
let y = graph.constant(1, shape: [32, 40, 4096], dataType: .float32)
let z = graph.matrixMultiplication(primary: x, secondary: y, name: nil)
let device = MTLCreateSystemDefaultDevice()!
let buf = device.makeBuffer(length: 16384)!
let td = MPSGraphTensorData(buf, shape: [64, 64], dataType: .int32)
let cmdBuf = MPSCommandBuffer(from: device.makeCommandQueue()!)
@malfet
malfet / mm_bmm-perf.py
Last active February 16, 2024 00:27
Measure performance difference of `torch.mm` vs `torch.bmm`
# Benchmark relative performance of torch.mm and torch.bmm with single batch
import torch
import time
def benchmark_fn(fn, args, warmup=5, cycles=300, use_kineto=False) -> float:
if use_kineto:
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as p:
fn(*args)
return sum([e.cuda_time for e in p.key_averages()])
@malfet
malfet / test_trition.py
Last active January 6, 2024 19:47
Test triton
import triton
import triton.language as tl
@triton.jit
def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):
xnumel = 10
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
import torch
import torch.nn.functional as F
def to_float8(x, dtype=torch.float8_e4m3fn):
finfo = torch.finfo(dtype)
# Calculate the scale as dtype max divided by absmax
scale = finfo.max / x.abs().max().clamp(min=1e-12)
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
// My attempt at FP8 matmul implementation
#include <iostream>
#include <vector>
#include <numeric>
#include <cublasLt.h>
#include <cuda_fp8.h>
#include <stdio.h>
@malfet
malfet / slowregexp.py
Last active June 29, 2023 15:29
Catastrophic backtracking in regexp
# For some reason does not work when copied-an-pasted not as raw file, but otherwise shoudl hang
import re
pat=re.compile('\\.\\. (code-block|math)::.*$\\n*(?P<S2VCUH>(?P<first>(^(?P<indent>[ ]+).*$\\n))(?P<other>(^([ \\t]+.*|[ \\t]*)$\\n)*))(?:(^(?![ \\t]+.*$))|\\Z)', re.MULTILINE)
text="""#####################################################################
We get the following performance profiling table for the eager-mode model (omitting some columns):
.. code-block:: shell
------------------------- ------------ ------------ ------------ ------------
Name CPU total % CPU total CPU time avg # of Calls
@malfet
malfet / computesqrt.py
Last active March 31, 2023 01:50
Spigot algorithm for computing digits of square root
#!/usr/bin/env python3
# Adapted from https://rosettacode.org/wiki/Square_root_by_hand
def next_digit(val, k):
for d in range(1, 11):
if val < d * (k + d):
return d - 1
raise RuntimeError("Impossible")
def compute_sqrt(val=2, num_char=500):