Skip to content

Instantly share code, notes, and snippets.

@CharlesJQuarra
Created September 21, 2018 16:59
Show Gist options
  • Save CharlesJQuarra/c5e2e1682ef33d70b9e58d2316bb744d to your computer and use it in GitHub Desktop.
Save CharlesJQuarra/c5e2e1682ef33d70b9e58d2316bb744d to your computer and use it in GitHub Desktop.
Attempt of a linear unit that supports splitting the parameter space in a grid of gradient checkpoint nodes. The issue right now is that when there is more than one segment, the `backward()` only updates the gradient for the last parameter
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
def get_segments(total, max_length):
if total > max_length:
segments = (total // max_length)
else:
segments = 1
return (segments-1)*[max_length] + [total - (segments-1)*max_length]
class GradCheckpoint_Linear(nn.Module):
def __init__(self, in_features, out_features, cpc_specs={}):
super(GradCheckpoint_Linear, self).__init__()
self.in_features = in_features
self.out_features = out_features
if 'in_max_segment' in cpc_specs:
in_max_segment = cpc_specs['in_max_segment']
else:
in_max_segment = in_features
self.in_segment_lengths = get_segments(in_features, in_max_segment)
if 'out_max_segment' in cpc_specs:
out_max_segment = cpc_specs['out_max_segment']
else:
out_max_segment = out_features
if 'initializer' in cpc_specs:
self.initializer = cpc_specs['initializer']
else:
def get_init(w, h):
return torch.randn(w,h)
self.initializer = get_init
self.out_segment_lengths = get_segments(out_features, out_max_segment)
print("in_segment_lengths: {0}".format(self.in_segment_lengths))
print("out_segment_lengths: {0}".format(self.out_segment_lengths))
weight_parameters_ = []
bias_parameters_ = []
self.array_to_weight_param = -torch.ones(len(self.in_segment_lengths), len(self.out_segment_lengths), dtype=torch.int32)
for in_idx, in_s_length in enumerate(self.in_segment_lengths):
for out_idx, out_s_length in enumerate(self.out_segment_lengths):
param = nn.Parameter( self.initializer(out_s_length,in_s_length) )
self.array_to_weight_param[in_idx,out_idx]=len(weight_parameters_)
weight_parameters_.append( param )
self.weight_parameters = nn.ParameterList( weight_parameters_ )
for out_s_length in self.out_segment_lengths:
bias_parameters_.append( nn.Parameter( self.initializer(1, out_s_length).view(out_s_length) ) )
self.bias_parameters = nn.ParameterList( bias_parameters_ )
def reset_parameters(self):
pass
def forward(self, inp):
unit_outs = []
for out_idx, out_s_length in enumerate(self.out_segment_lengths):
bias_param = self.bias_parameters[out_idx]
in_offset = 0
weight_outs = []
for in_idx, in_s_length in enumerate(self.in_segment_lengths):
weight_param = self.weight_parameters[ self.array_to_weight_param[in_idx, out_idx] ]
def fwd_unit_segment(inp_):
return torch.mv(weight_param, inp_)
weight_out = checkpoint.checkpoint( fwd_unit_segment , inp[in_offset:in_offset+in_s_length] )
in_offset += in_s_length
weight_outs.append(weight_out)
unit_outs.append( bias_param + sum(weight_outs) ) #(*1 we squeeze back the 1 after matmul)
result = torch.cat(unit_outs)
return result
def init_(w,h):
return torch.ones(w,h)
u_original = GradCheckpoint_Linear(6, 2, cpc_specs={'initializer': init_})
u_split = GradCheckpoint_Linear(6, 2, cpc_specs={'in_max_segment': 2, 'out_max_segment': 1, 'initializer': init_})
inp_0_orig = torch.ones(6, requires_grad=True)
inp_0_split = torch.ones(6, requires_grad=True)
u_original(inp_0_orig).sum().backward()
u_original.weight_parameters[0].grad #<--- looks good
u_split(inp_0_orig).sum().backward()
u_split.weight_parameters[0].grad #<--- grad is None
u_split.weight_parameters[5].grad #<--- only grad parameter that has values is the last one?
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment