Skip to content

Instantly share code, notes, and snippets.

@Dapid
Last active June 20, 2019 16:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Dapid/e980a26e063838d472f46db1b7df59bd to your computer and use it in GitHub Desktop.
Save Dapid/e980a26e063838d472f46db1b7df59bd to your computer and use it in GitHub Desktop.
import time
import numpy
import jax.numpy as np
from jax import random, grad, jit
from jax import vmap
def _compute_single_loss(h, J, sigma, N, lambda_h, lambda_j):
loss = lambda_h * np.sum(h * h) + lambda_j * np.sum(J * J)
indexes = np.arange(0, N)
for r in range(N):
this = h[r, :] + np.sum(J[indexes, r, sigma[indexes], :], axis=(0, 1))
denominator = np.sum(np.exp(this))
loss += np.sum(1 / denominator)
return loss
class ExponentialModel:
def __init__(self, N, lambda_h=0.001, lambda_j=0.01, q_max=21):
key = random.PRNGKey(42)
key, subkey_h = random.split(key)
key, subkey_j = random.split(key)
self.h = random.normal(subkey_h, shape=(N, q_max))
self.J = random.normal(subkey_j, shape=(N, N, q_max, q_max))
self.lambda_h = lambda_h
self.lambda_j = lambda_j
self.N = N
self.q_max = q_max
self.eps = 1e-7
def single_loss_(sigma):
return _compute_single_loss(self.h, self.J, sigma, self.N, self.lambda_h, self.lambda_j)
self.batched_loss = jit(vmap(single_loss_))
self._single_grad = grad(_compute_single_loss, argnums=(0, 1), holomorphic=True)
def single_grad_(sigma):
return self._single_grad(self.h, self.J, sigma, self.N, self.lambda_h, self.lambda_j)
def batched_grad_(sigma):
x, y = vmap(single_grad_)(sigma)
return np.sum(x, axis=0), np.sum(y, axis=0)
self._batched_grad = vmap(jit(single_grad_))
self.batched_grad_fast2 = jit(self._batched_grad)
self.batched_grad_fast = jit(batched_grad_)
def single_loss(self, sigma):
return _compute_single_loss(self.h, self.J, sigma, self.N, self.lambda_h, self.lambda_j)
if __name__ == '__main__':
t0 = time.time()
plm = ExponentialModel(50)
print('Instanciation:', time.time() - t0)
for batch_size in range(6, 20, 2):
print('--->', batch_size)
data = numpy.random.randint(0, 21, size=(batch_size, 50))
for _ in range(3):
t0 = time.time()
grad = plm.batched_grad_fast(data)
print(' Backward:', time.time() - t0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment