Skip to content

Instantly share code, notes, and snippets.

@cpuhrsch
Created August 16, 2023 20:26
Show Gist options
  • Save cpuhrsch/7fec60079cbe2daeff59c0577f933320 to your computer and use it in GitHub Desktop.
Save cpuhrsch/7fec60079cbe2daeff59c0577f933320 to your computer and use it in GitHub Desktop.
sparse.py
import torch
import torch.nn.functional as F
import itertools
import torch.utils.benchmark as benchmark
import math
dtype = torch.float16
device = "cuda"
def create_blocked_tensor(M, N, blocksize, sparsity):
assert sparsity <= 1.0 and sparsity >= 0.0, \
"sparsity should be a value between 0 and 1"
A = torch.bernoulli(torch.full((M//blocksize, N//blocksize),
1 - sparsity, dtype=dtype, device=device))
A = torch.repeat_interleave(A, blocksize, dim=0)
A = torch.repeat_interleave(A, blocksize, dim=1)
return A.contiguous()
def benchmark_in_us(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "f": f}
)
return int(t0.blocked_autorange().mean * 1e6)
def run_benchmark(x, b, weightsize, batchsize, seqlen, blocksize, sparsity):
A = create_blocked_tensor(weightsize, weightsize,
blocksize=blocksize, sparsity=sparsity)
A_sparse = A.to_sparse_bsr(blocksize=blocksize)
dense_time = benchmark_in_us(F.linear, x, A, b)
sparse_time = benchmark_in_us(F.linear, x, A_sparse, b)
ratio = dense_time / sparse_time
return (",".join(map(str, [weightsize, batchsize, blocksize, seqlen, sparsity, dense_time, sparse_time, ratio]))), ratio
def create_experiments():
shapes = [int(math.pow(2, i)) for i in range(13, 9, -1)]
batchsizes = [64, 128, 256]
seqlens = [256, 512]
blocksizes = [32, 64]
sparsity = list(range(10, 100, 10)) + [95, 99]
return list(itertools.product(shapes, batchsizes, seqlens, blocksizes, sparsity))
positives = []
experiments = create_experiments()
for weightsize, batchsize, seqlen, blocksize, sparsity in experiments:
x = torch.randn(batchsize, seqlen, weightsize, dtype=dtype, device=device)
b = torch.randn(weightsize, dtype=dtype, device=device)
result, ratio = run_benchmark(
x, b, weightsize, batchsize, seqlen, blocksize, sparsity / 100.)
if ratio > 1.0:
positives += [result]
print(",".join(["weightsize", "batchsize", "blocksize", "seqlen", "sparsity", "dense_time", "sparse_time", "ratio"]))
print("\n".join(positives))
@cpuhrsch
Copy link
Author

weightsize,batchsize,blocksize,seqlen,sparsity,dense_time,sparse_time,ratio
8192,64,32,256,0.8,8131,6321,1.2863470969783262
8192,64,32,256,0.9,8104,3244,2.4981504315659677
8192,64,32,256,0.95,8042,2020,3.981188118811881
8192,64,32,256,0.99,7946,1558,5.100128369704749
8192,64,64,256,0.5,8266,7610,1.0862023653088042
8192,64,64,256,0.6,8218,6288,1.3069338422391859
8192,64,64,256,0.7,8187,4898,1.671498570845243
8192,64,64,256,0.8,8142,3861,2.108780108780109
8192,64,64,256,0.9,8061,2157,3.737134909596662
8192,64,64,256,0.95,8014,1569,5.107711918419375
8192,64,64,256,0.99,7926,1085,7.305069124423963
8192,64,32,512,0.8,16140,11955,1.3500627352572145
8192,64,32,512,0.9,15958,6533,2.442675646716669
8192,64,32,512,0.95,15940,3980,4.005025125628141
8192,64,32,512,0.99,15757,2418,6.516542597187758
8192,64,64,512,0.5,16490,15432,1.068558838776568
8192,64,64,512,0.6,16373,12775,1.2816438356164384
8192,64,64,512,0.7,16286,9777,1.6657461388974122
8192,64,64,512,0.8,16191,7034,2.3018197327267558
8192,64,64,512,0.9,16140,4284,3.7675070028011204
8192,64,64,512,0.95,15977,3033,5.267721727662381
8192,64,64,512,0.99,15611,2074,7.5270009643201545
8192,128,32,256,0.8,16149,11899,1.3571728716698883
8192,128,32,256,0.9,15947,6578,2.4242930982061415
8192,128,32,256,0.95,15955,3993,3.995742549461558
8192,128,32,256,0.99,15769,2390,6.597907949790795
8192,128,64,256,0.5,16743,15667,1.068679389800217
8192,128,64,256,0.6,16420,12649,1.298126334097557
8192,128,64,256,0.7,16309,9862,1.653721354694788
8192,128,64,256,0.8,16239,6947,2.337555779473154
8192,128,64,256,0.9,16181,4160,3.8896634615384613
8192,128,64,256,0.95,16046,3057,5.248936866208702
8192,128,64,256,0.99,15728,2035,7.728746928746928
8192,128,32,512,0.8,32701,23634,1.3836422103748836
8192,128,32,512,0.9,32654,13049,2.5024139780826116
8192,128,32,512,0.95,32069,8202,3.9099000243842967
8192,128,32,512,0.99,31674,4982,6.357687675632276
8192,128,64,512,0.5,33014,31721,1.0407616405535765
8192,128,64,512,0.6,33110,26052,1.2709196990634117
8192,128,64,512,0.7,32806,20657,1.5881299317422666
8192,128,64,512,0.8,32533,15117,2.15208043924059
8192,128,64,512,0.9,32418,9437,3.43520186499947
8192,128,64,512,0.95,31939,6906,4.624818997972778
8192,128,64,512,0.99,31457,4441,7.083314568790813
8192,256,32,256,0.8,32576,23851,1.3658127541822145
8192,256,32,256,0.9,32333,13163,2.456354934285497
8192,256,32,256,0.95,32046,7768,4.125386199794026
8192,256,32,256,0.99,31725,4901,6.47316874107325
8192,256,64,256,0.5,32959,32037,1.0287792240222242
8192,256,64,256,0.6,33096,25863,1.279665932026447
8192,256,64,256,0.7,33006,20750,1.5906506024096385
8192,256,64,256,0.8,32514,15121,2.1502546127901594
8192,256,64,256,0.9,32649,9827,3.322377124249517
8192,256,64,256,0.95,31853,6768,4.7064125295508275
8192,256,64,256,0.99,31504,4632,6.801381692573402
8192,256,32,512,0.8,65223,49911,1.3067860792210133
8192,256,32,512,0.9,65173,26720,2.439109281437126
8192,256,32,512,0.95,64679,16855,3.837377632749926
8192,256,32,512,0.99,63603,9025,7.0474238227146815
8192,256,64,512,0.5,66261,62287,1.0638014352914733
8192,256,64,512,0.6,65846,50688,1.2990451388888888
8192,256,64,512,0.7,65337,39494,1.6543525598825137
8192,256,64,512,0.8,65206,28054,2.3243031296784773
8192,256,64,512,0.9,64540,17535,3.68063872255489
8192,256,64,512,0.95,64093,12162,5.269939154744286
8192,256,64,512,0.99,63471,8309,7.638825370080635
4096,64,32,256,0.8,2211,1679,1.3168552709946397
4096,64,32,256,0.9,2188,980,2.23265306122449
4096,64,32,256,0.95,2188,737,2.9687924016282223
4096,64,32,256,0.99,2165,526,4.11596958174905
4096,64,64,256,0.5,2243,2113,1.0615238996687175
4096,64,64,256,0.6,2237,1708,1.309718969555035
4096,64,64,256,0.7,2236,1344,1.6636904761904763
4096,64,64,256,0.8,2195,1071,2.049486461251167
4096,64,64,256,0.9,2188,770,2.8415584415584414
4096,64,64,256,0.95,2165,633,3.420221169036335
4096,64,64,256,0.99,2165,435,4.977011494252873
4096,64,32,512,0.8,4415,3305,1.3358547655068078
4096,64,32,512,0.9,4393,2208,1.9895833333333333
4096,64,32,512,0.95,4326,1426,3.0336605890603088
4096,64,32,512,0.99,4306,1046,4.1166347992351815
4096,64,64,512,0.5,4461,4165,1.0710684273709483
4096,64,64,512,0.6,4444,3386,1.312463083284111
4096,64,64,512,0.7,4397,2734,1.6082662765179225
4096,64,64,512,0.8,4382,2065,2.1220338983050846
4096,64,64,512,0.9,4350,1518,2.8656126482213438
4096,64,64,512,0.95,4305,1240,3.471774193548387
4096,64,64,512,0.99,4307,880,4.894318181818182
4096,128,32,256,0.8,4371,3320,1.316566265060241
4096,128,32,256,0.9,4351,2016,2.158234126984127
4096,128,32,256,0.95,4311,1406,3.0661450924608817
4096,128,32,256,0.99,4306,1031,4.176527643064985
4096,128,64,256,0.5,4427,4040,1.0957920792079208
4096,128,64,256,0.6,4413,3434,1.2850902737332557
4096,128,64,256,0.7,4392,2831,1.5513952666902155
4096,128,64,256,0.8,4351,2089,2.082814743896601
4096,128,64,256,0.9,4340,1531,2.834748530372306
4096,128,64,256,0.95,4306,1237,3.481002425222312
4096,128,64,256,0.99,4306,941,4.575982996811902
4096,128,32,512,0.8,8683,6357,1.3658958628283782
4096,128,32,512,0.9,8673,4071,2.130434782608696
4096,128,32,512,0.95,8586,2852,3.0105189340813463
4096,128,32,512,0.99,8584,2112,4.0643939393939394
4096,128,64,512,0.5,8765,8143,1.0763846248311433
4096,128,64,512,0.6,8764,6915,1.2673897324656545
4096,128,64,512,0.7,8723,5530,1.577396021699819
4096,128,64,512,0.8,8679,4137,2.097897026831037
4096,128,64,512,0.9,8585,3061,2.804639006860503
4096,128,64,512,0.95,8584,2434,3.5267050123253902
4096,128,64,512,0.99,8585,1769,4.853024307518372
4096,256,32,256,0.8,8684,6390,1.358998435054773
4096,256,32,256,0.9,8657,3967,2.1822535921351145
4096,256,32,256,0.95,8585,2907,2.953216374269006
4096,256,32,256,0.99,8584,2085,4.117026378896883
4096,256,64,256,0.5,8808,8458,1.0413809411208323
4096,256,64,256,0.6,8732,6778,1.288285629979345
4096,256,64,256,0.7,8694,5561,1.5633878798777199
4096,256,64,256,0.8,8674,4333,2.001846295868913
4096,256,64,256,0.9,8585,3088,2.7801165803108807
4096,256,64,256,0.95,8584,2505,3.426746506986028
4096,256,64,256,0.99,8586,1873,4.584089695675387
4096,256,32,512,0.8,17291,13604,1.271023228462217
4096,256,32,512,0.9,17291,7995,2.162726704190119
4096,256,32,512,0.95,17114,5813,2.944090830896267
4096,256,32,512,0.99,17116,4014,4.2640757349277525
4096,256,64,512,0.5,17605,16694,1.0545705043728286
4096,256,64,512,0.6,17454,13535,1.2895456224602881
4096,256,64,512,0.7,17467,11141,1.56781258414864
4096,256,64,512,0.8,17292,8272,2.0904255319148937
4096,256,64,512,0.9,17150,6105,2.809172809172809
4096,256,64,512,0.95,17117,5010,3.416566866267465
4096,256,64,512,0.99,17112,3362,5.08982748364069
2048,64,32,256,0.8,581,519,1.1194605009633911
2048,64,32,256,0.9,576,366,1.5737704918032787
2048,64,32,256,0.95,569,317,1.7949526813880126
2048,64,32,256,0.99,562,277,2.0288808664259927
2048,64,64,256,0.6,584,539,1.0834879406307978
2048,64,64,256,0.7,583,452,1.2898230088495575
2048,64,64,256,0.8,577,380,1.518421052631579
2048,64,64,256,0.9,573,324,1.7685185185185186
2048,64,64,256,0.95,568,281,2.02135231316726
2048,64,64,256,0.99,557,276,2.0181159420289854
2048,64,32,512,0.8,1139,1006,1.13220675944334
2048,64,32,512,0.9,1125,731,1.5389876880984952
2048,64,32,512,0.95,1117,646,1.7291021671826625
2048,64,32,512,0.99,1090,458,2.3799126637554586
2048,64,64,512,0.5,1151,1139,1.0105355575065846
2048,64,64,512,0.6,1143,1042,1.0969289827255277
2048,64,64,512,0.7,1137,910,1.2494505494505495
2048,64,64,512,0.8,1127,762,1.479002624671916
2048,64,64,512,0.9,1123,662,1.6963746223564955
2048,64,64,512,0.95,1104,569,1.9402460456942003
2048,64,64,512,0.99,1092,403,2.7096774193548385
2048,128,32,256,0.8,1138,1004,1.1334661354581674
2048,128,32,256,0.9,1127,729,1.5459533607681757
2048,128,32,256,0.95,1114,631,1.7654516640253566
2048,128,32,256,0.99,1090,468,2.3290598290598292
2048,128,64,256,0.6,1144,1087,1.0524379024839006
2048,128,64,256,0.7,1139,891,1.2783389450056117
2048,128,64,256,0.8,1128,793,1.4224464060529634
2048,128,64,256,0.9,1117,661,1.6898638426626325
2048,128,64,256,0.95,1102,541,2.0369685767097967
2048,128,64,256,0.99,1086,411,2.6423357664233578
2048,128,32,512,0.8,2253,2040,1.1044117647058824
2048,128,32,512,0.9,2227,1426,1.5617110799438991
2048,128,32,512,0.95,2219,1237,1.793856103476152
2048,128,32,512,0.99,2172,931,2.3329752953813103
2048,128,64,512,0.6,2279,2091,1.0899091343854614
2048,128,64,512,0.7,2258,1790,1.2614525139664805
2048,128,64,512,0.8,2246,1481,1.5165428764348414
2048,128,64,512,0.9,2227,1248,1.7844551282051282
2048,128,64,512,0.95,2177,1077,2.021355617455896
2048,128,64,512,0.99,2154,716,3.0083798882681565
2048,256,32,256,0.8,2252,1977,1.139099645928174
2048,256,32,256,0.9,2229,1440,1.5479166666666666
2048,256,32,256,0.95,2227,1231,1.809098294069862
2048,256,32,256,0.99,2158,894,2.413870246085011
2048,256,64,256,0.6,2284,2072,1.1023166023166022
2048,256,64,256,0.7,2256,1734,1.301038062283737
2048,256,64,256,0.8,2230,1446,1.5421853388658369
2048,256,64,256,0.9,2229,1242,1.7946859903381642
2048,256,64,256,0.95,2195,1099,1.997270245677889
2048,256,64,256,0.99,2158,824,2.6189320388349513
2048,256,32,512,0.9,4489,3418,1.313341135166764
2048,256,32,512,0.95,4434,2562,1.730679156908665
2048,256,32,512,0.99,4340,1676,2.5894988066825775
2048,256,64,512,0.6,4548,4105,1.107917174177832
2048,256,64,512,0.7,4523,3457,1.3083598495805613
2048,256,64,512,0.8,4522,2920,1.5486301369863014
2048,256,64,512,0.9,4440,2338,1.8990590248075279
2048,256,64,512,0.95,4416,2213,1.9954812471757795
2048,256,64,512,0.99,4326,1450,2.983448275862069
1024,64,32,512,0.95,282,278,1.014388489208633
1024,64,64,512,0.9,282,277,1.0180505415162455
1024,64,64,512,0.95,282,277,1.0180505415162455
1024,64,64,512,0.99,279,82,3.402439024390244
1024,128,32,256,0.99,279,277,1.0072202166064983
1024,128,64,256,0.9,285,278,1.025179856115108
1024,128,64,256,0.95,282,276,1.0217391304347827
1024,128,64,256,0.99,279,276,1.0108695652173914
1024,128,32,512,0.95,567,560,1.0125
1024,128,32,512,0.99,556,388,1.4329896907216495
1024,128,64,512,0.9,568,550,1.0327272727272727
1024,128,64,512,0.95,559,485,1.1525773195876288
1024,128,64,512,0.99,553,375,1.4746666666666666
1024,256,32,256,0.95,566,556,1.0179856115107915
1024,256,32,256,0.99,558,457,1.2210065645514223
1024,256,64,256,0.95,564,504,1.119047619047619
1024,256,64,256,0.99,558,389,1.4344473007712082
1024,256,32,512,0.95,1115,1061,1.0508953817153628
1024,256,32,512,0.99,1102,797,1.3826850690087829
1024,256,64,512,0.95,1100,931,1.1815252416756177
1024,256,64,512,0.99,1087,742,1.4649595687331536

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment