-
-
Save christopherhesse/3fa507c7b1d50dceede20b60653d307f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
from jax import lax | |
import jax.numpy as jp | |
import numpy as np | |
obs_shape = (64, 64, 3) | |
def maxpool(x, kersize, stride, padding): | |
strides = (1,) * (x.ndim - 3) + (stride, stride, 1) | |
dims = (1,) * (x.ndim - 3) + (kersize, kersize, 1) | |
return lax.reduce_window(x, -jp.inf, lax.max, dims, strides, padding) | |
class Model: | |
def initialize(self): | |
out = {} | |
for (k, v) in self.__dict__.items(): | |
if isinstance(v, Model): | |
out[k] = v.initialize() | |
elif (isinstance(v, dict) and | |
any(isinstance(elem, Model) for elem in v.values())): | |
assert all(isinstance(elem, Model) for elem in v.values()), \ | |
f"I don't know how to initialize {type(self)}.{k}, "\ | |
f"a dict with types {list(map(type, v.values()))}." | |
out[k] = {ek: ev.initialize() for ek, ev in v.items()} | |
elif (isinstance(v, (list, tuple)) and | |
any(isinstance(elem, Model) for elem in v)): | |
assert all(isinstance(elem, Model) for elem in v), \ | |
f"I don't know how to initialize {type(self)}.{k}, "\ | |
f"a sequence with types {list(map(type, v))}." | |
out[k] = [elem.initialize() for elem in v] | |
return out | |
class Conv2d(Model): | |
def __init__(self, inchan, outchan, kersize, padding, stride=1): | |
self.inchan = inchan | |
self.outchan = outchan | |
self.kersize = kersize | |
self.stride = stride | |
self.padding = padding | |
def initialize(self, scale=1.0): | |
W = np.random.randn(self.kersize, self.kersize, | |
self.inchan, self.outchan).astype('f') | |
W *= scale / jp.sqrt(jp.square(W).sum(axis=(0, 1, 2), keepdims=True)) | |
b = jp.zeros(self.outchan) | |
return dict(W=W, b=b) | |
def __call__(self, params, x): | |
return lax.conv_general_dilated(x, params['W'], | |
window_strides=(self.stride, self.stride), | |
padding=self.padding, | |
dimension_numbers=('NHWC', 'HWIO', 'NHWC'), | |
) + params['b'] | |
class Block(Model): | |
def __init__(self, inchan): | |
self.inchan = inchan | |
self.conv0 = Conv2d(inchan=inchan, outchan=inchan, | |
kersize=3, padding='SAME') | |
self.conv1 = Conv2d(inchan=inchan, outchan=inchan, | |
kersize=3, padding='SAME') | |
def __call__(self, params, x): | |
input = x | |
x = self.conv0(params['conv0'], x) | |
x = self.conv1(params['conv1'], x) | |
return x + input | |
def initialize(self): | |
params = super().initialize() | |
params['conv1']['W'] *= 0.1 | |
return params | |
class Stack(Model): | |
def __init__(self, inchan, nblock, outchan): | |
self.first_conv = Conv2d(inchan=inchan, kersize=3, | |
outchan=outchan, padding='SAME') | |
self.blocks = [Block(outchan) for _ in range(nblock)] | |
def __call__(self, params, x): | |
x = self.first_conv(params['first_conv'], x) | |
x = maxpool(x, kersize=3, stride=2, padding='SAME') | |
for (block, blockparam) in zip(self.blocks, params['blocks']): | |
x = block(blockparam, x) | |
return x | |
class CNN(Model): | |
def __init__(self, inshape, nblock=2, stack_channels=(16, 32, 32), | |
outsize=256, scale_ob=1.0): | |
inheight, inwidth, inchan = inshape | |
self.scale_ob = scale_ob | |
self.stacks = [] | |
h, w = inheight, inwidth | |
for outchan in stack_channels: | |
self.stacks.append(Stack(inchan, nblock, outchan)) | |
inchan = outchan | |
w = (w + 1)//2 | |
h = (h + 1)//2 | |
def __call__(self, params, x): | |
x = x / self.scale_ob | |
batch_shape = x.shape[:-3] | |
x = x.reshape((int(np.prod(batch_shape)), *x.shape[-3:])) | |
for i, (stack, stackparam) in enumerate(zip(self.stacks, params['stacks'])): | |
x = stack(stackparam, x) | |
return x | |
def main(): | |
model_size = 8 | |
for batch_size in [2 ** factor for factor in range(10)] + [600, 700, 800]: | |
print("batch_size", batch_size) | |
fe = CNN(inshape=obs_shape, scale_ob=255.0, outsize=256, stack_channels=(16*model_size, 32*model_size, 32*model_size)) | |
params = fe.initialize() | |
def loss(params, ob): | |
return -jp.sum(fe(params, ob)) | |
training_gradient_fun = jax.jit(jax.grad(loss)) | |
inputs = np.tile(np.random.randint(0, 255, size=obs_shape, dtype=np.uint8), (batch_size,1,1,1)) | |
training_gradient_fun(params, inputs) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment