Last active
February 20, 2023 11:59
-
-
Save Kei-jan/2cb4200ca0a5d136ed72c03a929e353c to your computer and use it in GitHub Desktop.
AdaIN plugin for any tensor size and origin code backup
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
AdaIN | |
url: https://gist.github.com/Kei-jan/2cb4200ca0a5d136ed72c03a929e353c | |
usage: | |
normalized_feat = adaptive_instance_normalization(content_feat, style_feat) | |
content_feat: Tensor[B,C,*,] | |
style_feat: Tensor[B,C,*] | |
normalized_feat: Tensor[B,C,*] (equal to content_feat) | |
copied from: | |
https://github.com/naoto0804/pytorch-AdaIN | |
change: | |
extand to different tensor shape. | |
''' | |
# import torch | |
def adaptive_instance_normalization(content_feat, style_feat): | |
""" | |
by removing the assertion, tensor shape should be checked manually. | |
""" | |
size = content_feat.size() | |
style_mean, style_std = style_feat.mean(dim=[dim for dim in range(2, len(style_feat.shape))]), style_feat.std(dim=[dim for dim in range(2, len(style_feat.shape))]) | |
while len(style_mean.size()) < len(size): | |
style_mean = style_mean.unsqueeze(dim=-1) | |
while len(style_std.size()) < len(size): | |
style_std = style_std.unsqueeze(dim=-1) | |
content_mean, content_std = content_feat.mean(dim=[dim for dim in range(2, len(content_feat.shape))]), content_feat.std(dim=[dim for dim in range(2, len(content_feat.shape))]) | |
while len(content_mean.size()) < len(size): | |
content_mean = content_mean.unsqueeze(dim=-1) | |
while len(content_std.size()) < len(size): | |
content_std = content_std.unsqueeze(dim=-1) | |
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) | |
return normalized_feat * style_std.expand(size) + style_mean.expand(size) | |
# origin code | |
def calc_mean_std(feat, eps=1e-5): | |
# eps is a small value added to the variance to avoid divide-by-zero. | |
size = feat.size() | |
assert (len(size) == 4) | |
N, C = size[:2] | |
feat_var = feat.view(N, C, -1).var(dim=2) + eps | |
feat_std = feat_var.sqrt().view(N, C, 1, 1) | |
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) | |
return feat_mean, feat_std | |
# origin code | |
def adaptive_instance_normalization_origin(content_feat, style_feat): | |
assert (content_feat.size()[:2] == style_feat.size()[:2]) | |
size = content_feat.size() | |
style_mean, style_std = calc_mean_std(style_feat) | |
content_mean, content_std = calc_mean_std(content_feat) | |
normalized_feat = (content_feat - content_mean.expand( | |
size)) / content_std.expand(size) | |
return normalized_feat * style_std.expand(size) + style_mean.expand(size) | |
if __name__ == '__main__': | |
y_content = torch.rand((1,3,128,128)) | |
# print(y_content.shape) # torch.Size([1, 3, 128, 128]) | |
y_style = torch.rand((1,3,)) | |
# print(y_style.shape) # torch.Size([1, 3]) | |
y_mix = adaptive_instance_normalization(y_content,y_style) | |
# print(y_mix.shape) # torch.Size([1, 3, 128, 128]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment