Last active
January 8, 2024 01:10
-
-
Save thowell/cec911dbca3545fe817bee0bbf5208b3 to your computer and use it in GitHub Desktop.
contact dynamics model (impact + friction) for 2D particle on flat surface
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2023 Taylor Howell | |
import jax | |
import jax.numpy as jnp | |
import matplotlib.pyplot as plt | |
# %matplotlib inline | |
# system parameters | |
param = { | |
"timestep": 0.1, | |
"mass": 1.0, | |
"gravity": jnp.array([0.0, 9.81]), | |
"friction_coeff": 0.25, | |
"central_path": 0.0, | |
} | |
# fixed point for contact model | |
def residual(var, param): | |
# configurations | |
q0 = var[0:2] | |
q1 = var[2:4] | |
q2 = var[4:6] | |
# velocity | |
v1 = (q2 - q1) / param["timestep"] | |
# normal impulse | |
n = var[6] | |
# friction | |
bp = var[7:9] | |
bd = var[9:11] | |
f = bp[1] | |
# dynamics | |
d = ( | |
param["mass"] * (q2 - 2.0 * q1 + q0) / param["timestep"] | |
+ param["mass"] * param["gravity"] * param["timestep"] | |
- jnp.array([f, n]) | |
) | |
# impact complementarity | |
cn = q2[1] * n - param["central_path"] | |
# friction complementarity | |
cf = jnp.array( | |
[jnp.inner(bp, bd) - param["central_path"], bp[0] * bd[1] + bd[0] * bp[1]] | |
) | |
# tangential velocity | |
tv = v1[0] - bd[1] | |
# impact impulse | |
ii = bp[0] - param["friction_coeff"] * n | |
return jnp.hstack([d, cn, cf, tv, ii]) | |
# contact (impact + friction) dynamics simulation step | |
def step(pos, vel, param, verbose=False): | |
# current configuration | |
q1 = pos | |
# previous configuration | |
q0 = pos - param["timestep"] * vel | |
# warm start next configuration | |
q2 = jnp.copy(pos) | |
# warm start contact impulse | |
n = jnp.array([0.1]) | |
# warm start friction impulse | |
bp = jnp.array([1.0, 0.1]) | |
bd = jnp.array([1.0, 0.1]) | |
# loop over central-path parameters | |
for cp in [0.1, 0.01, 0.001, 0.0001]: | |
param["central_path"] = cp | |
# loop over Newton steps | |
for i in range(10): | |
# assemble | |
var = jnp.hstack([q0, q1, q2, n, bp, bd]) | |
# residual | |
res = residual(var, param) | |
# residual norm | |
res_norm = jnp.linalg.norm(res) | |
# jacobian of residual | |
jac = jax.jacfwd(residual)(var, param)[:, 4:11] | |
# search direction | |
dir = jnp.linalg.solve(jac, res) | |
# initial step size | |
step = 1.0 | |
# candidate next configuration | |
q2_cand = q2 - step * dir[0:2] | |
# candidate contact impulse | |
n_cand = n - step * dir[2] | |
# candidate friction | |
bp_cand = bp - step * dir[3:5] | |
bd_cand = bd - step * dir[5:7] | |
# cone search | |
cone_iter = 0 | |
while ( | |
n_cand <= 0.0 | |
or q2_cand[1] <= 0.0 | |
or jnp.abs(bp_cand[1]) >= bp_cand[0] | |
or jnp.abs(bd_cand[1]) >= bd_cand[0] | |
): | |
step *= 0.5 | |
q2_cand = q2 - step * dir[0:2] | |
n_cand = n - step * dir[2] | |
bp_cand = bp - step * dir[3:5] | |
bd_cand = bd - step * dir[5:7] | |
cone_iter += 1 | |
if cone_iter > 10: | |
if verbose: | |
print("cone search failure!") | |
break | |
# residual candidate | |
res_cand = residual( | |
jnp.hstack([q0, q1, q2_cand, n_cand, bp_cand, bd_cand]), param | |
) | |
# residual candidate norm | |
res_cand_norm = jnp.linalg.norm(res_cand) | |
# decrease step size until cost reduce | |
step_iter = 0 | |
while res_cand_norm >= res_norm: | |
step *= 0.5 | |
# candidate next configuration | |
q2_cand = q2 - step * dir[0:2] | |
# candidate contact impulse | |
n_cand = n - step * dir[2] | |
# candidate friction impulse | |
bp_cand = bp - step * dir[3:5] | |
bd_cand = bd - step * dir[5:7] | |
# residual candidate | |
res_cand = residual( | |
jnp.hstack([q0, q1, q2_cand, n_cand, bp_cand, bd_cand]), param | |
) | |
# residual candidate norm | |
res_cand_norm = jnp.linalg.norm(res_cand) | |
# increment | |
step_iter += 1 | |
if step_iter > 10: | |
if verbose: | |
print("line search failure") | |
break | |
# update | |
q2 = q2_cand | |
n = n_cand | |
bp = bp_cand | |
bd = bd_cand | |
# convergence | |
if res_cand_norm < 1.0e-4: | |
if verbose: | |
print("converged! cp (", cp, ")") | |
break | |
return q2, (q2 - q1) / param["timestep"] | |
# initialize trajectories | |
pos_traj = [jnp.array([0.0, 1.0])] | |
vel_traj = [jnp.array([1.0, 0.0])] | |
# step simulation | |
for t in range(10): | |
next_pos, next_vel = step(pos_traj[-1], vel_traj[-1], param, verbose=False) | |
pos_traj.append(next_pos) | |
vel_traj.append(next_vel) | |
print("t: ", t, " pos = ", next_pos) | |
# stack position trajectories | |
pos_stack = jnp.vstack(pos_traj) | |
# plot position trajectories | |
fig = plt.figure() | |
plt.plot(pos_stack[:, 0], color="cyan", label="x") | |
plt.plot(pos_stack[:, 1], color="orange", label="z") | |
plt.title("Particle 2D") | |
plt.legend() | |
plt.xlabel("Time step") | |
plt.ylabel("Position") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment