Created
April 1, 2024 07:50
-
-
Save zhuhaozhe/ef9a290ad3f8f4067b3373a3bdaa33e7 to your computer and use it in GitHub Desktop.
benchmark-fused-adam
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List, Optional, Union | |
import torch | |
from torch import Tensor | |
from torch.optim.optimizer import _get_value, _dispatch_sqrt | |
NPARAM = 10 | |
TENSOR_SIZE = 1024 * 1024 | |
kwargs = {} | |
kwargs['params'] = [torch.randn(TENSOR_SIZE) for _ in range(NPARAM)] | |
kwargs['grads'] = [torch.randn(TENSOR_SIZE) for _ in range(NPARAM)] | |
kwargs['exp_avgs'] = [torch.randn(TENSOR_SIZE) for _ in range(NPARAM)] | |
kwargs['exp_avg_sqs'] = [torch.randn(TENSOR_SIZE) for _ in range(NPARAM)] | |
kwargs['max_exp_avg_sqs'] = [torch.randn(TENSOR_SIZE) for _ in range(NPARAM)] | |
kwargs['max_exp_avg_sqs'] = [] | |
kwargs['state_steps'] = [torch.tensor([10.0]) for _ in range(NPARAM)] | |
kwargs['grad_scale'] = None | |
kwargs['found_inf'] = None | |
kwargs['amsgrad'] = False | |
kwargs['beta1'] = 0.9 | |
kwargs['beta2'] = 0.999 | |
kwargs['lr'] = 0.1 | |
kwargs['weight_decay'] = 0.0 | |
kwargs['eps'] = 1e-8 | |
kwargs['has_complex'] = False | |
kwargs['maximize'] = False | |
kwargs['capturable'] = False | |
kwargs['differentiable'] = False | |
def _single_tensor_adam(params: List[Tensor], | |
grads: List[Tensor], | |
exp_avgs: List[Tensor], | |
exp_avg_sqs: List[Tensor], | |
max_exp_avg_sqs: List[Tensor], | |
state_steps: List[Tensor], | |
grad_scale: Optional[Tensor], | |
found_inf: Optional[Tensor], | |
*, | |
amsgrad: bool, | |
has_complex: bool, | |
beta1: float, | |
beta2: float, | |
lr: Union[float, Tensor], | |
weight_decay: float, | |
eps: float, | |
maximize: bool, | |
capturable: bool, | |
differentiable: bool): | |
assert grad_scale is None and found_inf is None | |
for i, param in enumerate(params): | |
grad = grads[i] if not maximize else -grads[i] | |
exp_avg = exp_avgs[i] | |
exp_avg_sq = exp_avg_sqs[i] | |
step_t = state_steps[i] | |
# update step | |
step_t += 1 | |
if weight_decay != 0: | |
grad = grad.add(param, alpha=weight_decay) | |
# Decay the first and second moment running average coefficient | |
exp_avg.lerp_(grad, 1 - beta1) | |
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) | |
step = _get_value(step_t) | |
bias_correction1 = 1 - beta1 ** step | |
bias_correction2 = 1 - beta2 ** step | |
step_size = lr / bias_correction1 | |
bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) | |
if amsgrad: | |
# Maintains the maximum of all 2nd moment running avg. till now | |
torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) | |
# Use the max. for normalizing running avg. of gradient | |
denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps) | |
else: | |
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) | |
param.addcdiv_(exp_avg, denom, value=-step_size) | |
def _fused_adam( | |
params: List[Tensor], | |
grads: List[Tensor], | |
exp_avgs: List[Tensor], | |
exp_avg_sqs: List[Tensor], | |
max_exp_avg_sqs: List[Tensor], | |
state_steps: List[Tensor], | |
grad_scale: Optional[Tensor], | |
found_inf: Optional[Tensor], | |
*, | |
amsgrad: bool, | |
has_complex: bool, # Needed for consistency. | |
beta1: float, | |
beta2: float, | |
lr: Union[float, Tensor], | |
weight_decay: float, | |
eps: float, | |
maximize: bool, | |
capturable: bool, # Needed for consistency. | |
differentiable: bool, | |
) -> None: | |
grad_scale_dict = {grad_scale.device: grad_scale} if grad_scale is not None else None | |
found_inf_dict = {found_inf.device: found_inf} if found_inf is not None else None | |
# We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer | |
# treating it as a scalar. | |
lr_dict = {lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None | |
grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( | |
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]) | |
for (device, _), ((device_params, | |
device_grads, | |
device_exp_avgs, | |
device_exp_avg_sqs, | |
device_max_exp_avg_sqs, | |
device_state_steps,), _) in grouped_tensors.items(): | |
device_grad_scale, device_found_inf = None, None | |
if grad_scale is not None: | |
if device not in grad_scale_dict: | |
grad_scale_dict[device] = grad_scale.to(device, non_blocking=True) | |
device_grad_scale = grad_scale_dict[device] | |
if found_inf is not None: | |
if found_inf not in found_inf_dict: | |
found_inf_dict[device] = found_inf.to(device, non_blocking=True) | |
device_found_inf = found_inf_dict[device] | |
if lr_dict is not None and device not in lr_dict: | |
lr_dict[device] = lr.to(device=device, non_blocking=True) | |
lr = lr_dict[device] | |
torch._foreach_add_(device_state_steps, 1) | |
torch._fused_adam_( | |
device_params, | |
device_grads, | |
device_exp_avgs, | |
device_exp_avg_sqs, | |
device_max_exp_avg_sqs, | |
device_state_steps, | |
amsgrad=amsgrad, | |
lr=lr, | |
beta1=beta1, | |
beta2=beta2, | |
weight_decay=weight_decay, | |
eps=eps, | |
maximize=maximize, | |
grad_scale=device_grad_scale, | |
found_inf=device_found_inf, | |
) | |
if device_found_inf is not None: | |
torch._foreach_sub_(device_state_steps, [device_found_inf] * len(device_state_steps)) | |
a = torch.ones(256 * 1024 * 1024 // 4, dtype=torch.float) | |
b = torch.ones(256 * 1024 * 1024 // 4, dtype=torch.float) | |
def flush(): | |
global a, b | |
a += b | |
WARMUP=100 | |
ITERS=1000 | |
import time | |
def bench(func, kwargs): | |
duration = 0 | |
for _ in range(WARMUP): | |
flush() | |
kwargs['state_steps'] = [torch.tensor([10.0]) for _ in range(NPARAM)] | |
func(**kwargs) | |
for _ in range(ITERS): | |
flush() | |
kwargs['state_steps'] = [torch.tensor([10.0]) for _ in range(NPARAM)] | |
start = time.time() | |
func(**kwargs) | |
end = time.time() | |
duration += (end - start) | |
return duration | |
print("non-fused", bench(_single_tensor_adam, kwargs), "s") | |
print("fused", bench(_fused_adam, kwargs), "s") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment