Created
February 26, 2020 03:59
-
-
Save alexmlamb/f5c4241040fad812f39f1381910f4ca9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
import random | |
from torch.autograd import Variable, grad | |
import matplotlib | |
matplotlib.use('Agg') | |
import matplotlib.pyplot as plt | |
from pylab import rcParams | |
rcParams['figure.figsize'] = 10, 10 | |
def sample_moons(n_samples=100, noise=0): | |
n_per_class = n_samples // 2 | |
pi = 3.141592653589793 | |
outer_grid = torch.rand(n_per_class, 1) * pi | |
outer_circ_x = torch.cos(outer_grid) | |
outer_circ_y = torch.sin(outer_grid) | |
inner_grid = torch.rand(n_per_class, 1) * pi | |
inner_circ_x = 1 - torch.cos(inner_grid) | |
inner_circ_y = 1 - torch.sin(inner_grid) - .5 | |
X = torch.cat((torch.cat((outer_circ_x, inner_circ_x), 0), | |
torch.cat((outer_circ_y, inner_circ_y), 0)), 1) | |
y = torch.cat((torch.zeros(n_per_class), | |
torch.ones(n_per_class)), 0) | |
X += torch.randn(X.size()) * noise | |
xlst = [] | |
ylst = [] | |
std = 0.2 | |
#for j in range(0,30): | |
# xlst.append([random.gauss(-8,std), random.gauss(-8,std)]) | |
# ylst.append(1) | |
#for j in range(0,30): | |
# xlst.append([random.gauss(8,std), random.gauss(8,std)]) | |
# ylst.append(0) | |
#xlst.append([6,6]) | |
#ylst.append(1) | |
#X = torch.Tensor(xlst).float() | |
#y = torch.Tensor(ylst) | |
print(X.shape, y.shape) | |
return X.detach().cuda(), y.view(-1, 1).detach().cuda() | |
def sample_linear(n_samples=100, noise=0): | |
n_per_class = n_samples // 2 | |
first = torch.rand(n_per_class, 2) * 1.0 - 0.5 | |
second = torch.rand(n_per_class, 2) * 1.0 - 0.5 | |
first[:,0] *= 55.0 | |
second[:,0] *= 55.0 | |
first[:,1] *= 5.1 | |
second[:,1] *= 5.1 | |
first[:,1] -= 3.0 | |
second[:,1] += 3.0 | |
#second[:10,0] *= 1.0 | |
#second[10:,0] *= 1.0 | |
#second[:10,0] -= 20.0 | |
#second[10:,0] += 15.0 | |
X = torch.cat((first, second), 0) | |
y = torch.cat((torch.zeros(n_per_class), | |
torch.ones(n_per_class)), 0) | |
X += torch.randn(X.size()) * noise | |
print(X.shape, y.shape) | |
return X.cuda(), y.cuda() | |
def sample_twospirals(n_samples=100, noise=0): | |
np.random.seed(42) | |
n_points = n_samples//2 | |
n = np.sqrt(np.random.rand(n_points,1)) * 780 * (2*np.pi)/360 | |
d1x = -np.cos(n)*n + np.random.rand(n_points,1) * noise | |
d1y = np.sin(n)*n + np.random.rand(n_points,1) * noise | |
X, Y = (np.vstack((np.hstack((d1x,d1y)),np.hstack((-d1x,-d1y)))), | |
np.hstack((np.zeros(n_points),np.ones(n_points)))) | |
print(X.shape, Y.shape) | |
return torch.from_numpy(X.astype('float32')).cuda(), torch.from_numpy(Y.astype('float32')).cuda() | |
def plot_moons(net, x_labeled, y_labeled, x_unlabeled, meshres): | |
colors = ['blue' if yi == 0 else 'red' for yi in y_labeled] | |
#plt.scatter(x_unlabeled[:, 0], x_unlabeled[:, 1], c="gray", alpha=.75) | |
plt.scatter(x_labeled[:, 0].data.cpu().numpy(), x_labeled[:, 1].data.cpu().numpy(), c=colors, s=100) | |
x_min, x_max = x_unlabeled[:, 0].min() - 0.25, x_unlabeled[:, 0].max() + 0.25 | |
y_min, y_max = x_unlabeled[:, 1].min() - 0.25, x_unlabeled[:, 1].max() + 0.25 | |
x_min = -5. | |
y_min = -5. | |
x_max = 5.0 | |
y_max = 5.0 | |
xx, yy = np.meshgrid(np.arange(x_min, x_max, meshres), np.arange(x_min, x_max, meshres)) | |
Z = net(torch.Tensor(np.c_[xx.ravel(), yy.ravel()]).cuda()).detach().cpu().numpy() | |
Z = Z.reshape(xx.shape) | |
plt.imshow(Z, extent=(x_min, x_max, y_min, y_max), cmap='jet', interpolation='bilinear', origin='lower', alpha=0.5) | |
#plt.contour(xx, yy, Z, levels=[0.01, 0.1,.3,.5,0.7, 0.9, 0.99], linewidths=[3], cmap='jet', origin='lower') | |
plt.xlim(x_min, x_max) | |
plt.ylim(y_min, y_max) | |
#plt.ylim(-10,20) | |
plt.savefig('mymoon.png') | |
plt.clf() | |
def mixup(x, y, a=1): | |
#l = torch.from_numpy(np.random.beta(a, a, size=(x.size(0),1)).astype('float32')).cuda() | |
#if random.uniform(0,1) < 0.5: | |
# u = torch.from_numpy(np.random.normal(0.0, 6.0, size=(x.size(0),1)).astype('float32')).cuda() | |
#else: | |
# u = torch.from_numpy(np.random.normal(0.0, 0.2, size=(x.size(0),1)).astype('float32')).cuda() | |
#l = np.random.beta(a,a) | |
#l = np.addbroadcast(l, 1) | |
#print('lambda', l) | |
extrapolate = False | |
if extrapolate: | |
u = np.random.beta(a+0.01,a+0.01) | |
p = torch.randperm(x.size(0)).cuda() | |
direc = x[p] - x | |
x_mix = x + u*(direc)#/direc.norm(2) | |
d1 = (torch.abs(x_mix - x)).sum() | |
d2 = (torch.abs(x_mix - x[p])).sum() | |
mr = (1/(d1+0.0001)) / (1/(d1 + 0.0001) + 1/(d2+0.0001)) | |
y_mix = mr*y + (1-mr)*y[p] | |
y_mix = Variable(y_mix.data, requires_grad=False) | |
else: | |
l = np.random.beta(a,a) | |
p = torch.randperm(x.size(0)).cuda() | |
x_mix = l * x + (1 - l) * x[p] | |
y_mix = l * y + (1 - l) * y[p] | |
return x_mix, y_mix | |
def train_net(x_labeled, | |
y_labeled, | |
x_unlabeled, | |
n_hiddens=512, | |
n_iterations=3200, | |
lamba=1): | |
net1 = torch.nn.Sequential( | |
torch.nn.Linear(x_labeled.size(1), n_hiddens), | |
#torch.nn.Dropout(0.5), | |
torch.nn.LeakyReLU(), | |
torch.nn.Linear(n_hiddens, n_hiddens), | |
#torch.nn.Dropout(0.5), | |
torch.nn.LeakyReLU(), | |
torch.nn.Linear(n_hiddens, n_hiddens), | |
#torch.nn.Dropout(0.5), | |
torch.nn.LeakyReLU(), | |
torch.nn.Linear(n_hiddens, n_hiddens), | |
#torch.nn.Dropout(0.5), | |
torch.nn.LeakyReLU(), | |
torch.nn.Linear(n_hiddens, 2)) | |
bn = torch.nn.BatchNorm1d(2, affine=False).cuda() | |
net2 = torch.nn.Sequential( | |
torch.nn.Linear(2, n_hiddens), | |
#torch.nn.Dropout(0.5), | |
torch.nn.LeakyReLU(), | |
torch.nn.Linear(n_hiddens, n_hiddens), | |
#torch.nn.Dropout(0.5), | |
torch.nn.LeakyReLU(), | |
torch.nn.Linear(n_hiddens, n_hiddens), | |
#torch.nn.Dropout(0.5), | |
torch.nn.LeakyReLU(), | |
torch.nn.Linear(n_hiddens, 1), | |
torch.nn.Sigmoid()) | |
def net(inp): | |
#return net2(autoencoder(net1(inp))) | |
return net2(net1(inp)) | |
net1.cuda() | |
net2.cuda() | |
opt = torch.optim.Adam(list(net1.parameters()) + list(net2.parameters()), lr = 0.0001, weight_decay = 1e-4) | |
bce = torch.nn.BCELoss() | |
mse = torch.nn.MSELoss() | |
for iteration in range(n_iterations): | |
opt.zero_grad() | |
mix = "h" | |
#mix = 'none' | |
#mix = 'x' | |
if mix == 'none': | |
x_labeled_mix, y_labeled_mix = x_labeled, y_labeled | |
x_labeled = Variable(x_labeled.data, requires_grad=True) | |
x_labeled_use = x_labeled_mix | |
#x_labeled_use = x_labeled_mix + torch.randn(size=x_labeled_mix.shape).cuda() * 0.1 | |
error_labeled = bce(net2(net1((x_labeled_use))), y_labeled_mix) | |
elif mix == 'x': | |
a = 0.5 | |
#x_labeled_mix, y_labeled_mix = mixup(x_labeled, y_labeled, a) | |
#error_labeled = bce(net2(net1((x_labeled_mix))), y_labeled_mix) | |
lamb = np.random.normal(0.5, 0.2) | |
perm = torch.randperm(x_labeled.shape[0]) | |
x_mix = x_labeled*lamb + x_labeled[perm] * (1-lamb) | |
l1 = bce(net2(net1(x_mix)), y_labeled) | |
l2 = bce(net2(net1(x_mix)), y_labeled[perm]) | |
error_labeled = l1*lamb + l2*(1-lamb) | |
elif mix == "h": | |
#a = 1.0 | |
x_labeled = Variable(x_labeled.data, requires_grad=True) | |
h = net1(x_labeled) | |
#h_mix, y_labeled_mix = mixup(h, y_labeled, a) | |
#error_labeled = bce(net2(h_mix), y_labeled_mix) | |
lamb = np.random.normal(0.5,0.2) | |
perm = torch.randperm(h.shape[0]) | |
h_mix = h*lamb + h[perm] * (1-lamb) | |
l1 = bce(net2(h_mix), y_labeled) | |
l2 = bce(net2(h_mix), y_labeled[perm]) | |
error_labeled = l1*lamb + l2*(1-lamb) | |
else: | |
raise Exception('mixing option not found') | |
error_labeled.backward() | |
opt.step() | |
#opt.zero_grad() | |
#x_unlabeled, y_unlabeled = sample_moons(1024, noise=0.1) | |
#y_unlabeled = net(x_unlabeled) | |
#x_unlabeled_mix, y_unlabeled_mix = mixup(x_unlabeled, torch.round(y_unlabeled), 0.1) | |
#error_unlabeled = ((net(x_unlabeled_mix) - y_unlabeled_mix)**2).sum() / x_unlabeled.size(0) | |
#neg_entropy = 0.001 * (y_unlabeled*torch.log(0.001 + y_unlabeled) + (1-y_unlabeled)*torch.log(1 - y_unlabeled + 0.001)).sum() | |
#conf = ((y_unlabeled - torch.round(y_unlabeled))**2).sum() / x_unlabeled.size(0) | |
#(lamba * (error_unlabeled)).backward() | |
#opt.step() | |
if iteration % 100 == 0: | |
print(iteration, error_labeled)#, error_unlabeled, "conf (high means low confidence)", conf) | |
return net1,net2 | |
def plot_h(net1, net2,x_labeled, bound_stat, meshres): | |
x_labeled = Variable(x_labeled.data, requires_grad=True) | |
h = net1(x_labeled) | |
hr = net1(18.0*torch.rand((300,2)).cuda() - 9.0) | |
h = h.data | |
hr = hr.data | |
plt.scatter(h[:h.size(0)//2,0].cpu().numpy(), h[:h.size(0)//2,1].cpu().numpy(), color='blue') | |
plt.scatter(h[h.size(0)//2:,0].cpu().numpy(), h[h.size(0)//2:,1].cpu().numpy(), color='red') | |
#plt.scatter(h_rec[:h.size(0)//2,0].cpu().numpy(), h_rec[:h.size(0)//2,1].cpu().numpy(), color='purple') | |
#plt.scatter(h_rec[h.size(0)//2:,0].cpu().numpy(), h_rec[h.size(0)//2:,1].cpu().numpy(), color='orange') | |
plt.scatter(hr[:,0].cpu().numpy(), hr[:,1].cpu().numpy(), color='black', alpha=0.08, linewidth=0.5) | |
hr = h*1.0 | |
p_min = min(hr[:,0].min(), hr[0,:].min()) | |
p_max = max(hr[:,0].max(), hr[0,:].max()) | |
xx, yy = np.meshgrid(np.arange(p_min.item() - bound_stat, p_max.item() + bound_stat, meshres), np.arange(p_min.item() - bound_stat, p_max.item() + bound_stat, meshres)) | |
if True: | |
model = lambda inp: net2(inp) | |
Z = model(torch.Tensor(np.c_[xx.ravel(), yy.ravel()]).cuda()).detach().cpu().numpy() | |
Z = Z.reshape(xx.shape) | |
plt.imshow(Z, extent=(xx.min(), xx.max(), yy.min(), yy.max()), cmap='jet', interpolation='bilinear', origin='lower', alpha=0.5) | |
plt.savefig('hmoon.png') | |
plt.clf() | |
if __name__ == "__main__": | |
lamba = 1 | |
torch.manual_seed(1000) | |
#was doing 128 | |
x_labeled, y_labeled = sample_moons(128, 0.01) | |
x_unlabeled, y_unlabeled = sample_twospirals(1000, noise=0.1) | |
net1,net2 = train_net(x_labeled, y_labeled, x_unlabeled, lamba=lamba) | |
plot_moons(lambda inp: net2(net1(inp)), x_labeled, y_labeled, x_unlabeled, meshres=0.1) | |
plot_h(net1, net2, x_labeled, bound_stat=0.2, meshres=0.2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment