Skip to content

Instantly share code, notes, and snippets.

@christopherhesse
Last active June 10, 2019 22:15
Show Gist options
  • Save christopherhesse/3fa507c7b1d50dceede20b60653d307f to your computer and use it in GitHub Desktop.
Save christopherhesse/3fa507c7b1d50dceede20b60653d307f to your computer and use it in GitHub Desktop.
import jax
from jax import lax
import jax.numpy as jp
import numpy as np
obs_shape = (64, 64, 3)
def maxpool(x, kersize, stride, padding):
strides = (1,) * (x.ndim - 3) + (stride, stride, 1)
dims = (1,) * (x.ndim - 3) + (kersize, kersize, 1)
return lax.reduce_window(x, -jp.inf, lax.max, dims, strides, padding)
class Model:
def initialize(self):
out = {}
for (k, v) in self.__dict__.items():
if isinstance(v, Model):
out[k] = v.initialize()
elif (isinstance(v, dict) and
any(isinstance(elem, Model) for elem in v.values())):
assert all(isinstance(elem, Model) for elem in v.values()), \
f"I don't know how to initialize {type(self)}.{k}, "\
f"a dict with types {list(map(type, v.values()))}."
out[k] = {ek: ev.initialize() for ek, ev in v.items()}
elif (isinstance(v, (list, tuple)) and
any(isinstance(elem, Model) for elem in v)):
assert all(isinstance(elem, Model) for elem in v), \
f"I don't know how to initialize {type(self)}.{k}, "\
f"a sequence with types {list(map(type, v))}."
out[k] = [elem.initialize() for elem in v]
return out
class Conv2d(Model):
def __init__(self, inchan, outchan, kersize, padding, stride=1):
self.inchan = inchan
self.outchan = outchan
self.kersize = kersize
self.stride = stride
self.padding = padding
def initialize(self, scale=1.0):
W = np.random.randn(self.kersize, self.kersize,
self.inchan, self.outchan).astype('f')
W *= scale / jp.sqrt(jp.square(W).sum(axis=(0, 1, 2), keepdims=True))
b = jp.zeros(self.outchan)
return dict(W=W, b=b)
def __call__(self, params, x):
return lax.conv_general_dilated(x, params['W'],
window_strides=(self.stride, self.stride),
padding=self.padding,
dimension_numbers=('NHWC', 'HWIO', 'NHWC'),
) + params['b']
class Block(Model):
def __init__(self, inchan):
self.inchan = inchan
self.conv0 = Conv2d(inchan=inchan, outchan=inchan,
kersize=3, padding='SAME')
self.conv1 = Conv2d(inchan=inchan, outchan=inchan,
kersize=3, padding='SAME')
def __call__(self, params, x):
input = x
x = self.conv0(params['conv0'], x)
x = self.conv1(params['conv1'], x)
return x + input
def initialize(self):
params = super().initialize()
params['conv1']['W'] *= 0.1
return params
class Stack(Model):
def __init__(self, inchan, nblock, outchan):
self.first_conv = Conv2d(inchan=inchan, kersize=3,
outchan=outchan, padding='SAME')
self.blocks = [Block(outchan) for _ in range(nblock)]
def __call__(self, params, x):
x = self.first_conv(params['first_conv'], x)
x = maxpool(x, kersize=3, stride=2, padding='SAME')
for (block, blockparam) in zip(self.blocks, params['blocks']):
x = block(blockparam, x)
return x
class CNN(Model):
def __init__(self, inshape, nblock=2, stack_channels=(16, 32, 32),
outsize=256, scale_ob=1.0):
inheight, inwidth, inchan = inshape
self.scale_ob = scale_ob
self.stacks = []
h, w = inheight, inwidth
for outchan in stack_channels:
self.stacks.append(Stack(inchan, nblock, outchan))
inchan = outchan
w = (w + 1)//2
h = (h + 1)//2
def __call__(self, params, x):
x = x / self.scale_ob
batch_shape = x.shape[:-3]
x = x.reshape((int(np.prod(batch_shape)), *x.shape[-3:]))
for i, (stack, stackparam) in enumerate(zip(self.stacks, params['stacks'])):
x = stack(stackparam, x)
return x
def main():
model_size = 8
for batch_size in [2 ** factor for factor in range(10)] + [600, 700, 800]:
print("batch_size", batch_size)
fe = CNN(inshape=obs_shape, scale_ob=255.0, outsize=256, stack_channels=(16*model_size, 32*model_size, 32*model_size))
params = fe.initialize()
def loss(params, ob):
return -jp.sum(fe(params, ob))
training_gradient_fun = jax.jit(jax.grad(loss))
inputs = np.tile(np.random.randint(0, 255, size=obs_shape, dtype=np.uint8), (batch_size,1,1,1))
training_gradient_fun(params, inputs)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment