Skip to content

Instantly share code, notes, and snippets.

@kshitij12345
kshitij12345 / moe_exec_trc.py
Created October 7, 2025 13:02
MoE TP Trace
# Constructed by Unwrap the actual return value
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(l_hidden_states_, l_self_modules_shared_experts_modules_gate_proj_parameters_weight_, l_self_modules_shared_experts_modules_up_proj_parameters_weight_, l_self_modules_shared_experts_modules_down_proj_parameters_weight_, l_self_modules_gate_parameters_weight_, l_self_modules_routed_experts_modules_gate_proj_parameters_weight_, l_self_modules_routed_experts_modules_up_proj_parameters_weight_, l_self_modules_routed_experts_modules_down_proj_parameters_weight_):
# l_hidden_states_: "cuda:1 bf16[1, 2048, 256]"
# l_self_modules_shared_experts_modules_gate_proj_parameters_weight_: "DTensor cuda:1 bf16[512, 256] mesh=DeviceMesh((2,), device: 'cuda', stride: (1,)), placements=(Shard(dim=0),)"
# l_self_modules_shared_experts_modules_up_proj_parameters_weight_: "DTensor cuda:1 bf16[512, 256] mesh=DeviceMesh((2,), device: 'cuda', stride: (1,)), placements=(Sha
@kshitij12345
kshitij12345 / flip.py
Created May 5, 2019 14:57
numpy flip
import numpy as np
# To handle negative axis.
def normalize_axis(axis, ndim):
if (axis < -ndim or ndim <= axis):
raise ValueError(
"Invalid value for axis, out of range for array of given dim.")
if axis < 0:
return axis + ndim