Skip to content

Instantly share code, notes, and snippets.

@dlibenzi
Created June 1, 2020 19:05
Show Gist options
  • Save dlibenzi/a01647cd204ebfb9e3bc0d1a8cb3eb51 to your computer and use it in GitHub Desktop.
Save dlibenzi/a01647cd204ebfb9e3bc0d1a8cb3eb51 to your computer and use it in GitHub Desktop.
import sys
import torch
import torch_xla
import torch_xla.core.functions as xf
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
def big_mm(w, x, split=1):
ordinal = xm.get_ordinal()
# w = N x Ko
# WG = Ko * WORLD_SIZE
# x = WG x M
assert x.size(0) // xm.xrt_world_size() == w.size(1)
splits = []
if split != 1:
size = x.size(1)
assert size % split == 0
split_size = size // split
splits = torch.split(x, split_size, dim=1)
else:
splits.append(x)
results = []
for xs in splits:
# xg = WG x (M * WORLD_SIZE)
xg = xf.all_gather(xs, dim=1)
# xgn = Ko x (M * WORLD_SIZE)
xgn = torch.narrow(xg, 0, ordinal * w.size(1), w.size(1))
# wxg = N x (M * WORLD_SIZE)
wxg = w @ xgn
# rwxg = N x (M * WORLD_SIZE)
rwxg = xf.all_reduce(xm.REDUCE_SUM, wxg)
# wx = N x M
wx = torch.narrow(rwxg, 1, ordinal * xs.size(1), xs.size(1))
results.append(wx)
return torch.cat(results, dim=1) if len(results) > 1 else results[0]
def _mp_fn(index):
device = xm.xla_device()
if xm.xla_device_hw(device) != 'CPU':
torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
use_full_mat_mul_precision=True)
torch.manual_seed(11)
xm.set_rng_state(11)
KO = 2
wsize = KO * xm.xrt_world_size()
wg = torch.randn(3, wsize, device=device, requires_grad=True)
w = torch.narrow(wg, 1, index * KO, KO)
x = torch.randn(wsize, 4, device=device)
mm = wg @ x
bmm = big_mm(w, x, split=1)
mm_cpu = mm.cpu()
bmm_cpu = bmm.cpu()
if not mm_cpu.allclose(bmm_cpu, rtol=1e-04, atol=1e-04):
print('big_mm() produced wrong result', file=sys.stderr)
print('[{}]\n{}\n{}'.format(index, mm_cpu, bmm_cpu), file=sys.stderr)
sys.exit(1)
else:
print(
'Default device {} does not support replication'.format(device),
file=sys.stderr)
if __name__ == '__main__':
xmp.spawn(_mp_fn, nprocs=None)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment