Last active
December 21, 2020 21:51
-
-
Save InnovArul/bd8b7ce5ac9e615ec251470293563eb2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch, torch.nn as nn, torch.nn.functional as F | |
def perform_non_parallelconv(input, convs): | |
outs = [] | |
for i in range(len(convs)): | |
o = convs[i](input[:, i]) | |
outs.append(o) | |
outs = torch.cat(outs, dim=1) | |
return outs | |
def create_parallel_conv(convs): | |
# collect weights and biases | |
weights = [] | |
biases = [] | |
for i in range(len(convs)): | |
weights.append(convs[i].weight) | |
biases.append(convs[i].bias) | |
weights = torch.cat(weights, dim=0) | |
biases = torch.cat(biases, dim=0) | |
# create parallel conv with groups | |
parallel_conv = nn.Conv2d(convs[0].in_channels * len(convs), weights.shape[0], convs[0].kernel_size, groups=len(convs)) | |
parallel_conv.weight = nn.Parameter(weights) | |
parallel_conv.bias = nn.Parameter(biases) | |
return parallel_conv, weights, biases | |
if __name__ == "__main__": | |
input = torch.randn(1,4,3,120,120).cuda() | |
convs = nn.ModuleList([ | |
nn.Conv2d(3,50,3), | |
nn.Conv2d(3,50,3), | |
nn.Conv2d(3,50,3), | |
nn.Conv2d(3,50,3),] | |
).cuda() | |
non_parallel_out = perform_non_parallelconv(input, convs) | |
parallel_conv, weights, biases = create_parallel_conv(convs) | |
# parallel conv with Conv2d module | |
parallel_out = parallel_conv(input.view(1, -1, 120, 120)) | |
print(torch.allclose(non_parallel_out, parallel_out)) | |
# parallel conv with F.conv2d | |
parallel_out = F.conv2d(input.view(1, -1, 120, 120), weights, biases, groups=4) | |
print(torch.allclose(non_parallel_out, parallel_out)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment