Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@etienne87
Created June 29, 2019 20:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save etienne87/4904b3207d31cdb7ff80e581807ff610 to your computer and use it in GitHub Desktop.
Save etienne87/4904b3207d31cdb7ff80e581807ff610 to your computer and use it in GitHub Desktop.
trying to draw mickey using binary cross entropy
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import random
def bce_loss_with_logits(x, y):
z = y.float()
x = x.squeeze()
losses = F.relu(x) - x * z + torch.log(1 + torch.exp(-torch.abs(x)))
return losses.mean()
def bce_focal_loss(x, y):
alpha = 0.25
gamma = 2
z = y.float()
x = x.squeeze()
p = x.sigmoid()
pt = torch.where(z>0, p, 1-p)
weights = (1-pt).pow(gamma)
#weights = torch.where(z > 0, alpha * weights, alpha * weights)
losses = F.relu(x) - x * z + torch.log(1 + torch.exp(-torch.abs(x)))
#losses = F.binary_cross_entropy_with_logits(x, z)
loss = losses * weights
return loss.mean()
def softmax_focal_loss(sx, y):
gamma = 2
r = torch.arange(x.size(0))
pt = F.softmax(x, dim=1)[r,y]
weights = (1-pt).pow(gamma) #should normalize?
ce = -F.log_softmax(x, dim=1)[r,y]
loss = weights * ce
return loss.mean()
def make_pic_distribution():
import cv2
img = cv2.imread('mickey.jpg', cv2.IMREAD_GRAYSCALE)
x1, x2 = np.where(img > -1)
x1 = img.shape[0]-x1
x = np.concatenate([x2[:,None], x1[:,None]], axis=1)
y = (img>3).reshape(-1)
return x, y
def make_net(cin=2, hidden=64, cout=1):
return nn.Sequential(nn.BatchNorm1d(cin),
nn.Linear(cin, hidden), nn.ELU(),
nn.Linear(hidden, hidden), nn.ELU(),
nn.Linear(hidden, cout))
class ResNet(nn.Module):
def __init__(self, cin=2, hidden=64, num_layers=5):
super(ResNet, self).__init__()
self.prepare = nn.Sequential(nn.BatchNorm1d(cin),
nn.Linear(cin, hidden),
nn.ReLU())
self.residuals = nn.ModuleList()
for _ in range(num_layers):
self.residuals.append(nn.Sequential(nn.Linear(hidden, hidden), nn.BatchNorm1d(hidden), nn.ReLU()))
self.out = nn.Linear(hidden, 1)
def forward(self, x):
x = self.prepare(x)
for res in self.residuals:
x = x + res(x)
return self.out(x)
#N = 50000
#Ntr = N*50/100
#x = torch.randn(N, 4)
# xtr = x[:Ntr]
# ytr = xtr.norm(dim=1)
# ytr = ytr < (ytr.mean()-2*ytr.std())
# xval = x[Ntr:]
# yval = xval.norm(dim=1)
# yval = yval < (yval.mean()-2*yval.std())
x, y = make_pic_distribution()
N = len(x)
Ntr = N*100/100
x = torch.from_numpy(x).float()
#x = (x-x.mean(dim=0))/(x.std(dim=0) + 1e7)
y = torch.from_numpy(y.astype(np.uint8))
idx = range(N)
random.shuffle(idx)
x = x[idx]
y = y[idx]
xtr = x[:Ntr]
ytr = y[:Ntr]
xval = x[Ntr:]
yval = y[Ntr:]
ytrnp = ytr.numpy().astype(np.int32)
cuda = 1
hidden = 256
net1 = ResNet(num_layers=10)
net2 = ResNet(num_layers=10)
if cuda:
xtr = xtr.cuda()
ytr = ytr.cuda()
net1.cuda()
net2.cuda()
p = np.array([0.8, 0.2], dtype=np.float32)
p = p[ytrnp]
p = p/p.sum()
batchsize = 1024*5
net1, net2 = net1.train(), net2.train()
opt1 = optim.Adam(net1.parameters(), lr=0.1, betas=(0.9, 0.99), weight_decay=1e-5)
opt2 = optim.Adam(net2.parameters(), lr=0.1, betas=(0.9, 0.99), weight_decay=1e-5)
for i in range(1000):
idx = np.random.choice(np.arange(0, len(ytr)), size=batchsize, p=p)
bx = xtr[idx]
by = ytr[idx]
opt1.zero_grad()
out = net1(bx)
loss1 = bce_loss_with_logits(out, by)
loss1.backward()
opt1.step()
opt2.zero_grad()
out = net2(bx)
loss2 = bce_focal_loss(out, by)
loss2.backward()
opt2.step()
if i%100 == 0:
print('loss1: ', loss1.item(), ' loss2: ', loss2.item())
net1.cpu()
net2.cpu()
net1.eval()
net2.eval()
xval = xtr.cpu()
yval = ytr.cpu()
y_hat = (net1(xval)>=0).squeeze()
Error = (y_hat != yval).float().mean()
y_hat2 = (net2(xval)>=0).squeeze()
Error2 = (y_hat2 != yval).float().mean()
print('BCE Error: ', Error, ' FocalBCE Error2: ', Error2)
plt.subplot(311)
plt.scatter(xval[yval, 0], xval[yval, 1], marker='.', color='b', s=1)
plt.scatter(xval[~yval, 0], xval[~yval, 1], marker='.', color='r', s=1)
plt.subplot(312)
plt.scatter(xval[y_hat, 0], xval[y_hat, 1], marker='.', color='b', s=1)
plt.scatter(xval[~y_hat, 0], xval[~y_hat, 1], marker='.', color='r', s=1)
plt.subplot(313)
plt.scatter(xval[y_hat2, 0], xval[y_hat2, 1], marker='.', color='b', s=1)
plt.scatter(xval[~y_hat2, 0], xval[~y_hat2, 1], marker='.', color='r', s=1)
# xerr = xval[y_hat != yval]
# yerr = yval[y_hat != yval]
#
# plt.subplot(312)
# plt.scatter(xerr[yerr, 0], xerr[yerr, 1], marker='.', color='b', s=1)
# plt.scatter(xerr[~yerr, 0], xerr[~yerr, 1], marker='.', color='r', s=1)
#
# xerr = xval[y_hat2 != yval]
# yerr = yval[y_hat2 != yval]
#
# plt.subplot(313)
# plt.scatter(xerr[yerr, 0], xerr[yerr, 1], marker='.', color='b', s=1)
# plt.scatter(xerr[~yerr, 0], xerr[~yerr, 1], marker='.', color='r', s=1)
plt.show()
@etienne87
Copy link
Author

bceloss

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment