Skip to content

Instantly share code, notes, and snippets.

@jlebensold
Last active May 7, 2019 13:49
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jlebensold/f7d5c889ae4d94f7630a96f7effc7e8e to your computer and use it in GitHub Desktop.
Save jlebensold/f7d5c889ae4d94f7630a96f7effc7e8e to your computer and use it in GitHub Desktop.
Baselines for MorphNet paper
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
class TanhNet(nn.Module):
def __init__(self, in_features, h_units):
super(TanhNet, self).__init__()
self.fc1 = nn.Linear(in_features, h_units)
self.fc2 = nn.Linear(h_units, 10)
self.in_features = in_features
self.h_units = h_units
def forward(self, input: torch.Tensor):
flattened = input.view(-1, self.in_features)
input = F.tanh(self.fc1(flattened))
input = F.tanh(self.fc2(input))
return input
def name(self, dset_name: str):
return "tanh_{}x{}_{}".format(
self.in_features, self.h_units, dset_name)
def store(self, dset_name: str, directory: Path):
name = self.name(dset_name)
fname = "{}.tch".format(name)
torch.save(self, str(directory / fname))
class ReLUNet(nn.Module):
def __init__(self, in_features, h_units):
super(ReLUNet, self).__init__()
self.fc1 = nn.Linear(in_features, h_units)
self.fc2 = nn.Linear(h_units, 10)
self.in_features = in_features
self.h_units = h_units
def forward(self, input: torch.Tensor):
flattened = input.view(-1, self.in_features)
input = F.relu(self.fc1(flattened))
input = F.relu(self.fc2(input))
return input
def name(self, dset_name: str):
return "relu_{}x{}_{}".format(
self.in_features, self.h_units, dset_name)
def store(self, dset_name: str, directory: Path):
name = self.name(dset_name)
fname = "{}.tch".format(name)
torch.save(self, str(directory / fname))
class MaxoutNet(nn.Module):
def __init__(self, in_features, h_units, out_features=10):
super(MaxoutNet, self).__init__()
self.fc1 = Maxout(in_features, h_units, 2)
self.fc2 = Maxout(h_units, out_features, 2)
self.in_features = in_features
self.h_units = h_units
def forward(self, input: torch.Tensor):
flattened = input.view(-1, self.in_features)
input = self.fc1(flattened)
input = self.fc2(input)
return input
def name(self, dset_name: str):
return "maxout_{}x{}_{}".format(
self.in_features, self.h_units, dset_name)
def store(self, dset_name: str, directory: Path):
name = self.name(dset_name)
fname = "{}.tch".format(name)
torch.save(self, str(directory / fname))
# from https://github.com/pytorch/pytorch/issues/805
class Maxout(nn.Module):
def __init__(self, d_in, d_out, pool_size):
super().__init__()
self.d_in, self.d_out, self.pool_size = d_in, d_out, pool_size
self.lin = nn.Linear(d_in, d_out * pool_size)
def forward(self, inputs):
shape = list(inputs.size())
shape[-1] = self.d_out
shape.append(self.pool_size)
max_dim = len(shape) - 1
out = self.lin(inputs)
maxout, _i = out.view(*shape).max(max_dim)
return maxout
import numpy as np
from pathlib import Path
import torch
from torch import nn
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
def dilate(x, s):
return torch.max(x + s, dim=1)
def erode(x, s):
return torch.min(x - s, dim=1)
class DilateErode(nn.Module):
def __init__(self, in_features: int, number_of_dilations: int, number_of_erosions: int):
super().__init__()
self.in_features = in_features
self.number_of_dilations = number_of_dilations
self.number_of_erosions = number_of_erosions
if self.number_of_dilations > 0:
self.dilations = nn.Parameter(torch.randn(in_features, number_of_dilations))
else:
self.dilations = torch.Tensor()
if self.number_of_erosions > 0:
self.erosions = nn.Parameter(torch.randn(in_features, number_of_erosions))
else:
self.erosions = torch.Tensor()
self.dilation_bias = nn.Parameter(torch.zeros(1))
self.erosion_bias = nn.Parameter(torch.zeros(1))
def forward(self, input: torch.Tensor):
batch_size = input.shape[0]
flattened = input.view(batch_size, self.in_features, 1)
if self.number_of_dilations > 0:
# Each dilation is a max of a sum of all the input features.
dsum = flattened + self.dilations
dilated = torch.max(dsum, dim=1)[0]
# Append the dilation bias. The paper treats it as a tensor, but because you take a max it's
# actually just a constant.
dilated_with_bias = torch.cat((dilated, self.dilation_bias.expand(batch_size, 1)), dim=1)
else:
dilated_with_bias = torch.Tensor().to(device)
if self.number_of_erosions > 0:
# Each erosion is a min of a difference of all the input features.
esub = flattened - self.erosions
eroded = torch.min(esub, dim=1)[0]
# Append the erosion bias.
eroded_with_bias = torch.cat((eroded, (-self.erosion_bias).expand(batch_size, 1)), dim=1)
else:
eroded_with_bias = torch.Tensor().to(device)
combined = torch.cat((dilated_with_bias, eroded_with_bias), dim=1)
return combined
class DenMoNet(nn.Module):
"""The dilation-erosion network."""
def __init__(self, input_space_dim: int, number_dilations: int, number_erosions: int, output_space_dim: int):
super().__init__()
self.de_layer = DilateErode(input_space_dim, number_dilations, number_erosions)
# The linear combination size is the number of erosions plus the number of dilations, plus
# one bias node for each (if there's at least one, that is).
lc_size = number_erosions + np.sign(number_erosions) + number_dilations + np.sign(number_dilations)
self.linear_combination_layer = nn.Linear(lc_size, output_space_dim)
def name(self, dset_name: str):
return "denmo_{}x{}_{}".format(self.de_layer.number_of_dilations,
self.de_layer.number_of_erosions, dset_name)
def forward(self, input: torch.Tensor):
temp = self.de_layer(input)
self.temp = temp
classification = self.linear_combination_layer(temp)
return classification
def store(self, dset_name: str, directory: Path):
name = self.name(dset_name)
fname = "{}.tch".format(name)
torch.save(self, str(directory / fname))
@jlebensold
Copy link
Author

screen shot 2018-11-30 at 1 47 07 pm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment