Skip to content

Instantly share code, notes, and snippets.

@arunmallya
Created June 29, 2017 01:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save arunmallya/ba077f554deb71bdeeb4eda499e910df to your computer and use it in GitHub Desktop.
Save arunmallya/ba077f554deb71bdeeb4eda499e910df to your computer and use it in GitHub Desktop.
import torch.nn as nn
def myModule(nn.Module):
def __init__(self):
# Init stuff here
self.X = nn.Sequential(
nn.Linear(num_input_genes, num_tfs),
nn.ReLU(),
nn.BatchNorm1d(num_tfs)
)
self.C = nn.Sequential(
nn.Conv1d(num_tfs, num_conv_out_channels, conv_kernel_size),
nn.ReLU(),
nn.BatchNorm1d(num_conv_out_channels),
nn.MaxPool1d(max_pool_kernel_size)
)
def forward(self, input, M):
x_out = self.X(input)
x_out = M * x_out # With required reshaping, ...
x_out = self.C(x_out)
return x_out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment