Skip to content

Instantly share code, notes, and snippets.

@merrymercy
Last active March 3, 2021 19:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save merrymercy/c8206743fb35c1c173e99966bd122502 to your computer and use it in GitHub Desktop.
Save merrymercy/c8206743fb35c1c173e99966bd122502 to your computer and use it in GitHub Desktop.
# Style 1
@auto_parallel
def step(batch, weight):
grads = grad(loss_func)(batch, weight)
# do not know where to insert pmean
new_weight = optimier_step(grads)
return new_weight # REQUIREMENT: new_weight and weight maps
for i in range(epoch):
weight = step(batch, weight)
# Style 2
def grad_func(batch, weight):
return grad(loss_func)(batch, weight)
def update_func(grad, weight):
new_weight = optimier_step(grads)
return new_weight
for i in range(epoch):
weight = step(batch, weight)
# Auto parallel
def train_parallel(batch, weights):
step_parallel, process_batch, process_weight =\
auto_parallel(step_serial, batch, weights)
weights = process_weight(weights)
for i in range(n_epoch):
pbatch = process_batch(batch)
weights = step_parallel(pbatch, weights)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment