Skip to content

Instantly share code, notes, and snippets.

@fgolemo
Created February 19, 2021 05:48
Show Gist options
  • Save fgolemo/b762ddc59c83ca19cd15f3767e2c3780 to your computer and use it in GitHub Desktop.
Save fgolemo/b762ddc59c83ca19cd15f3767e2c3780 to your computer and use it in GitHub Desktop.
autobot_toy.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "autobot_toy.ipynb",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/fgolemo/b762ddc59c83ca19cd15f3767e2c3780/autobot_toy.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EB9MwlqV98w-"
},
"source": [
"# AutoBot Toy Dataset Modelling"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true,
"id": "6lMLqmTN9i8G"
},
"source": [
"## Generate Toy Dataset\n",
"In this section, we generate our tiny toy dataset that showcases AutoBot's ability to model multimodal trajectories. We generate these trajectories by adopting a simple bicycle model that turns with a constant steering angle at a constant speed. This toy dataset not only demonstrates the multimodal trajectories AutoBot generates, but also shows the importance of the entropy regularization term."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 596
},
"id": "cBIkHktQ9i8K",
"outputId": "13cebbfc-0bd7-434f-e929-22b7948eec08"
},
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"\n",
"\n",
"l = 2 # length of bycicle\n",
"dt = 0.5\n",
"start_pos = np.array([-0.0, -10.0])\n",
"data = []\n",
"\n",
"# we'll generate a total of 6 trajectories\n",
"\n",
"# we will have 2 trajectories go left, 2 go straight, 2 go right\n",
"phis = [-0.2, 0.0, 0.2]\n",
"\n",
"# the trjeactories will go left/straight/right at one of 2 possible speeds\n",
"speeds = [1.5, 3.0]\n",
"\n",
"configs = np.array(np.meshgrid(phis, speeds)).T.reshape(-1, 2)\n",
"\n",
"fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(15, 10))\n",
"row = 0\n",
"for i, config in enumerate(configs):\n",
" wheel_pos = []\n",
" speed = 3.0\n",
" heading = np.pi / 2\n",
" phi = 0.0\n",
" positions = {\n",
" \"rear\": np.array([0.0, 0.0]) + start_pos,\n",
" \"front\": np.array([l * np.cos(heading), l * np.sin(heading)]) + start_pos\n",
" }\n",
"\n",
" # we are generating trajectories of total length 18, \n",
" # 6 of which will be used as input trajectory and\n",
" # the remaining 12 as output trajectory\n",
" for t in range(18):\n",
" \n",
" # the first 6 steps are fixed to be straight upwards-facing at the given velocity\n",
" if t > 6:\n",
" phi, speed = config\n",
"\n",
" # for the remaining 12 steps, we apply different headings and velocities\n",
" x_v = speed*np.cos(heading)\n",
" y_v = speed*np.sin(heading)\n",
" omega = speed*np.tan(phi)/l\n",
" heading += omega * dt\n",
" positions[\"rear\"] += np.array([x_v * dt, y_v * dt])\n",
" positions[\"front\"] = positions[\"rear\"] + np.array([l * np.cos(heading), l * np.sin(heading)])\n",
" wheel_pos.append([positions[\"rear\"][0], positions[\"rear\"][1], positions[\"front\"][0], positions[\"front\"][1]])\n",
" data.append(np.array(wheel_pos))\n",
" \n",
" # plotting the data\n",
" col = i % 3\n",
" if i > 0 and i % 3 == 0:\n",
" row += 1\n",
" ax[row, col].scatter(np.array(wheel_pos)[:6, 0], np.array(wheel_pos)[:6, 1], color='#94D0FF', label='past', s=40)\n",
" ax[row, col].scatter(np.array(wheel_pos)[6:, 0], np.array(wheel_pos)[6:, 1], color='#FF6AD5', label='future', s=40)\n",
" ax[row, col].axis(xmin=-15, xmax=15, ymin=-15, ymax=20)\n",
"\n",
"ax[1, 2].legend()\n",
"plt.show()\n",
"data = np.array(data)[:, :, :2]"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 1080x720 with 6 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9JaFXTes9i8M"
},
"source": [
"## Creating a pytorch dataloader"
]
},
{
"cell_type": "code",
"metadata": {
"id": "P0eRpxAs9i8M"
},
"source": [
"from torch.utils.data import Dataset\n",
"\n",
"\n",
"class NPYFakeDataset(Dataset):\n",
" def __init__(self):\n",
" self.ego_dataset = data\n",
"\n",
" def get_input_output_seqs(self, ego_data):\n",
" # 6 input timesteps, (cyan-colored in the plot above), \n",
" # which are identical across all 6 examples.\n",
" ego_in = ego_data[:6] \n",
"\n",
" # 12 output (to be predicted by the model) timesteps,\n",
" # (pink in the plot above)\n",
" ego_out = ego_data[6:]\n",
" \n",
" return ego_in, ego_out\n",
"\n",
" def __getitem__(self, idx: int):\n",
" ego_data = self.ego_dataset[idx]\n",
" in_ego, out_ego = self.get_input_output_seqs(ego_data)\n",
" return in_ego, out_ego\n",
"\n",
" def __len__(self):\n",
" return len(self.ego_dataset)\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "7sg9ohVJ9i8N"
},
"source": [
"## Model Code"
]
},
{
"cell_type": "code",
"metadata": {
"id": "9On9HNKh9i8N"
},
"source": [
"import math\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"\n",
"def init(module, weight_init, bias_init, gain=1):\n",
" weight_init(module.weight.data, gain=gain)\n",
" bias_init(module.bias.data)\n",
" return module\n",
"\n",
"\n",
"class PositionalEncoding(nn.Module):\n",
" '''\n",
" Sine/cosine positional encoding (standard procedure for transformer sequential inputs)\n",
" '''\n",
"\n",
" def __init__(self, d_model, dropout=0.1, max_len=20):\n",
" super(PositionalEncoding, self).__init__()\n",
" self.dropout = nn.Dropout(p=dropout)\n",
" pe = torch.zeros(max_len, d_model)\n",
" position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\n",
" div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))\n",
" pe[:, 0::2] = torch.sin(position * div_term)\n",
" pe[:, 1::2] = torch.cos(position * div_term)\n",
" pe = pe.unsqueeze(0).transpose(0, 1)\n",
" self.register_buffer('pe', pe)\n",
"\n",
" def forward(self, x):\n",
" '''\n",
" :param x: must be (T, B, H)\n",
" :return:\n",
" '''\n",
" x = x + self.pe[:x.size(0), :]\n",
" return self.dropout(x)\n",
"\n",
"\n",
"class AutoBot(nn.Module):\n",
" '''\n",
" Nested Set Transformer model specialized for car environment with opponents.\n",
" '''\n",
" def __init__(self, hidden_size=64, num_modes=3):\n",
" super(AutoBot, self).__init__()\n",
"\n",
" init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2))\n",
"\n",
" self.hidden_size = hidden_size\n",
" self.num_modes = num_modes\n",
" self.num_heads = 8\n",
"\n",
" self.output_model = OutputModelBVG(hidden_size=hidden_size)\n",
"\n",
" tx_encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=self.num_heads)\n",
" self.tx_encoder = nn.TransformerEncoder(tx_encoder_layer, num_layers=2)\n",
"\n",
" self.emb_pos = init_(nn.Linear(2, hidden_size))\n",
" \n",
" tx_decoder_layer = nn.TransformerDecoderLayer(d_model=hidden_size, nhead=self.num_heads)\n",
" self.tx_decoder = nn.TransformerDecoder(tx_decoder_layer, num_layers=2)\n",
"\n",
" self.pos_encoder = PositionalEncoding(hidden_size, dropout=0.0)\n",
"\n",
" self.emb_intention = nn.Sequential(\n",
" init_(nn.Linear(num_modes, hidden_size))\n",
" )\n",
" self.emb_posint = nn.Sequential(\n",
" init_(nn.Linear(2*hidden_size, hidden_size)), nn.ReLU(),\n",
" init_(nn.Linear(hidden_size, hidden_size))\n",
" )\n",
"\n",
" self.mode_parameters = nn.Parameter(torch.Tensor(1, num_modes, hidden_size))\n",
" nn.init.xavier_uniform_(self.mode_parameters)\n",
" self.prob_decoder = nn.TransformerDecoderLayer(d_model=hidden_size, nhead=8)\n",
" self.prob_predictor = init_(nn.Linear(hidden_size, 1))\n",
"\n",
" self.train()\n",
"\n",
" def generate_decoder_mask(self, seq_len, device):\n",
" ''' For masking out the subsequent info. '''\n",
" subsequent_mask = (torch.triu(torch.ones((seq_len, seq_len), device=device), diagonal=1)).bool()\n",
" return subsequent_mask\n",
"\n",
" def forward(self, ego_input_positions, ego_output_positions):\n",
" B = ego_input_positions.size(0)\n",
" horizon = ego_output_positions.size(1)\n",
" \n",
" # Encode all observations\n",
" encoded_obs = self.emb_pos(ego_input_positions).transpose(0, 1)\n",
"\n",
" # Add positional encoding\n",
" encoded_obs = self.pos_encoder(encoded_obs)\n",
"\n",
" # TX on input seqs\n",
" in_memory = self.tx_encoder(encoded_obs)\n",
" mode_probs = self.prob_decoder(self.mode_parameters.repeat(B, 1, 1).transpose(0, 1), in_memory).transpose(0,1)\n",
" mode_probs = F.softmax(self.prob_predictor(mode_probs).squeeze(-1), dim=1)\n",
"\n",
" intentions = torch.eye(self.num_modes).to(device=ego_input_positions.device).unsqueeze(0).repeat(B, 1, 1)\n",
" enc_intentions = self.emb_intention(intentions).view(B*self.num_modes, self.hidden_size).unsqueeze(0)\n",
" in_memory = in_memory.unsqueeze(2).repeat(1, 1, self.num_modes, 1).view(-1, B * self.num_modes, self.hidden_size)\n",
"\n",
" pred_obs = [ego_input_positions[:, -1].unsqueeze(1).repeat(1, self.num_modes, 1).view(B * self.num_modes, -1)]\n",
" dec_start_emb = self.emb_pos(torch.stack(pred_obs, dim=0))\n",
" dec_input_emb = dec_start_emb\n",
" for ts in range(horizon): # autoregressive rollout\n",
" T = len(dec_input_emb)\n",
" curr_intentions = enc_intentions.repeat(T, 1, 1)\n",
" out_emb = torch.cat((curr_intentions, dec_input_emb), dim=-1)\n",
" out_emb = self.emb_posint(out_emb)\n",
"\n",
" out_emb = self.pos_encoder(out_emb)\n",
" time_masks = self.generate_decoder_mask(seq_len=T, device=ego_input_positions.device)\n",
" out_seq = self.tx_decoder(out_emb, in_memory, tgt_mask=time_masks)\n",
" dec_input_emb = torch.cat((dec_start_emb, out_seq), dim=0)\n",
"\n",
" out_dists = self.output_model(out_seq).view(horizon, B, self.num_modes, -1).permute(2, 0, 1, 3)\n",
" return out_dists, mode_probs\n",
"\n",
"\n",
"class OutputModelBVG(nn.Module):\n",
" def __init__(self, hidden_size=64):\n",
" super(OutputModelBVG, self).__init__()\n",
" self.hidden_size = hidden_size\n",
" init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2))\n",
" self.observation_model = nn.Sequential(\n",
" init_(nn.Linear(hidden_size, hidden_size)), nn.ReLU(),\n",
" init_(nn.Linear(hidden_size, hidden_size)), nn.ReLU(),\n",
" init_(nn.Linear(hidden_size, 5))\n",
" )\n",
" self.min_stdev = 0.1\n",
"\n",
" def forward(self, agent_latent_state):\n",
" '''\n",
" :param agent_latent_state: the social state of the ego-agent (B, H).\n",
" :return: reward for current latent state\n",
" '''\n",
" pred_obs = self.observation_model(agent_latent_state)\n",
" x_mean = pred_obs[:, :, 0]\n",
" y_mean = pred_obs[:, :, 1]\n",
" x_sigma = F.softplus(pred_obs[:, :, 2]) + self.min_stdev\n",
" y_sigma = F.softplus(pred_obs[:, :, 3]) + self.min_stdev\n",
" rho = torch.tanh(pred_obs[:, :, 4]) * 0.9 # for stability\n",
" return torch.stack([x_mean, y_mean, x_sigma, y_sigma, rho], dim=2)\n",
"\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "hvQG71Au9i8Q"
},
"source": [
"## Utility Functions\n",
"We define some utility functions for plotting circles for the output distributions (mean and variance at each timestep) and for calculating the multimodal loss.\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Iwz5gzTE9i8S"
},
"source": [
"import numpy as np\n",
"import torch\n",
"from scipy import special\n",
"import torch.distributions as D\n",
"from torch.distributions import MultivariateNormal\n",
"from matplotlib.patches import Ellipse\n",
"\n",
"\n",
"def _plot_gaussian(dist, ax, color, zorder=0):\n",
" \"\"\"Plots the mean and 2-std ellipse of a given Gaussian\"\"\"\n",
" cov_val = dist[4] * dist[2] * dist[3]\n",
" mean = [dist[0], dist[1]]\n",
" covariance = np.array([[dist[2] ** 2, cov_val], [cov_val, dist[3] ** 2]])\n",
"\n",
" if covariance.ndim == 1:\n",
" covariance = np.diag(covariance)\n",
"\n",
" radius = np.sqrt(5.991) # for 95% confidence interval.\n",
" eigvals, eigvecs = np.linalg.eig(covariance)\n",
" axis = np.sqrt(eigvals) * radius\n",
" slope = eigvecs[1][0] / eigvecs[1][1]\n",
" angle = 180.0 * np.arctan(slope) / np.pi\n",
"\n",
" e = Ellipse(mean, 2 * axis[0], 2 * axis[1], angle=angle, fill=False, color=color, linewidth=1, zorder=zorder, alpha=1.0)\n",
" ax.add_artist(e)\n",
" e.set_clip_box(ax.bbox)\n",
" return ax\n",
"\n",
"\n",
"def get_BVG_distributions(pred):\n",
" '''\n",
" Transform the prediction tensor of dim (B, T, 5) to torch Multivariate Gaussians distributions.\n",
" '''\n",
" B = pred.size(0)\n",
" T = pred.size(1)\n",
" mu_x = pred[:, :, 0].unsqueeze(2)\n",
" mu_y = pred[:, :, 1].unsqueeze(2)\n",
" sigma_x = pred[:, :, 2]\n",
" sigma_y = pred[:, :, 3]\n",
" rho = pred[:, :, 4]\n",
"\n",
" cov = torch.zeros((B, T, 2, 2)).to(pred.device)\n",
" cov[:, :, 0, 0] = sigma_x ** 2\n",
" cov[:, :, 1, 1] = sigma_y ** 2\n",
" cov_val = rho * sigma_x * sigma_y\n",
" cov[:, :, 0, 1] = cov_val\n",
" cov[:, :, 1, 0] = cov_val\n",
"\n",
" biv_gauss_dist = MultivariateNormal(loc=torch.cat((mu_x, mu_y), dim=-1), covariance_matrix=cov)\n",
" return biv_gauss_dist\n",
"\n",
"\n",
"def nll_pytorch_dist(pred, data):\n",
" '''\n",
" Args:\n",
" pred: [B, T, 5]\n",
" data: [B, T, 2]\n",
" This function computes the negative log-likelihood of the data given the predicted distributions.\n",
" Returns the nll vector for all elements in the batch.\n",
" '''\n",
" biv_gauss_dist = get_BVG_distributions(pred)\n",
" loss = -biv_gauss_dist.log_prob(data).sum(1) # sum over all timesteps\n",
" return loss # [B]\n",
"\n",
"\n",
"def nll_loss_multimodes(pred, data, modes_pred, entropy_weight=1.0, val_nll=False, kl_weight=1.0):\n",
" \"\"\"NLL loss multimodes for training. MFP Loss function\n",
" Args:\n",
" pred: [K, T, B, 5]\n",
" data: [B, T, 2]\n",
" modes_pred: [B, K], prior prob over modes\n",
" \"\"\"\n",
" K = len(pred)\n",
" T, B, dim = pred[0].shape\n",
"\n",
" # Here, we compute the log-likelihood of the data given the predicted distributions, p(y|z,x). \n",
" # This part is used in combination with the predicted prior distribution p(z|x) to compute the posterior p(z|y,x).\n",
" log_lik = np.zeros((B, K))\n",
" with torch.no_grad():\n",
" for kk in range(K):\n",
" nll = nll_pytorch_dist(pred[kk].transpose(0, 1), data)\n",
" log_lik[:, kk] = -nll.cpu().numpy()\n",
"\n",
" # The following is an application of Bayes Rule.\n",
" priors = modes_pred.detach().cpu().numpy()\n",
" log_post_unnorm = log_lik + np.log(priors)\n",
" log_post = log_post_unnorm - special.logsumexp(log_post_unnorm, axis=1).reshape((B, 1))\n",
" post_prob = np.exp(log_post)\n",
" post_prob = torch.tensor(post_prob).float().to(data.device)\n",
"\n",
" # Using the computed posterior, we now can compute the data negative loglikelihood exactly.\n",
" loss = 0.0\n",
" for kk in range(K):\n",
" nll_k = nll_pytorch_dist(pred[kk].transpose(0, 1), data) * post_prob[:, kk]\n",
" loss += nll_k.sum() / float(B)\n",
"\n",
" # Compute the KL divergence between p(z|x) and p(z|x,y).\n",
" kl_loss = torch.nn.KLDivLoss(reduction='batchmean')\n",
" loss += kl_weight*kl_loss(torch.log(modes_pred), post_prob)\n",
"\n",
" # The entropy regularization term.\n",
" if not val_nll:\n",
" entropy_vals = []\n",
" for kk in range(K):\n",
" entropy_vals.append(get_BVG_distributions(pred[kk]).entropy())\n",
" entropy_loss = torch.mean(torch.stack(entropy_vals).permute(2, 0, 1).sum(2).max(1)[0])\n",
" loss += entropy_weight*entropy_loss\n",
"\n",
" return loss\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "aXJMdKqU9i8U"
},
"source": [
"## Training Loop\n",
"The training loop takes about 10 minutes on a single-GPU machine."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6mxWEJKd9i8V",
"outputId": "700362c2-e517-4d47-e169-6c7700cf8aed"
},
"source": [
"import torch\n",
"from torch import optim\n",
"import torch.distributions as D\n",
"\n",
"\n",
"num_modes = 10\n",
"hidden_size = 64\n",
"learning_rate = 0.000075\n",
"entropy_weight = 10.0 # turn this up/down to see the effect on the variance.\n",
"seed = 0\n",
"np.random.seed(seed)\n",
"\n",
"if torch.cuda.is_available():\n",
" device = torch.device(\"cuda\")\n",
" torch.cuda.manual_seed(seed)\n",
"else:\n",
" device = torch.device(\"cpu\")\n",
"\n",
"# Initialize model\n",
"autobot_model = AutoBot(hidden_size=hidden_size, num_modes=num_modes).to(device)\n",
"optimiser = optim.Adam(autobot_model.parameters(), lr=learning_rate, eps=1e-4)\n",
"\n",
"# Initialize dataloader\n",
"train_nuscenes = NPYFakeDataset()\n",
"train_loader = torch.utils.data.DataLoader(train_nuscenes, batch_size=6, shuffle=True, num_workers=3, drop_last=True, pin_memory=True)\n",
"\n",
"total_steps = 0\n",
"losses = []\n",
"for train_iter in range(0, 3000):\n",
" for i, data in enumerate(train_loader):\n",
" ego_in, ego_out = data\n",
" ego_in = ego_in.float().to(device)\n",
" ego_out = ego_out.float().to(device)\n",
"\n",
" # encode observations\n",
" pred_obs, modes_pred = autobot_model(ego_in, ego_out)\n",
"\n",
" # Compute the loss.\n",
" loss = nll_loss_multimodes(pred_obs, ego_out[:, :, :2], modes_pred, entropy_weight=entropy_weight)\n",
"\n",
" # A measure of the entropy of the output distributions.\n",
" sigmas = pred_obs[:, :, :, 2:4]\n",
" sigma_magnitude = torch.mean(torch.norm(sigmas, dim=-1))\n",
"\n",
" optimiser.zero_grad()\n",
" loss.backward()\n",
" torch.nn.utils.clip_grad_norm_(autobot_model.parameters(), 0.5)\n",
" optimiser.step()\n",
"\n",
" # Store (0) observation loss (1) reward loss (2) KL loss\n",
" losses.append(loss.item())\n",
"\n",
" if train_iter % 50 == 0:\n",
" print(train_iter, \"Obs_Loss\", losses[-1], \"Prior Entropy\", torch.mean(D.Categorical(modes_pred).entropy()).item(), \"Sigma Magnitude\", sigma_magnitude.item())\n"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"0 Obs_Loss 11723.958984375 Prior Entropy 1.857177734375 Sigma Magnitude 1.3646622896194458\n",
"50 Obs_Loss 306.9662170410156 Prior Entropy 2.06771183013916 Sigma Magnitude 0.9929002523422241\n",
"100 Obs_Loss 234.70278930664062 Prior Entropy 2.0464229583740234 Sigma Magnitude 0.8464294075965881\n",
"150 Obs_Loss 219.2841339111328 Prior Entropy 1.9615720510482788 Sigma Magnitude 0.7770584225654602\n",
"200 Obs_Loss 198.8031463623047 Prior Entropy 2.008603811264038 Sigma Magnitude 0.673245906829834\n",
"250 Obs_Loss 138.7711639404297 Prior Entropy 2.072270393371582 Sigma Magnitude 0.6537865400314331\n",
"300 Obs_Loss 123.04424285888672 Prior Entropy 2.0243639945983887 Sigma Magnitude 0.6342011094093323\n",
"350 Obs_Loss 96.79385375976562 Prior Entropy 2.082216501235962 Sigma Magnitude 0.6264477968215942\n",
"400 Obs_Loss 81.11720275878906 Prior Entropy 2.120720863342285 Sigma Magnitude 0.619354784488678\n",
"450 Obs_Loss 67.96978759765625 Prior Entropy 2.129866123199463 Sigma Magnitude 0.5915012359619141\n",
"500 Obs_Loss 98.8485336303711 Prior Entropy 2.142521858215332 Sigma Magnitude 0.5233737826347351\n",
"550 Obs_Loss 142.97808837890625 Prior Entropy 2.0970046520233154 Sigma Magnitude 0.6683558225631714\n",
"600 Obs_Loss 68.43778228759766 Prior Entropy 2.091160297393799 Sigma Magnitude 0.6223364472389221\n",
"650 Obs_Loss 29.45511245727539 Prior Entropy 2.0658717155456543 Sigma Magnitude 0.5888274908065796\n",
"700 Obs_Loss 17.957138061523438 Prior Entropy 2.093405246734619 Sigma Magnitude 0.6362965106964111\n",
"750 Obs_Loss 31.798233032226562 Prior Entropy 2.0856008529663086 Sigma Magnitude 0.5461905002593994\n",
"800 Obs_Loss 10.051948547363281 Prior Entropy 2.090461254119873 Sigma Magnitude 0.6038259267807007\n",
"850 Obs_Loss 59.378623962402344 Prior Entropy 2.0865261554718018 Sigma Magnitude 0.5151104927062988\n",
"900 Obs_Loss -6.811798095703125 Prior Entropy 2.1136908531188965 Sigma Magnitude 0.5225884914398193\n",
"950 Obs_Loss -10.759872436523438 Prior Entropy 2.118497371673584 Sigma Magnitude 0.6135889291763306\n",
"1000 Obs_Loss -33.444671630859375 Prior Entropy 2.0734384059906006 Sigma Magnitude 0.5419284105300903\n",
"1050 Obs_Loss -35.84764862060547 Prior Entropy 2.1095805168151855 Sigma Magnitude 0.5613827109336853\n",
"1100 Obs_Loss -29.234100341796875 Prior Entropy 2.1014890670776367 Sigma Magnitude 0.5027151703834534\n",
"1150 Obs_Loss -34.968345642089844 Prior Entropy 2.1073012351989746 Sigma Magnitude 0.5491200685501099\n",
"1200 Obs_Loss -39.077980041503906 Prior Entropy 2.100574016571045 Sigma Magnitude 0.4987635016441345\n",
"1250 Obs_Loss -48.33007049560547 Prior Entropy 2.1208224296569824 Sigma Magnitude 0.510762631893158\n",
"1300 Obs_Loss -58.208953857421875 Prior Entropy 2.0919651985168457 Sigma Magnitude 0.4930139482021332\n",
"1350 Obs_Loss -87.03924560546875 Prior Entropy 2.118295669555664 Sigma Magnitude 0.4436866343021393\n",
"1400 Obs_Loss -47.048927307128906 Prior Entropy 2.097531795501709 Sigma Magnitude 0.44906753301620483\n",
"1450 Obs_Loss -72.15503692626953 Prior Entropy 2.092167854309082 Sigma Magnitude 0.4120866358280182\n",
"1500 Obs_Loss -80.47696685791016 Prior Entropy 2.103466749191284 Sigma Magnitude 0.4134567379951477\n",
"1550 Obs_Loss -94.25555419921875 Prior Entropy 2.080009698867798 Sigma Magnitude 0.36127781867980957\n",
"1600 Obs_Loss -88.59275817871094 Prior Entropy 2.0875422954559326 Sigma Magnitude 0.3591216206550598\n",
"1650 Obs_Loss -124.26649475097656 Prior Entropy 2.093277931213379 Sigma Magnitude 0.34293627738952637\n",
"1700 Obs_Loss -108.07246398925781 Prior Entropy 2.0913078784942627 Sigma Magnitude 0.29710695147514343\n",
"1750 Obs_Loss -133.356689453125 Prior Entropy 2.0764667987823486 Sigma Magnitude 0.2864120900630951\n",
"1800 Obs_Loss -150.91000366210938 Prior Entropy 2.1066884994506836 Sigma Magnitude 0.27043548226356506\n",
"1850 Obs_Loss -144.8616943359375 Prior Entropy 2.0808260440826416 Sigma Magnitude 0.2537614107131958\n",
"1900 Obs_Loss -142.93882751464844 Prior Entropy 2.0847625732421875 Sigma Magnitude 0.2622417211532593\n",
"1950 Obs_Loss -174.2861328125 Prior Entropy 2.078876495361328 Sigma Magnitude 0.23730503022670746\n",
"2000 Obs_Loss -148.6193389892578 Prior Entropy 2.062580108642578 Sigma Magnitude 0.24768507480621338\n",
"2050 Obs_Loss -131.22848510742188 Prior Entropy 2.0892975330352783 Sigma Magnitude 0.21530592441558838\n",
"2100 Obs_Loss -183.94528198242188 Prior Entropy 2.0834622383117676 Sigma Magnitude 0.2223784178495407\n",
"2150 Obs_Loss -139.89260864257812 Prior Entropy 2.076235055923462 Sigma Magnitude 0.19737689197063446\n",
"2200 Obs_Loss -177.97433471679688 Prior Entropy 2.0840203762054443 Sigma Magnitude 0.20732863247394562\n",
"2250 Obs_Loss -194.98463439941406 Prior Entropy 2.1062493324279785 Sigma Magnitude 0.20598143339157104\n",
"2300 Obs_Loss -177.98443603515625 Prior Entropy 2.083662748336792 Sigma Magnitude 0.19716553390026093\n",
"2350 Obs_Loss -183.91876220703125 Prior Entropy 2.076279640197754 Sigma Magnitude 0.19121479988098145\n",
"2400 Obs_Loss -201.8154296875 Prior Entropy 2.092423677444458 Sigma Magnitude 0.19742748141288757\n",
"2450 Obs_Loss -189.528076171875 Prior Entropy 2.092670440673828 Sigma Magnitude 0.1925940066576004\n",
"2500 Obs_Loss -194.66998291015625 Prior Entropy 2.0779054164886475 Sigma Magnitude 0.17988410592079163\n",
"2550 Obs_Loss -200.75863647460938 Prior Entropy 2.0983500480651855 Sigma Magnitude 0.18578983843326569\n",
"2600 Obs_Loss -199.38670349121094 Prior Entropy 2.093242645263672 Sigma Magnitude 0.18134918808937073\n",
"2650 Obs_Loss -192.5203857421875 Prior Entropy 2.089531421661377 Sigma Magnitude 0.18082661926746368\n",
"2700 Obs_Loss -187.53192138671875 Prior Entropy 2.0818114280700684 Sigma Magnitude 0.17122453451156616\n",
"2750 Obs_Loss -207.5955352783203 Prior Entropy 2.062140941619873 Sigma Magnitude 0.17080239951610565\n",
"2800 Obs_Loss -189.11380004882812 Prior Entropy 2.0947937965393066 Sigma Magnitude 0.17124949395656586\n",
"2850 Obs_Loss -198.44921875 Prior Entropy 2.0777957439422607 Sigma Magnitude 0.16536590456962585\n",
"2900 Obs_Loss -217.40306091308594 Prior Entropy 2.0797410011291504 Sigma Magnitude 0.16628232598304749\n",
"2950 Obs_Loss -228.67001342773438 Prior Entropy 2.0792946815490723 Sigma Magnitude 0.1648883819580078\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Tw0HaOtk9i8W"
},
"source": [
"## Testing Mode Learning on Toy Dataset\n",
"\n",
"This consistutes the results shown in Figure 5 (bottom row) of the paper. To get the results corresponding to the middle row, reduce the `entropy_weight` in the training block above and repeat training.\n"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 596
},
"id": "9o6cRo6y9i8W",
"outputId": "021a3675-9863-4103-b5a8-e41e0534b706"
},
"source": [
"autobot_model.eval()\n",
"with torch.no_grad():\n",
" for i, data in enumerate(train_loader):\n",
" ego_in, ego_out = data\n",
" ego_in = ego_in.float().to(device)\n",
" ego_out = ego_out.float().to(device)\n",
"\n",
" pred_obs, mode_preds = autobot_model(ego_in, ego_out)\n",
" pred_positions = pred_obs[:, :, 0, :2].squeeze().cpu().numpy()\n",
" mode_probs_np = mode_preds[0].squeeze().cpu().numpy()\n",
" pred_distributions = pred_obs[:, :, 0].squeeze().cpu().numpy()\n",
"\n",
" top_6_modes = mode_probs_np.argsort()[-6:][::-1]\n",
" fig, ax = plt.subplots(nrows=2, ncols=5, figsize=(15, 10))\n",
" row = 0\n",
" for k_idx in range(num_modes):\n",
" col = k_idx % 5\n",
" if k_idx > 0 and k_idx % 5 == 0:\n",
" row += 1\n",
" k = k_idx\n",
" ax[row, col].scatter(pred_positions[k, :, 0], pred_positions[k, :, 1], s=10, color='k')\n",
"\n",
" for t in range(12):\n",
" ax[row, col] = _plot_gaussian(pred_distributions[k, t], ax[row, col], color='#966BFF')\n",
" ax[row, col].axis(xmin=-15, xmax=15, ymin=-15, ymax=20)\n",
"\n",
" plt.show()\n"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 1080x720 with 10 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment