This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class BatchedGP(ExactGP): | |
"""Class for creating batched Gaussian Process Regression models. Ideal candidate if | |
using GPU-based acceleration such as CUDA for training. | |
Parameters: | |
train_x (torch.tensor): The training features used for Gaussian Process | |
Regression. These features will take shape (B * YD, N, XD), where: | |
(i) B is the batch dimension - minibatch size | |
(ii) N is the number of data points per GPR - the neighbors considered | |
(iii) XD is the dimension of the features (d_state + d_action) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class CompositeBatchedGP(ExactGP): | |
"""Class for creating batched Gaussian Process Regression models. Ideal candidate if | |
using GPU-based acceleration such as CUDA for training. | |
This kernel produces a composite kernel that multiplies actions times states, | |
i.e. we have a different kernel for both the actions and states. In turn, | |
the composite kernel is then multiplied by a Scale kernel. | |
Parameters: | |
train_x (torch.tensor): The training features used for Gaussian Process |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train_gp_batched_scalar(Zs, Ys, use_cuda=False, epochs=10, | |
lr=0.1, thr=0, use_ard=False, composite_kernel=False, | |
ds=None, global_hyperparams=False, | |
model_hyperparams=None): | |
"""Computes a Gaussian Process object using GPyTorch. Each outcome is | |
modeled as a single scalar outcome. | |
Parameters: | |
Zs (np.array): Array of inputs of expanded shape (B, N, XD), where B is | |
the size of the minibatch, N is the number of data points in each |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Wrapper placed around Gym Environments enabling easier multi-agent | |
reinforcement learning. Compatible with single-agent RL environments as well.""" | |
import tensorflow as tf | |
import numpy as np | |
class ObservationWrapper: | |
""" Class for stacking and processing frame observations. | |
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gym | |
import gym_multi_car_racingenv = gym.make("MultiCarRacing-v0", num_agents=2, direction='CCW', | |
use_random_direction=True, backwards_flag=True, | |
h_ratio=0.25, use_ego_color=False) | |
obs = env.reset() | |
done = False | |
total_reward = 0 | |
while not done: | |
# The actions have to be of the format (num_agents,3) | |
# The action format for each car is as in the CarRacing env. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" Class object and functions for creating, training, and evaluating PPO agents | |
using the TensorFlow Agents API. """ | |
# Native Python imports | |
import os | |
import argparse | |
from datetime import datetime | |
import pickle | |
# TensorFlow and tf-agents |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Tester script for GPyTorch using analytic sine functions.""" | |
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import time | |
from sklearn.metrics import mean_squared_error as mse | |
def main(): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import numpy as np | |
p_values = [1, 1.5, 2] | |
labels = ["Lasso (L=1)", "ElasticNet (L=1.5)", "Ridge (L=2)"] | |
xx, yy = np.meshgrid(np.linspace(-3, 3, num=101), np.linspace(-3, 3, num=101)) | |
fig, axes = plt.subplots(ncols=3, figsize=(28, 7)) | |
for p, ax, l in zip(p_values, axes.flat, labels): | |
if p == 0: | |
zz = (xx != 0).astype(int) + (yy != 0).astype(int) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Functions and script to generate a reward surface with state and | |
action inputs. Configured for the Pendulum-v0 Gym environment.""" | |
import gym | |
import matplotlib.pyplot as plt | |
import numpy as np | |
# Parameters | |
SAMPLES = 50 | |
SEED = 42 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Module for implementing an Audi Autonomous Driving Dataset (A2D2) DataLoader | |
designed for performing data fusion betweeen 2D RGB images and 3D Lidar point clouds. | |
This is the FULL class.""" | |
import torch | |
import torch.nn as nn | |
from torch.utils.data import Dataset, DataLoader | |
import numpy as np | |
import os | |
import matplotlib.pyplot as plt | |
import cv2 as cv |
OlderNewer