Skip to content

Instantly share code, notes, and snippets.

@sailist
Created November 30, 2023 07:35
Show Gist options
  • Save sailist/17fd2fb8ab9551ff5176cd01246945eb to your computer and use it in GitHub Desktop.
Save sailist/17fd2fb8ab9551ff5176cd01246945eb to your computer and use it in GitHub Desktop.
online-softmax python demo
from functools import reduce
from dataclasses import dataclass
import torch
from math import exp
def ground_truth(x):
y = torch.softmax(x, dim=-1)
return y
def naive_attn(x):
d = torch.exp(x)
return d / torch.sum(d, dim=-1, keepdim=True)
def safe_attn(x):
maxx = torch.max(x, dim=-1, keepdim=True)[0]
d = torch.exp(x - maxx)
print(torch.sum(d, dim=-1, keepdim=True))
return d / torch.sum(d, dim=-1, keepdim=True)
def online_raw_v0(x):
m = [float("-inf")] * len(x)
d = [0] * len(x)
for j in range(len(x)):
m[j] = max(m[j - 1], x[j])
d[j] = d[j - 1] * exp(m[j - 1] - m[j]) + exp(x[j] - m[j])
y = [0] * len(x)
for j in range(len(x)):
y[j] = exp(x[j] - m[-1]) / d[-1]
print(m[-1], d[-1])
return torch.tensor(y)
@dataclass
class MD:
m: float
d: float
def reduce_md_op(a: MD, b: MD):
if a is None:
return b
elif b is None:
return a
a_bigger = a.m > b.m
bigger_m = a if a_bigger else b
smaller_m = b if a_bigger else a
res = MD(0, 0)
res.m = bigger_m.m
res.d = bigger_m.d + smaller_m.d * exp(smaller_m.m - bigger_m.m)
return res
def online_parallel_v0(x):
mds = [MD(xi, 1) for xi in x]
md = reduce(reduce_md_op, mds)
y = [0] * len(x)
for j in range(len(x)):
y[j] = exp(x[j] - md.m) / md.d
return torch.tensor(y)
def online_parallel_v1(x):
"""
this reduce operation is associative and commutative
reference:
- https://github.com/NVIDIA/online-softmax/blob/master/online_softmax_benchmark.cu#L169
"""
mds = [MD(xi, 1) for xi in x]
md = reduce(reduce_md_op, mds[: len(x) // 2])
md2 = reduce(reduce_md_op, mds[len(x) // 2 :])
md = reduce_md_op(md, md2)
y = [0] * len(x)
for j in range(len(x)):
y[j] = exp(x[j] - md.m) / md.d
return torch.tensor(y)
x = torch.rand(120)
seq_v0_res = online_raw_v0(x.flatten().tolist())
parallel_v0_res = online_parallel_v0(x.flatten().tolist())
parallel_v1_res = online_parallel_v1(x.flatten().tolist())
print((seq_v0_res == parallel_v0_res).all())
print((seq_v0_res == parallel_v1_res).all())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment