Skip to content

Instantly share code, notes, and snippets.

@Chandler
Created April 10, 2020 02:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Chandler/80ef67b164cd1a7bb50446adc24f927a to your computer and use it in GitHub Desktop.
Save Chandler/80ef67b164cd1a7bb50446adc24f927a to your computer and use it in GitHub Desktop.
Normals w/ jax
# 3D Surface normals with JAX
import matplotlib.pyplot as plt
from jax import vmap, jacfwd, np
# Simple Torus surface parameterization f(u,v) -> (x,y,z)
# The multivariable analytic function to be differentiated
def _f(uv):
u,v = uv
x = (8 + 2 * np.cos(v)) * np.cos(u)
y = (8 + 2 * np.cos(v)) * np.sin(u)
z = 2 * np.sin(v)
return [x,y,z]
# List of surface points to compute normals for
coordinates = []
for u in np.linspace(0, 2*np.pi, 100):
for v in np.linspace(0, 2*np.pi, 100):
coordinates.append([u,v])
coordinates = np.array(coordinates) # (N x 2)
#================== JAX ==========================
# Compile a vectorized version of _f
f = vmap(_f)
# Compile the vectorized differential of f
# This is a function which returns the jacobian matrix
# at each point, essentially the multi-variable version of a derivative.
df = vmap(jacfwd(_f))
#=================================================
# 3d points on the surface (N x 3)
points = np.array(f(coordinates)).swapaxes(0,1)
# The Jacobian is the (3 x 2) matrix of all partial derivatives of f
# [dxdu, dxdv]
# [dydu, dydv] aka [dfdu, dfdv]
# [dzdu, dzdv]
jacobian_matrices = np.array(df(coordinates)) # Vectorized shape is (3 x N x 2)
# dfdu: list of 3d vectors tangent to the surface in u direction (N x 3)
# dfdv: list of 3d vectors tangent to the surface in v direction (N x 3)
dfdu, dfdv = np.array(jacobian_matrices).swapaxes(0,2)
# A list of 3d unit vectors normal to the surface (N x 3)
v = np.cross(dfdu, dfdv)
unit_normals = v / np.linalg.norm(v)
# See for yourself, create a quiver plot of the vectors
x,y,z = points.swapaxes(0,1) # Vector positions
u,v,w = unit_normals.swapaxes(0,1) # Vector directions
plt.figure().gca(projection='3d').quiver(x, y, z, u, v, w)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment