Skip to content

Instantly share code, notes, and snippets.

@hackyon
Created February 15, 2024 22:06
Show Gist options
  • Save hackyon/2bfcfb59666fab23655f6e62bc0752ba to your computer and use it in GitHub Desktop.
Save hackyon/2bfcfb59666fab23655f6e62bc0752ba to your computer and use it in GitHub Desktop.
[SDPA for BERT] Training benchmark for PR #28802
=== on existing PyTorch
num_training_steps, batch_size, seq_len, is cuda, Time per batch (eager - s), Time per batch (sdpa - s), Speedup (%), Eager peak mem (MB), sdpa peak mem (MB), Mem saving (%)
1000,4,256,True,0.023,0.019,18.233,939.213,764.834,22.800
1000,4,512,True,0.023,0.021,12.111,1970.447,1225.602,60.774
1000,8,256,True,0.024,0.019,25.488,1594.295,1225.065,30.140
1000,8,512,True,0.035,0.029,20.619,3629.401,2134.262,70.054
1000,16,256,True,0.030,0.026,14.265,2874.426,2134.262,34.680
1000,16,512,True,0.065,0.053,22.641,6964.659,3961.013,75.830
=== on PyTorch 2.2.0 (should support FA2)
num_training_steps, batch_size, seq_len, is cuda, Time per batch (eager - s), Time per batch (sdpa - s), Speedup (%), Eager peak mem (MB), sdpa peak mem (MB), Mem saving (%)
1000,4,256,True,0.023,0.020,17.960,939.213,782.791,19.983
1000,4,512,True,0.023,0.019,16.743,1970.447,1263.469,55.955
1000,8,256,True,0.024,0.019,28.112,1594.295,1263.993,26.132
1000,8,512,True,0.035,0.026,35.147,3629.401,2209.760,64.244
1000,16,256,True,0.030,0.025,18.073,2874.426,2209.760,30.079
1000,16,512,True,0.065,0.047,37.607,6964.659,4112.008,69.374
=== on PyTorch 2.2.0, no contiguous
num_training_steps, batch_size, seq_len, is cuda, Time per batch (eager - s), Time per batch (sdpa - s), Speedup (%), Eager peak mem (MB), sdpa peak mem (MB), Mem saving (%)
1000,4,256,True,0.023,0.017,35.472,939.213,764.834,22.800
1000,4,512,True,0.023,0.018,23.687,1970.447,1227.162,60.569
1000,8,256,True,0.023,0.018,23.491,1594.295,1226.114,30.028
1000,8,512,True,0.035,0.025,43.058,3629.401,2134.262,70.054
1000,16,256,True,0.030,0.024,25.583,2874.426,2134.262,34.680
1000,16,512,True,0.064,0.044,46.223,6964.659,3961.013,75.830
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment