Skip to content

Instantly share code, notes, and snippets.

View EelcoHoogendoorn's full-sized avatar

Eelco Hoogendoorn EelcoHoogendoorn

  • 3dhubs
  • Amsterdam
View GitHub Profile
"""Minimal example of DLR (diagonal linear recurrent) layer in JAX
https://arxiv.org/pdf/2212.00768.pdf
"""
from typing import Any, Callable, Sequence, Tuple
from flax import linen
import jax
import jax.numpy as jnp
"""Minimal example of DLR (diagonal linear recurrent) layer in JAX
https://arxiv.org/pdf/2212.00768.pdf
"""
from typing import Any, Callable, Sequence, Tuple
from flax import linen
import jax
import jax.numpy as jnp
def spin_transform_deform(mesh, rho):
# some boilerplate to convert pycomplex mesh datastructures to GA-sparse matrix operators
I20 = mesh.topology.incidence[2, 0] # [F, 3] face-vertex incidence
I21 = mesh.topology.incidence[2, 1] # [F, 3] face-edge incidence
I10 = mesh.topology.incidence[1, 0] # [E, 2] edge-vertex incidence
O10 = mesh.topology._orientation[0] # [E, 2] edge-vertex relative orientations
O21 = mesh.topology._orientation[1] # [F, 3] face-edge relative orientations
T10 = as_ga_sparse(I10, as_scalar(O10)) # edge-vertex oriented boundary operator
A10 = as_ga_sparse(I10, as_scalar(np.ones_like(I10) / 2)) # averages vertices over edges