Skip to content

Instantly share code, notes, and snippets.

@phuocphn
Created June 2, 2020 07:33
Show Gist options
  • Save phuocphn/dbfea39071a523502be0dbf120a4590a to your computer and use it in GitHub Desktop.
Save phuocphn/dbfea39071a523502be0dbf120a4590a to your computer and use it in GitHub Desktop.
import math
import numpy as np
import torch.nn.functional as F
import torch
import torch.nn as nn
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(0)
torch.manual_seed(0)
x = torch.randn(1,3,32,32) * 0.3 * 0.3
import matplotlib.pyplot as plt
plt.close('all')
import matplotlib
matplotlib.use('TkAgg')
# plt.plot(np.linspace(-3.0, 3.0, num = 50), hist.squeeze().detach().numpy())
x1 = x
v1 = 0.5 * 2 * x1 / 0.425
x2 = x1 - 0.425 * v1.sign() # 0.425 * v1.sign() = x1 - x2
v2 = 0.5 * 2 * x2 / 0.335
x3 = x2 - 0.335 * v2.sign() # 0.335 * v2.sign() = x2 - x3
v3 = 0.5 * 2 * x3 / 0.1225
fig = plt.figure(0)
fig.canvas.set_window_title('Output histogram')
plt.hist(v1.data.view(-1).numpy(), bins = 100,alpha=0.5, label="v1")
# plt.hist(v2.data.view(-1).numpy(), bins = 100,alpha=0.5, label="xx")
plt.hist(x2.data.view(-1).numpy(), bins = 100,alpha=0.5, label="x2")
plt.hist(v2.data.view(-1).numpy(), bins = 100,alpha=0.5, label="v2")
# gg = v1.sign()*0.425 + v2.sign() * 0.335
# plt.hist(gg.data.view(-1).numpy(), bins = 100,alpha=0.2, label="xx")
plt.hist(x3.data.view(-1).numpy(), bins = 100,alpha=0.5, label="x3")
plt.hist(v3.data.view(-1).numpy(), bins = 100,alpha=0.3, label="v3")
plt.hist(x1.data.view(-1).numpy(), bins = 100, alpha=0.9,label="x1")
# plt.hist(z.data.view(-1).numpy(), bins = 100,alpha=0.5)
plt.legend(prop={'size': 10})
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment