Skip to content

Instantly share code, notes, and snippets.

@piercelamb
Created December 19, 2022 22:54
Show Gist options
  • Save piercelamb/f12fa8a391928fb08bd4409cf336dd5c to your computer and use it in GitHub Desktop.
Save piercelamb/f12fa8a391928fb08bd4409cf336dd5c to your computer and use it in GitHub Desktop.
accelerator_accumulate
total_loss = 0.0
for batch in dataloader:
if config.is_comparison:
batch = get_model_specific_batch(batch, model_name)
with accelerator.accumulate(model):
outputs = model(**batch)
loss = outputs.loss
total_loss += loss.detach().float()
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
global_step += 1
return total_loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment