Skip to content

Instantly share code, notes, and snippets.

@leVirve
Last active May 15, 2019 18:46
Show Gist options
  • Save leVirve/0377a8fbac455bfd44e374e5cf8b1260 to your computer and use it in GitHub Desktop.
Save leVirve/0377a8fbac455bfd44e374e5cf8b1260 to your computer and use it in GitHub Desktop.
The real CoordConv in PyTorch. It can auto-infer the x-y dimensions in tensors. Use it without pain. 💜
class AddCoords(nn.Module):
def __init__(self, with_r=False):
super().__init__()
self.with_r = with_r
def forward(self, input_tensor):
"""
Args:
input_tensor: shape(batch, channel, x_dim, y_dim)
"""
batch_size, _, x_dim, y_dim = input_tensor.size()
xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1)
yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2)
xx_channel = xx_channel.float() / (x_dim - 1)
yy_channel = yy_channel.float() / (y_dim - 1)
xx_channel = xx_channel * 2 - 1
yy_channel = yy_channel * 2 - 1
xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
ret = torch.cat([
input_tensor,
xx_channel.type_as(input_tensor),
yy_channel.type_as(input_tensor)], dim=1)
if self.with_r:
rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2))
ret = torch.cat([ret, rr], dim=1)
return ret
class CoordConv(nn.Module):
def __init__(self, in_channels, out_channels, with_r=False, **kwargs):
super().__init__()
self.addcoords = AddCoords(with_r=with_r)
self.conv = nn.Conv2d(in_channels + 2, out_channels, **kwargs)
def forward(self, x):
ret = self.addcoords(x)
ret = self.conv(ret)
return ret
@mkocabas
Copy link

Hi @leVirve thanks for the implementation! Can I update my repo according to yours? Or if you want, you can create a pull request to my repo.

@leVirve
Copy link
Author

leVirve commented Jul 14, 2018

@mkocabas Sure! But this is an alternative implementation.
I'm looking into the author's version and will make a pull request for your project. 😄

@casssoft
Copy link

casssoft commented May 15, 2019

@leVirve This is great!

FYI there was one problem I ran into when running multiple experiments with multiple gpus. I kept getting out of memory errors on line 28 xx_channel.type_as(input_tensor), even with small batch sizes. It looks like I hit the same issue as pytorch/pytorch#3477.

The solution that worked for me was wrapping AddCoords.forward with with torch.cuda.device_of(input_tensor): but maybe that with should wrap the entire trained network code... anyhow hope this helps anyone else who runs into this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment