Skip to content

Instantly share code, notes, and snippets.

@vanbasten23
Created November 12, 2024 17:44
Show Gist options
  • Save vanbasten23/947496c9e095475035c5da588cd97542 to your computer and use it in GitHub Desktop.
Save vanbasten23/947496c9e095475035c5da588cd97542 to your computer and use it in GitHub Desktop.
You can try playing with those options then to see if the copies are removed.
On Mon, Nov 11, 2024, 5:44 PM Parin Dalal <parindalal@google.com> wrote:
Hi Sharad, we've found that the compiler copies to do the re-layout of these very large tensors can have a huge impact on performance on the Mamba kernels. Is there another option?
On Mon, Nov 11, 2024 at 5:30 PM Sharad Vikram <sharadmv@google.com> wrote:
The kernels constrain the layout and the compiler will add relayouts as needed to match the layouts the kernel wants.
Those two flags let you change the kernels desired layout. In principle they could be tuned but in practice we find the default layouts to be the best.
On Mon, Nov 11, 2024, 5:25 PM Parin Dalal <parindalal@google.com> wrote:
Thanks @Sharad Vikram , super helpful.
Do you think the layouts need to be reparameterizable at some point in order to match how the compiler will internally represent these inputs?
Best, Parin
On Mon, Nov 11, 2024 at 1:46 PM Sharad Vikram <sharadmv@google.com> wrote:
Hi Xiongfei,
I'll explain each of the parameters.
block_q: int
block_kv: int
block_kv_compute: int | None = None
You'll definitely need to tune these, these are the fwd pass parameters.
block_q_dkv: int | None = None
block_kv_dkv: int | None = None
block_kv_dkv_compute: int | None = None
These are block sizes for the backward pass, specifically for the dkv matrix. If you are training, you'll need to tune these.
use_fused_bwd_kernel: bool = False
This is a bool that picks a bwd pass kernel implementation. If true, it uses a more HBM-intensive, but faster kernel. If you have HBM to spare, set this to True.
If you do not have the HBM to spare and set it to False, you'll also need to tune the following bwd pass parameters as well:
block_q_dq: int | None = None
block_kv_dq: int | None = None
Finally, you can probably ignore these parameters:
q_layout: QKVLayout = QKVLayout.HEAD_DIM_MINOR
k_layout: QKVLayout = QKVLayout.HEAD_DIM_MINOR
v_layout: QKVLayout = QKVLayout.HEAD_DIM_MINOR
Thanks,
Sharad
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment