Skip to content

Instantly share code, notes, and snippets.

albanD /
Created July 7, 2023 22:50
Tracking time and stack traces of when Tensors are created, used and die
import torch
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map_only
from torch.utils.weak import WeakTensorKeyDictionary
import time
import warnings
import weakref
import traceback
albanD /
Created January 24, 2023 19:34
Make PyTorch custom Function unpack input and output using pytree.
import torch
from torch.autograd import Function
import torch.utils._pytree as pytree
# Basically wraps things in and out before passing it to the real function that the user defined.
def pytreeify(cls):
assert issubclass(cls, Function)
orig_fw = cls.forward
orig_bw = cls.backward
# Implements Alban's idea of making available the forward traceback
# corresponding to the execution of the current backwared node as a global
# Updated of
# to add inter op tracking
import torch
from torch import autograd
from torch.utils._python_dispatch import TorchDispatchMode
current_metadata = None
albanD /
Last active August 8, 2023 07:49
PyTorch optimizer as hook
import torch
from torch import nn
from torch.optim.sgd import sgd
import gc
import objgraph
import weakref
def all():
# Only a subset of the args you could have
def set_sgd_hook(mod, p, lr, weight_decay, momentum):
albanD /
Last active May 18, 2020 19:21
Python function common dtype

Ops to test on python side

If nothing is specified, all argument combination should be considered


  • copy_ no_sparse && no_quantize && self!=source && not_copy_transpose
  • gather
  • gather(out=)
  • scatter_(Tensor)
  • scatter(Tensor)
  • scatter_(value)
from patch_convolution import *
import torch
import torch.nn as nn
import time
# ---------------
# Parameters
# ---------------
# Number of profile iterations to run
itt = 30
import torch
from torch import nn
from torch.nn import functional as F
class EasyDataParallel(nn.Module):
def __init__(self, gpus):
# Handle cpu / 1 gpu case better
assert isinstance(gpus, list)
albanD /
Last active October 16, 2019 19:27
Autodiff linear debugging

Debugging code

std::cout << "Forwarding into jit module" << std::endl;
std::cout << "Forward code:" << std::endl;
std::cout << *grad.f.get() << std::endl;
std::cout << "Backward code:" << std::endl;
std::cout << *grad.df.get() << std::endl;
std::cout << "End print !" << std::endl;
albanD /
Created September 25, 2019 21:03
Compute full Hessian of a network
import torch
from torch import nn
from torchviz import make_dot
from torch.autograd.gradcheck import gradcheck
my_mod = nn.Sequential(nn.Linear(2, 2, bias=False), nn.Sigmoid(), nn.Linear(2, 2, bias=False), nn.Sigmoid(), nn.Linear(2, 1, bias=False))
params = list(my_mod.parameters())
local threads = require "threads"
n_task = 3
local pools = {}
for task=1,n_task do
pools[task] = threads.Threads(5,
-- Needed only for serialized elements