This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# inspired by the post: https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/ | |
# tl;dr | |
# If you are using numpy random generator with `torch.utils.data.Dataset`, | |
# you might get identical results either across different workers or epochs | |
# disclaimer: this might not be the best choice since setting worker to be persistent requires additional RAM. | |
# Welcome for any idea | |
# Here's a simple fix with torch>=1.7.0 | |
# See the original example here: https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/#a-minimal-example |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{"0": "tench", "1": "goldfish", "2": "great white shark", "3": "tiger shark", "4": "hammerhead", "5": "electric ray", "6": "stingray,", "7": "cock,", "8": "hen,", "9": "ostrich", "10": "brambling", "11": "goldfinch", "12": "house finch", "13": "junco", "14": "indigo bunting", "15": "robin", "16": "bulbul,", "17": "jay,", "18": "magpie,", "19": "chickadee,", "20": "water ouzel", "21": "kite,", "22": "bald eagle", "23": "vulture,", "24": "great grey owl", "25": "European fire salamander", "26": "common newt", "27": "eft,", "28": "spotted salamander", "29": "axolotl", "30": "bullfrog", "31": "tree frog", "32": "tailed frog", "33": "loggerhead", "34": "leatherback turtle", "35": "mud turtle,", "36": "terrapin,", "37": "box turtle", "38": "banded gecko,", "39": "common iguana", "40": "American chameleon", "41": "whiptail", "42": "agama,", "43": "frilled lizard", "44": "alligator lizard,", "45": "Gila monster", "46": "green lizard", "47": "African chameleon", "48": "Komodo dragon", "49": "African crocodile", "50": |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import imageio | |
import numpy as np | |
from utils import * | |
mode = 'sgd' # sgd, fisher, or dig_fisher | |
X_train, X_test, t_train, t_test = get_data() | |
W = get_model() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import imageio | |
import numpy as np | |
import seaborn | |
import matplotlib.pyplot as plt | |
import matplotlib | |
torch.manual_seed(1) | |
# data generation: y = ax + b |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# using pytorch==0.4.0 | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn.parameter import Parameter | |
from torch.nn.modules.rnn import RNNCellBase | |
from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import gym | |
import numpy as np | |
from itertools import count | |
from collections import namedtuple, deque | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# code locate in baselines/gail | |
def sample(algo, load_model_path, policy_fn): | |
assert algo in ['trpo', 'ppo', 'acktr', 'ddpg', 'a2c'] | |
if algo in ['trpo', 'ppo']: | |
with tf.Session() as sess: | |
# manually build graph | |
policy = policy_fn() | |
# load model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# for model in `trpo_mpi`, `ppo` | |
class CnnPolicy(): | |
def __init__(): | |
# build graph | |
_ = conv2d() | |
_ = conv2d() | |
def step(): | |
sess.run(act, feed_dict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# modified from https://gist.github.com/sorenbouma/6502fbf55ecdf988aa247ef7f60a9546 | |
import gym | |
import numpy as np | |
import matplotlib.pyplot as plt | |
env = gym.make('CartPole-v0') | |
env.render(close=True) | |
#vector of means(mu) and standard dev(sigma) for each paramater | |
mu=np.random.uniform(size=env.observation_space.shape) | |
sigma=np.random.uniform(low=0.001,size=env.observation_space.shape) |