Skip to content

Instantly share code, notes, and snippets.

View justinchuby's full-sized avatar
🌊
Better ML

Justin Chu justinchuby

🌊
Better ML
View GitHub Profile
import torch
import onnx_ir as ir
class ControlFlowModel(torch.nn.Module):
def forward(self, x):
def times_2(x):
return x * 2
def neg(x):
return -x
@justinchuby
justinchuby / export_hf.py
Last active August 28, 2025 01:09
Export HF model to ONNX
"""Export to ONNX.
transformers_version == "4.52.0"
"""
import onnx_diagnostic.tasks.text_generation
import torch
from transformers import AutoConfig, AutoModel
import onnxscript
import onnx_ir as ir
import onnx
def create_model():
"""Create a model that has a unsorted node with subgraph that uses a value defined later."""
a = ir.Value(name="a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)))
b = ir.Value(name="b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((3, 4)))
b_out = ir.Value(name="b_out", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((3, 4)))
c = ir.Value(name="c", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((4, 5)))
@justinchuby
justinchuby / pt_export.md
Created August 13, 2025 18:44
PyTorch export non-strict vs strict modes

torch.export non strict mode uses "make_fx [which] uses __torch_dispatch__ to trace under the hood. this is where it creates the fx nodes. AOTAutograd also calls into make_fx, but before that it also does some things related to functionalization. Since export is now "Training IR", it no longer does functionalization, so we just directly call make_fx."

@justinchuby
justinchuby / cast_verification.py
Created July 2, 2025 21:50
Cast verification
"""Verify the cast values"""
import os
import onnx
import onnx_ir as ir
DIR = "onnx/backend/test/data/node/"
def verify_one_case(path: str):
test_name = os.path.basename(path)
input_path = os.path.join(path, "test_data_set_0", "input_0.pb")
@justinchuby
justinchuby / Exported program bundle.txt
Created May 24, 2025 01:58
Exported program bundle
https://github.com/iree-org/iree-turbine/blob/main/iree/turbine/aot/fx_programs.py
Also ai-edge torch exporter
@justinchuby
justinchuby / stable.py
Last active May 23, 2025 04:35
Stable HLO
from ai_edge_torch.odml_torch.export import exported_program_to_mlir
import torch
class PowModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x ** 0.5
model = PowModel()
print(model(torch.tensor(2)))
@justinchuby
justinchuby / export_hf.py
Created April 22, 2025 23:39
Export HF models with torch.onnx
import torch
from onnx_diagnostic import torch_export_patches
from onnxscript.ir.passes.common import clear_metadata_and_docstring
from transformers import AttentionInterface, AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache
# Get position_ids from attention_mask
def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool):
# Owner(s): ["module: onnx"]
"""Unit LLM tests for the onnx dynamo exporter."""
from __future__ import annotations
from typing import Any
import logging
import transformers
@justinchuby
justinchuby / torch_geometric_onnx_comp.py
Last active March 7, 2025 00:13
Code for figuring out where an onnx model is inaccurate and visualize with model explorer
import logging
import torch
from torch_geometric.nn import GAT
logger = logging.getLogger(__name__)
logging.getLogger('torch.onnx').setLevel(logging.INFO)
logger.info("Prepare model")
num_features = 23
num_classes = 12