Skip to content

Instantly share code, notes, and snippets.

@will-thompson-k
Created August 15, 2023 15:09
Show Gist options
  • Save will-thompson-k/f6201b68c428d0344a6affa6d53bc91b to your computer and use it in GitHub Desktop.
Save will-thompson-k/f6201b68c428d0344a6affa6d53bc91b to your computer and use it in GitHub Desktop.
multi-gpu inference via ddp for pytorch-lightning/lightning-ai models
import pytorch_lightning as pl
...
# override this method on pytorch-lightning model
def on_predict_epoch_end(self, results):
# gather all results onto each device
# find created world_size from pl.trainer
results = all_gather(results[0], WORLD_SIZE, self._device)
# concatenate on the cpu
results = torch.concat([x.cpu() for x in results], dim=1)
# output will not preserve input order.
# suggest outputing index in predict step and sort by index value here.
...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment