Skip to content

Instantly share code, notes, and snippets.

@shivance
Created December 31, 2022 11:06
Show Gist options
  • Save shivance/4771bb348755b488597b461ab9868aa5 to your computer and use it in GitHub Desktop.
Save shivance/4771bb348755b488597b461ab9868aa5 to your computer and use it in GitHub Desktop.
Pytorch implentation of UNet
def forward(self, x):
enc1 = self.encoder1(x)
enc2 = self.encoder2(self.pool1(enc1))
enc3 = self.encoder3(self.pool2(enc2))
enc4 = self.encoder4(self.pool3(enc3))
bottleneck = self.bottleneck(self.pool4(enc4))
dec4 = self.upconv4(bottleneck)
dec4 = torch.cat((dec4, enc4), dim=1)
dec4 = self.decoder4(dec4)
dec3 = self.upconv3(dec4)
dec3 = torch.cat((dec3, enc3), dim=1)
dec3 = self.decoder3(dec3)
dec2 = self.upconv2(dec3)
dec2 = torch.cat((dec2, enc2), dim=1)
dec2 = self.decoder2(dec2)
dec1 = self.upconv1(dec2)
dec1 = torch.cat((dec1, enc1), dim=1)
dec1 = self.decoder1(dec1)
return torch.sigmoid(self.conv(dec1))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment