Skip to content

Instantly share code, notes, and snippets.

@zmjjmz
Last active December 18, 2019 20:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zmjjmz/21965e3de7c3e9966bf6aaa6732c96c3 to your computer and use it in GitHub Desktop.
Save zmjjmz/21965e3de7c3e9966bf6aaa6732c96c3 to your computer and use it in GitHub Desktop.
Dynamic Slices JAX
Traceback (most recent call last):
File "jax_models.py", line 232, in <module>
shuffle=True,
File "jax_models.py", line 181, in fit
voter_indices, target_indices, ratings, batch_size, batched_dataset_size)
File "/home/u1/zach/proj/dataplayground3/lib/python3.5/site-packages/jax/api.py", line 150, in f_jitted
out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend)
File "/home/u1/zach/proj/dataplayground3/lib/python3.5/site-packages/jax/core.py", line 592, in call_bind
outs = primitive.impl(f, *args, **params)
File "/home/u1/zach/proj/dataplayground3/lib/python3.5/site-packages/jax/interpreters/xla.py", line 400, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, *map(abstractify, args))
File "/home/u1/zach/proj/dataplayground3/lib/python3.5/site-packages/jax/linear_util.py", line 209, in memoized_fun
ans = call(fun, *args)
File "/home/u1/zach/proj/dataplayground3/lib/python3.5/site-packages/jax/interpreters/xla.py", line 412, in _xla_callable
jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
File "/home/u1/zach/proj/dataplayground3/lib/python3.5/site-packages/jax/linear_util.py", line 153, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "jax_models.py", line 140, in train_outer
voter_ib = lax.dynamic_slice(voter_indices, (batch_start,), (batch_end,))
File "/home/u1/zach/proj/dataplayground3/lib/python3.5/site-packages/jax/lax/lax.py", line 690, in dynamic_slice
operand_shape=operand.shape)
File "/home/u1/zach/proj/dataplayground3/lib/python3.5/site-packages/jax/core.py", line 153, in bind
out_tracer = top_trace.process_primitive(self, tracers, kwargs)
File "/home/u1/zach/proj/dataplayground3/lib/python3.5/site-packages/jax/interpreters/partial_eval.py", line 96, in process_primitive
out_aval = primitive.abstract_eval(*avals, **params)
File "/home/u1/zach/proj/dataplayground3/lib/python3.5/site-packages/jax/lax/lax.py", line 1496, in standard_abstract_eval
return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
File "/home/u1/zach/proj/dataplayground3/lib/python3.5/site-packages/jax/lax/lax.py", line 2747, in _dynamic_slice_shape_rule
if not onp.all(onp.less_equal(slice_sizes, operand.shape)):
File "/home/u1/zach/proj/dataplayground3/lib/python3.5/site-packages/jax/core.py", line 340, in __bool__
def __bool__(self): return self.aval._bool(self)
File "/home/u1/zach/proj/dataplayground3/lib/python3.5/site-packages/jax/abstract_arrays.py", line 38, in error
raise TypeError(concretization_err_msg(fun))
TypeError: Abstract value passed to `bool`, which requires a concrete value. The function to be transformed can't be traced at the required lev
el of abstraction. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead.
/home/u1/zach/proj/dataplayground3/lib/python3.5/site-packages/jax/lib/xla_bridge.py:120: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
Initial loss: 1.2672464847564697
Epoch 0 took 7.99 seconds, loss: 0.99
Epoch 1 took 0.01 seconds, loss: 0.77
Epoch 2 took 0.02 seconds, loss: 0.61
Epoch 3 took 0.01 seconds, loss: 0.49
Epoch 4 took 0.01 seconds, loss: 0.42
Epoch 5 took 0.01 seconds, loss: 0.37
Epoch 6 took 0.01 seconds, loss: 0.34
Epoch 7 took 0.01 seconds, loss: 0.32
Epoch 8 took 0.01 seconds, loss: 0.30
Epoch 9 took 0.01 seconds, loss: 0.29
epoch_inner_toc epoch_toc shuffle_toc
count 10.000000 10.000000 10.000000
mean 0.797775 0.810883 0.013099
std 2.521647 2.521210 0.000919
min 0.000338 0.012801 0.011854
25% 0.000347 0.013204 0.012600
50% 0.000359 0.013391 0.012902
75% 0.000380 0.014156 0.013219
max 7.974509 7.986372 0.015189
import time
from collections import defaultdict
from functools import partial
from jax import grad, jit, random
from jax.ops import index_add
from jax.experimental import loops
from jax.lax import lax
import jax.numpy as np
import numpy as onp
import pandas as pd
def make_index_map(lis):
return {x: i for i, x in enumerate(lis)}
class JaxSVDModel:
def __init__(self, dimensions):
self.dimensions = dimensions
self.gradient_fun = jit(self.get_gradient_fun())
def get_gradient_fun(self):
# voter_mat, target_mat, voter_bias, target_bias
return grad(self.loss, argnums=(1, 2, 3, 4))
def forward(self, voter_mat, target_mat, voter_bias, target_bias):
predictions = np.sum(voter_mat * target_mat, axis=1) + \
voter_bias + target_bias + self.global_bias
return predictions
def loss(self, ratings_vec, *forward_args):
# l2 loss
loss = np.mean(np.sqrt(np.power(
ratings_vec - self.forward(*forward_args),
2
)))
return loss
def fit(self, dataset, voter_col='voter_uid', target_col='target_uid', rating_col='interested',
epochs=1, batch_size=1, lr=0.01, rand_seed=4, shuffle=False):
# we're going to actually implement the fitting stuff here
# initialize
# move these to an init function?
self.canonical_voter_order = list(dataset[voter_col].unique())
self.canonical_target_order = list(dataset[target_col].unique())
self.n_voters = len(self.canonical_voter_order)
self.n_targets = len(self.canonical_target_order)
self.voter_uid_index_map = make_index_map(self.canonical_voter_order)
self.target_uid_index_map = make_index_map(self.canonical_target_order)
# we'll just do random uniform for now
voter_key, target_key = random.split(random.PRNGKey(rand_seed))
self.voter_embedding_mat = random.uniform(
voter_key, shape=(self.n_voters, self.dimensions))
self.target_embedding_mat = random.uniform(
target_key, shape=(self.n_targets, self.dimensions))
self.voter_bias_vector = np.zeros(self.n_voters)
self.target_bias_vector = np.zeros(self.n_targets)
# constant
self.global_bias = np.mean(dataset[rating_col])
# prepare the indices
voter_indices = [self.voter_uid_index_map[voter]
for voter in dataset[voter_col]]
target_indices = [self.target_uid_index_map[target]
for target in dataset[target_col]]
dataset_ind = pd.DataFrame.from_dict({
voter_col: voter_indices,
target_col: target_indices,
rating_col: dataset[rating_col], })
@jit
def train_inner(vem, tem, vbv, tbv,
voter_indices, target_indices, ratings):
# these could be turned into a jit function
# batch_size x dimensions
voter_embedding = vem[voter_indices]
target_embedding = tem[target_indices]
# batch_size x 1
voter_bias = vbv[voter_indices]
target_bias = tbv[target_indices]
ve_grad, te_grad, vb_grad, tb_grad = self.gradient_fun(ratings,
voter_embedding, target_embedding, voter_bias, target_bias)
# update
new_vem = index_add(vem, voter_indices, -lr * ve_grad)
new_tem = index_add(tem, target_indices, -lr * te_grad)
new_vbv = index_add(vbv, voter_indices, -lr * vb_grad)
new_tbv = index_add(tbv, target_indices, -lr * tb_grad)
return new_vem, new_tem, new_vbv, new_tbv
@partial(jit,static_argnums=(7,8)) # recompile on batch_size change
def train_outer(vem, tem, vbv, tbv,
voter_indices, target_indices, ratings, # ea batched_dataset_size x batch_size
batch_size, batched_dataset_size):
with loops.Scope() as s:
s.vem = vem
s.tem = tem
s.vbv = vbv
s.tbv = tbv
for i in s.range(0, batched_dataset_size):
batch_start = i * batch_size
voter_ib = lax.dynamic_slice(voter_indices, (batch_start,), (batch_size,))
target_ib = lax.dynamic_slice(target_indices, (batch_start,), (batch_size,))
ratings_b = lax.dynamic_slice(ratings, (batch_start,), (batch_size,))
s.vem, s.tem, s.vbv, s.tbv = train_inner(
s.vem, s.tem, s.vbv, s.tbv,
voter_indices=voter_ib,
target_indices=target_ib,
ratings=ratings_b)
return s.vem, s.tem, s.vbv, s.tbv
@jit
def shuffle(all_inputs, key): # all_inputs: dataset_size x 3
shuffled_input = random.shuffle(key, all_inputs, axis=0) # shuffle along dataset_size axis
# now batch
# TODO(ZJ): if batch_size doesn't divide len(input_col), this approach will fail
# TODO(ZJ): we need to make a padding and use that
return (
shuffled_input[:,0], # voter_indices
shuffled_input[:,1], # target_indices
shuffled_input[:,2], # ratings
)
@jit
def get_loss_full(vem, tem, vbv, tbv,
voter_indices, target_indices, ratings):
# convenience
return self.loss(ratings, vem[voter_indices],
tem[target_indices], vbv[voter_indices],
tbv[target_indices])
# batch up training
_, shuffle_key = random.split(target_key)
batched_dataset_size = int(np.ceil(len(dataset) / batch_size))
# it's important that we reuse the same key
packed_dataset = dataset_ind[[voter_col, target_col, rating_col]].values
voter_indices, target_indices, ratings = shuffle(packed_dataset, shuffle_key)
def get_loss_full_ez():
# we couldn't just have get_loss_full do this because it can't jit up the
# self.embedding_mat etc
return get_loss_full(self.voter_embedding_mat,
self.target_embedding_mat, self.voter_bias_vector,
self.target_bias_vector, voter_indices, target_indices, ratings)
print("Initial loss: {0}".format(get_loss_full_ez()))
timing_records = []
for epoch in range(epochs):
epoch_tic = time.time()
record = defaultdict(float)
if shuffle:
shuffle_tic = time.time()
# get a new shuffle key
_, shuffle_key = random.split(shuffle_key)
# it's important that we reuse the same key
# spaghet
voter_indices, target_indices, ratings = shuffle(packed_dataset, shuffle_key)
record['shuffle_toc'] = time.time() - shuffle_tic
epoch_inner_tic = time.time()
self.voter_embedding_mat, self.target_embedding_mat, self.voter_bias_vector, self.target_bias_vector = train_outer(
self.voter_embedding_mat, self.target_embedding_mat, self.voter_bias_vector, self.target_bias_vector,
voter_indices, target_indices, ratings, batch_size, batched_dataset_size)
record['epoch_inner_toc'] = time.time() - epoch_inner_tic
record['epoch_toc'] = time.time() - epoch_tic
print("Epoch {0} took {1:0.2f} seconds, loss: {2:0.2f}".format(
epoch, time.time() - epoch_tic, get_loss_full_ez()))
timing_records.append(record)
self._is_fitted = True
print(pd.DataFrame.from_records(timing_records).describe())
def make_test_data(N=500, seed=7, n_voters=100, n_targets=200):
onp.random.seed(seed)
voter_ids = [str(x) for x in range(n_voters)]
target_ids = [str(x) for x in range(n_targets)]
return pd.DataFrame({
"voter_uid": onp.random.choice(voter_ids, N, replace=True),
"target_uid": onp.random.choice(target_ids, N, replace=True),
"interested": (onp.random.random(N) < 0.2).astype(int),
})
if __name__ == "__main__":
data = make_test_data(N=1200000, n_voters=96000, n_targets=126000000)
ll_model = JaxSVDModel(dimensions=5)
ll_model.fit(
dataset=data,
batch_size=1,
epochs=10,
shuffle=True,
)
@jit
def train_inner(vem, tem, vbv, tbv,
voter_indices, target_indices, ratings):
# these could be turned into a jit function
# batch_size x dimensions
voter_embedding = vem[voter_indices]
target_embedding = tem[target_indices]
# batch_size x 1
voter_bias = vbv[voter_indices]
target_bias = tbv[target_indices]
ve_grad, te_grad, vb_grad, tb_grad = self.gradient_fun(ratings,
voter_embedding, target_embedding, voter_bias, target_bias)
# update
new_vem = index_add(vem, voter_indices, -lr * ve_grad)
new_tem = index_add(tem, target_indices, -lr * te_grad)
new_vbv = index_add(vbv, voter_indices, -lr * vb_grad)
new_tbv = index_add(tbv, target_indices, -lr * tb_grad)
return new_vem, new_tem, new_vbv, new_tbv
@partial(jit,static_argnums=(7,8)) # recompile on batch_size change
def train_outer(vem, tem, vbv, tbv,
voter_indices, target_indices, ratings,
batch_size, batched_dataset_size):
with loops.Scope() as s:
s.vem = vem
s.tem = tem
s.vbv = vbv
s.tbv = tbv
for i in s.range(0, batched_dataset_size):
batch_start = i * batch_size
batch_end = (i + 1) * batch_size
voter_ib = lax.dynamic_slice(voter_indices, (batch_start,), (batch_end,))
target_ib = lax.dynamic_slice(target_indices, (batch_start,), (batch_end,))
ratings_b = lax.dynamic_slice(ratings, (batch_start,), (batch_end,))
s.vem, s.tem, s.vbv, s.tbv = train_inner(
s.vem, s.tem, s.vbv, s.tbv,
voter_indices=voter_ib,
target_indices=target_ib,
ratings=ratings_b)
return s.vem, s.tem, s.vbv, s.tbv
# batch up training
print("Initial loss: {0}".format(self.get_loss_full(
dataset[voter_col], dataset[target_col], dataset[rating_col])))
timing_records = []
for epoch in range(epochs):
epoch_tic = time.time()
record = defaultdict(float)
if shuffle:
shuffle_tic = time.time()
dataset_ind = dataset_ind.sample(
n=len(dataset_ind), random_state=rand_seed)
record['shuffle_toc'] = time.time() - shuffle_tic
values_tic = time.time()
# np.stack(list(batch_list(dataset_ind[voter_col].values, size=batch_size)))
voter_indices = dataset_ind[voter_col].values
# np.stack(list(batch_list(dataset_ind[target_col].values, size=batch_size)))
target_indices = dataset_ind[target_col].values
# np.stack(list(batch_list(dataset_ind[rating_col].values, size=batch_size)))
ratings = dataset_ind[rating_col].values
record['values_toc'] = time.time() - values_tic
epoch_inner_tic = time.time()
self.voter_embedding_mat, self.target_embedding_mat, self.voter_bias_vector, self.target_bias_vector = train_outer(
self.voter_embedding_mat, self.target_embedding_mat, self.voter_bias_vector, self.target_bias_vector,
# voter_index_chunks, target_index_chunks, rating_chunks)
voter_indices, target_indices, ratings, batch_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment