Skip to content

Instantly share code, notes, and snippets.

@dmarx
Last active November 15, 2023 17:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dmarx/dac6bfc004bd99bca125f89d0c7e3f4a to your computer and use it in GitHub Desktop.
Save dmarx/dac6bfc004bd99bca125f89d0c7e3f4a to your computer and use it in GitHub Desktop.
no idea if this would work, just sketching out the idea
eff_batch_size=2048
#########################
# what you're doing-ish #
#########################
gradient_accumulation_steps = 4
micro_batch_size = eff_batch_size // gradient_accumulation_steps
eff_minibatch_loss = 0
minibatch = fetch_data(eff_batch_size)
for i in range(gradient_accumulation_steps):
start, end = i*micro_batch_size, (i+1)*micro_batch_size -1
x = minibatch[start:end]
micro_batch_loss = loss(x)
eff_minibatch_loss += micro_batch_loss / gradient_accumulation_steps
eff_minibatch_loss.backward()
optimizer.step()
optimizer.zero_grad()
###########################
# what i'm suggesting-ish #
###########################
# you probably don't need to turn up your GAS this high, but you'll def want to go bigger than 4
gradient_accumulation_steps = floor(sqrt(eff_batch_size))
# doesn't need to be scaled relative to GAS, just however big you can handle
micro_batch_size = eff_batch_size / 4
eff_minibatch_loss = 0
minibatch = fetch_data(eff_batch_size)
for _ in range(gradient_accumulation_steps):
x = sample(minibatch, micro_batch_size) # bootstrap that shit
micro_batch_loss = loss(x)
eff_minibatch_loss += micro_batch_loss / gradient_accumulation_steps
eff_minibatch_loss.backward()
optimizer.step()
optimizer.zero_grad()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment