Last active
March 21, 2024 18:23
-
-
Save mlazos/6a6d45f7d5cf2ab41a7a77bf96862bae to your computer and use it in GitHub Desktop.
init_per_param
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def init_state_per_param(self, param, param_group): | |
state = self.state[param] | |
if len(state) == 0: | |
# note(crcrpar): [special device hosting for step] | |
# Deliberately host `step` on CPU if both capturable and fused are off. | |
# This is because kernel launches are costly on CUDA and XLA. | |
state['step'] = ( | |
torch.zeros((), dtype=_get_scalar_dtype(is_fused=param_group['fused']), device=param.device) | |
if param_group['capturable'] or param_group['fused'] | |
else torch.tensor(0.0, dtype=_get_scalar_dtype()) | |
) | |
# Exponential moving average of gradient values | |
state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format) | |
# Exponential moving average of squared gradient values | |
state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format) | |
if param_group['amsgrad']: | |
# Maintains max of all exp. moving avg. of sq. grad. values | |
state['max_exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format) | |
def _init_group( | |
self, | |
group, | |
params_with_grad, | |
grads, | |
exp_avgs, | |
exp_avg_sqs, | |
max_exp_avg_sqs, | |
state_steps | |
): | |
has_complex = False | |
for p in group['params']: | |
if p.grad is not None: | |
has_complex |= torch.is_complex(p) | |
params_with_grad.append(p) | |
if p.grad.is_sparse: | |
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') | |
grads.append(p.grad) | |
# Lazy state initialization | |
# we call this function directly in dynamo to | |
# avoid double compilation | |
self.init_state_per_param(p, group) | |
state = self.state[p] | |
exp_avgs.append(state['exp_avg']) | |
exp_avg_sqs.append(state['exp_avg_sq']) | |
if group['amsgrad']: | |
max_exp_avg_sqs.append(state['max_exp_avg_sq']) | |
if group['differentiable'] and state['step'].requires_grad: | |
raise RuntimeError('`requires_grad` is not supported for `step` in differentiable mode') | |
# Foreach without capturable does not support a tensor lr | |
if group['foreach'] and torch.is_tensor(group['lr']) and not group['capturable']: | |
raise RuntimeError('lr as a Tensor is not supported for capturable=False and foreach=True') | |
state_steps.append(state['step']) | |
return has_complex |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment