Created
January 13, 2020 13:33
-
-
Save cynthia/a208d04b8aa0f29fe114b0da13225c6b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# mish_original(), (c) 2019 Diganta Misra | |
# https://github.com/digantamisra98/Mish/blob/master/LICENSE | |
# | |
# Other functions MIT licensed, (c) 2019 Sangwhan Moon | |
# In [53]: %timeit mish_original(x) | |
# 30.2 µs ± 184 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) | |
# In [54]: %timeit mish1(x) | |
# 194 µs ± 1.68 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) | |
# In [55]: %timeit mish1s(x) | |
# 12.7 µs ± 36.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) | |
# In [56]: %timeit mish2(x) | |
# 149 µs ± 182 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) | |
# In [57]: %timeit mish2s(x) | |
# 12.4 µs ± 35.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) | |
def mish_original(x): | |
return x * torch.tanh(torch.nn.functional.softplus(x)) | |
def mish1(x): | |
y = torch.exp(-x) | |
return x * (1 + 2 * y) / (1 + 2 * y + 2 * (y ** 2)) | |
@torch.jit.script | |
def mish1s(x): | |
y = torch.exp(-x) | |
return x * (1 + 2 * y) / (1 + 2 * y + 2 * (y ** 2)) | |
def mish2(x): | |
y = torch.exp(-x) | |
r = 1 + 2 * y | |
return x * r / (r + 2 * (y ** 2)) | |
@torch.jit.script | |
def mish2s(x): | |
y = torch.exp(-x) | |
r = 1 + 2 * y | |
return x * r / (r + 2 * (y ** 2)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Before:
Profiling over 100 runs after 10 warmup runs.
Profiling on Tesla P100-SXM2-16GB
Testing on torch.float16:
relu_fwd: 158.1µs ± 3.901µs (150.1µs - 187.8µs)
relu_bwd: 237.4µs ± 4.772µs (231.5µs - 269.5µs)
softplus_fwd: 284.6µs ± 4.315µs (281.3µs - 320.5µs)
softplus_bwd: 391.0µs ± 3.010µs (387.6µs - 414.0µs)
mish_pt_fwd: 581.5µs ± 7.529µs (570.7µs - 599.7µs)
mish_pt_bwd: 909.6µs ± 14.15µs (888.2µs - 963.1µs)
mish_cuda_fwd: 287.1µs ± 2.513µs (263.6µs - 294.5µs)
mish_cuda_bwd: 382.6µs ± 2.828µs (381.6µs - 410.2µs)
Testing on torch.float32:
relu_fwd: 199.6µs ± 2.611µs (196.4µs - 221.9µs)
relu_bwd: 320.7µs ± 3.608µs (315.7µs - 347.3µs)
softplus_fwd: 235.4µs ± 1.209µs (232.9µs - 241.7µs)
softplus_bwd: 316.0µs ± 6.212µs (311.3µs - 357.2µs)
mish_pt_fwd: 655.6µs ± 3.063µs (650.2µs - 671.0µs)
mish_pt_bwd: 1.244ms ± 2.747µs (1.238ms - 1.253ms)
mish_cuda_fwd: 292.7µs ± 1.124µs (291.2µs - 299.9µs)
mish_cuda_bwd: 414.0µs ± 3.148µs (412.3µs - 443.8µs)
Testing on torch.float64:
relu_fwd: 340.0µs ± 1.951µs (332.0µs - 353.1µs)
relu_bwd: 595.8µs ± 3.008µs (594.4µs - 625.3µs)
softplus_fwd: 478.5µs ± 3.584µs (460.0µs - 503.9µs)
softplus_bwd: 606.0µs ± 2.856µs (604.5µs - 634.0µs)
mish_pt_fwd: 1.253ms ± 3.263µs (1.249ms - 1.279ms)
mish_pt_bwd: 2.430ms ± 6.424µs (2.426ms - 2.479ms)
mish_cuda_fwd: 646.6µs ± 12.28µs (639.3µs - 704.5µs)
mish_cuda_bwd: 847.1µs ± 536.3ns (846.1µs - 851.0µs)