Last active
June 7, 2023 19:40
-
-
Save csuter/2a86ac22495bc7f0c3bce4bbfa140c57 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def intersect(l1, l2): | |
# l1 and l2 are lines specified by 3 coefficients, as in | |
# a x + b y + c = 0 | |
mat = jnp.stack([l1[:2], l2[:2]], axis=0) | |
vec = -jnp.stack([l1[2], l2[2]], axis=0) | |
return jnp.linalg.solve(mat, vec) | |
def intersection_depth(x, hd, seg): | |
def _cross(u, v): | |
return u[0] * v[1] - u[1] * v[0] | |
dx = jnp.stack([jnp.cos(hd), jnp.sin(hd)]) | |
ds = seg[1] - seg[0] | |
segnorm = jnp.linalg.norm(ds) | |
normed_seg = ds / segnorm | |
pose_l = jnp.array([dx[1], -dx[0], _cross(dx, x)]) | |
seg_l = jnp.array([ds[1], -ds[0], _cross(ds, seg[0])]) | |
int_pt = intersect(pose_l, seg_l) | |
seg_prop = (int_pt - seg[0]).dot(normed_seg) / segnorm | |
strikes_segment = (seg_prop >= 0.) & (seg_prop <= 1.) | |
is_forward = dx.dot(int_pt - x) > 0. | |
in_bounds = strikes_segment & is_forward | |
return jnp.where( | |
in_bounds, | |
jnp.linalg.norm(int_pt - x), | |
jnp.inf) | |
def cast(x, hd, segs, fov, num_a, zmax): | |
angles = hd + jnp.linspace(-fov / 2, fov / 2, num_a) | |
vhd_int_depth = jax.vmap( | |
intersection_depth, | |
in_axes=[None, 0, None]) | |
vseg_vhd_int_depth = jax.vmap( | |
vhd_int_depth, | |
in_axes=[None, None, 0]) | |
depths = vseg_vhd_int_depth(x, angles, segs) | |
depths = jnp.min(depths, axis=0) | |
depths = jnp.where(depths < zmax, depths, zmax) | |
return angles, depths | |
vcast = jax.vmap(cast, in_axes=(0, 0, None, None, None, None)) | |
all_segs = data_json['env']['segs'] + data_json['env']['clutter'] | |
segs = jnp.array( | |
[jnp.stack([jnp.array(d['x']), jnp.array(d['y'])], axis=0) | |
for d in all_segs]) | |
angles, depths = vcast(xs, hds, segs, 2 * np.pi, 361, jnp.inf) | |
depths += .1 * jax.random.normal(key=jax.random.PRNGKey(0), shape=depths.shape) | |
angles.shape, depths.shape |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment