This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import multiprocessing | |
import os | |
from metadrive.utils.waymo_utils.script.convert_waymo_to_metadrive import parse_data | |
from tqdm.auto import tqdm | |
REPO_ROOT = os.path.dirname(os.path.dirname(__file__)) | |
DATA_FOLDER = os.path.join(REPO_ROOT, "trafficgen_v2", "data", "waymo") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import logging | |
import pickle | |
from collections import deque | |
import numpy as np | |
import ray | |
from ray import tune | |
from ray.rllib.agents.es.es import ESTrainer, utils | |
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
#!/bin/bash | |
# Please copy this string to a file name "sbatch_template.sh". | |
# THIS FILE IS GENERATED BY AUTOMATION SCRIPT! PLEASE REFER TO ORIGINAL SCRIPT! | |
# THIS FILE IS A TEMPLATE AND IT SHOULD NOT BE DEPLOYED TO PRODUCTION! | |
#SBATCH --partition={{PARTITION_NAME}} | |
#SBATCH --job-name={{JOB_NAME}} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
x = "timesteps_total" | |
y = "novelty_reward_mean" | |
ppo_path = "/home/sunhao/ray_results/0131-dece-use_bisector-use_bisector" | |
clip = None | |
dece_path = "/home/sunhao/ray_results/0102-dece" | |
dece_df = parse(dece_path, | |
["seed", "constrain_novelty", "delay_update", "replay_values", "use_diversity_value_network"], |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import numpy as np | |
class NumpyEncoder(json.JSONEncoder): | |
""" Special json encoder for numpy types """ | |
def default(self, obj): | |
if isinstance(obj, (np.int_, np.intc, np.intp, np.int8, | |
np.int16, np.int32, np.int64, np.uint8, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from toolbox.process_data import parse | |
data3 = [] | |
a2c_data3 = parse("/home/zhpeng/OLD-novel-rl-archive-20200210/0206-a2c-baseline-large-Walker2d-v3-Walker2d-v3-50000000ts.pkl") | |
a2c_data3["label"] = "A2C" | |
a2c_data3.episode_reward_mean *= 5 | |
data3.append(a2c_data3) | |
ppo_data3 = parse("/home/zhpeng/0119-ppo-baseline-walker-Walker2d-v3-50000000ts.pkl") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import IPython | |
import tempfile | |
import PIL | |
def animate(img_array): | |
path = tempfile.mkstemp(suffix=".gif")[1] | |
images = [PIL.Image.fromarray(frame) for frame in img_array] | |
images[0].save( | |
path, | |
save_all=True, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import seaborn as sns | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import numpy as np | |
# agents_name = ["ours", "DDPG"] | |
x_name = "uniqueness" | |
y_name = "performance" | |
dpi = 300 | |
figsize = (12, 8) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from gym.envs.box2d.bipedal_walker import ( | |
BipedalWalker, VIEWPORT_H, VIEWPORT_W, SCALE, TERRAIN_HEIGHT, TERRAIN_STEP | |
) | |
from Box2D.b2 import circleShape | |
import cv2 | |
import numpy as np | |
import copy | |
import uuid |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import deque | |
class MovingAverage(object): | |
def __init__(self, max_len=10): | |
self.val = deque(maxlen=max_len) | |
self.avg = 0 | |
self.maxlen = max_len | |
def update(self, val): | |
self.val.append(val) | |
self.avg = sum(self.val) / self.maxlen |
NewerOlder