Skip to content

Instantly share code, notes, and snippets.

@maxrohleder
Last active January 11, 2023 13:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save maxrohleder/82507a4eecbe4192871117cdb2897181 to your computer and use it in GitHub Desktop.
Save maxrohleder/82507a4eecbe4192871117cdb2897181 to your computer and use it in GitHub Desktop.
Convert tf to pytorch model
# >>> tf1 implementation (without encapsulating class)
import tensorflow as tf
def upconvcat(self, x1, x2, n_filter, name):
x1 = tf.keras.layers.UpSampling2D((2, 2))(x1)
x1 = tf.layers.conv2d(x1, filters=n_filter, kernel_size=(3, 3), padding='same', name="upsample_{}".format(name))
return tf.concat([x1, x2], axis=-1, name="concat_{}".format(name)) # NHWC format
# >>> pytorch implementation
import torch
class UpConvCat(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.up = torch.nn.Upsample(scale_factor=2)
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, x1, x2):
x1 = self.up(x1)
return torch.cat([x1, x2], dim=1) # NCHW format
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment