Created
November 30, 2023 07:35
-
-
Save sailist/17fd2fb8ab9551ff5176cd01246945eb to your computer and use it in GitHub Desktop.
online-softmax python demo
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import reduce | |
from dataclasses import dataclass | |
import torch | |
from math import exp | |
def ground_truth(x): | |
y = torch.softmax(x, dim=-1) | |
return y | |
def naive_attn(x): | |
d = torch.exp(x) | |
return d / torch.sum(d, dim=-1, keepdim=True) | |
def safe_attn(x): | |
maxx = torch.max(x, dim=-1, keepdim=True)[0] | |
d = torch.exp(x - maxx) | |
print(torch.sum(d, dim=-1, keepdim=True)) | |
return d / torch.sum(d, dim=-1, keepdim=True) | |
def online_raw_v0(x): | |
m = [float("-inf")] * len(x) | |
d = [0] * len(x) | |
for j in range(len(x)): | |
m[j] = max(m[j - 1], x[j]) | |
d[j] = d[j - 1] * exp(m[j - 1] - m[j]) + exp(x[j] - m[j]) | |
y = [0] * len(x) | |
for j in range(len(x)): | |
y[j] = exp(x[j] - m[-1]) / d[-1] | |
print(m[-1], d[-1]) | |
return torch.tensor(y) | |
@dataclass | |
class MD: | |
m: float | |
d: float | |
def reduce_md_op(a: MD, b: MD): | |
if a is None: | |
return b | |
elif b is None: | |
return a | |
a_bigger = a.m > b.m | |
bigger_m = a if a_bigger else b | |
smaller_m = b if a_bigger else a | |
res = MD(0, 0) | |
res.m = bigger_m.m | |
res.d = bigger_m.d + smaller_m.d * exp(smaller_m.m - bigger_m.m) | |
return res | |
def online_parallel_v0(x): | |
mds = [MD(xi, 1) for xi in x] | |
md = reduce(reduce_md_op, mds) | |
y = [0] * len(x) | |
for j in range(len(x)): | |
y[j] = exp(x[j] - md.m) / md.d | |
return torch.tensor(y) | |
def online_parallel_v1(x): | |
""" | |
this reduce operation is associative and commutative | |
reference: | |
- https://github.com/NVIDIA/online-softmax/blob/master/online_softmax_benchmark.cu#L169 | |
""" | |
mds = [MD(xi, 1) for xi in x] | |
md = reduce(reduce_md_op, mds[: len(x) // 2]) | |
md2 = reduce(reduce_md_op, mds[len(x) // 2 :]) | |
md = reduce_md_op(md, md2) | |
y = [0] * len(x) | |
for j in range(len(x)): | |
y[j] = exp(x[j] - md.m) / md.d | |
return torch.tensor(y) | |
x = torch.rand(120) | |
seq_v0_res = online_raw_v0(x.flatten().tolist()) | |
parallel_v0_res = online_parallel_v0(x.flatten().tolist()) | |
parallel_v1_res = online_parallel_v1(x.flatten().tolist()) | |
print((seq_v0_res == parallel_v0_res).all()) | |
print((seq_v0_res == parallel_v1_res).all()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment