Skip to content

Instantly share code, notes, and snippets.

@mlazos
Last active March 21, 2024 18:23
Show Gist options
  • Save mlazos/6a6d45f7d5cf2ab41a7a77bf96862bae to your computer and use it in GitHub Desktop.
Save mlazos/6a6d45f7d5cf2ab41a7a77bf96862bae to your computer and use it in GitHub Desktop.
init_per_param
def init_state_per_param(self, param, param_group):
state = self.state[param]
if len(state) == 0:
# note(crcrpar): [special device hosting for step]
# Deliberately host `step` on CPU if both capturable and fused are off.
# This is because kernel launches are costly on CUDA and XLA.
state['step'] = (
torch.zeros((), dtype=_get_scalar_dtype(is_fused=param_group['fused']), device=param.device)
if param_group['capturable'] or param_group['fused']
else torch.tensor(0.0, dtype=_get_scalar_dtype())
)
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)
if param_group['amsgrad']:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)
def _init_group(
self,
group,
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps
):
has_complex = False
for p in group['params']:
if p.grad is not None:
has_complex |= torch.is_complex(p)
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
grads.append(p.grad)
# Lazy state initialization
# we call this function directly in dynamo to
# avoid double compilation
self.init_state_per_param(p, group)
state = self.state[p]
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
if group['amsgrad']:
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
if group['differentiable'] and state['step'].requires_grad:
raise RuntimeError('`requires_grad` is not supported for `step` in differentiable mode')
# Foreach without capturable does not support a tensor lr
if group['foreach'] and torch.is_tensor(group['lr']) and not group['capturable']:
raise RuntimeError('lr as a Tensor is not supported for capturable=False and foreach=True')
state_steps.append(state['step'])
return has_complex
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment