Skip to content

Instantly share code, notes, and snippets.

@john-bradshaw
Created August 12, 2022 17:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save john-bradshaw/544d40e8b1167095ea16980251be19a0 to your computer and use it in GitHub Desktop.
Save john-bradshaw/544d40e8b1167095ea16980251be19a0 to your computer and use it in GitHub Desktop.
Showing how flax's nn.compact if used incorrectly can cause memory leaks.
"""
Simple demonstration of memory leak.
XLA_PYTHON_CLIENT_PREALLOCATE=false CUDA_VISIBLE_DEVICES=0 python flax_mem_leak.py
Breaks on last iter through loop.
# Tested on version:
Name: flax
Version: 0.4.1
Summary: Flax: A neural network library for JAX designed for flexibility
Home-page: https://github.com/google/flax
Author: Flax team
Author-email: flax-dev@google.com
License: UNKNOWN
Location: /home/jbrad/anaconda3/envs/ss_meta_learning/lib/python3.9/site-packages
Requires: numpy, msgpack, optax, matplotlib, jax
Required-by:
Name: jax
Version: 0.3.13
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/jbrad/anaconda3/envs/ss_meta_learning/lib/python3.9/site-packages
Requires: scipy, typing-extensions, numpy, absl-py, opt-einsum
Required-by: optax, flax, chex
"""
from typing import Sequence
import jax
from jax import numpy as jnp
from flax import linen as nn
from jax import random
from tqdm import tqdm
class Sequential(nn.Module):
layers: Sequence[nn.Module]
@nn.compact
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x
class Sequential2(nn.Module):
layers_sizes: Sequence[int]
@nn.compact
def __call__(self, x):
for i, layers in enumerate(self.layers_sizes):
x = nn.Dense(layers)(x)
if i != len(self.layers_sizes) - 1:
x = nn.relu(x)
return x
class NN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(6400)(x)
x = nn.relu(x)
x = nn.Dense(6400)(x)
x = nn.relu(x)
x = nn.Dense(1)(x)
return x
def define_loss(net):
def loss(params, x, y):
preds = net.apply({"params": params}, x)
assert preds.shape == y.shape # <- just to check we aren't doing some unintended broadcasting...
err = preds - y
return jnp.mean(err ** 2)
return loss
def main():
networks = [NN(), Sequential2([6400, 6400, 1]),
Sequential(
[nn.Dense(6400),
nn.relu,
nn.Dense(6400),
nn.relu,
nn.Dense(1)])]
for feed_forward_net in networks:
loss = define_loss(feed_forward_net)
model_all_the_way_to_loss_and_grad = jax.value_and_grad(loss, argnums=0)
jxkey = random.PRNGKey(42)
jxkey, subkey = jax.random.split(jxkey)
x = random.normal(subkey, (5, 728))
jxkey, subkey = jax.random.split(jxkey)
y = random.normal(subkey, (5, 1))
jxkey, subkey = jax.random.split(jxkey)
params = feed_forward_net.init(subkey, x)['params']
for stp in tqdm(range(1000), desc=f"running inner loop steps"):
loss_val, grads = model_all_the_way_to_loss_and_grad(params, x, y)
params = jax.tree_util.tree_map(
lambda p, g: p - 0.01 * g, params, grads)
print("done!")
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment