Skip to content

Instantly share code, notes, and snippets.

@hackyon
Created February 15, 2024 22:06
Show Gist options
  • Save hackyon/c3d829d06144b98c6b092a66733fa1a9 to your computer and use it in GitHub Desktop.
Save hackyon/c3d829d06144b98c6b092a66733fa1a9 to your computer and use it in GitHub Desktop.
[SDPA for BERT] Inference benchmark for PR #28802
=== on existing PyTorch
num_batches, batch_size, seq_len, is cuda, is half, use mask, Per token latency eager (ms), Per token latency SDPA (ms), Speedup (%), Mem eager (MB), Mem BT (MB), Mem saved (%)
50,1,128,True,True,True,5.691,5.421,4.984,282.661,282.924,-0.093
50,1,256,True,True,True,5.688,5.523,2.983,298.686,298.948,-0.088
50,2,128,True,True,True,6.161,5.545,11.110,314.523,314.785,-0.083
50,2,256,True,True,True,6.132,5.558,10.323,347.546,347.758,-0.061
50,4,128,True,True,True,6.154,5.532,11.242,378.895,379.158,-0.069
50,4,256,True,True,True,6.326,5.987,5.669,443.209,444.382,-0.264
=== on PyTorch 2.2.0 (should support FA2)
num_batches, batch_size, seq_len, is cuda, is half, use mask, Per token latency eager (ms), Per token latency SDPA (ms), Speedup (%), Mem eager (MB), Mem BT (MB), Mem saved (%)
50,1,128,True,True,True,5.777,5.315,8.692,282.661,282.924,-0.093
50,1,256,True,True,True,5.739,5.295,8.370,298.686,298.948,-0.088
50,2,128,True,True,True,6.142,5.359,14.622,314.523,314.785,-0.083
50,2,256,True,True,True,6.160,5.308,16.046,347.546,347.758,-0.061
50,4,128,True,True,True,6.157,5.346,15.169,378.895,379.158,-0.069
50,4,256,True,True,True,6.371,5.766,10.504,443.209,444.382,-0.264
=== on PyTorch 2.2.0, no contiguous
num_batches, batch_size, seq_len, is cuda, is half, use mask, Per token latency eager (ms), Per token latency SDPA (ms), Speedup (%), Mem eager (MB), Mem BT (MB), Mem saved (%)
50,1,128,True,True,True,5.736,4.987,15.022,282.661,282.924,-0.093
50,1,256,True,True,True,5.689,4.945,15.055,298.686,298.948,-0.088
50,2,128,True,True,True,6.154,4.982,23.521,314.523,314.785,-0.083
50,2,256,True,True,True,6.201,4.949,25.303,347.546,347.033,0.148
50,4,128,True,True,True,6.049,4.987,21.305,378.895,379.301,-0.107
50,4,256,True,True,True,6.285,5.364,17.166,443.209,444.382,-0.264
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment