Skip to content

Instantly share code, notes, and snippets.

@cynthia
Created January 13, 2020 13:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cynthia/a208d04b8aa0f29fe114b0da13225c6b to your computer and use it in GitHub Desktop.
Save cynthia/a208d04b8aa0f29fe114b0da13225c6b to your computer and use it in GitHub Desktop.
# mish_original(), (c) 2019 Diganta Misra
# https://github.com/digantamisra98/Mish/blob/master/LICENSE
#
# Other functions MIT licensed, (c) 2019 Sangwhan Moon
# In [53]: %timeit mish_original(x)
# 30.2 µs ± 184 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
# In [54]: %timeit mish1(x)
# 194 µs ± 1.68 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
# In [55]: %timeit mish1s(x)
# 12.7 µs ± 36.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
# In [56]: %timeit mish2(x)
# 149 µs ± 182 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
# In [57]: %timeit mish2s(x)
# 12.4 µs ± 35.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
def mish_original(x):
return x * torch.tanh(torch.nn.functional.softplus(x))
def mish1(x):
y = torch.exp(-x)
return x * (1 + 2 * y) / (1 + 2 * y + 2 * (y ** 2))
@torch.jit.script
def mish1s(x):
y = torch.exp(-x)
return x * (1 + 2 * y) / (1 + 2 * y + 2 * (y ** 2))
def mish2(x):
y = torch.exp(-x)
r = 1 + 2 * y
return x * r / (r + 2 * (y ** 2))
@torch.jit.script
def mish2s(x):
y = torch.exp(-x)
r = 1 + 2 * y
return x * r / (r + 2 * (y ** 2))
@cynthia
Copy link
Author

cynthia commented Jan 13, 2020

Profiling over 100 runs after 10 warmup runs.
Profiling on Tesla P100-SXM2-16GB
Testing on torch.float16:
relu_fwd: 157.4µs ± 4.237µs (148.8µs - 196.3µs)
relu_bwd: 237.7µs ± 6.121µs (233.2µs - 283.4µs)
softplus_fwd: 268.9µs ± 7.435µs (248.1µs - 288.5µs)
softplus_bwd: 370.0µs ± 10.66µs (365.2µs - 398.5µs)
mish_pt_fwd: 404.9µs ± 7.229µs (390.9µs - 446.3µs)
mish_pt_bwd: 582.3µs ± 2.989µs (574.8µs - 589.2µs)
mish_cuda_fwd: 601.7µs ± 6.060µs (598.5µs - 646.5µs)
mish_cuda_bwd: 382.3µs ± 357.4ns (381.7µs - 383.1µs)
Testing on torch.float32:
relu_fwd: 200.5µs ± 3.429µs (197.0µs - 220.0µs)
relu_bwd: 321.0µs ± 2.286µs (316.1µs - 326.0µs)
softplus_fwd: 235.3µs ± 1.350µs (233.7µs - 242.7µs)
softplus_bwd: 317.1µs ± 2.095µs (314.9µs - 331.6µs)
mish_pt_fwd: 641.2µs ± 13.21µs (628.4µs - 712.0µs)
mish_pt_bwd: 678.6µs ± 2.878µs (672.5µs - 687.2µs)
mish_cuda_fwd: 619.2µs ± 4.172µs (616.1µs - 645.7µs)
mish_cuda_bwd: 413.6µs ± 1.047µs (412.3µs - 422.2µs)
Testing on torch.float64:
relu_fwd: 340.5µs ± 4.800µs (337.9µs - 384.7µs)
relu_bwd: 596.1µs ± 1.813µs (594.8µs - 613.3µs)
softplus_fwd: 477.7µs ± 3.943µs (458.6µs - 500.9µs)
softplus_bwd: 606.6µs ± 5.622µs (604.6µs - 652.6µs)
mish_pt_fwd: 1.106ms ± 5.360µs (1.075ms - 1.132ms)
mish_pt_bwd: 1.260ms ± 3.745µs (1.255ms - 1.283ms)
mish_cuda_fwd: 655.8µs ± 2.783µs (651.7µs - 660.3µs)
mish_cuda_bwd: 849.2µs ± 2.882µs (846.0µs - 854.0µs)

@cynthia
Copy link
Author

cynthia commented Jan 13, 2020

Before:

Profiling over 100 runs after 10 warmup runs.
Profiling on Tesla P100-SXM2-16GB
Testing on torch.float16:
relu_fwd: 158.1µs ± 3.901µs (150.1µs - 187.8µs)
relu_bwd: 237.4µs ± 4.772µs (231.5µs - 269.5µs)
softplus_fwd: 284.6µs ± 4.315µs (281.3µs - 320.5µs)
softplus_bwd: 391.0µs ± 3.010µs (387.6µs - 414.0µs)
mish_pt_fwd: 581.5µs ± 7.529µs (570.7µs - 599.7µs)
mish_pt_bwd: 909.6µs ± 14.15µs (888.2µs - 963.1µs)
mish_cuda_fwd: 287.1µs ± 2.513µs (263.6µs - 294.5µs)
mish_cuda_bwd: 382.6µs ± 2.828µs (381.6µs - 410.2µs)
Testing on torch.float32:
relu_fwd: 199.6µs ± 2.611µs (196.4µs - 221.9µs)
relu_bwd: 320.7µs ± 3.608µs (315.7µs - 347.3µs)
softplus_fwd: 235.4µs ± 1.209µs (232.9µs - 241.7µs)
softplus_bwd: 316.0µs ± 6.212µs (311.3µs - 357.2µs)
mish_pt_fwd: 655.6µs ± 3.063µs (650.2µs - 671.0µs)
mish_pt_bwd: 1.244ms ± 2.747µs (1.238ms - 1.253ms)
mish_cuda_fwd: 292.7µs ± 1.124µs (291.2µs - 299.9µs)
mish_cuda_bwd: 414.0µs ± 3.148µs (412.3µs - 443.8µs)
Testing on torch.float64:
relu_fwd: 340.0µs ± 1.951µs (332.0µs - 353.1µs)
relu_bwd: 595.8µs ± 3.008µs (594.4µs - 625.3µs)
softplus_fwd: 478.5µs ± 3.584µs (460.0µs - 503.9µs)
softplus_bwd: 606.0µs ± 2.856µs (604.5µs - 634.0µs)
mish_pt_fwd: 1.253ms ± 3.263µs (1.249ms - 1.279ms)
mish_pt_bwd: 2.430ms ± 6.424µs (2.426ms - 2.479ms)
mish_cuda_fwd: 646.6µs ± 12.28µs (639.3µs - 704.5µs)
mish_cuda_bwd: 847.1µs ± 536.3ns (846.1µs - 851.0µs)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment