Skip to content

Instantly share code, notes, and snippets.

@PhilipVinc
Created November 29, 2021 09:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save PhilipVinc/72fb6d451dc40db4f0cdf15cf585cc3a to your computer and use it in GitHub Desktop.
Save PhilipVinc/72fb6d451dc40db4f0cdf15cf585cc3a to your computer and use it in GitHub Desktop.
complex-flax.py
from jax.config import config
config.update("jax_enable_x64", True)
from jax import random
from jax import numpy as jnp
import flax.linen as nn
key = random.PRNGKey(0)
xr32 = random.normal(key, (4,), dtype=jnp.float32)
xr64 = random.normal(key, (4,), dtype=jnp.float64)
xc32 = random.normal(key, (4,), dtype=jnp.complex64)
xc64 = random.normal(key, (4,), dtype=jnp.complex128)
######################################################################
## Inconsistency: Probably because you don't run with float64 enabled
# Since you don't pass dtype to the initializer, and jax initialisers
# default to float64 if jax is in double precision mode, then this is
# inconsistent
dense32 = nn.Dense(kernel_init=nn.initializers.normal(stddev=0.01), dtype=jnp.float32, features=3)
out, pars = dense32.init_with_output(key, xr32)
print(out.dtype) # float32
print(pars['params']['kernel'].dtype) # float64
######################################################################
## Problem 1: type promotion and complex truncation
# there are two issues at play here:
dense64 = nn.Dense(kernel_init=nn.initializers.normal(stddev=0.01), dtype=jnp.float64, features=3)
# I assume that by default this uses float64 for parameters.
# by default it should not cast outputs and just promote
# assuming parameters will be float64, input float32, output will be
# float64
out, pars = dense64.init_with_output(key, xr32)
assert pars['params']['kernel'].dtype == jnp.float64
assert out.dtype == jnp.float64
# assuming parameters will be float64, input float64, output will be
# float64
out, pars = dense64.init_with_output(key, xr64)
assert pars['params']['kernel'].dtype == jnp.float64
assert out.dtype == jnp.float64 # this fails. Why?
# assuming parameters will be float64, input complex32, output will be
# complex128.
out, pars = dense64.init_with_output(key, xc32)
assert pars['params']['kernel'].dtype == jnp.float64
assert out.dtype == jnp.complex128
out, pars = dense64.init_with_output(key, xc64)
assert pars['params']['kernel'].dtype == jnp.float64
assert out.dtype == jnp.complex128
######################################################################
# Problem 2: I would like to change the dtype of the parameters quickly and easily
# I would like an API more like..
dense32 = nn.Dense(dtype=jnp.float32, features=3)
pars = dense32.init(key, xr32)
assert pars['params']['kernel'].dtype == jnp.float32
assert pars['params']['bias'].dtype == jnp.float32
dense64 = nn.Dense(dtype=jnp.float64, features=3)
pars = dense32.init(key, xr32)
assert pars['params']['kernel'].dtype == jnp.float64
assert pars['params']['bias'].dtype == jnp.float64
dense64 = nn.Dense(dtype=jnp.complex64, features=3)
out, pars = dense32.init(key, xr32)
assert pars['params']['kernel'].dtype == jnp.complex64
assert pars['params']['bias'].dtype == jnp.complex64
dense64 = nn.Dense(dtype=jnp.complex128, features=3)
out, pars = dense32.init(key, xr32)
assert pars['params']['kernel'].dtype == jnp.complex128
assert pars['params']['bias'].dtype == jnp.complex128
######################################################################
# Question: I want to have float64 kernel and float32 bias.
# Answer: We should prioritize common use-cases. And this use-case is less
# common. Regardless, it can still be implemented as:
# dense32 = nn.Dense(bias_init=lambda k,s,d: normal(stddev=0.01)(k,s), dtype=jnp.float64, features=3)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment