Skip to content

Instantly share code, notes, and snippets.

@InnovArul
Last active December 21, 2020 21:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save InnovArul/bd8b7ce5ac9e615ec251470293563eb2 to your computer and use it in GitHub Desktop.
Save InnovArul/bd8b7ce5ac9e615ec251470293563eb2 to your computer and use it in GitHub Desktop.
import torch, torch.nn as nn, torch.nn.functional as F
def perform_non_parallelconv(input, convs):
outs = []
for i in range(len(convs)):
o = convs[i](input[:, i])
outs.append(o)
outs = torch.cat(outs, dim=1)
return outs
def create_parallel_conv(convs):
# collect weights and biases
weights = []
biases = []
for i in range(len(convs)):
weights.append(convs[i].weight)
biases.append(convs[i].bias)
weights = torch.cat(weights, dim=0)
biases = torch.cat(biases, dim=0)
# create parallel conv with groups
parallel_conv = nn.Conv2d(convs[0].in_channels * len(convs), weights.shape[0], convs[0].kernel_size, groups=len(convs))
parallel_conv.weight = nn.Parameter(weights)
parallel_conv.bias = nn.Parameter(biases)
return parallel_conv, weights, biases
if __name__ == "__main__":
input = torch.randn(1,4,3,120,120).cuda()
convs = nn.ModuleList([
nn.Conv2d(3,50,3),
nn.Conv2d(3,50,3),
nn.Conv2d(3,50,3),
nn.Conv2d(3,50,3),]
).cuda()
non_parallel_out = perform_non_parallelconv(input, convs)
parallel_conv, weights, biases = create_parallel_conv(convs)
# parallel conv with Conv2d module
parallel_out = parallel_conv(input.view(1, -1, 120, 120))
print(torch.allclose(non_parallel_out, parallel_out))
# parallel conv with F.conv2d
parallel_out = F.conv2d(input.view(1, -1, 120, 120), weights, biases, groups=4)
print(torch.allclose(non_parallel_out, parallel_out))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment