Skip to content

Instantly share code, notes, and snippets.

@tengomucho
Created May 13, 2024 20:24
Show Gist options
  • Save tengomucho/3e6b56cb09148aa32623e82855303934 to your computer and use it in GitHub Desktop.
Save tengomucho/3e6b56cb09148aa32623e82855303934 to your computer and use it in GitHub Desktop.
Errors trying FSDP on TPU

Results in torch_xla 2.2.0

Traceback (most recent call last):
  File "/home/amoran/optimum-tpu/alvaro/tuning_gemma2b.py", line 50, in <module>
    trainer = Trainer(
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/transformers/trainer.py", line 659, in __init__
    xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor")))
AttributeError: module 'torch_xla.distributed.spmd' has no attribute 'set_global_mesh'

Results in torch_xla 2.3.0

  0%|                                                                                                | 0/100 [00:00<?, ?it/s]Exception in thread Thread-3 (_loader_worker):
Traceback (most recent call last):
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/accelerate/data_loader.py", line 464, in __iter__
There seems to be not a single sample in your epoch_iterator, stopping training at step 0! This is expected if you're using an IterableDataset and set num_steps (100) higher than the number of available samples.
    next_batch = next(dataloader_iter)
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
    data = self._next_data()
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 674, in _next_data
                                                                                                                                 index = self._next_index()  # may raise StopIteration
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 621, in _next_index
    return next(self._sampler_iter)  # may raise StopIteration
{'train_runtime': 0.0044, 'train_samples_per_second': 0.0, 'train_steps_per_second': 22822.418, 'train_loss': 0.0, 'epoch': 0}
StopIteration

During handling of the above exception, another exception occurred:

  0%|                                                                                                | 0/100 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
  0%|                                                                                                | 0/100 [00:00<?, ?it/s]
    self.run()
  File "/usr/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch_xla/distributed/parallel_loader.py", line 152, in _loader_worker
    _, data = next(data_iter)
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/accelerate/data_loader.py", line 472, in __iter__
    yield current_batch
UnboundLocalError: local variable 'current_batch' referenced before assignment

import os
os.environ["PJRT_DEVICE"] = "TPU"
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "google/gemma-2b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
# Enable SPMD mode execution
import torch_xla.runtime as xr
xr.use_spmd()
import torch_xla.core.xla_model as xm
text = "Quote: Imagination is more"
device = xm.xla_device()
model.to(device)
inputs = tokenizer(text, return_tensors="pt").to(device)
from datasets import load_dataset
data = load_dataset("Abirate/english_quotes")
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments
# Set up the FSDP config. To enable FSDP via SPMD, set xla_fsdp_v2 to True.
fsdp_config = {
"fsdp_transformer_layer_cls_to_wrap": ["GemmaDecoderLayer"],
"xla": True,
"xla_fsdp_v2": True,
"xla_fsdp_grad_ckpt": True
}
# Finally, set up the trainer and train the model.
trainer = Trainer(
model=model,
train_dataset=data,
args=TrainingArguments(
per_device_train_batch_size=64, # This is actually the global batch size for SPMD.
num_train_epochs=100,
max_steps=-1,
output_dir="./output",
optim="adafactor",
logging_steps=1,
dataloader_drop_last = True, # Required for SPMD.
fsdp="full_shard",
fsdp_config=fsdp_config,
),
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment