Skip to content

Instantly share code, notes, and snippets.

@mikaylagawarecki
mikaylagawarecki / gist:2f624656fef264ce2016d1dbe9034e0c
Created October 15, 2024 22:16
Save llama-style model with and without pinned memory in `torch.save`
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import math
from dataclasses import dataclass
from typing import Optional, Tuple
# import fairscale.nn.model_parallel.initialize as fs_init
import torch
import torch.nn.functional as F
@mikaylagawarecki
mikaylagawarecki / gist:60cf8e77f6230fcd45c62a4a7f479370
Created August 12, 2024 22:24
Verify checkpoints are the same
import torch
ref_model_sd = torch.load("ckpt_ref/meta_model_0.pt", weights_only=True)
real_model_sd = torch.load("ckpt_new/meta_model_0.pt", weights_only=True)
for k, v in ref_model_sd.items():
assert torch.equal(v, real_model_sd[k])
@mikaylagawarecki
mikaylagawarecki / gist:ef0cf6e1c2effd5ffc03eaf4461f9692
Created October 20, 2023 16:59
Breakdown for dispatcher registrations
import torch
from torchgen.model import DispatchKey
num_impl_registrations = 0
registration_dict = {}
failed_keys = set()
# === Collect number of kernel registrations for each dispatch key ===
for key in DispatchKey:
try: