Skip to content

Instantly share code, notes, and snippets.

@muellerzr
Last active June 14, 2022 18:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save muellerzr/bcc12e97d04681e439f8263e2b12d0a9 to your computer and use it in GitHub Desktop.
Save muellerzr/bcc12e97d04681e439f8263e2b12d0a9 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import torch
from torch.utils.data import DataLoader
from accelerate import Accelerator
from accelerate.data_loader import prepare_data_loader
from accelerate.state import AcceleratorState
from accelerate.test_utils import RegressionDataset, RegressionModel, are_the_same_tensors
from accelerate.utils import DistributedType, gather, is_torch_version, set_seed, synchronize_rng_states
from torch_xla.distributed.parallel_loader import MpDeviceLoader
def main():
accelerator = Accelerator()
state = AcceleratorState()
length = 32 * state.num_processes
dl = DataLoader(range(length), batch_size=8)
dl_ac = prepare_data_loader(dl, state.device, state.num_processes, state.process_index, put_on_device=True)
def _time_dl(dl):
start_time = time.time()
result = []
for batch in dl:
result.append(gather(batch))
result = torch.cat(result)
return result, time.time() - start_time
dl_ac_result, dl_ac_time = _time_dl(dl_ac)
prepare_data_loader()
dl_reg = prepare_data_loader(
dl,
state.process_index,
state.num_processes,
put_on_device=False
)
dl_reg = MpDeviceLoader(dl_reg, state.device)
dl_reg_result, dl_reg_time = _time_dl(dl_reg)
print(f'Accelerator DataLoader Time on device {state.device}, process {state.process_index}: {dl_ac_time} (s)')
print(f'XLA DataLoader Time on device {state.device}, process {state.process_index}: {dl_reg_time} (s)')
print("Test Results align")
assert torch.allclose(dl_ac_result, dl_reg_result), "Tensors not the same!"
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment