Created
August 12, 2022 17:11
-
-
Save john-bradshaw/544d40e8b1167095ea16980251be19a0 to your computer and use it in GitHub Desktop.
Showing how flax's nn.compact if used incorrectly can cause memory leaks.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Simple demonstration of memory leak. | |
XLA_PYTHON_CLIENT_PREALLOCATE=false CUDA_VISIBLE_DEVICES=0 python flax_mem_leak.py | |
Breaks on last iter through loop. | |
# Tested on version: | |
Name: flax | |
Version: 0.4.1 | |
Summary: Flax: A neural network library for JAX designed for flexibility | |
Home-page: https://github.com/google/flax | |
Author: Flax team | |
Author-email: flax-dev@google.com | |
License: UNKNOWN | |
Location: /home/jbrad/anaconda3/envs/ss_meta_learning/lib/python3.9/site-packages | |
Requires: numpy, msgpack, optax, matplotlib, jax | |
Required-by: | |
Name: jax | |
Version: 0.3.13 | |
Summary: Differentiate, compile, and transform Numpy code. | |
Home-page: https://github.com/google/jax | |
Author: JAX team | |
Author-email: jax-dev@google.com | |
License: Apache-2.0 | |
Location: /home/jbrad/anaconda3/envs/ss_meta_learning/lib/python3.9/site-packages | |
Requires: scipy, typing-extensions, numpy, absl-py, opt-einsum | |
Required-by: optax, flax, chex | |
""" | |
from typing import Sequence | |
import jax | |
from jax import numpy as jnp | |
from flax import linen as nn | |
from jax import random | |
from tqdm import tqdm | |
class Sequential(nn.Module): | |
layers: Sequence[nn.Module] | |
@nn.compact | |
def __call__(self, x): | |
for layer in self.layers: | |
x = layer(x) | |
return x | |
class Sequential2(nn.Module): | |
layers_sizes: Sequence[int] | |
@nn.compact | |
def __call__(self, x): | |
for i, layers in enumerate(self.layers_sizes): | |
x = nn.Dense(layers)(x) | |
if i != len(self.layers_sizes) - 1: | |
x = nn.relu(x) | |
return x | |
class NN(nn.Module): | |
@nn.compact | |
def __call__(self, x): | |
x = nn.Dense(6400)(x) | |
x = nn.relu(x) | |
x = nn.Dense(6400)(x) | |
x = nn.relu(x) | |
x = nn.Dense(1)(x) | |
return x | |
def define_loss(net): | |
def loss(params, x, y): | |
preds = net.apply({"params": params}, x) | |
assert preds.shape == y.shape # <- just to check we aren't doing some unintended broadcasting... | |
err = preds - y | |
return jnp.mean(err ** 2) | |
return loss | |
def main(): | |
networks = [NN(), Sequential2([6400, 6400, 1]), | |
Sequential( | |
[nn.Dense(6400), | |
nn.relu, | |
nn.Dense(6400), | |
nn.relu, | |
nn.Dense(1)])] | |
for feed_forward_net in networks: | |
loss = define_loss(feed_forward_net) | |
model_all_the_way_to_loss_and_grad = jax.value_and_grad(loss, argnums=0) | |
jxkey = random.PRNGKey(42) | |
jxkey, subkey = jax.random.split(jxkey) | |
x = random.normal(subkey, (5, 728)) | |
jxkey, subkey = jax.random.split(jxkey) | |
y = random.normal(subkey, (5, 1)) | |
jxkey, subkey = jax.random.split(jxkey) | |
params = feed_forward_net.init(subkey, x)['params'] | |
for stp in tqdm(range(1000), desc=f"running inner loop steps"): | |
loss_val, grads = model_all_the_way_to_loss_and_grad(params, x, y) | |
params = jax.tree_util.tree_map( | |
lambda p, g: p - 0.01 * g, params, grads) | |
print("done!") | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment