This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from ctypes import cdll, c_char_p, c_uint32 | |
libdyld = cdll.LoadLibrary("libSystem.dylib") | |
libdyld._dyld_image_count.restype = c_uint32 | |
libdyld._dyld_get_image_name.restype = c_char_p | |
libdyld._dyld_get_image_name.argtypes = [c_uint32] | |
before_torch = {libdyld._dyld_get_image_name(i).decode("ascii") for i in range(libdyld._dyld_image_count())} | |
import torch | |
after_torch = {libdyld._dyld_get_image_name(i).decode("ascii") for i in range(libdyld._dyld_image_count())} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Metal | |
import MetalPerformanceShadersGraph | |
func calculateExpMetal(device: MTLDevice, ibuf: MTLBuffer, obuf: MTLBuffer, nelem: Int, fastMathEnabled: Bool = false) { | |
let shader_source = """ | |
#include <metal_stdlib> | |
using namespace metal; | |
kernel void do_exp(constant float *input [[buffer(0)]], |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Metal | |
let shader_source = """ | |
#include <metal_stdlib> | |
using namespace metal; | |
kernel void nextafter_pred(device float *data [[buffer(0)]], | |
device bool *pred [[buffer(1)]], | |
uint thread_index [[thread_position_in_grid]]) { | |
data[thread_index] = nextafter(float(thread_index) - 8.0, 1e4); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import MetalPerformanceShadersGraph | |
let graph = MPSGraph() | |
let x = graph.constant(1, shape: [32, 4096, 40], dataType: .float32) | |
let y = graph.constant(1, shape: [32, 40, 4096], dataType: .float32) | |
let z = graph.matrixMultiplication(primary: x, secondary: y, name: nil) | |
let device = MTLCreateSystemDefaultDevice()! | |
let buf = device.makeBuffer(length: 16384)! | |
let td = MPSGraphTensorData(buf, shape: [64, 64], dataType: .int32) | |
let cmdBuf = MPSCommandBuffer(from: device.makeCommandQueue()!) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Benchmark relative performance of torch.mm and torch.bmm with single batch | |
import torch | |
import time | |
def benchmark_fn(fn, args, warmup=5, cycles=300, use_kineto=False) -> float: | |
if use_kineto: | |
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as p: | |
fn(*args) | |
return sum([e.cuda_time for e in p.key_averages()]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import triton | |
import triton.language as tl | |
@triton.jit | |
def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): | |
xnumel = 10 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = xindex < xnumel | |
x0 = xindex |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
def to_float8(x, dtype=torch.float8_e4m3fn): | |
finfo = torch.finfo(dtype) | |
# Calculate the scale as dtype max divided by absmax | |
scale = finfo.max / x.abs().max().clamp(min=1e-12) | |
# scale and clamp the tensor to bring it to | |
# the representative range of float8 data type | |
# (as default cast is unsaturated) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// My attempt at FP8 matmul implementation | |
#include <iostream> | |
#include <vector> | |
#include <numeric> | |
#include <cublasLt.h> | |
#include <cuda_fp8.h> | |
#include <stdio.h> | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# For some reason does not work when copied-an-pasted not as raw file, but otherwise shoudl hang | |
import re | |
pat=re.compile('\\.\\. (code-block|math)::.*$\\n*(?P<S2VCUH>(?P<first>(^(?P<indent>[ ]+).*$\\n))(?P<other>(^([ \\t]+.*|[ \\t]*)$\\n)*))(?:(^(?![ \\t]+.*$))|\\Z)', re.MULTILINE) | |
text="""##################################################################### | |
We get the following performance profiling table for the eager-mode model (omitting some columns): | |
.. code-block:: shell | |
------------------------- ------------ ------------ ------------ ------------ | |
Name CPU total % CPU total CPU time avg # of Calls |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# Adapted from https://rosettacode.org/wiki/Square_root_by_hand | |
def next_digit(val, k): | |
for d in range(1, 11): | |
if val < d * (k + d): | |
return d - 1 | |
raise RuntimeError("Impossible") | |
def compute_sqrt(val=2, num_char=500): |
NewerOlder