Skip to content

Instantly share code, notes, and snippets.

@thowell
Last active January 8, 2024 01:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thowell/cec911dbca3545fe817bee0bbf5208b3 to your computer and use it in GitHub Desktop.
Save thowell/cec911dbca3545fe817bee0bbf5208b3 to your computer and use it in GitHub Desktop.
contact dynamics model (impact + friction) for 2D particle on flat surface
# Copyright 2023 Taylor Howell
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
# %matplotlib inline
# system parameters
param = {
"timestep": 0.1,
"mass": 1.0,
"gravity": jnp.array([0.0, 9.81]),
"friction_coeff": 0.25,
"central_path": 0.0,
}
# fixed point for contact model
def residual(var, param):
# configurations
q0 = var[0:2]
q1 = var[2:4]
q2 = var[4:6]
# velocity
v1 = (q2 - q1) / param["timestep"]
# normal impulse
n = var[6]
# friction
bp = var[7:9]
bd = var[9:11]
f = bp[1]
# dynamics
d = (
param["mass"] * (q2 - 2.0 * q1 + q0) / param["timestep"]
+ param["mass"] * param["gravity"] * param["timestep"]
- jnp.array([f, n])
)
# impact complementarity
cn = q2[1] * n - param["central_path"]
# friction complementarity
cf = jnp.array(
[jnp.inner(bp, bd) - param["central_path"], bp[0] * bd[1] + bd[0] * bp[1]]
)
# tangential velocity
tv = v1[0] - bd[1]
# impact impulse
ii = bp[0] - param["friction_coeff"] * n
return jnp.hstack([d, cn, cf, tv, ii])
# contact (impact + friction) dynamics simulation step
def step(pos, vel, param, verbose=False):
# current configuration
q1 = pos
# previous configuration
q0 = pos - param["timestep"] * vel
# warm start next configuration
q2 = jnp.copy(pos)
# warm start contact impulse
n = jnp.array([0.1])
# warm start friction impulse
bp = jnp.array([1.0, 0.1])
bd = jnp.array([1.0, 0.1])
# loop over central-path parameters
for cp in [0.1, 0.01, 0.001, 0.0001]:
param["central_path"] = cp
# loop over Newton steps
for i in range(10):
# assemble
var = jnp.hstack([q0, q1, q2, n, bp, bd])
# residual
res = residual(var, param)
# residual norm
res_norm = jnp.linalg.norm(res)
# jacobian of residual
jac = jax.jacfwd(residual)(var, param)[:, 4:11]
# search direction
dir = jnp.linalg.solve(jac, res)
# initial step size
step = 1.0
# candidate next configuration
q2_cand = q2 - step * dir[0:2]
# candidate contact impulse
n_cand = n - step * dir[2]
# candidate friction
bp_cand = bp - step * dir[3:5]
bd_cand = bd - step * dir[5:7]
# cone search
cone_iter = 0
while (
n_cand <= 0.0
or q2_cand[1] <= 0.0
or jnp.abs(bp_cand[1]) >= bp_cand[0]
or jnp.abs(bd_cand[1]) >= bd_cand[0]
):
step *= 0.5
q2_cand = q2 - step * dir[0:2]
n_cand = n - step * dir[2]
bp_cand = bp - step * dir[3:5]
bd_cand = bd - step * dir[5:7]
cone_iter += 1
if cone_iter > 10:
if verbose:
print("cone search failure!")
break
# residual candidate
res_cand = residual(
jnp.hstack([q0, q1, q2_cand, n_cand, bp_cand, bd_cand]), param
)
# residual candidate norm
res_cand_norm = jnp.linalg.norm(res_cand)
# decrease step size until cost reduce
step_iter = 0
while res_cand_norm >= res_norm:
step *= 0.5
# candidate next configuration
q2_cand = q2 - step * dir[0:2]
# candidate contact impulse
n_cand = n - step * dir[2]
# candidate friction impulse
bp_cand = bp - step * dir[3:5]
bd_cand = bd - step * dir[5:7]
# residual candidate
res_cand = residual(
jnp.hstack([q0, q1, q2_cand, n_cand, bp_cand, bd_cand]), param
)
# residual candidate norm
res_cand_norm = jnp.linalg.norm(res_cand)
# increment
step_iter += 1
if step_iter > 10:
if verbose:
print("line search failure")
break
# update
q2 = q2_cand
n = n_cand
bp = bp_cand
bd = bd_cand
# convergence
if res_cand_norm < 1.0e-4:
if verbose:
print("converged! cp (", cp, ")")
break
return q2, (q2 - q1) / param["timestep"]
# initialize trajectories
pos_traj = [jnp.array([0.0, 1.0])]
vel_traj = [jnp.array([1.0, 0.0])]
# step simulation
for t in range(10):
next_pos, next_vel = step(pos_traj[-1], vel_traj[-1], param, verbose=False)
pos_traj.append(next_pos)
vel_traj.append(next_vel)
print("t: ", t, " pos = ", next_pos)
# stack position trajectories
pos_stack = jnp.vstack(pos_traj)
# plot position trajectories
fig = plt.figure()
plt.plot(pos_stack[:, 0], color="cyan", label="x")
plt.plot(pos_stack[:, 1], color="orange", label="z")
plt.title("Particle 2D")
plt.legend()
plt.xlabel("Time step")
plt.ylabel("Position")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment