Skip to content

Instantly share code, notes, and snippets.

@J3698
Created March 24, 2021 00:15
Show Gist options
  • Save J3698/b0852b231d3a448e257546080ad239cd to your computer and use it in GitHub Desktop.
Save J3698/b0852b231d3a448e257546080ad239cd to your computer and use it in GitHub Desktop.
def main():
target = torch.randint(-20, 20, (8, 3, 4, 4)).float()
source = torch.randint(-20, 20, (8, 3, 4, 4)).float()
stylized_source = adain(source, target)
target = target.view(8, 3, -1)
stylized_source = stylized_source.view(8, 3, -1)
# check variances the same
target_variances = target.var(-1)
stylized_variances = target.var(-1)
print(torch.all(torch.abs(stylized_variances - target_variances) < 1e-6))
# check means the same
target_means = target.mean(-1)
stylized_means = target.mean(-1)
print(torch.all(torch.avs(stylized_means - target_means) < 1e-6))
# check values are just rescaled / shifted
print(torch.all(torch.abs(F.instance_norm(source) - F.instance_norm(stylized_source.view(8, 3, 4, 4))) < 1e-6))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment