Skip to content

Instantly share code, notes, and snippets.

HDCharles /
Created June 4, 2024 20:21
doing lm_eval's work
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from lm_eval.models.huggingface import HFLM
from lm_eval.evaluator import evaluate
from lm_eval.tasks import get_task_dict
path_to_hf_checkpoint = "/home/cdhernandez/local/gpt-fast/checkpoints/meta-llama/Meta-Llama-3-8B"
HDCharles / gist:888bc5973198ca447046b974439dca03
Last active March 28, 2024 20:35
repro for subclass issue
import torch
import torch.nn as nn
from torch.utils._pytree import tree_flatten, tree_unflatten
class MultiTensor(torch.Tensor):
def __new__(cls, input, **kwargs):
if isinstance(input, (list, tuple)):
input = input[0]
kwargs["dtype"]=kwargs.get("dtype", input.dtype)
HDCharles /
Last active March 1, 2024 17:00
script for comparing performance of several linear triton kernels across several shapes
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from triton import Config
from torch._inductor import config
from torch import _dynamo
aten = torch.ops.aten
def get_configs_io_bound():
HDCharles /
Created February 24, 2024 16:46
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from triton.ops.matmul import matmul as triton_matmul
from triton.ops.matmul import _kernel
from triton import Config
from torch._inductor import config
from torch import _dynamo
torch._inductor.config.coordinate_descent_tuning = True
HDCharles /
Created January 25, 2024 03:07
compare bitsandbytes with torchao
# Comparing Torchao #
# and BitsandBytes #
# Set up Your Environment
# --------------------------------
# First, let's configure your environment. This guide requires you to use CUDA 12.1.
# We have run this tutorial on an A100-PG509-200 power limited to 330.00 W. If you
# are using a different hardware, you might see different performance numbers.
/home/cdhernandez/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/ UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
/home/cdhernandez/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/ UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
/home/cdhernandez/local/diffusers/src/diffusers/utils/ UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
Namespace(no_bf16=False, no_sdpa=False, batch_size=1, num_inference_steps=30, enable_fused_projections=True, upcast_vae=False, compile_unet=True, compile_vae=True, compile_mode='max-autotune', change_comp_config=True, do_quan
"""Full definition of a LLaMA Language Model, all of it in this single file.
Based on the nanoGPT implementation:
# mypy: ignore-errors
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
HDCharles / gist:44952fc614a75ad083f5054d50ef5341
Created September 19, 2023 23:56
not using block pointers
def matmul_kernel_with_block_pointers(
# Pointers to matrices
a_ptr, b_ptr, c_ptr, s1_ptr, s2_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am, stride_ak,
HDCharles / gist:b2d8c916cfc4629d3f81f09de734e577
Created August 14, 2023 16:21
microbenchmarks for mixed dtype kernels
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from triton.ops.matmul import matmul as triton_matmul
from triton.ops.matmul import _kernel
from triton import Config
import nvtx
import time
HDCharles /
Created August 1, 2023 17:33
benchmarking mixed dtype matmul's
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from triton.ops.matmul import matmul as triton_matmul
from triton.ops.matmul import _kernel
from triton import Config
import nvtx
import time