Skip to content

Instantly share code, notes, and snippets.

@stephenroller
Created August 1, 2019 13:56
Show Gist options
  • Save stephenroller/ae87eef9d704fcf5797067f15a4b742e to your computer and use it in GitHub Desktop.
Save stephenroller/ae87eef9d704fcf5797067f15a4b742e to your computer and use it in GitHub Desktop.
FusedLayerNorm cannot handle batchsize >= 2**16
#!/usr/bin/env python
"""
Results of running. Seems indifferent to --dim and --eps.
$ python flntest.py --batchsize 65535
Worse case difference: 2.86102294921875e-06
Average case difference: 3.698113104633194e-08
$ python flntest.py --batchsize 65536
Worse case difference: 14.040550231933594
Average case difference: 1.076439619064331
Failure
"""
import sys
import argparse
import torch
import torch.nn.functional as F
import apex.normalization.fused_layer_norm as apexnorm
def test(batchsize, dim, eps):
weight = torch.randn(dim).cuda()
bias = torch.randn(dim).cuda()
# input
X = torch.randn(batchsize, dim).cuda()
# using torch's layernorm
Yapx = apexnorm.FusedLayerNormAffineFunction.apply(X, weight, bias, (dim,), eps)
Ypyt = F.layer_norm(X, (dim,), weight, bias, eps)
print("Worse case difference: {}".format((Yapx - Ypyt).abs().max()))
print("Average case difference: {}".format((Yapx - Ypyt).abs().mean()))
return Yapx.allclose(Ypyt, atol=1e-5)
def main():
ap = argparse.ArgumentParser()
ap.add_argument('-b', '--batchsize', type=int, default=128)
ap.add_argument('-d', '--dim', type=int, default=512)
ap.add_argument('-e', '--eps', type=float, default=1e-6)
args = ap.parse_args()
result = test(args.batchsize, args.dim, args.eps)
if not result:
print("Failure")
sys.exit(1)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment