Skip to content

Instantly share code, notes, and snippets.

@ugo-nama-kun
Last active August 2, 2021 10:50
Show Gist options
  • Save ugo-nama-kun/782c92b92939f707d46d76e91dfe040f to your computer and use it in GitHub Desktop.
Save ugo-nama-kun/782c92b92939f707d46d76e91dfe040f to your computer and use it in GitHub Desktop.
A summary of seeding in some packages
import random
import functools
import numpy as np
import tensorflow as tf
import torch
import gym
import tianshou as ts
import pfrl
RANDOM_SEED = 0
# python
random.seed(RANDOM_SEED)
# numpy
np.random.seed(RANDOM_SEED)
# TF
tf.random.set_seed(RANDOM_SEED)
# pytorch
torch.manual_seed(RANDOM_SEED)
# gym
env = gym.make("Pendulum-v0")
env.seed(RANDOM_SEED)
env.action_space.seed(RANDOM_SEED)
# Tianshou (based on Pytorch)
vector_env_t = ts.env.DummyVectorEnv([lambda: gym.make("Pendulum-v0") for _ in range(3)])
vector_env_t.seed(RANDOM_SEED)
# PFRL (based on Pytorch)
# pfrl.utils.set_random_seed(RANDOM_SEED) # same with the seeding above
def make_env(idx):
return gym.make("Pendulum-v0")
vector_env_p = pfrl.envs.MultiprocessVectorEnv([functools.partial(make_env, idx) for idx in range(3)])
vector_env_p.seed(seeds=[RANDOM_SEED + i for i in range(3)])
# print result:
print("python: ",random.random())
print("numpy: ", np.random.rand())
print("tf: ", tf.random.normal((1,)))
print("torch: ", torch.rand(1))
print("gym obs: ", env.reset())
print("gym action: ", env.action_space.sample())
print("vector tianshou: ", vector_env_t.reset())
print("vector pfrl: ", vector_env_p.reset())
""" You'll get those results exactly by running this code.
python: 0.8444218515250481
numpy: 0.5488135039273248
tf: tf.Tensor([1.5110626], shape=(1,), dtype=float32)
torch: tensor([0.4963])
gym obs: [-0.94223519 -0.33495202 0.93078187]
gym action: [-1.7825598]
vector tianshou: [[-0.94223519 -0.33495202 0.93078187]
[-0.35283079 0.93568715 0.02900014]
[-0.57182853 0.82037317 0.30004494]]
vector pfrl: [array([-0.94223519, -0.33495202, 0.93078187]), array([-0.35283079, 0.93568715, 0.02900014]), array([-0.57182853, 0.82037317, 0.30004494])]
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment