Created
November 12, 2024 17:44
-
-
Save vanbasten23/947496c9e095475035c5da588cd97542 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
You can try playing with those options then to see if the copies are removed. | |
On Mon, Nov 11, 2024, 5:44 PM Parin Dalal <parindalal@google.com> wrote: | |
Hi Sharad, we've found that the compiler copies to do the re-layout of these very large tensors can have a huge impact on performance on the Mamba kernels. Is there another option? | |
On Mon, Nov 11, 2024 at 5:30 PM Sharad Vikram <sharadmv@google.com> wrote: | |
The kernels constrain the layout and the compiler will add relayouts as needed to match the layouts the kernel wants. | |
Those two flags let you change the kernels desired layout. In principle they could be tuned but in practice we find the default layouts to be the best. | |
On Mon, Nov 11, 2024, 5:25 PM Parin Dalal <parindalal@google.com> wrote: | |
Thanks @Sharad Vikram , super helpful. | |
Do you think the layouts need to be reparameterizable at some point in order to match how the compiler will internally represent these inputs? | |
Best, Parin | |
On Mon, Nov 11, 2024 at 1:46 PM Sharad Vikram <sharadmv@google.com> wrote: | |
Hi Xiongfei, | |
I'll explain each of the parameters. | |
block_q: int | |
block_kv: int | |
block_kv_compute: int | None = None | |
You'll definitely need to tune these, these are the fwd pass parameters. | |
block_q_dkv: int | None = None | |
block_kv_dkv: int | None = None | |
block_kv_dkv_compute: int | None = None | |
These are block sizes for the backward pass, specifically for the dkv matrix. If you are training, you'll need to tune these. | |
use_fused_bwd_kernel: bool = False | |
This is a bool that picks a bwd pass kernel implementation. If true, it uses a more HBM-intensive, but faster kernel. If you have HBM to spare, set this to True. | |
If you do not have the HBM to spare and set it to False, you'll also need to tune the following bwd pass parameters as well: | |
block_q_dq: int | None = None | |
block_kv_dq: int | None = None | |
Finally, you can probably ignore these parameters: | |
q_layout: QKVLayout = QKVLayout.HEAD_DIM_MINOR | |
k_layout: QKVLayout = QKVLayout.HEAD_DIM_MINOR | |
v_layout: QKVLayout = QKVLayout.HEAD_DIM_MINOR | |
Thanks, | |
Sharad |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment