Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save nilsleh/b6d8aeeb20f3b56d58dda762857e5283 to your computer and use it in GitHub Desktop.
Save nilsleh/b6d8aeeb20f3b56d58dda762857e5283 to your computer and use it in GitHub Desktop.
import numpy as np
import torch
import torch.nn as nn
from functorch import vmap, jacrev, make_functional_with_buffers
batch_size = 2
in_channels = 5
out_channels = 20
feature_shape = 8
feature = torch.rand(batch_size, in_channels, feature_shape, feature_shape)
class ConvBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super(ConvBlock, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, in_ch, 3, 1),
nn.Conv2d(in_ch, in_ch, 3, 1),
nn.Conv2d(in_ch, out_ch, 3, 1),
)
def forward(self, x):
return self.conv(x)
model = ConvBlock(in_channels, out_channels)
fmodel, params, buffers = make_functional_with_buffers(model)
def map_layer_param_to_flat_param(model):
tuple_param_lists = []
num_params_so_far = 0
for param_layer_idx, p in enumerate(model.parameters()):
vec = p.flatten()
for vec_idx in range(len(vec)):
tuple_param_lists.append([num_params_so_far + vec_idx, param_layer_idx, vec_idx])
num_params_so_far += len(vec)
param_idx = torch.tensor(tuple_param_lists)
return param_idx
def define_subnet_and_other_indices(model):
param_vector = torch.cat([p.flatten() for p in model.parameters()], dim=0)
subnet_indices = torch.from_numpy(
np.sort(np.random.choice(np.arange(0, len(param_vector)), size=10))
)
deterministic_indices = torch.tensor([k for k in range(len(param_vector)) if k not in subnet_indices])
return subnet_indices, deterministic_indices
def split(params, relevant_indices, other_indices):
relevant_params = {}
other_params = {}
param_shapes = {}
for i, param in enumerate(params):
# gather the relevant parameters
relevant_param_idx_at_i = relevant_indices[relevant_indices[:,1]==i]
if relevant_param_idx_at_i.nelement() != 0:
relevant_idx_flat_param = relevant_param_idx_at_i[:,2]
relevant_params[i] = param.flatten()[relevant_idx_flat_param]
# gather the other parameters
other_param_idx_at_i = other_indices[other_indices[:,1]==i]
if other_param_idx_at_i.nelement() != 0:
other_idx_flat_param = other_param_idx_at_i[:,2]
other_params[i] = param.flatten()[other_idx_flat_param]
# keep track of shapes to reconstruct them later
param_shapes[i] = param.shape
return relevant_params, other_params, param_shapes
def combine(relevant_params, other_params, relevant_indices, other_indices, param_shapes):
"""Reconstruct convolutional weight tensors, from 1d tensors."""
reconstructed_params = []
for p_idx, p_shape in param_shapes.items():
relevant_indices_at_p = relevant_indices[relevant_indices[:,1]==p_idx]
other_indices_at_p = other_indices[other_indices[:,1]==p_idx]
if (relevant_indices_at_p.nelement() != 0) & (other_indices_at_p.nelement() != 0):
all_indices_at_p = torch.cat([relevant_indices_at_p, other_indices_at_p], dim=0)
# argsort in the correct order of original flattened param_vector
sorted_indices_at_p = torch.argsort(all_indices_at_p, dim=0)
all_params_at_p = torch.cat([relevant_params[p_idx], other_params[p_idx]])[sorted_indices_at_p[:,2]]
elif (relevant_indices_at_p.nelement() != 0) & (other_indices_at_p.nelement() == 0):
all_params_at_p = relevant_params[p_idx]
elif (relevant_indices_at_p.nelement() == 0) & (other_indices_at_p.nelement() != 0):
all_params_at_p = other_params[p_idx]
reconstructed_params.append(all_params_at_p.view(p_shape))
return tuple(reconstructed_params)
def compute_output_stateless_model(relevant_params, other_params, relevant_indices, other_indices, param_shapes, buffers, feature):
params = combine(relevant_params, other_params, relevant_indices, other_indices, param_shapes)
batch = feature.unsqueeze(0)
output = fmodel(params, buffers, batch)
output = output.view(batch.shape[0], -1, 8)
return output
param_idx = map_layer_param_to_flat_param(model)
subnet_indices, deterministic_indices = define_subnet_and_other_indices(model)
relevant_indices = param_idx[subnet_indices]
other_indices = param_idx[deterministic_indices]
relevant_params, other_params, param_shapes = split(params, relevant_indices, other_indices)
ft_compute_grad = jacrev(compute_output_stateless_model)
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, None, None, None, None, 0))
ft_per_sample_grads = ft_compute_sample_grad(relevant_params, other_params, relevant_indices, other_indices, param_shapes, buffers, feature)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment