Skip to content

Instantly share code, notes, and snippets.

@Dapid
Created June 27, 2019 14:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Dapid/80751f98ec5c51c926bd9c4171ced282 to your computer and use it in GitHub Desktop.
Save Dapid/80751f98ec5c51c926bd9c4171ced282 to your computer and use it in GitHub Desktop.
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import time
import numpy
import jax.numpy as np
from jax import random, grad, jit
from jax import lax, vmap
def _compute_single_loss(h, J, sigma, N, lambda_h, lambda_j):
loss = lambda_h * np.sum(h * h) + lambda_j * np.sum(J * J)
indexes = np.arange(0, N)
for r in range(N):
this = h[r, :] + np.sum(J[:, r, sigma, :], axis=(0, 1))
denominator = np.sum(np.exp(this))
loss += np.sum(1 / denominator)
return loss
class ExponentialModel:
def __init__(self, N, lambda_h=0.001, lambda_j=0.01, q_max=21):
key = random.PRNGKey(42)
key, subkey_h = random.split(key)
key, subkey_j = random.split(key)
self.h = random.normal(subkey_h, shape=(N, q_max))
self.J = random.normal(subkey_j, shape=(N, N, q_max, q_max))
self.lambda_h = lambda_h
self.lambda_j = lambda_j
self.N = N
self.q_max = q_max
self.eps = 1e-7
def single_loss_(sigma):
return _compute_single_loss(self.h, self.J, sigma, self.N, self.lambda_h, self.lambda_j)
self.batched_loss = jit(vmap(single_loss_))
self._single_grad = grad(_compute_single_loss, argnums=(0, 1), holomorphic=True)
def single_grad_(sigma):
return self._single_grad(self.h, self.J, sigma, self.N, self.lambda_h, self.lambda_j)
def batched_grad_(sigma):
x, y = vmap(single_grad_)(sigma)
return np.sum(x, axis=0), np.sum(y, axis=0)
self._batched_grad = vmap(jit(single_grad_))
#self.single_loss_fast = jit(self.single_loss)
#self.single_grad_fast = jit(self.single_grad)
self.batched_grad_fast2 = jit(self._batched_grad)
self.batched_grad_fast = jit(batched_grad_)
def single_loss(self, sigma):
return _compute_single_loss(self.h, self.J, sigma, self.N, self.lambda_h, self.lambda_j)
#def single_grad(self, sigma):
# return self._single_grad(self.h, self.J, sigma, self.N, self.lambda_h, self.lambda_j)
# def batched_grad(self, sigma):
# return self._batched_grad(self.h, self.J, sigma, self.N, self.lambda_h, self.lambda_j)
if __name__ == '__main__':
t0 = time.time()
plm = ExponentialModel(50)
print('Instanciation:', time.time() - t0)
all_times = dict()
for batch_size in range(6, 20, 2):
print('--->', batch_size)
times_run = []
data = numpy.random.randint(0, 21, size=(batch_size, 50))
for _ in range(30):
t0 = time.time()
grad = plm.batched_grad_fast2(data)
dt = time.time() - t0
print(' Backward:', dt)
times_run.append(dt)
all_times[batch_size] = times_run
import json
f = open('times.json', 'w')
json.dump(all_times, f)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment