Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save schmrlng/fe64d89b637c79342e18bfd5f61cb34a to your computer and use it in GitHub Desktop.
Save schmrlng/fe64d89b637c79342e18bfd5f61cb34a to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "embedded-assembly",
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"\n",
"import matplotlib.pyplot as plt; plt.rcParams.update({'font.size': 20})\n",
"from ipywidgets import interact\n",
"\n",
"# `NamedTuple`s are used (more accurately, abused) in this notebook to minimize dependencies;\n",
"# better choices would be `flax.struct.dataclass` or `equinox.Module`.\n",
"from typing import NamedTuple"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def interval_interval_separation_distance(interval_0, interval_1):\n",
" return jnp.maximum(\n",
" interval_0[0] - interval_1[1],\n",
" interval_1[0] - interval_0[1],\n",
" )\n",
"\n",
"\n",
"def rotate_points(points, angle):\n",
" c, s = jnp.cos(angle), jnp.sin(angle)\n",
" return points @ jnp.array([[c, s], [-s, c]])\n",
"\n",
"\n",
"class Rectangle(NamedTuple):\n",
" center: jnp.array\n",
" orientation: jnp.array\n",
" half_dimensions: jnp.array # (half_width, half_height)\n",
"\n",
" @property\n",
" def corners(self):\n",
" half_width, half_height = self.half_dimensions\n",
" untransformed_corners = jnp.array([\n",
" [half_width, half_height],\n",
" [-half_width, half_height],\n",
" [-half_width, -half_height],\n",
" [half_width, -half_height],\n",
" ])\n",
" return self.center + rotate_points(untransformed_corners, self.orientation)\n",
"\n",
" def distance_to_point(self, point):\n",
" point_in_body_frame = rotate_points(point - self.center, -self.orientation)\n",
" return jnp.linalg.norm(point_in_body_frame -\n",
" jnp.clip(point_in_body_frame, -self.half_dimensions, self.half_dimensions))\n",
"\n",
"\n",
"@jax.jit\n",
"def rectangle_rectangle_separation_distance(rectangle_0: Rectangle, rectangle_1: Rectangle):\n",
"\n",
" def _separation_distance(rectangle_0, rectangle_1):\n",
" points_1 = rotate_points(rectangle_1.corners - rectangle_0.center, -rectangle_0.orientation)\n",
" return jnp.max(\n",
" jax.vmap(interval_interval_separation_distance, 1)(\n",
" jnp.array([-rectangle_0.half_dimensions, rectangle_0.half_dimensions]),\n",
" jnp.array([jnp.min(points_1, 0), jnp.max(points_1, 0)]),\n",
" ))\n",
"\n",
" return jnp.maximum(\n",
" _separation_distance(rectangle_0, rectangle_1),\n",
" _separation_distance(rectangle_1, rectangle_0),\n",
" )\n",
"\n",
"\n",
"@jax.jit\n",
"def rectangle_rectangle_signed_distance(rectangle_0: Rectangle, rectangle_1: Rectangle):\n",
" separation_distance = rectangle_rectangle_separation_distance(rectangle_0, rectangle_1)\n",
" return jnp.where(\n",
" separation_distance < 0, separation_distance,\n",
" jnp.minimum(jnp.min(jax.vmap(rectangle_0.distance_to_point)(rectangle_1.corners)),\n",
" jnp.min(jax.vmap(rectangle_1.distance_to_point)(rectangle_0.corners))))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@interact(x=(-10, 10), y=(-10, 10), q=(-np.pi, np.pi))\n",
"def plot(x, y, q):\n",
" r0 = Rectangle(np.array([1., 2.]), 0., np.array([1., 2.]))\n",
" r1 = Rectangle(np.array([x, y]), q, np.array([1., 2.]))\n",
"\n",
" plt.figure(figsize=(10, 8))\n",
" plt.title(rectangle_rectangle_signed_distance(r0, r1))\n",
" plt.scatter(*r0.corners.T)\n",
" plt.scatter(*r1.corners.T)\n",
" plt.axis(\"equal\")\n",
" plt.xlim(-10, 10)\n",
" plt.ylim(-10, 10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment