Last active
November 15, 2022 18:58
-
-
Save hjm-aws/d3b402535db6729b30678eab15faafda to your computer and use it in GitHub Desktop.
Example demonstrating graph variation when callbacks accumulated in bwd pass
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
import torch_xla.core.xla_model as xm | |
import torch_xla.distributed.xla_multiprocessing as xmp | |
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP | |
from fairscale.nn.wrap import enable_wrap, auto_wrap, default_auto_wrap_policy | |
from transformers import BertTokenizer, BertForMaskedLM, BertConfig | |
import functools | |
num_iterations = 6 | |
coalesce_cc_ops = True | |
fsdp_params = dict(flatten_parameters=False, | |
shard_param_on_dim_0=True, | |
optimization_barrier_in_forward=True, | |
optimization_barrier_in_backward=True, | |
disable_reshard_on_root=False, | |
coalesce_all_gather_ops=coalesce_cc_ops, | |
reduce_scatter_bucket_size_mb=20 if coalesce_cc_ops else 0) | |
def fsdp_wrap(module, device): | |
min_num_params = 1e7 | |
with enable_wrap(wrapper_cls=FSDP, **fsdp_params): | |
auto_wrap_policy = functools.partial(default_auto_wrap_policy, | |
min_num_params=min_num_params) | |
module = auto_wrap(module.to(device), | |
auto_wrap_policy=auto_wrap_policy) | |
for sub_module in module.modules(): | |
if isinstance(sub_module, FSDP): | |
sub_module.set_gradient_divide_factors(1, 1, True) | |
return FSDP(module, **fsdp_params) | |
# The following will not repro the problem: | |
# def fsdp_wrap(module, device): | |
# module = FSDP(module.to(device), **fsdp_params) | |
# module.set_gradient_divide_factors(1, 1, True) | |
# return module | |
def _mp_fn(index): | |
if index == 0: | |
import datetime | |
ts = datetime.datetime.now().strftime("%m%d-%H:%M:%S") | |
os.environ['XLA_SAVE_TENSORS_FILE'] = f'reduce_scatter_test-{ts}.hlo' | |
os.environ['XLA_SAVE_TENSORS_FMT'] = 'hlo' | |
os.environ['XLA_IR_DEBUG'] = '1' | |
os.environ['XLA_HLO_DEBUG'] = '1' | |
device = xm.xla_device() | |
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
configuration = BertConfig(num_hidden_layers=1, num_attention_heads=1) | |
model = BertForMaskedLM(configuration).to(device) | |
model = fsdp_wrap(model, device) | |
xm.master_print('=== === === === ====== Model:.', flush=True) | |
xm.master_print(model, flush=True) | |
inputs = tokenizer("The capital of France is [MASK].", | |
return_tensors="pt").to(device) | |
labels = tokenizer("The capital of France is Paris.", | |
return_tensors="pt")["input_ids"].to(device) | |
labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, | |
-100) | |
for i in range(num_iterations): | |
outputs = model(**inputs, labels=labels) | |
loss = outputs.loss | |
loss.backward() | |
xm.mark_step() | |
xm.rendezvous('some tag') | |
xm.master_print('=== ~~~~~~~~~~~~~ Training done. ', flush=True) | |
if __name__ == '__main__': | |
xmp.spawn(_mp_fn) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment