Last active
December 18, 2019 20:30
-
-
Save zmjjmz/21965e3de7c3e9966bf6aaa6732c96c3 to your computer and use it in GitHub Desktop.
Dynamic Slices JAX
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Traceback (most recent call last): | |
File "jax_models.py", line 232, in <module> | |
shuffle=True, | |
File "jax_models.py", line 181, in fit | |
voter_indices, target_indices, ratings, batch_size, batched_dataset_size) | |
File "/home/u1/zach/proj/dataplayground3/lib/python3.5/site-packages/jax/api.py", line 150, in f_jitted | |
out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend) | |
File "/home/u1/zach/proj/dataplayground3/lib/python3.5/site-packages/jax/core.py", line 592, in call_bind | |
outs = primitive.impl(f, *args, **params) | |
File "/home/u1/zach/proj/dataplayground3/lib/python3.5/site-packages/jax/interpreters/xla.py", line 400, in _xla_call_impl | |
compiled_fun = _xla_callable(fun, device, backend, *map(abstractify, args)) | |
File "/home/u1/zach/proj/dataplayground3/lib/python3.5/site-packages/jax/linear_util.py", line 209, in memoized_fun | |
ans = call(fun, *args) | |
File "/home/u1/zach/proj/dataplayground3/lib/python3.5/site-packages/jax/interpreters/xla.py", line 412, in _xla_callable | |
jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals) | |
File "/home/u1/zach/proj/dataplayground3/lib/python3.5/site-packages/jax/linear_util.py", line 153, in call_wrapped | |
ans = self.f(*args, **dict(self.params, **kwargs)) | |
File "jax_models.py", line 140, in train_outer | |
voter_ib = lax.dynamic_slice(voter_indices, (batch_start,), (batch_end,)) | |
File "/home/u1/zach/proj/dataplayground3/lib/python3.5/site-packages/jax/lax/lax.py", line 690, in dynamic_slice | |
operand_shape=operand.shape) | |
File "/home/u1/zach/proj/dataplayground3/lib/python3.5/site-packages/jax/core.py", line 153, in bind | |
out_tracer = top_trace.process_primitive(self, tracers, kwargs) | |
File "/home/u1/zach/proj/dataplayground3/lib/python3.5/site-packages/jax/interpreters/partial_eval.py", line 96, in process_primitive | |
out_aval = primitive.abstract_eval(*avals, **params) | |
File "/home/u1/zach/proj/dataplayground3/lib/python3.5/site-packages/jax/lax/lax.py", line 1496, in standard_abstract_eval | |
return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs)) | |
File "/home/u1/zach/proj/dataplayground3/lib/python3.5/site-packages/jax/lax/lax.py", line 2747, in _dynamic_slice_shape_rule | |
if not onp.all(onp.less_equal(slice_sizes, operand.shape)): | |
File "/home/u1/zach/proj/dataplayground3/lib/python3.5/site-packages/jax/core.py", line 340, in __bool__ | |
def __bool__(self): return self.aval._bool(self) | |
File "/home/u1/zach/proj/dataplayground3/lib/python3.5/site-packages/jax/abstract_arrays.py", line 38, in error | |
raise TypeError(concretization_err_msg(fun)) | |
TypeError: Abstract value passed to `bool`, which requires a concrete value. The function to be transformed can't be traced at the required lev | |
el of abstraction. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/home/u1/zach/proj/dataplayground3/lib/python3.5/site-packages/jax/lib/xla_bridge.py:120: UserWarning: No GPU/TPU found, falling back to CPU. | |
warnings.warn('No GPU/TPU found, falling back to CPU.') | |
Initial loss: 1.2672464847564697 | |
Epoch 0 took 7.99 seconds, loss: 0.99 | |
Epoch 1 took 0.01 seconds, loss: 0.77 | |
Epoch 2 took 0.02 seconds, loss: 0.61 | |
Epoch 3 took 0.01 seconds, loss: 0.49 | |
Epoch 4 took 0.01 seconds, loss: 0.42 | |
Epoch 5 took 0.01 seconds, loss: 0.37 | |
Epoch 6 took 0.01 seconds, loss: 0.34 | |
Epoch 7 took 0.01 seconds, loss: 0.32 | |
Epoch 8 took 0.01 seconds, loss: 0.30 | |
Epoch 9 took 0.01 seconds, loss: 0.29 | |
epoch_inner_toc epoch_toc shuffle_toc | |
count 10.000000 10.000000 10.000000 | |
mean 0.797775 0.810883 0.013099 | |
std 2.521647 2.521210 0.000919 | |
min 0.000338 0.012801 0.011854 | |
25% 0.000347 0.013204 0.012600 | |
50% 0.000359 0.013391 0.012902 | |
75% 0.000380 0.014156 0.013219 | |
max 7.974509 7.986372 0.015189 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
from collections import defaultdict | |
from functools import partial | |
from jax import grad, jit, random | |
from jax.ops import index_add | |
from jax.experimental import loops | |
from jax.lax import lax | |
import jax.numpy as np | |
import numpy as onp | |
import pandas as pd | |
def make_index_map(lis): | |
return {x: i for i, x in enumerate(lis)} | |
class JaxSVDModel: | |
def __init__(self, dimensions): | |
self.dimensions = dimensions | |
self.gradient_fun = jit(self.get_gradient_fun()) | |
def get_gradient_fun(self): | |
# voter_mat, target_mat, voter_bias, target_bias | |
return grad(self.loss, argnums=(1, 2, 3, 4)) | |
def forward(self, voter_mat, target_mat, voter_bias, target_bias): | |
predictions = np.sum(voter_mat * target_mat, axis=1) + \ | |
voter_bias + target_bias + self.global_bias | |
return predictions | |
def loss(self, ratings_vec, *forward_args): | |
# l2 loss | |
loss = np.mean(np.sqrt(np.power( | |
ratings_vec - self.forward(*forward_args), | |
2 | |
))) | |
return loss | |
def fit(self, dataset, voter_col='voter_uid', target_col='target_uid', rating_col='interested', | |
epochs=1, batch_size=1, lr=0.01, rand_seed=4, shuffle=False): | |
# we're going to actually implement the fitting stuff here | |
# initialize | |
# move these to an init function? | |
self.canonical_voter_order = list(dataset[voter_col].unique()) | |
self.canonical_target_order = list(dataset[target_col].unique()) | |
self.n_voters = len(self.canonical_voter_order) | |
self.n_targets = len(self.canonical_target_order) | |
self.voter_uid_index_map = make_index_map(self.canonical_voter_order) | |
self.target_uid_index_map = make_index_map(self.canonical_target_order) | |
# we'll just do random uniform for now | |
voter_key, target_key = random.split(random.PRNGKey(rand_seed)) | |
self.voter_embedding_mat = random.uniform( | |
voter_key, shape=(self.n_voters, self.dimensions)) | |
self.target_embedding_mat = random.uniform( | |
target_key, shape=(self.n_targets, self.dimensions)) | |
self.voter_bias_vector = np.zeros(self.n_voters) | |
self.target_bias_vector = np.zeros(self.n_targets) | |
# constant | |
self.global_bias = np.mean(dataset[rating_col]) | |
# prepare the indices | |
voter_indices = [self.voter_uid_index_map[voter] | |
for voter in dataset[voter_col]] | |
target_indices = [self.target_uid_index_map[target] | |
for target in dataset[target_col]] | |
dataset_ind = pd.DataFrame.from_dict({ | |
voter_col: voter_indices, | |
target_col: target_indices, | |
rating_col: dataset[rating_col], }) | |
@jit | |
def train_inner(vem, tem, vbv, tbv, | |
voter_indices, target_indices, ratings): | |
# these could be turned into a jit function | |
# batch_size x dimensions | |
voter_embedding = vem[voter_indices] | |
target_embedding = tem[target_indices] | |
# batch_size x 1 | |
voter_bias = vbv[voter_indices] | |
target_bias = tbv[target_indices] | |
ve_grad, te_grad, vb_grad, tb_grad = self.gradient_fun(ratings, | |
voter_embedding, target_embedding, voter_bias, target_bias) | |
# update | |
new_vem = index_add(vem, voter_indices, -lr * ve_grad) | |
new_tem = index_add(tem, target_indices, -lr * te_grad) | |
new_vbv = index_add(vbv, voter_indices, -lr * vb_grad) | |
new_tbv = index_add(tbv, target_indices, -lr * tb_grad) | |
return new_vem, new_tem, new_vbv, new_tbv | |
@partial(jit,static_argnums=(7,8)) # recompile on batch_size change | |
def train_outer(vem, tem, vbv, tbv, | |
voter_indices, target_indices, ratings, # ea batched_dataset_size x batch_size | |
batch_size, batched_dataset_size): | |
with loops.Scope() as s: | |
s.vem = vem | |
s.tem = tem | |
s.vbv = vbv | |
s.tbv = tbv | |
for i in s.range(0, batched_dataset_size): | |
batch_start = i * batch_size | |
voter_ib = lax.dynamic_slice(voter_indices, (batch_start,), (batch_size,)) | |
target_ib = lax.dynamic_slice(target_indices, (batch_start,), (batch_size,)) | |
ratings_b = lax.dynamic_slice(ratings, (batch_start,), (batch_size,)) | |
s.vem, s.tem, s.vbv, s.tbv = train_inner( | |
s.vem, s.tem, s.vbv, s.tbv, | |
voter_indices=voter_ib, | |
target_indices=target_ib, | |
ratings=ratings_b) | |
return s.vem, s.tem, s.vbv, s.tbv | |
@jit | |
def shuffle(all_inputs, key): # all_inputs: dataset_size x 3 | |
shuffled_input = random.shuffle(key, all_inputs, axis=0) # shuffle along dataset_size axis | |
# now batch | |
# TODO(ZJ): if batch_size doesn't divide len(input_col), this approach will fail | |
# TODO(ZJ): we need to make a padding and use that | |
return ( | |
shuffled_input[:,0], # voter_indices | |
shuffled_input[:,1], # target_indices | |
shuffled_input[:,2], # ratings | |
) | |
@jit | |
def get_loss_full(vem, tem, vbv, tbv, | |
voter_indices, target_indices, ratings): | |
# convenience | |
return self.loss(ratings, vem[voter_indices], | |
tem[target_indices], vbv[voter_indices], | |
tbv[target_indices]) | |
# batch up training | |
_, shuffle_key = random.split(target_key) | |
batched_dataset_size = int(np.ceil(len(dataset) / batch_size)) | |
# it's important that we reuse the same key | |
packed_dataset = dataset_ind[[voter_col, target_col, rating_col]].values | |
voter_indices, target_indices, ratings = shuffle(packed_dataset, shuffle_key) | |
def get_loss_full_ez(): | |
# we couldn't just have get_loss_full do this because it can't jit up the | |
# self.embedding_mat etc | |
return get_loss_full(self.voter_embedding_mat, | |
self.target_embedding_mat, self.voter_bias_vector, | |
self.target_bias_vector, voter_indices, target_indices, ratings) | |
print("Initial loss: {0}".format(get_loss_full_ez())) | |
timing_records = [] | |
for epoch in range(epochs): | |
epoch_tic = time.time() | |
record = defaultdict(float) | |
if shuffle: | |
shuffle_tic = time.time() | |
# get a new shuffle key | |
_, shuffle_key = random.split(shuffle_key) | |
# it's important that we reuse the same key | |
# spaghet | |
voter_indices, target_indices, ratings = shuffle(packed_dataset, shuffle_key) | |
record['shuffle_toc'] = time.time() - shuffle_tic | |
epoch_inner_tic = time.time() | |
self.voter_embedding_mat, self.target_embedding_mat, self.voter_bias_vector, self.target_bias_vector = train_outer( | |
self.voter_embedding_mat, self.target_embedding_mat, self.voter_bias_vector, self.target_bias_vector, | |
voter_indices, target_indices, ratings, batch_size, batched_dataset_size) | |
record['epoch_inner_toc'] = time.time() - epoch_inner_tic | |
record['epoch_toc'] = time.time() - epoch_tic | |
print("Epoch {0} took {1:0.2f} seconds, loss: {2:0.2f}".format( | |
epoch, time.time() - epoch_tic, get_loss_full_ez())) | |
timing_records.append(record) | |
self._is_fitted = True | |
print(pd.DataFrame.from_records(timing_records).describe()) | |
def make_test_data(N=500, seed=7, n_voters=100, n_targets=200): | |
onp.random.seed(seed) | |
voter_ids = [str(x) for x in range(n_voters)] | |
target_ids = [str(x) for x in range(n_targets)] | |
return pd.DataFrame({ | |
"voter_uid": onp.random.choice(voter_ids, N, replace=True), | |
"target_uid": onp.random.choice(target_ids, N, replace=True), | |
"interested": (onp.random.random(N) < 0.2).astype(int), | |
}) | |
if __name__ == "__main__": | |
data = make_test_data(N=1200000, n_voters=96000, n_targets=126000000) | |
ll_model = JaxSVDModel(dimensions=5) | |
ll_model.fit( | |
dataset=data, | |
batch_size=1, | |
epochs=10, | |
shuffle=True, | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@jit | |
def train_inner(vem, tem, vbv, tbv, | |
voter_indices, target_indices, ratings): | |
# these could be turned into a jit function | |
# batch_size x dimensions | |
voter_embedding = vem[voter_indices] | |
target_embedding = tem[target_indices] | |
# batch_size x 1 | |
voter_bias = vbv[voter_indices] | |
target_bias = tbv[target_indices] | |
ve_grad, te_grad, vb_grad, tb_grad = self.gradient_fun(ratings, | |
voter_embedding, target_embedding, voter_bias, target_bias) | |
# update | |
new_vem = index_add(vem, voter_indices, -lr * ve_grad) | |
new_tem = index_add(tem, target_indices, -lr * te_grad) | |
new_vbv = index_add(vbv, voter_indices, -lr * vb_grad) | |
new_tbv = index_add(tbv, target_indices, -lr * tb_grad) | |
return new_vem, new_tem, new_vbv, new_tbv | |
@partial(jit,static_argnums=(7,8)) # recompile on batch_size change | |
def train_outer(vem, tem, vbv, tbv, | |
voter_indices, target_indices, ratings, | |
batch_size, batched_dataset_size): | |
with loops.Scope() as s: | |
s.vem = vem | |
s.tem = tem | |
s.vbv = vbv | |
s.tbv = tbv | |
for i in s.range(0, batched_dataset_size): | |
batch_start = i * batch_size | |
batch_end = (i + 1) * batch_size | |
voter_ib = lax.dynamic_slice(voter_indices, (batch_start,), (batch_end,)) | |
target_ib = lax.dynamic_slice(target_indices, (batch_start,), (batch_end,)) | |
ratings_b = lax.dynamic_slice(ratings, (batch_start,), (batch_end,)) | |
s.vem, s.tem, s.vbv, s.tbv = train_inner( | |
s.vem, s.tem, s.vbv, s.tbv, | |
voter_indices=voter_ib, | |
target_indices=target_ib, | |
ratings=ratings_b) | |
return s.vem, s.tem, s.vbv, s.tbv | |
# batch up training | |
print("Initial loss: {0}".format(self.get_loss_full( | |
dataset[voter_col], dataset[target_col], dataset[rating_col]))) | |
timing_records = [] | |
for epoch in range(epochs): | |
epoch_tic = time.time() | |
record = defaultdict(float) | |
if shuffle: | |
shuffle_tic = time.time() | |
dataset_ind = dataset_ind.sample( | |
n=len(dataset_ind), random_state=rand_seed) | |
record['shuffle_toc'] = time.time() - shuffle_tic | |
values_tic = time.time() | |
# np.stack(list(batch_list(dataset_ind[voter_col].values, size=batch_size))) | |
voter_indices = dataset_ind[voter_col].values | |
# np.stack(list(batch_list(dataset_ind[target_col].values, size=batch_size))) | |
target_indices = dataset_ind[target_col].values | |
# np.stack(list(batch_list(dataset_ind[rating_col].values, size=batch_size))) | |
ratings = dataset_ind[rating_col].values | |
record['values_toc'] = time.time() - values_tic | |
epoch_inner_tic = time.time() | |
self.voter_embedding_mat, self.target_embedding_mat, self.voter_bias_vector, self.target_bias_vector = train_outer( | |
self.voter_embedding_mat, self.target_embedding_mat, self.voter_bias_vector, self.target_bias_vector, | |
# voter_index_chunks, target_index_chunks, rating_chunks) | |
voter_indices, target_indices, ratings, batch_size) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment