Skip to content

Instantly share code, notes, and snippets.

View andreaskoepf's full-sized avatar

Andreas Köpf andreaskoepf

View GitHub Profile
"""
Modified version of https://github.com/Dao-AILab/flash-attention/blob/87a1277653fc55cd615f5341255e00c69d5c00a1/flash_attn/flash_attn_triton.py
Experiments with attention bias by andreas.koepf
Main fix was "fixing the fix", e.g. removing lines of the original like:
```
# BUG: have to store and immediately load
# tl.store(t_ptrs, o_scale)
# o_scale = tl.load(t_ptrs)
```
"""
*Experimental* implementation of FlashAttention in Triton.
Tested with triton==2.0.0.dev20221202.
Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions
other than 64:
https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207
We'll update this implementation with the new Triton backend once this is fixed.
We use the FlashAttention implementation from Phil Tillet a starting point.
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
@andreaskoepf
andreaskoepf / load_hf_ds.py
Created November 6, 2023 08:47
OASST1 Huggingface compatible dataset generation scripts
from datasets import load_dataset
ds = load_dataset("/path/oasst1", name='ready')
train = ds['train']
val = ds['validation']
print(f'{len(train)=}')
print(f'{len(val)=}')
for i in range(5):
print(train[i]["message_tree_id"])
@andreaskoepf
andreaskoepf / fix_embedding_size.py
Created August 29, 2023 14:06
load model & pad embedding layers to multiple of N (e.g. 128)
import argparse
import sys
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("model_name", type=str, help="checkpoint path or model name")
parser.add_argument("--dtype", type=str, default="auto", help="auto, fp16, bf16 or fp32")
@andreaskoepf
andreaskoepf / llama2_system_prompt.txt
Created August 21, 2023 01:37
llama2 system prompt
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct.
If you don't know the answer to a question, please don't share false information."
@andreaskoepf
andreaskoepf / export_model.py
Last active August 17, 2023 22:08
Simplified export model script
import argparse
import sys
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Push checkpoints in HF transformers format to the Huggingface Hub.",
from typing import Optional
import torch
def precompute_freqs_cis(
dim: int, end: int, theta: float = 10000.0, scaling_factor: float = 1.0
) -> torch.Tensor:
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(end, device=freqs.device).float() / scaling_factor # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
@andreaskoepf
andreaskoepf / install_tgi.sh
Last active August 29, 2023 08:51
Install Huggingface TGI v1.0 without docker
# sent to me by tju01, thx
# install base tools
apt update
apt install protobuf-compiler libssl-dev gcc pkg-config g++ make
# install rust
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
source "$HOME/.cargo/env"
# adapted from: https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils import checkpoint
from einops import rearrange, repeat
import triton
# adapted from: https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils import checkpoint
from einops import rearrange, repeat