Created
November 29, 2021 09:47
-
-
Save PhilipVinc/72fb6d451dc40db4f0cdf15cf585cc3a to your computer and use it in GitHub Desktop.
complex-flax.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from jax.config import config | |
config.update("jax_enable_x64", True) | |
from jax import random | |
from jax import numpy as jnp | |
import flax.linen as nn | |
key = random.PRNGKey(0) | |
xr32 = random.normal(key, (4,), dtype=jnp.float32) | |
xr64 = random.normal(key, (4,), dtype=jnp.float64) | |
xc32 = random.normal(key, (4,), dtype=jnp.complex64) | |
xc64 = random.normal(key, (4,), dtype=jnp.complex128) | |
###################################################################### | |
## Inconsistency: Probably because you don't run with float64 enabled | |
# Since you don't pass dtype to the initializer, and jax initialisers | |
# default to float64 if jax is in double precision mode, then this is | |
# inconsistent | |
dense32 = nn.Dense(kernel_init=nn.initializers.normal(stddev=0.01), dtype=jnp.float32, features=3) | |
out, pars = dense32.init_with_output(key, xr32) | |
print(out.dtype) # float32 | |
print(pars['params']['kernel'].dtype) # float64 | |
###################################################################### | |
## Problem 1: type promotion and complex truncation | |
# there are two issues at play here: | |
dense64 = nn.Dense(kernel_init=nn.initializers.normal(stddev=0.01), dtype=jnp.float64, features=3) | |
# I assume that by default this uses float64 for parameters. | |
# by default it should not cast outputs and just promote | |
# assuming parameters will be float64, input float32, output will be | |
# float64 | |
out, pars = dense64.init_with_output(key, xr32) | |
assert pars['params']['kernel'].dtype == jnp.float64 | |
assert out.dtype == jnp.float64 | |
# assuming parameters will be float64, input float64, output will be | |
# float64 | |
out, pars = dense64.init_with_output(key, xr64) | |
assert pars['params']['kernel'].dtype == jnp.float64 | |
assert out.dtype == jnp.float64 # this fails. Why? | |
# assuming parameters will be float64, input complex32, output will be | |
# complex128. | |
out, pars = dense64.init_with_output(key, xc32) | |
assert pars['params']['kernel'].dtype == jnp.float64 | |
assert out.dtype == jnp.complex128 | |
out, pars = dense64.init_with_output(key, xc64) | |
assert pars['params']['kernel'].dtype == jnp.float64 | |
assert out.dtype == jnp.complex128 | |
###################################################################### | |
# Problem 2: I would like to change the dtype of the parameters quickly and easily | |
# I would like an API more like.. | |
dense32 = nn.Dense(dtype=jnp.float32, features=3) | |
pars = dense32.init(key, xr32) | |
assert pars['params']['kernel'].dtype == jnp.float32 | |
assert pars['params']['bias'].dtype == jnp.float32 | |
dense64 = nn.Dense(dtype=jnp.float64, features=3) | |
pars = dense32.init(key, xr32) | |
assert pars['params']['kernel'].dtype == jnp.float64 | |
assert pars['params']['bias'].dtype == jnp.float64 | |
dense64 = nn.Dense(dtype=jnp.complex64, features=3) | |
out, pars = dense32.init(key, xr32) | |
assert pars['params']['kernel'].dtype == jnp.complex64 | |
assert pars['params']['bias'].dtype == jnp.complex64 | |
dense64 = nn.Dense(dtype=jnp.complex128, features=3) | |
out, pars = dense32.init(key, xr32) | |
assert pars['params']['kernel'].dtype == jnp.complex128 | |
assert pars['params']['bias'].dtype == jnp.complex128 | |
###################################################################### | |
# Question: I want to have float64 kernel and float32 bias. | |
# Answer: We should prioritize common use-cases. And this use-case is less | |
# common. Regardless, it can still be implemented as: | |
# dense32 = nn.Dense(bias_init=lambda k,s,d: normal(stddev=0.01)(k,s), dtype=jnp.float64, features=3) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment