Skip to content

Instantly share code, notes, and snippets.

View wdphy16's full-sized avatar

Dian Wu wdphy16

  • CQSL, EPFL
  • Lausanne, Switzerland
View GitHub Profile

[RFC] Proposal for complex-valued optimization in Optax

Motivation

Complex-valued neural networks have been widely used in various science fields, such as the complex-valued wave functions in quantum physics and molecular chemistry, the complex-valued Fourier coefficients in signal processing, and manifold learning on torii that carry the complex structure. A recent survey is J. Bassey et al., arXiv:2101.12249.

One of the differences between complex- and real-valued neural networks is optimization. There is not yet a widely accepted way to generalize most optimizers to the complex domain, and most machine learning (ML) frameworks have not properly implemented them.

In the JAX community, Optax is the only actively maintained package dedicated to optimization, used in companion with other packages dedicated to neural network construction like Flax and Haiku. In particular, Flax developers have