Skip to content

Instantly share code, notes, and snippets.

@CharlesJQuarra
CharlesJQuarra / grad_checkpoint_linear.py
Created September 21, 2018 16:59
Attempt of a linear unit that supports splitting the parameter space in a grid of gradient checkpoint nodes. The issue right now is that when there is more than one segment, the `backward()` only updates the gradient for the last parameter
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
def get_segments(total, max_length):
if total > max_length:
segments = (total // max_length)
else:
segments = 1
return (segments-1)*[max_length] + [total - (segments-1)*max_length]
@CharlesJQuarra
CharlesJQuarra / fat_cat.py
Created August 16, 2018 15:56
broadcastable version of `torch.cat`
import torch
""""
behavior:
fat_cat([torch.randn(1,7,20), torch.randn(5,1,13)], dim=-1).size() == torch.Size([5, 7, 33])
""""
def axis_repeat(t, dim, times):
if t.size()[dim] != 1: