Skip to content

Instantly share code, notes, and snippets.

@zhuhaozhe
Created April 1, 2024 07:50
Show Gist options
  • Save zhuhaozhe/ef9a290ad3f8f4067b3373a3bdaa33e7 to your computer and use it in GitHub Desktop.
Save zhuhaozhe/ef9a290ad3f8f4067b3373a3bdaa33e7 to your computer and use it in GitHub Desktop.
benchmark-fused-adam
from typing import List, Optional, Union
import torch
from torch import Tensor
from torch.optim.optimizer import _get_value, _dispatch_sqrt
NPARAM = 10
TENSOR_SIZE = 1024 * 1024
kwargs = {}
kwargs['params'] = [torch.randn(TENSOR_SIZE) for _ in range(NPARAM)]
kwargs['grads'] = [torch.randn(TENSOR_SIZE) for _ in range(NPARAM)]
kwargs['exp_avgs'] = [torch.randn(TENSOR_SIZE) for _ in range(NPARAM)]
kwargs['exp_avg_sqs'] = [torch.randn(TENSOR_SIZE) for _ in range(NPARAM)]
kwargs['max_exp_avg_sqs'] = [torch.randn(TENSOR_SIZE) for _ in range(NPARAM)]
kwargs['max_exp_avg_sqs'] = []
kwargs['state_steps'] = [torch.tensor([10.0]) for _ in range(NPARAM)]
kwargs['grad_scale'] = None
kwargs['found_inf'] = None
kwargs['amsgrad'] = False
kwargs['beta1'] = 0.9
kwargs['beta2'] = 0.999
kwargs['lr'] = 0.1
kwargs['weight_decay'] = 0.0
kwargs['eps'] = 1e-8
kwargs['has_complex'] = False
kwargs['maximize'] = False
kwargs['capturable'] = False
kwargs['differentiable'] = False
def _single_tensor_adam(params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
amsgrad: bool,
has_complex: bool,
beta1: float,
beta2: float,
lr: Union[float, Tensor],
weight_decay: float,
eps: float,
maximize: bool,
capturable: bool,
differentiable: bool):
assert grad_scale is None and found_inf is None
for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step_t = state_steps[i]
# update step
step_t += 1
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)
# Decay the first and second moment running average coefficient
exp_avg.lerp_(grad, 1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
step = _get_value(step_t)
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
step_size = lr / bias_correction1
bias_correction2_sqrt = _dispatch_sqrt(bias_correction2)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
# Use the max. for normalizing running avg. of gradient
denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps)
else:
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
param.addcdiv_(exp_avg, denom, value=-step_size)
def _fused_adam(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
amsgrad: bool,
has_complex: bool, # Needed for consistency.
beta1: float,
beta2: float,
lr: Union[float, Tensor],
weight_decay: float,
eps: float,
maximize: bool,
capturable: bool, # Needed for consistency.
differentiable: bool,
) -> None:
grad_scale_dict = {grad_scale.device: grad_scale} if grad_scale is not None else None
found_inf_dict = {found_inf.device: found_inf} if found_inf is not None else None
# We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
# treating it as a scalar.
lr_dict = {lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None
grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
for (device, _), ((device_params,
device_grads,
device_exp_avgs,
device_exp_avg_sqs,
device_max_exp_avg_sqs,
device_state_steps,), _) in grouped_tensors.items():
device_grad_scale, device_found_inf = None, None
if grad_scale is not None:
if device not in grad_scale_dict:
grad_scale_dict[device] = grad_scale.to(device, non_blocking=True)
device_grad_scale = grad_scale_dict[device]
if found_inf is not None:
if found_inf not in found_inf_dict:
found_inf_dict[device] = found_inf.to(device, non_blocking=True)
device_found_inf = found_inf_dict[device]
if lr_dict is not None and device not in lr_dict:
lr_dict[device] = lr.to(device=device, non_blocking=True)
lr = lr_dict[device]
torch._foreach_add_(device_state_steps, 1)
torch._fused_adam_(
device_params,
device_grads,
device_exp_avgs,
device_exp_avg_sqs,
device_max_exp_avg_sqs,
device_state_steps,
amsgrad=amsgrad,
lr=lr,
beta1=beta1,
beta2=beta2,
weight_decay=weight_decay,
eps=eps,
maximize=maximize,
grad_scale=device_grad_scale,
found_inf=device_found_inf,
)
if device_found_inf is not None:
torch._foreach_sub_(device_state_steps, [device_found_inf] * len(device_state_steps))
a = torch.ones(256 * 1024 * 1024 // 4, dtype=torch.float)
b = torch.ones(256 * 1024 * 1024 // 4, dtype=torch.float)
def flush():
global a, b
a += b
WARMUP=100
ITERS=1000
import time
def bench(func, kwargs):
duration = 0
for _ in range(WARMUP):
flush()
kwargs['state_steps'] = [torch.tensor([10.0]) for _ in range(NPARAM)]
func(**kwargs)
for _ in range(ITERS):
flush()
kwargs['state_steps'] = [torch.tensor([10.0]) for _ in range(NPARAM)]
start = time.time()
func(**kwargs)
end = time.time()
duration += (end - start)
return duration
print("non-fused", bench(_single_tensor_adam, kwargs), "s")
print("fused", bench(_fused_adam, kwargs), "s")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment