Skip to content

Instantly share code, notes, and snippets.

@FeepingCreature
Created March 17, 2025 16:38
import torch
import torch.nn.functional as F
import time
def mlp(x, fc1_weight, fc2_weight, proj_weight):
fc1_out = F.linear(x, fc1_weight)
fc2_out = F.linear(x, fc2_weight)
hidden = F.silu(fc1_out) * fc2_out
return F.linear(hidden, proj_weight)
def single_attention(x, q_weight, k_weight, v_weight, o_weight, n_heads, head_dim):
bsz, seqlen, _ = x.shape
q = F.linear(x, q_weight)
k = F.linear(x, k_weight)
v = F.linear(x, v_weight)
q = q.view(bsz, seqlen, n_heads, head_dim)
k = k.view(bsz, seqlen, n_heads, head_dim)
v = v.view(bsz, seqlen, n_heads, head_dim)
q = F.layer_norm(q, q.shape[-1:])
k = F.layer_norm(k, k.shape[-1:])
output = torch.nn.functional.scaled_dot_product_attention(
q.permute(0, 2, 1, 3),
k.permute(0, 2, 1, 3),
v.permute(0, 2, 1, 3),
).transpose(1, 2).reshape(q.shape[0], -1, n_heads * q.shape[-1])
return F.linear(output, o_weight)
def ditblock(cx, global_cond, modcx_weight,
attn_q_weight, attn_k_weight, attn_v_weight, attn_o_weight,
mlp_fc1_weight, mlp_fc2_weight, mlp_proj_weight,
n_heads, head_dim):
cxres = cx
# modCX
mod_input = F.silu(global_cond)
mod_output = F.linear(mod_input, modcx_weight)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod_output.chunk(6, dim=1)
# norm1, modulate
cx_norm = F.layer_norm(cx, [n_heads * head_dim])
cx = cx_norm * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
# attention
cx = single_attention(cx,
attn_q_weight, attn_k_weight, attn_v_weight, attn_o_weight,
n_heads, head_dim)
# norm2
cx = F.layer_norm(cxres + gate_msa.unsqueeze(1) * cx, [n_heads * head_dim])
# mlp with modulation
cx_mod = cx * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
mlpout = mlp(cx_mod, mlp_fc1_weight, mlp_fc2_weight, mlp_proj_weight)
cx = gate_mlp.unsqueeze(1) * mlpout
cx = cxres + cx
return cx
def run_benchmark():
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")
# Generate random input tensors with the specified sizes
cx = torch.randn(2, 4360, 3072, dtype=torch.float16, device=device)
global_cond = torch.randn(2, 3072, dtype=torch.float16, device=device)
modcx_weight = torch.randn(18432, 3072, dtype=torch.float16, device=device)
attn_q_weight = torch.randn(3072, 3072, dtype=torch.float16, device=device)
attn_k_weight = torch.randn(3072, 3072, dtype=torch.float16, device=device)
attn_v_weight = torch.randn(3072, 3072, dtype=torch.float16, device=device)
attn_o_weight = torch.randn(3072, 3072, dtype=torch.float16, device=device)
mlp_fc1_weight = torch.randn(8192, 3072, dtype=torch.float16, device=device)
mlp_fc2_weight = torch.randn(8192, 3072, dtype=torch.float16, device=device)
mlp_proj_weight = torch.randn(3072, 8192, dtype=torch.float16, device=device)
n_heads = 12
head_dim = 256
# Warm-up run
print("Performing warm-up run...")
with torch.no_grad():
for _ in range(5):
ditblock(
cx, global_cond, modcx_weight,
attn_q_weight, attn_k_weight, attn_v_weight, attn_o_weight,
mlp_fc1_weight, mlp_fc2_weight, mlp_proj_weight,
n_heads, head_dim
)
torch.cuda.synchronize()
# Benchmark: Run the DiT block 40 times
print("Running benchmark...")
num_runs = 40
total_time = 0.0
with torch.no_grad():
# Start timer
torch.cuda.synchronize()
start_time = time.time()
# Run DiT block 40 times
result = cx
for i in range(num_runs):
result = ditblock(
result, global_cond, modcx_weight,
attn_q_weight, attn_k_weight, attn_v_weight, attn_o_weight,
mlp_fc1_weight, mlp_fc2_weight, mlp_proj_weight,
n_heads, head_dim
)
# End timer
torch.cuda.synchronize()
end_time = time.time()
total_time = end_time - start_time
# Report results
print(f"Total time for {num_runs} runs: {total_time:.4f} seconds")
print(f"Average time per run: {(total_time / num_runs) * 1000:.4f} ms")
if __name__ == "__main__":
run_benchmark()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment