Skip to content

Instantly share code, notes, and snippets.

@vene
Last active November 1, 2020 16:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vene/ebd75aa06b39568b1eec9c69ca98a56c to your computer and use it in GitHub Desktop.
Save vene/ebd75aa06b39568b1eec9c69ca98a56c to your computer and use it in GitHub Desktop.
# author: vlad niculae <vlad@vene.ro>
# license: mit
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from entmax import sparsemax, entmax15
from entmax.losses import sparsemax_loss, entmax15_loss
# unify API for output layers: softmax, sparsemax, entmax.
def _extend_2d(z):
z = z.unsqueeze(dim=-1)
z = torch.cat((torch.zeros_like(z), z), dim=-1)
return z
class SoftmaxOut(object):
def __init__(self):
self.loss_obj = torch.nn.BCEWithLogitsLoss()
def loss(self, z, y_true):
return self.loss_obj(z, y_true)
def yhat(self, z):
return torch.sigmoid(z)
class SparsemaxOut(object):
def loss(self, z, y_true):
return sparsemax_loss(_extend_2d(z), y_true.long()).mean()
def yhat(self, z):
return sparsemax(_extend_2d(z))[..., 1]
class Entmax15Out(object):
def loss(self, z, y_true):
return entmax15_loss(_extend_2d(z), y_true.long()).mean()
def yhat(self, z):
return entmax15(_extend_2d(z))[..., 1]
# evaluate a function over a mesh for making contour plots
class MeshEval(object):
def __init__(self, xlim, ylim, n_points):
self.n_points = n_points
x_min, x_max = xlim
y_min, y_max = ylim
grid_x = np.linspace(x_min, x_max, n_points)
grid_y = np.linspace(y_min, y_max, n_points)
mesh_x, mesh_y = np.meshgrid(grid_x, grid_y)
grid_pts = np.column_stack([mesh_x.ravel(), mesh_y.ravel()])
self.mesh_x = mesh_x
self.mesh_y = mesh_y
self.grid_pts = torch.from_numpy(grid_pts).float()
def __call__(self, net):
z = (net(self.grid_pts)
.reshape(self.n_points, self.n_points)
.detach())
return z
OUT_LAYERS = {
'softmax': SoftmaxOut(),
'sparsemax': SparsemaxOut(),
'entmax15': Entmax15Out()
}
def train(net, X, y, out, n_epochs, callback):
optim = torch.optim.SGD(params=net.parameters(), lr=0.05,
momentum=.9, nesterov=True)
for it in range(n_epochs):
optim.zero_grad()
z = net(X).squeeze()
loss = out.loss(z, y)
callback(it, z.detach(), loss.item(), net)
loss.backward()
optim.step()
def main(out_name):
# xor data
hid = 100
n_epochs = 300
torch.manual_seed(42)
def sample_X_y(batch_size):
X = 2 * torch.rand(size=(batch_size, 2)) - 1
y = (torch.prod(X, dim=1) > 0).float()
return X, y
X_train, y_train = sample_X_y(batch_size=300)
net = torch.nn.Sequential(
torch.nn.Linear(2, hid),
torch.nn.ReLU(),
torch.nn.Linear(hid, 1))
its = []
loss_vals = []
acc_vals = []
mesh_eval = MeshEval(xlim=(-1.5, 1.5),
ylim=(-1.5, 1.5),
n_points=50)
out = OUT_LAYERS[out_name]
def callback(it, y_pred, loss, net):
accuracy = torch.mean(((y_pred > 0) == (y_train > 0)).double())
its.append(it)
loss_vals.append(loss)
acc_vals.append(accuracy)
print("Iter {:3d} Loss {:.3f} Acc {:.3f}".format(
it,
loss,
accuracy
))
mesh_z = mesh_eval(net)
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(14, 4), tight_layout=True)
# plot decision function (z)
max_val = np.abs(mesh_z).max()
divnorm = colors.DivergingNorm(vmin=-max_val, vcenter=0, vmax=max_val)
contour = ax1.contourf(mesh_eval.mesh_x,
mesh_eval.mesh_y,
mesh_z,
levels=np.linspace(-max_val, max_val, 30),
norm=divnorm,
cmap=plt.cm.PuOr)
ax1.axvline(0, ls=":", color="k")
ax1.axhline(0, ls=":", color="k")
ax1.set_title('z')
plt.colorbar(contour, ax=ax1)
# plot positive class probability sigma(z)
ax2.set_title('$\\sigma(z)$')
divnorm = colors.DivergingNorm(vmin=0, vcenter=0.5, vmax=1)
contour = ax2.contourf(mesh_eval.mesh_x,
mesh_eval.mesh_y,
out.yhat(mesh_z),
levels=np.linspace(0, 1, 30),
norm=divnorm,
cmap=plt.cm.PuOr)
ax2.axvline(0, ls=":", color="k")
ax2.axhline(0, ls=":", color="k")
plt.colorbar(contour, ax=ax2)
# plot training loss value
ax3.plot(its, loss_vals)
ax3.set_xlim(-1, n_epochs + 1)
ax3.set_ylim(0, loss_vals[0])
ax3.set_xlabel("iteration")
ax3.set_title("loss")
plt.suptitle(f"{out_name} Iter {it:03d} Accuracy {accuracy * 100:0.02f}")
plt.savefig(f"{out_name}_{it:03d}.png")
plt.close(fig)
train(net, X_train, y_train, out, n_epochs, callback)
if __name__ == '__main__':
main('softmax')
main('sparsemax')
main('entmax15')
@vene
Copy link
Author

vene commented Nov 1, 2020

softmax_out_opt
sparsemax_out_opt
entmax15_out_opt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment