Skip to content

Instantly share code, notes, and snippets.

@ohadravid
ohadravid / min_repro.py
Last active January 22, 2025 09:15
A script for comparing the TensorRT compilation of torch's `scaled_dot_product_attention`, `multi_head_attention_forward` and an explicit version
import torch
import torch.nn as nn
import torch.nn.functional as F
import io
import tensorrt as trt
import torch_tensorrt
class AttentionUsingScaledDotProduct(nn.Module):
"""