Skip to content

Instantly share code, notes, and snippets.

Last active February 28, 2024 22:08
Show Gist options
  • Save rrampage/c2fe7a585c6639163eeaca749df4cac7 to your computer and use it in GitHub Desktop.
Save rrampage/c2fe7a585c6639163eeaca749df4cac7 to your computer and use it in GitHub Desktop.
Arena stuff
# %%
import os
os.environ["KMP_DUPLICATE_LIB_OK"] ="TRUE"
import sys
import re
import time
import torch as t
import numpy as np
from pathlib import Path
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Callable, Iterator, Iterable, Optional, Union, Dict, List, Tuple
from import DataLoader
from tqdm import tqdm
Arr = np.ndarray
grad_tracking_enabled = True
# Make sure exercises are in the path
chapter = r"chapter0_fundamentals"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = exercises_dir / "part4_backprop"
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))
import part4_backprop.tests as tests
from part4_backprop.utils import visualize, get_mnist
from plotly_utils import line
# %%
def log_back(grad_out: Arr, out: Arr, x: Arr) -> Arr:
'''Backwards function for f(x) = log(x)
grad_out: Gradient of some loss wrt out
out: the output of np.log(x).
x: the input of np.log.
Return: gradient of the given loss wrt x
return grad_out/x
# %%
def unbroadcast(broadcasted: Arr, original: Arr) -> Arr:
Sum 'broadcasted' until it has the shape of 'original'.
broadcasted: An array that was formerly of the same shape of 'original' and was expanded by broadcasting rules.
# Step 1: sum and remove prepended dims, so both arrays have same number of dims
n_dims_to_sum = len(broadcasted.shape) - len(original.shape)
broadcasted = broadcasted.sum(axis=tuple(range(n_dims_to_sum)))
# Step 2: sum over dims which were originally 1 (but don't remove them)
dims_to_sum = tuple([
i for i, (o, b) in enumerate(zip(original.shape, broadcasted.shape))
if o == 1 and b > 1
broadcasted = broadcasted.sum(axis=dims_to_sum, keepdims=True)
return broadcasted
# %%
def multiply_back0(grad_out: Arr, out: Arr, x: Arr, y: Union[Arr, float]) -> Arr:
'''Backwards function for x * y wrt argument 0 aka x.'''
if not isinstance(y, Arr):
y = np.array(y)
# f.g(x) = f'(g(x)) * g(f'(x)) -> unbroadcast . (x grad_out, y)
# multiply ->(broardcast(x), broadcast(y))
print(f'{x.shape=} {y.shape=} {out.shape=} {grad_out.shape=}')
# x*y -> wrt x -> broadcast(x,y) * y
# B'(x*y, x)
return unbroadcast(y * grad_out, x)
def multiply_back1(grad_out: Arr, out: Arr, x: Union[Arr, float], y: Arr) -> Arr:
'''Backwards function for x * y wrt argument 1 aka y.'''
if not isinstance(x, Arr):
x = np.array(x)
print(f'{x.shape=} {y.shape=} {out.shape=} {grad_out.shape=}')
return unbroadcast(x * grad_out, y)
tests.test_multiply_back(multiply_back0, multiply_back1)
tests.test_multiply_back_float(multiply_back0, multiply_back1)
# %%
def forward_and_back(a: Arr, b: Arr, c: Arr) -> Tuple[Arr, Arr, Arr]:
g = log(f)
f = d * e
d = a * b
e = log(c)
Calculates the output of the computational graph above (g), then backpropogates the gradients and returns dg/da, dg/db, and dg/dc
d = a * b
e = np.log(c)
f = d*e
g = np.log(f)
out = np.ones_like(g)
dg_df = log_back(out,g,f)
dg_dd = multiply_back0(dg_df, f, d, e)
dg_de = multiply_back1(dg_df, f, d ,e)
dg_dc = log_back(dg_de, e, c)
dg_da = multiply_back0(dg_dd, d, a ,b)
dg_db = multiply_back1(dg_dd, d, a ,b)
return (dg_da, dg_db, dg_dc,)
# %%
class Recipe:
'''Extra information necessary to run backpropagation. You don't need to modify this.'''
func: Callable
"The 'inner' NumPy function that does the actual forward computation."
"Note, we call it 'inner' to distinguish it from the wrapper we'll create for it later on."
args: tuple
"The input arguments passed to func."
"For instance, if func was np.sum then args would be a length-1 tuple containing the tensor to be summed."
kwargs: Dict[str, Any]
"Keyword arguments passed to func."
"For instance, if func was np.sum then kwargs might contain 'dim' and 'keepdims'."
parents: Dict[int, "Tensor"]
"Map from positional argument index to the Tensor at that position, in order to be able to pass gradients back along the computational graph."
# %%
class BackwardFuncLookup:
def __init__(self) -> None: = {}
def add_back_func(self, forward_fn: Callable, arg_position: int, back_fn: Callable) -> None:[(forward_fn, arg_position)] = back_fn
def get_back_func(self, forward_fn: Callable, arg_position: int) -> Callable:
return[(forward_fn, arg_position)]
BACK_FUNCS = BackwardFuncLookup()
BACK_FUNCS.add_back_func(np.log, 0, log_back)
BACK_FUNCS.add_back_func(np.multiply, 0, multiply_back0)
BACK_FUNCS.add_back_func(np.multiply, 1, multiply_back1)
assert BACK_FUNCS.get_back_func(np.log, 0) == log_back
assert BACK_FUNCS.get_back_func(np.multiply, 0) == multiply_back0
assert BACK_FUNCS.get_back_func(np.multiply, 1) == multiply_back1
print("Tests passed - BackwardFuncLookup class is working as expected!")
# %%
Arr = np.ndarray
class Tensor:
A drop-in replacement for torch.Tensor supporting a subset of features.
array: Arr
"The underlying array. Can be shared between multiple Tensors."
requires_grad: bool
"If True, calling functions or methods on this tensor will track relevant data for backprop."
grad: Optional["Tensor"]
"Backpropagation will accumulate gradients into this field."
recipe: Optional[Recipe]
"Extra information necessary to run backpropagation."
def __init__(self, array: Union[Arr, list], requires_grad=False):
self.array = array if isinstance(array, Arr) else np.array(array)
if self.array.dtype == np.float64:
self.array = self.array.astype(np.float32)
self.requires_grad = requires_grad
self.grad = None
self.recipe = None
"If not None, this tensor's array was created via recipe.func(*recipe.args, **recipe.kwargs)."
def __neg__(self) -> "Tensor":
return negative(self)
def __add__(self, other) -> "Tensor":
return add(self, other)
def __radd__(self, other) -> "Tensor":
return add(other, self)
def __sub__(self, other) -> "Tensor":
return subtract(self, other)
def __rsub__(self, other):
return subtract(other, self)
def __mul__(self, other) -> "Tensor":
return multiply(self, other)
def __rmul__(self, other) -> "Tensor":
return multiply(other, self)
def __truediv__(self, other) -> "Tensor":
return true_divide(self, other)
def __rtruediv__(self, other) -> "Tensor":
return true_divide(other, self)
def __matmul__(self, other) -> "Tensor":
return matmul(self, other)
def __rmatmul__(self, other) -> "Tensor":
return matmul(other, self)
def __eq__(self, other) -> "Tensor":
return eq(self, other)
def __repr__(self) -> str:
return f"Tensor({repr(self.array)}, requires_grad={self.requires_grad})"
def __len__(self) -> int:
if self.array.ndim == 0:
raise TypeError
return self.array.shape[0]
def __hash__(self) -> int:
return id(self)
def __getitem__(self, index) -> "Tensor":
return getitem(self, index)
def add_(self, other: "Tensor", alpha: float = 1.0) -> "Tensor":
add_(self, other, alpha=alpha)
return self
def T(self) -> "Tensor":
return permute(self, axes=(-1, -2))
def item(self):
return self.array.item()
def sum(self, dim=None, keepdim=False):
return sum(self, dim=dim, keepdim=keepdim)
def log(self):
return log(self)
def exp(self):
return exp(self)
def reshape(self, new_shape):
return reshape(self, new_shape)
def expand(self, new_shape):
return expand(self, new_shape)
def permute(self, dims):
return permute(self, dims)
def maximum(self, other):
return maximum(self, other)
def relu(self):
return relu(self)
def argmax(self, dim=None, keepdim=False):
return argmax(self, dim=dim, keepdim=keepdim)
def uniform_(self, low: float, high: float) -> "Tensor":
self.array[:] = np.random.uniform(low, high, self.array.shape)
return self
def backward(self, end_grad: Union[Arr, "Tensor", None] = None) -> None:
if isinstance(end_grad, Arr):
end_grad = Tensor(end_grad)
return backprop(self, end_grad)
def size(self, dim: Optional[int] = None):
if dim is None:
return self.shape
return self.shape[dim]
def shape(self):
return self.array.shape
def ndim(self):
return self.array.ndim
def is_leaf(self):
'''Same as'''
if self.requires_grad and self.recipe and self.recipe.parents:
return False
return True
def __bool__(self):
if np.array(self.shape).prod() != 1:
raise RuntimeError("bool value of Tensor with more than one value is ambiguous")
return bool(self.item())
def empty(*shape: int) -> Tensor:
'''Like torch.empty.'''
return Tensor(np.empty(shape))
def zeros(*shape: int) -> Tensor:
'''Like torch.zeros.'''
return Tensor(np.zeros(shape))
def arange(start: int, end: int, step=1) -> Tensor:
'''Like torch.arange(start, end).'''
return Tensor(np.arange(start, end, step=step))
def tensor(array: Arr, requires_grad=False) -> Tensor:
'''Like torch.tensor.'''
return Tensor(array, requires_grad=requires_grad)
# %%
def log_forward(x: Tensor) -> Tensor:
'''Performs np.log on a Tensor object.'''
is_grad_req = grad_tracking_enabled and (x.requires_grad or x.recipe is not None)
out = Tensor(array=np.log(x.array), requires_grad=is_grad_req)
if is_grad_req:
out.recipe = Recipe(func=np.log, args=(x.array,), kwargs={}, parents={0: x})
return out
log = log_forward
tests.test_log(Tensor, log_forward)
tests.test_log_no_grad(Tensor, log_forward)
a = Tensor([1], requires_grad=True)
grad_tracking_enabled = False
b = log_forward(a)
grad_tracking_enabled = True
assert not b.requires_grad, "should not require grad if grad tracking globally disabled"
assert b.recipe is None, "should not create recipe if grad tracking globally disabled"
# %%
def multiply_forward(a: Union[Tensor, int], b: Union[Tensor, int]) -> Tensor:
'''Performs np.multiply on a Tensor object.'''
assert isinstance(a, Tensor) or isinstance(b, Tensor)
is_grad_req = grad_tracking_enabled
array = None
recipe = None
if isinstance(a, Tensor) and isinstance(b, Tensor):
array = a.array * b.array
recipe = Recipe(func=np.multiply, args=(a.array, b.array), kwargs={}, parents={0: a, 1: b})
is_grad_req = is_grad_req and (a.requires_grad or a.recipe is not None or b.requires_grad or b.recipe is not None)
elif isinstance(a, Tensor):
array = a.array * b
recipe = Recipe(func=np.multiply, args=(a.array, b), kwargs={}, parents={0: a})
is_grad_req = is_grad_req and (a.requires_grad or a.recipe is not None)
array = b.array * a
recipe = Recipe(func=np.multiply, args=(a, b.array), kwargs={}, parents={1: b})
is_grad_req = is_grad_req and (b.requires_grad or b.recipe is not None)
out = Tensor(array=array, requires_grad=is_grad_req)
if is_grad_req:
out.recipe = recipe
return out
multiply = multiply_forward
tests.test_multiply(Tensor, multiply_forward)
tests.test_multiply_no_grad(Tensor, multiply_forward)
tests.test_multiply_float(Tensor, multiply_forward)
a = Tensor([2], requires_grad=True)
b = Tensor([3], requires_grad=True)
grad_tracking_enabled = False
b = multiply_forward(a, b)
grad_tracking_enabled = True
assert not b.requires_grad, "should not require grad if grad tracking globally disabled"
assert b.recipe is None, "should not create recipe if grad tracking globally disabled"
# %%
def wrap_forward_fn(numpy_func: Callable, is_differentiable=True) -> Callable:
numpy_func: Callable
takes any number of positional arguments, some of which may be NumPy arrays, and
any number of keyword arguments which we aren't allowing to be NumPy arrays at
present. It returns a single NumPy array.
if True, numpy_func is differentiable with respect to some input argument, so we
may need to track information in a Recipe. If False, we definitely don't need to
track information.
Return: Callable
It has the same signature as numpy_func, except wherever there was a NumPy array,
this has a Tensor instead.
def tensor_func(*args: Any, **kwargs: Any) -> Tensor:
req_grad = is_differentiable and grad_tracking_enabled and any([(isinstance(x, Tensor) and (x.requires_grad or x.recipe is not None)) for x in args])
in_args = [a.array if isinstance(a, Tensor) else a for a in args]
out = Tensor(array=numpy_func(*in_args, **kwargs), requires_grad=req_grad )
if req_grad:
parents = {idx: arr for idx, arr in enumerate(args) if isinstance(arr, Tensor)}
out.recipe = Recipe(func=numpy_func, args=in_args, kwargs=kwargs, parents=parents)
#print(f'{args=} {kwargs=} {out.array=}')
return out
return tensor_func
def _sum(x: Arr, dim=None, keepdim=False) -> Arr:
# need to be careful with sum, because kwargs have different names in torch and numpy
return np.sum(x, axis=dim, keepdims=keepdim)
log = wrap_forward_fn(np.log)
multiply = wrap_forward_fn(np.multiply)
eq = wrap_forward_fn(np.equal, is_differentiable=False)
sum = wrap_forward_fn(_sum)
tests.test_log(Tensor, log)
tests.test_log_no_grad(Tensor, log)
tests.test_multiply(Tensor, multiply)
tests.test_multiply_no_grad(Tensor, multiply)
tests.test_multiply_float(Tensor, multiply)
# %%
class Node:
def __init__(self, *children):
self.children = list(children)
def get_children(node: Node) -> List[Node]:
return node.children
def topological_sort(node: Node, get_children: Callable) -> List[Node]:
Return a list of node's descendants in reverse topological order from future to past (i.e. `node` should be last).
Should raise an error if the graph with `node` as root is not in fact acyclic.
seen = []
visted_set= set()
temp = set()
def dfs(node):
if node in visted_set:
if node in temp:
raise ValueError("Cycle!!")
for neighbor in get_children(node):
return seen
# %%
def sorted_computational_graph(tensor: Tensor) -> List[Tensor]:
For a given tensor, return a list of Tensors that make up the nodes of the given Tensor's computational graph,
in reverse topological order (i.e. `tensor` should be first).
get_parents = lambda tensor: [] if not tensor.requires_grad or tensor.recipe is None else list(tensor.recipe.parents.values())
return topological_sort(tensor, get_parents)[::-1]
a = Tensor([1], requires_grad=True)
b = Tensor([2], requires_grad=True)
c = Tensor([3], requires_grad=True)
d = a * b
e = c.log()
f = d * e
g = f.log()
name_lookup = {a: "a", b: "b", c: "c", d: "d", e: "e", f: "f", g: "g"}
print([name_lookup[t] for t in sorted_computational_graph(g)])
# %%
def backprop(end_node: Tensor, end_grad: Optional[Tensor] = None) -> None:
'''Accumulates gradients in the grad field of each leaf node.
tensor.backward() is equivalent to backprop(tensor).
The rightmost node in the computation graph.
If it contains more than one element, end_grad must be provided.
A tensor of the same shape as end_node.
Set to 1 if not specified and end_node has only one element.
# Get value of end_grad_arr
end_grad = end_grad if end_grad is not None else Tensor(np.ones_like(end_node.array))
end_grad_arr = end_grad.array
# Create dictionary 'grads' to store gradients
grads = {}
grads[end_node] = end_grad_arr
# Iterate through the computational graph, using your sorting function
for node in sorted_computational_graph(end_node):
# Get the outgradient from the grads dict
outgradient = grads[node]
# If this node is a leaf & requires_grad is true, then store the gradient
if node.is_leaf and node.requires_grad:
if node.grad is None:
node.grad = Tensor(outgradient)
node.grad.array += outgradient
if node.recipe is None or node.recipe.parents is None:
# For all parents in the node:
# If node has a recipe, then we iterate through parents (which is a dict of {arg_posn: tensor})
for argnum, parent in node.recipe.parents.items():
# Get the backward function corresponding to the function that created this node
back_fn = BACK_FUNCS.get_back_func(node.recipe.func, argnum)
# Use this backward function to calculate the gradient
gnp = back_fn(outgradient, node.array, *node.recipe.args, **node.recipe.kwargs)
# Add the gradient to this node in the dictionary `grads`
if parent not in grads:
grads[parent] = gnp
grads[parent] += gnp
# %%
def negative_back(grad_out: Arr, out: Arr, x: Arr) -> Arr:
'''Backward function for f(x) = -x elementwise.'''
negative = wrap_forward_fn(np.negative)
BACK_FUNCS.add_back_func(np.negative, 0, negative_back)
# %%
MAIN = __name__ == "__main__"
# %%
import os
import sys
import numpy as np
import einops
from typing import Union, Optional, Tuple, List, Dict
import torch as t
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from jaxtyping import Float, Int
import functools
from pathlib import Path
from torchvision import datasets, transforms, models
from import DataLoader, Subset
from tqdm.notebook import tqdm
from dataclasses import dataclass
from PIL import Image
import json
# Make sure exercises are in the path
chapter = r"chapter0_fundamentals"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = exercises_dir / "part2_cnns"
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))
from plotly_utils import imshow, line, bar
import part2_cnns.tests as tests
from part2_cnns.utils import print_param_count
# device = t.device("cuda" if t.cuda.is_available() else "cpu")
device = t.device("mps" if t.backends.mps.is_available() else "cpu")
# %%
class ReLU(nn.Module):
def forward(self, x: t.Tensor) -> t.Tensor:
# return t.maximum(x, t.tensor(0.0))
return x.max(t.zeros_like(x))
# %%
class Linear(nn.Module):
def __init__(self, in_features: int, out_features: int, bias=True):
A simple linear (technically, affine) transformation.
The fields should be named `weight` and `bias` for compatibility with PyTorch.
If `bias` is False, set `self.bias` to None.
n = in_features**-0.5
self.weight = nn.Parameter(t.zeros([out_features, in_features]).uniform_(-n, n))
self.bias = nn.Parameter(t.zeros([out_features]).uniform_(-n, n)) if bias else None
self.is_bias = bias
def forward(self, x: t.Tensor) -> t.Tensor:
x: shape (*, in_features)
Return: shape (*, out_features)
res = x @ self.weight.T
return res if self.bias is None else res + self.bias
def extra_repr(self) -> str:
return f'Weight: {self.weight.shape=}'
# %%
class Flatten(nn.Module):
def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None:
self.start_dim = start_dim
self.end_dim = end_dim
def forward(self, input: t.Tensor) -> t.Tensor:
Flatten out dimensions from start_dim to end_dim, inclusive of both.
s = input.shape #(x,y,z) -> (x, y*z)
sd = self.start_dim
ed = self.end_dim + 1
if self.start_dim < 0:
sd = len(s) + sd
if self.end_dim < 0:
ed = len(s) + ed
p = 1
for x in s[sd:ed]:
p *= x
new_shape = list(s[:sd]) + [p] + list(s[ed:])
return t.reshape(input, new_shape)
def extra_repr(self) -> str:
return f'Flatten: {self.start_dim=} {self.end_dim=}'
# %%
class SimpleMLP(nn.Module):
def __init__(self):
self.flatten = Flatten(1, -1)
self.layer0 = Linear(784, 100, True)
self.relu = ReLU()
self.layer1 = Linear(100, 10, True)
def forward(self, x: t.Tensor) -> t.Tensor:
return self.layer1(self.relu(self.layer0(self.flatten(x))))
# %%
MNIST_TRANSFORM = transforms.Compose([
transforms.Normalize((0.1307,), (0.3081,))
def get_mnist(subset: int = 1):
'''Returns MNIST training data, sampled by the frequency given in `subset`.'''
mnist_trainset = datasets.MNIST(root="./data", train=True, download=True, transform=MNIST_TRANSFORM)
mnist_testset = datasets.MNIST(root="./data", train=False, download=True, transform=MNIST_TRANSFORM)
if subset > 1:
mnist_trainset = Subset(mnist_trainset, indices=range(0, len(mnist_trainset), subset))
mnist_testset = Subset(mnist_testset, indices=range(0, len(mnist_testset), subset))
return mnist_trainset, mnist_testset
mnist_trainset, mnist_testset = get_mnist()
mnist_trainloader = DataLoader(mnist_trainset, batch_size=64, shuffle=True)
mnist_testloader = DataLoader(mnist_testset, batch_size=64, shuffle=False)
# %%
class SimpleMLPTrainingArgs():
Defining this class implicitly creates an __init__ method, which sets arguments as
given below, e.g. self.batch_size = 64. Any of these arguments can also be overridden
when you create an instance, e.g. args = SimpleMLPTrainingArgs(batch_size=128).
batch_size: int = 64
epochs: int = 3
learning_rate: float = 1e-3
subset: int = 10
def validate(model: SimpleMLP, loader: DataLoader):
batch_loss = 0
batch_size = 0
for imgs, labels in loader:
imgs =
labels =
logits = model(imgs)
loss = (logits.argmax(dim=-1) == labels).sum()
batch_loss += loss
batch_size += imgs.shape[0]
return (batch_loss/batch_size).item()
def train(args: SimpleMLPTrainingArgs):
Trains the model, using training parameters from the `args` object.
model = SimpleMLP().to(device)
mnist_trainset, mnist_testset = get_mnist(subset=args.subset)
mnist_trainloader = DataLoader(mnist_trainset, batch_size=args.batch_size, shuffle=True)
mnist_testloader = DataLoader(mnist_testset, batch_size=args.batch_size, shuffle=True)
optimizer = t.optim.Adam(model.parameters(), lr=args.learning_rate)
loss_list = []
validation_loss = []
validation_loss.append(validate(model, mnist_testloader))
for epoch in tqdm(range(args.epochs)):
for imgs, labels in mnist_trainloader:
imgs =
labels =
logits = model(imgs)
loss = F.cross_entropy(logits, labels)
validation_loss.append(validate(model, mnist_testloader))
yaxis_range=[0, max(loss_list) + 0.1],
labels={"x": "Num batches seen", "y": "Cross entropy loss"},
title="SimpleMLP training on MNIST",
yaxis_range=[0, max(validation_loss) + 0.1],
labels={"x": "Num batches seen", "y": "Validation loss"},
title="SimpleMLP Validation loss training on MNIST",
args = SimpleMLPTrainingArgs()
# %%
class Conv2d(nn.Module):
def __init__(
self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0
Same as torch.nn.Conv2d with bias=False.
Name your weight field `self.weight` for compatibility with the PyTorch version.
print(f'Stride: {stride=}')
if stride <= 0:
self.stride = stride
self.padding = padding
self.kernel_size = kernel_size
self.in_channels = in_channels
self.out_channels = out_channels
# TODO kernel
n_in = in_channels*kernel_size*kernel_size
n_out = out_channels*kernel_size*kernel_size
n = (6**0.5)/ (n_in+n_out+1)**0.5
self.kernel = t.zeros([out_channels, in_channels, kernel_size, kernel_size]).uniform_(-n, n)
self.weight = nn.Parameter(self.kernel)
def forward(self, x: t.Tensor) -> t.Tensor:
'''Apply the functional conv2d, which you can import.'''
if self.stride <= 0:
return nn.functional.conv2d(input=x, weight=self.weight, stride=self.stride, padding=self.padding, bias=None)
def extra_repr(self) -> str:
return f'Conv2d: {self.in_channels=} {self.out_channels=} {self.stride=}'
m = Conv2d(in_channels=24, out_channels=12, kernel_size=3, stride=2, padding=1)
print(f"Manually verify that this is an informative repr: {m}")
# %%
class MaxPool2d(nn.Module):
def __init__(self, kernel_size: int, stride: Optional[int] = None, padding: int = 1):
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
def forward(self, x: t.Tensor) -> t.Tensor:
'''Call the functional version of max_pool2d.'''
# result is not as smooth as for mean-pooling
# return einops.reduce(x, 'b (h h2) (w w2) c -> h (b w) c', 'max', h2=self.kernel_size, w2=self.kernel_size)
return nn.functional.max_pool2d(x, self.kernel_size, stride=self.stride, padding=self.padding)
def extra_repr(self) -> str:
'''Add additional information to the string representation of this class.'''
return ", ".join([f"{key}={getattr(self, key)}" for key in ["kernel_size", "stride", "padding"]])
m = MaxPool2d(kernel_size=3, stride=2, padding=1)
print(f"Manually verify that this is an informative repr: {m}")
# %%
class Sequential(nn.Module):
_modules: Dict[str, nn.Module]
def __init__(self, *modules: nn.Module):
for index, mod in enumerate(modules):
self._modules[str(index)] = mod
def __getitem__(self, index: int) -> nn.Module:
index %= len(self._modules) # deal with negative indices
return self._modules[str(index)]
def __setitem__(self, index: int, module: nn.Module) -> None:
index %= len(self._modules) # deal with negative indices
self._modules[str(index)] = module
def forward(self, x: t.Tensor) -> t.Tensor:
'''Chain each module together, with the output from one feeding into the next one.'''
for mod in self._modules.values():
x = mod(x)
return x
# %%
class BatchNorm2d(nn.Module):
# The type hints below aren't functional, they're just for documentation
running_mean: Float[Tensor, "num_features"]
running_var: Float[Tensor, "num_features"]
num_batches_tracked: Int[Tensor, ""] # This is how we denote a scalar tensor
def __init__(self, num_features: int, eps=1e-05, momentum=0.1):
Like nn.BatchNorm2d with track_running_stats=True and affine=True.
Name the learnable affine parameters `weight` and `bias` in that order.
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.register_buffer("running_mean", t.zeros(num_features))
self.register_buffer("running_var", t.ones(num_features))
self.register_buffer("num_batches_tracked", t.tensor(0))
self.weight = nn.Parameter(t.ones(num_features))
self.bias = nn.Parameter(t.zeros(num_features))
def forward(self, x: t.Tensor) -> t.Tensor:
Normalize each channel.
Compute the variance using `torch.var(x, unbiased=False)`
Hint: you may also find it helpful to use the argument `keepdim`.
x: shape (batch, channels, height, width)
Return: shape (batch, channels, height, width)
# running_mean <- (1 - momentum) * running_mean + momentum * new_mean
m = self.running_mean
v = self.running_var
m = x.mean(dim=(0,2,3))
v = t.var(x, unbiased=False, dim=(0,2,3))
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * m
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * v
self.num_batches_tracked += 1
return ((x - m.view(1, -1, 1, 1))/(v.view(1, -1, 1, 1) + self.eps)**0.5)*self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
def extra_repr(self) -> str:
'''Add additional information to the string representation of this class.'''
return ", ".join([f"{key}={getattr(self, key)}" for key in ["num_features","eps", "momentum"]])
# %%
class AveragePool(nn.Module):
def forward(self, x: t.Tensor) -> t.Tensor:
x: shape (batch, channels, height, width)
Return: shape (batch, channels)
return x.mean(dim=(2,3))
# %%
class ResidualBlock(nn.Module):
def __init__(self, in_feats: int, out_feats: int, first_stride=1):
A single residual block with optional downsampling.
For compatibility with the pretrained model, declare the left side branch first using a `Sequential`.
If first_stride is > 1, this means the optional (conv + bn) should be present on the right branch. Declare it second using another `Sequential`.
self.first_stride = first_stride
self.in_feats = in_feats
self.out_feats = out_feats
self.left = Sequential(
Conv2d(in_channels=in_feats, out_channels=out_feats, stride=first_stride, kernel_size=3, padding=1),
Conv2d(in_channels=out_feats, out_channels=out_feats, stride=1,kernel_size=3, padding=1),
if (first_stride > 1):
self.right = Sequential(
Conv2d(in_channels=in_feats, out_channels=out_feats, stride=first_stride,kernel_size=1, padding=0),
self.right = nn.Identity()
self.relu = ReLU()
def forward(self, x: t.Tensor) -> t.Tensor:
Compute the forward pass.
x: shape (batch, in_feats, height, width)
Return: shape (batch, out_feats, height / stride, width / stride)
If no downsampling block is present, the addition should just add the left branch's output to the input.
return self.relu(self.left(x) + self.right(x))
# %%
class BlockGroup(nn.Module):
def __init__(self, n_blocks: int, in_feats: int, out_feats: int, first_stride=1):
'''An n_blocks-long sequence of ResidualBlock where only the first block uses the provided stride.'''
self.n_blocks = n_blocks
self.in_feats = in_feats
self.out_feats = out_feats
self.first_stride = first_stride
seq = [ResidualBlock(in_feats=in_feats, out_feats=out_feats, first_stride=first_stride)] + [ResidualBlock(in_feats=out_feats, out_feats=out_feats, first_stride=1) for x in range(n_blocks-1)]
self.weight = Sequential(
def forward(self, x: t.Tensor) -> t.Tensor:
Compute the forward pass.
x: shape (batch, in_feats, height, width)
Return: shape (batch, out_feats, height / first_stride, width / first_stride)
return self.weight(x)
# %%
class ResNet34(nn.Module):
def __init__(
n_blocks_per_group=[3, 4, 6, 3],
out_features_per_group=[64, 128, 256, 512],
first_strides_per_group=[1, 2, 2, 2],
self.n_blocks_per_group = n_blocks_per_group
self.out_features_per_group = out_features_per_group
self.first_strides_per_group = first_strides_per_group
self.n_classes = n_classes
head_seq = [Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3), BatchNorm2d(num_features=64), ReLU(), MaxPool2d(kernel_size=3, stride=2)]
all_in_feats = [64] + out_features_per_group[:-1]
in_features_per_group = [x//2 for x in out_features_per_group]
block_grp_seq = [BlockGroup(*params) for params in zip(n_blocks_per_group, all_in_feats, out_features_per_group, first_strides_per_group)]
self.head = Sequential(*head_seq)
self.block_grp_seq = Sequential(*block_grp_seq)
tail_seq = [AveragePool(), Flatten(1,-1), Linear(512,1000)]
self.tail = Sequential(*tail_seq)
def forward(self, x: t.Tensor) -> t.Tensor:
x: shape (batch, channels, height, width)
Return: shape (batch, n_classes)
x = self.head(x)
x = self.block_grp_seq(x)
return self.tail(x)
my_resnet = ResNet34()
# %%
from pprint import pprint as pp
def copy_weights(my_resnet: ResNet34, pretrained_resnet: models.resnet.ResNet) -> ResNet34:
'''Copy over the weights of `pretrained_resnet` to your resnet.'''
# Get the state dictionaries for each model, check they have the same number of parameters & buffers
mydict = my_resnet.state_dict()
pretraineddict = pretrained_resnet.state_dict()
assert len(mydict) == len(pretraineddict), "Mismatching state dictionaries."
# Define a dictionary mapping the names of your parameters / buffers to their values in the pretrained model
state_dict_to_load = {
mykey: pretrainedvalue
for (mykey, myvalue), (pretrainedkey, pretrainedvalue) in zip(mydict.items(), pretraineddict.items())
# Load in this dictionary to your model
return my_resnet
pretrained_resnet = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)
my_resnet = copy_weights(my_resnet, pretrained_resnet)
# %%
IMAGE_FOLDER = section_dir / "resnet_inputs"
images = [ / filename) for filename in IMAGE_FILENAMES]
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
IMAGENET_TRANSFORM = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
prepared_images = t.stack([IMAGENET_TRANSFORM(img) for img in images], dim=0)
assert prepared_images.shape == (len(images), 3, IMAGE_SIZE, IMAGE_SIZE)
# %%
def predict(model, images: t.Tensor) -> t.Tensor:
Returns the predicted class for each image (as a 1D array of ints).
return t.argmax(model(images), dim=-1)
with open(section_dir / "imagenet_labels.json") as f:
imagenet_labels = list(json.load(f).values())
# Check your predictions match those of the pretrained model
my_predictions = predict(my_resnet, prepared_images)
pretrained_predictions = predict(pretrained_resnet, prepared_images)
assert all(my_predictions == pretrained_predictions)
print("All predictions match!")
# Print out your predictions, next to the corresponding images
for img, label in zip(images, my_predictions):
print(f"Class {label}: {imagenet_labels[label]}")
# %%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment