Skip to content

Instantly share code, notes, and snippets.

@vene
Created October 27, 2020 16:12
Show Gist options
  • Save vene/dd29002e214302dfeb62c867a883714a to your computer and use it in GitHub Desktop.
Save vene/dd29002e214302dfeb62c867a883714a to your computer and use it in GitHub Desktop.
# author: vn
import numpy as np
from scipy.optimize import root_scalar
import torch
import matplotlib.pyplot as plt
def entropy(y, a, b):
# move to [0, 1]
u = (y - a) / (b - a)
# apply Fermi-Dirac
h = -u * torch.log(u) - (1 - u) * torch.log(1 - u)
return h
def f(y, x, a, b):
return -y * x - entropy(y, a, b)
def fp(y, x, a, b):
y = torch.tensor(y, requires_grad=True, dtype=torch.double)
fval = f(y, x, a, b)
grad, = torch.autograd.grad(fval, y)
return grad.item()
def main():
a, b = 0.5, 3
yy = np.linspace(a, b, 2000)
yy = torch.from_numpy(yy)
hh = entropy(yy, a, b)
plt.plot(yy, hh)
plt.show()
xx = np.linspace(a - 2, b + 2)
# compute numerically
yy = []
bracket = (a + 1e-14, b - 1e-14)
for x in xx:
fp_ = lambda y: fp(y, x, a, b)
y_star = root_scalar(fp_, bracket=bracket, method='bisect').root
yy.append(y_star)
plt.plot(xx, yy, label="numerical")
# compute closed form
xx = torch.from_numpy(xx)
w = b - a
yy = w * torch.sigmoid(w * xx) + a
plt.plot(xx, yy, label="exact", ls=":")
plt.show()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment