Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wdphy16/118aef6fb5f82c49790d7678cf87da29 to your computer and use it in GitHub Desktop.
Save wdphy16/118aef6fb5f82c49790d7678cf87da29 to your computer and use it in GitHub Desktop.

[RFC] Proposal for complex-valued optimization in Optax

Motivation

Complex-valued neural networks have been widely used in various science fields, such as the complex-valued wave functions in quantum physics and molecular chemistry, the complex-valued Fourier coefficients in signal processing, and manifold learning on torii that carry the complex structure. A recent survey is J. Bassey et al., arXiv:2101.12249.

One of the differences between complex- and real-valued neural networks is optimization. There is not yet a widely accepted way to generalize most optimizers to the complex domain, and most machine learning (ML) frameworks have not properly implemented them.

In the JAX community, Optax is the only actively maintained package dedicated to optimization, used in companion with other packages dedicated to neural network construction like Flax and Haiku. In particular, Flax developers have proposed to replace flax.optim with Optax (see FLIP 1009). Optax is also used in various downstream projects like NetKet, JAX MD, RLax, and Brax. We would like to implement complex-valued optimization in Optax and benefit the whole community.

Status in other ML frameworks

  • PyTorch: Recently there is a long discussion about complex-valued optimizers in pytorch#59998, and they are being implemented now. Regarding the popularity of PyTorch, we may follow some design choices from them.
  • TensorFlow: The maintainers refused to implement them in the main repo (see tensorflow#38541 comment), but it is possible to contribute to TF Addons.
  • Julia language: Many kinds of complex-valued optimizers have been implemented in Optim.jl and GalacticOptim.jl. However, common optimizers for neural networks implemented in Flux or other packages are unaware of complex numbers, and there is no issue about them yet.

Norm of complex variables

Many optimizers make use of the norm of variables (parameters and gradients). For example, in Adam we accumulate the first moment of gradients m_t = \beta_1 m_{t-1} + (1-\beta_1) g_t in the numerator, and the second moment of the norm of gradients v_t = \beta_2 v_{t-1} + (1-\beta_2) |g_t|^2 in the denominator. Most ML frameworks, including Optax, incorrectly assume that the square of the norm is g**2, which is true only in the real domain. As a result, complex parameters will not be correctly optimized, as shown in pytorch#59998.

To generalize those optimizers to the complex domain, there are two natural choices: the complex norm and the split real norm. We may decide to implement either or both of them.

Complex norm

We define the complex norm as norm(g: complex) = (g.conj() * g).real. This is the standard norm on the complex plane, and is reduced to the standard real norm if the imaginary part of g is zero.

To implement such norm, we take as an example Optax's Adam optimizer:

def update_fn(updates, state, params=None):
  del params
  mu = _update_moment(updates, state.mu, b1, 1)
  nu = _update_moment(updates, state.nu, b2, 2)
  count_inc = numerics.safe_int32_increment(state.count)
  mu_hat = utils.cast_tree(_bias_correction(mu, b1, count_inc), mu_dtype)
  nu_hat = _bias_correction(nu, b2, count_inc)
  updates = jax.tree_multimap(
      lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat)
  return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)

The only change needed is to accumulate the complex norm of nu, so we replace the _update_moment for nu with another function defined as

def _update_norm_moment(updates, moments, decay, order):
  """Compute the exponential moving average of the `order-th` moment of the norm."""

  def orderth_norm(g):
    if jnp.isrealobj(g):
      return g ** order
    else:
      return (g.conj() * g).real ** (order / 2)

  return jax.tree_multimap(
      lambda g, t: (1 - decay) * orderth_norm(g) + decay * t, updates, moments)

which is semantically different from _update_moment, because the semantics of _update_moment is to accumulate the variable itself, not the norm.

This change is non-breaking, in the sense that it does not affect at all users who only do real-valued optimization.

Split real norm

Another choice of the norm is, quoting the PyTorch developer's comment, "Optimizers on complex tensors should behave the same way as if they were running on two real tensors". For example, in Adam we separately normalize g.real and g.imag for each complex gradient g, as if there are two real parameters. This principle is consistent with the behavior of vjp chosen by JAX.

To implement it following the composable design of Optax, we can write an optimizer wrapper that splits the complex parameters into pairs of real parameters before sending them to the update of the wrapped optimizer, and merges the pairs of real updates into complex updates afterward:

def split_real_and_imaginary(inner):
  def init_fn(params):
    params = jax.tree_map(_complex_to_real_pair, params)
    inner_state = inner.init(params)
    return SplitRealAndImaginaryState(inner_state)

  def update_fn(updates, state, params=None):
    inner_state = state.inner_state
    updates = jax.tree_map(_complex_to_real_pair, updates)
    params = jax.tree_map(_complex_to_real_pair, params)
    updates, inner_state = inner.update(updates, inner_state, params)
    updates = jax.tree_map(_real_pair_to_complex, updates, is_leaf=_is_real_pair)
    return updates, SplitRealAndImaginaryState(inner_state)

  return base.GradientTransformation(init_fn, update_fn)

The usage is, for example, optimizer = optax.split_real_and_imaginary(optax.adam(learning_rate)).

There is no change to the existing API.

Implementing both of the norms

It is also possible that we implement both the complex norm and the split real norm as proposed above, and let the user choose between them according to their needs. If the user needs the complex norm, they may ignore split_real_and_imaginary and directly use optimizer = optax.adam(learning_rate). Otherwise, if the user needs the split real norm, they may use optimizer = optax.split_real_and_imaginary(optax.adam(learning_rate)), so the adam only processes real gradients, and the above change to adam is irrelevant.

Other related issues

Conjugate of complex gradients

JAX takes the convention that the output of jax.grad needs to be conjugated before being added to the parameter in gradient descent optimization. It originates from a choice of convention in the Wirtinger derivatives, and is different from the convention used by PyTorch, TensorFlow and Flux (see jax#4891 and pytorch#41857). Here is an example to demonstrate the difference:

import torch
from torch.autograd import grad
x = torch.tensor(1j, requires_grad=True)
print(grad(abs(x), x)) # 1j

from jax import numpy as jnp
from jax import grad
x = jnp.array(1j)
print(grad(abs)(x)) # -1j

In Optax, we should set up a guideline of how to do the conjugate in the optimization. One choice is to explicitly do the conjugate before optimizer.update:

grads = jax.grad(compute_loss)(params)
grads = jax.tree_map(lambda x: x.conj(), grads)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)

Some users already implemented it in their existing code, like the wrapped vjp in NetKet.

Another choice is to implement the conjugate in optax.apply_updates. Specifically, we add a conj to the update u:

def apply_updates(params, updates):
  return jax.tree_multimap(
      lambda p, u: jnp.asarray(p + u.conj()).astype(jnp.asarray(p).dtype),
      params, updates)

This change does not affect users who only do real-valued optimization, and there is no performance regression as jax.jit can eliminate the dispatch overhead of conj. For the old users who already implemented the conjugate, they need to modify their code accordingly. This change reduces one line of coding for users, but may break some semantics of the gradient in JAX and Optax. For now, we do not intend to implement it until further consensus is reached.

Gradient clipping

The gradient clipping transformations also depend on the choice of the norm for complex numbers. If we decide to implement the complex norm (or both of the norms), we need to accordingly implement it in the gradient clipping, as in the PR #161.

If we decide on the split real norm, the usage of gradient clipping becomes

optimizer = optax.split_real_and_imaginary(
    optax.chain(
        optax.clip_by_global_norm(max_norm),
        optax.adam(learning_rate)))

so the clip_by_global_norm only processes real gradients, and there is no change to be made.

@sracaniere
Copy link

Thanks a lot for this nice proposal!

I'd personally vote for the split-real because. As you pointed out, it is consistent with the behaviour of jax.vjp. Another reason to prefer that approach for me might be that users wanting to use complex numbers in neural networks can currently do so by creating their own Haiku classes that split the real and imaginary parts of all inputs. For example, a user could write:

class ComplexLinear(hk.Module):

  def __init__(
      self,
      output_size: int,
      with_bias: bool = True,
      w_init: Optional[hk.initializers.Initializer] = None,
      b_init: Optional[hk.initializers.Initializer] = None,
      name: Optional[str] = None,
  ):
    super().__init__(name=name)
    self.lin_r = hk.Linear(output_size, with_bias=with_bias, w_init=w_init, b_init=b_init)
    self.lin_j = hk.Linear(output_size, with_bias=with_bias, w_init=w_init, b_init=b_init)

  def __call__(
      self,
      inputs: jnp.ndarray,
      *,
      precision: Optional[lax.Precision] = None,
  ) -> jnp.ndarray:
    inputs_s = jnp.stack([inputs.real, inputs.imag])
    out_r = self.lin_r(inputs_s, precision=precision)
    out_j = self.lin_j(inputs_s, precision=precision)
    return out_r[0] - out_j[1] + 1j * (out_r[1] + out_j[0])

The split real norm approach would allow them to use normal hk.Linear instead, but have their network and optimisation behave the same way as with the above workaround.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment