Skip to content

Instantly share code, notes, and snippets.

@hjm-aws
Last active November 15, 2022 18:58
Show Gist options
  • Save hjm-aws/d3b402535db6729b30678eab15faafda to your computer and use it in GitHub Desktop.
Save hjm-aws/d3b402535db6729b30678eab15faafda to your computer and use it in GitHub Desktop.
Example demonstrating graph variation when callbacks accumulated in bwd pass
import os
import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
from fairscale.nn.wrap import enable_wrap, auto_wrap, default_auto_wrap_policy
from transformers import BertTokenizer, BertForMaskedLM, BertConfig
import functools
num_iterations = 6
coalesce_cc_ops = True
fsdp_params = dict(flatten_parameters=False,
shard_param_on_dim_0=True,
optimization_barrier_in_forward=True,
optimization_barrier_in_backward=True,
disable_reshard_on_root=False,
coalesce_all_gather_ops=coalesce_cc_ops,
reduce_scatter_bucket_size_mb=20 if coalesce_cc_ops else 0)
def fsdp_wrap(module, device):
min_num_params = 1e7
with enable_wrap(wrapper_cls=FSDP, **fsdp_params):
auto_wrap_policy = functools.partial(default_auto_wrap_policy,
min_num_params=min_num_params)
module = auto_wrap(module.to(device),
auto_wrap_policy=auto_wrap_policy)
for sub_module in module.modules():
if isinstance(sub_module, FSDP):
sub_module.set_gradient_divide_factors(1, 1, True)
return FSDP(module, **fsdp_params)
# The following will not repro the problem:
# def fsdp_wrap(module, device):
# module = FSDP(module.to(device), **fsdp_params)
# module.set_gradient_divide_factors(1, 1, True)
# return module
def _mp_fn(index):
if index == 0:
import datetime
ts = datetime.datetime.now().strftime("%m%d-%H:%M:%S")
os.environ['XLA_SAVE_TENSORS_FILE'] = f'reduce_scatter_test-{ts}.hlo'
os.environ['XLA_SAVE_TENSORS_FMT'] = 'hlo'
os.environ['XLA_IR_DEBUG'] = '1'
os.environ['XLA_HLO_DEBUG'] = '1'
device = xm.xla_device()
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
configuration = BertConfig(num_hidden_layers=1, num_attention_heads=1)
model = BertForMaskedLM(configuration).to(device)
model = fsdp_wrap(model, device)
xm.master_print('=== === === === ====== Model:.', flush=True)
xm.master_print(model, flush=True)
inputs = tokenizer("The capital of France is [MASK].",
return_tensors="pt").to(device)
labels = tokenizer("The capital of France is Paris.",
return_tensors="pt")["input_ids"].to(device)
labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels,
-100)
for i in range(num_iterations):
outputs = model(**inputs, labels=labels)
loss = outputs.loss
loss.backward()
xm.mark_step()
xm.rendezvous('some tag')
xm.master_print('=== ~~~~~~~~~~~~~ Training done. ', flush=True)
if __name__ == '__main__':
xmp.spawn(_mp_fn)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment