Skip to content

Instantly share code, notes, and snippets.

@anijain2305
Created February 5, 2022 01:01
Show Gist options
  • Save anijain2305/f63174fdf93e3403ac14328b3ba06348 to your computer and use it in GitHub Desktop.
Save anijain2305/f63174fdf93e3403ac14328b3ba06348 to your computer and use it in GitHub Desktop.
WARNING: torch sampled_addmm does not support batch indices. Benchmarked by iterating over batches. Can be improved significantly
Swin Transformer
[----------------------------------------------------------- sddmm ------------------------------------------------------------]
| torch_dense | torch_sddmm | csr_sputnik | csr_ge | coo_ge | csr_to_coo
1 threads: ---------------------------------------------------------------------------------------------------------------------
B= 96, M=3136, K= 32, prob=0.0000 | 9532.0 | 47358.3 | 20816.9 | 106767.1 | 44071.4 | 259.6
B= 192, M= 784, K= 32, prob=0.7500 | 1168.4 | 9095.9 | 728.3 | 2942.2 | 1480.8 | 117.1
B= 384, M= 196, K= 32, prob=0.9375 | 171.0 | 17444.5 | 103.0 | 91.8 | 148.3 | 117.3
B= 768, M= 49, K= 32, prob=0.9844 | 78.6 | 34361.4 | 18.7 | 12.8 | 145.7 | 116.3
Times are in microseconds (us).
ViT
[---------------------------------------------------------- sddmm ----------------------------------------------------------]
| torch_dense | torch_sddmm | csr_sputnik | csr_ge | coo_ge | csr_to_coo
1 threads: ------------------------------------------------------------------------------------------------------------------
B= 192, M= 785, K= 64, prob=0.7000 | 1988.2 | 9588.9 | 1292.4 | 4141.0 | 2110.6 | 116.8
B= 192, M= 785, K= 64, prob=0.8000 | 1988.2 | 8635.6 | 919.8 | 2755.1 | 1453.8 | 116.9
B= 192, M= 785, K= 64, prob=0.8500 | 1988.2 | 8647.1 | 810.6 | 2076.6 | 1125.0 | 117.0
B= 192, M= 785, K= 64, prob=0.9000 | 1987.9 | 8579.0 | 710.8 | 1406.7 | 800.2 | 117.0
B= 192, M= 785, K= 64, prob=0.9300 | 1988.0 | 8640.9 | 709.3 | 1002.2 | 594.1 | 117.4
B= 192, M= 785, K= 64, prob=0.9500 | 1988.1 | 8591.8 | 709.2 | 740.2 | 466.4 | 117.2
B= 192, M= 785, K= 64, prob=0.9700 | 1988.1 | 8688.1 | 708.6 | 455.3 | 325.2 | 117.2
B= 192, M= 197, K= 64, prob=0.7000 | 171.4 | 9203.8 | 108.1 | 255.3 | 245.5 | 116.8
B= 192, M= 197, K= 64, prob=0.8000 | 171.4 | 8667.0 | 90.0 | 173.3 | 193.1 | 116.5
B= 192, M= 197, K= 64, prob=0.8500 | 171.6 | 8637.7 | 67.4 | 127.3 | 169.9 | 116.1
B= 192, M= 197, K= 64, prob=0.9000 | 171.4 | 8672.2 | 59.4 | 84.1 | 147.2 | 116.5
B= 192, M= 197, K= 64, prob=0.9300 | 171.6 | 8687.2 | 58.8 | 60.6 | 146.9 | 116.0
B= 192, M= 197, K= 64, prob=0.9500 | 171.4 | 8630.8 | 57.7 | 44.3 | 147.5 | 116.5
B= 192, M= 197, K= 64, prob=0.9700 | 171.6 | 8639.0 | 57.2 | 29.2 | 145.4 | 115.9
B= 384, M= 785, K= 64, prob=0.7000 | 3950.5 | 19193.6 | 2591.5 | 8274.8 | 4112.5 | 116.8
B= 384, M= 785, K= 64, prob=0.8000 | 3952.1 | 18430.4 | 1833.5 | 5547.4 | 2809.3 | 117.1
B= 384, M= 785, K= 64, prob=0.8500 | 3950.5 | 17386.7 | 1606.8 | 4146.5 | 2138.9 | 116.4
B= 384, M= 785, K= 64, prob=0.9000 | 3950.2 | 17315.2 | 1413.4 | 2823.7 | 1495.1 | 117.3
B= 384, M= 785, K= 64, prob=0.9300 | 3950.1 | 17212.3 | 1411.8 | 2014.5 | 1085.9 | 117.1
B= 384, M= 785, K= 64, prob=0.9500 | 3950.4 | 17334.1 | 1410.9 | 1445.7 | 811.6 | 117.5
B= 384, M= 785, K= 64, prob=0.9700 | 3950.3 | 17285.7 | 1410.0 | 917.9 | 547.0 | 117.4
B= 384, M= 197, K= 64, prob=0.7000 | 312.1 | 18479.7 | 207.6 | 506.6 | 385.1 | 116.7
B= 384, M= 197, K= 64, prob=0.8000 | 311.6 | 18497.8 | 178.6 | 348.3 | 286.6 | 116.0
B= 384, M= 197, K= 64, prob=0.8500 | 311.8 | 17386.4 | 148.8 | 265.8 | 241.5 | 117.1
B= 384, M= 197, K= 64, prob=0.9000 | 311.6 | 17367.0 | 136.6 | 188.0 | 201.9 | 116.1
B= 384, M= 197, K= 64, prob=0.9300 | 311.8 | 17271.5 | 123.2 | 132.8 | 171.5 | 116.9
B= 384, M= 197, K= 64, prob=0.9500 | 311.6 | 17271.2 | 114.9 | 104.6 | 156.9 | 116.4
B= 384, M= 197, K= 64, prob=0.9700 | 311.6 | 17306.8 | 110.0 | 69.4 | 147.6 | 116.1
Times are in microseconds (us).
Basic cases
[---------------------------------------------------------- sddmm ----------------------------------------------------------]
| torch_dense | torch_sddmm | csr_sputnik | csr_ge | coo_ge | csr_to_coo
1 threads: ------------------------------------------------------------------------------------------------------------------
B= 32, M=1024, K= 32, prob=0.9000 | 343.3 | 1443.3 | 197.8 | 341.8 | 264.6 | 116.9
B= 32, M=1024, K= 32, prob=0.9300 | 343.8 | 1437.4 | 197.6 | 241.7 | 216.7 | 116.6
B= 32, M=1024, K= 32, prob=0.9500 | 343.3 | 1443.0 | 197.6 | 175.5 | 185.5 | 117.4
B= 32, M=1024, K= 32, prob=0.9700 | 343.6 | 1445.3 | 197.5 | 108.4 | 154.9 | 117.2
B= 32, M=1024, K= 32, prob=0.9800 | 343.4 | 1436.8 | 197.3 | 74.9 | 146.8 | 116.6
B= 32, M=1024, K= 32, prob=0.9900 | 343.5 | 1440.1 | 197.2 | 41.3 | 146.8 | 117.4
B= 32, M=1024, K= 32, prob=0.9950 | 343.2 | 1441.6 | 197.2 | 23.5 | 146.5 | 117.0
B= 32, M=1024, K= 32, prob=0.9990 | 343.7 | 1435.8 | 197.0 | 12.6 | 147.0 | 116.8
B= 32, M=1024, K=128, prob=0.9000 | 382.4 | 1447.9 | 310.8 | 584.9 | 437.7 | 117.3
B= 32, M=1024, K=128, prob=0.9300 | 383.0 | 1445.3 | 250.6 | 417.5 | 345.4 | 117.0
B= 32, M=1024, K=128, prob=0.9500 | 382.2 | 1455.9 | 217.5 | 307.3 | 280.2 | 116.9
B= 32, M=1024, K=128, prob=0.9700 | 382.8 | 1443.7 | 209.6 | 191.4 | 216.7 | 116.5
B= 32, M=1024, K=128, prob=0.9800 | 382.4 | 1439.3 | 210.9 | 135.8 | 187.7 | 117.3
B= 32, M=1024, K=128, prob=0.9900 | 382.8 | 1438.7 | 208.8 | 78.6 | 153.8 | 116.6
B= 32, M=1024, K=128, prob=0.9950 | 382.3 | 1445.9 | 207.5 | 50.2 | 151.2 | 116.6
B= 32, M=1024, K=128, prob=0.9990 | 382.9 | 1441.7 | 199.1 | 12.7 | 146.3 | 117.2
B= 8, M=4096, K= 32, prob=0.9000 | 1363.4 | 1181.9 | 958.9 | 1563.0 | 739.1 | 122.8
B= 8, M=4096, K= 32, prob=0.9300 | 1365.9 | 1002.3 | 925.4 | 1105.4 | 552.9 | 122.7
B= 8, M=4096, K= 32, prob=0.9500 | 1364.0 | 852.7 | 886.5 | 799.0 | 426.1 | 118.5
B= 8, M=4096, K= 32, prob=0.9700 | 1364.6 | 754.5 | 834.1 | 485.9 | 302.6 | 120.3
B= 8, M=4096, K= 32, prob=0.9800 | 1364.9 | 707.7 | 825.7 | 319.4 | 239.9 | 121.7
B= 8, M=4096, K= 32, prob=0.9900 | 1363.6 | 642.0 | 780.4 | 161.6 | 174.1 | 122.5
B= 8, M=4096, K= 32, prob=0.9950 | 1363.8 | 588.0 | 780.0 | 84.8 | 149.7 | 121.4
B= 8, M=4096, K= 32, prob=0.9990 | 1363.7 | 574.6 | 779.9 | 23.2 | 150.5 | 122.8
B= 8, M=4096, K=128, prob=0.9000 | 1500.9 | 2001.2 | 1492.9 | 2251.7 | 1450.8 | 118.6
B= 8, M=4096, K=128, prob=0.9300 | 1501.6 | 1692.7 | 1336.1 | 1597.7 | 1040.2 | 118.0
B= 8, M=4096, K=128, prob=0.9500 | 1500.8 | 1455.2 | 1189.8 | 1169.3 | 778.2 | 117.7
B= 8, M=4096, K=128, prob=0.9700 | 1500.8 | 1234.7 | 1021.0 | 733.4 | 516.6 | 118.3
B= 8, M=4096, K=128, prob=0.9800 | 1503.7 | 1100.3 | 1002.5 | 513.8 | 385.0 | 118.6
B= 8, M=4096, K=128, prob=0.9900 | 1500.8 | 959.4 | 793.9 | 281.9 | 254.2 | 117.8
B= 8, M=4096, K=128, prob=0.9950 | 1501.2 | 846.2 | 780.4 | 157.3 | 190.2 | 116.9
B= 8, M=4096, K=128, prob=0.9990 | 1501.0 | 824.8 | 780.5 | 48.6 | 154.6 | 121.4
Times are in microseconds (us).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment