Created
November 23, 2019 21:06
-
-
Save renan-cunha/60caed7b1cbe1026b7072fadddeae99c to your computer and use it in GitHub Desktop.
HER not working with modified version of bitflipping that has max_step = 1
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import OrderedDict | |
import numpy as np | |
from gym import GoalEnv, spaces | |
class BitFlippingEnv(GoalEnv): | |
""" | |
Simple bit flipping env, useful to test HER. | |
The goal is to flip all the bits to get a vector of ones. | |
In the continuous variant, if the ith action component has a value > 0, | |
then the ith bit will be flipped. | |
:param n_bits: (int) Number of bits to flip | |
:param continuous: (bool) Whether to use the continuous actions version or not, | |
by default, it uses the discrete one | |
:param max_steps: (int) Max number of steps, by default, equal to n_bits | |
:param discrete_obs_space: (bool) Whether to use the discrete observation | |
version or not, by default, it uses the MultiBinary one | |
""" | |
def __init__(self, n_bits=10, continuous=False, max_steps=None, | |
discrete_obs_space=False): | |
super(BitFlippingEnv, self).__init__() | |
# The achieved goal is determined by the current state | |
# here, it is a special where they are equal | |
if discrete_obs_space: | |
# In the discrete case, the agent act on the binary | |
# representation of the observation | |
self.observation_space = spaces.Dict({ | |
'observation': spaces.Discrete(2 ** n_bits - 1), | |
'achieved_goal': spaces.Discrete(2 ** n_bits - 1), | |
'desired_goal': spaces.Discrete(2 ** n_bits - 1) | |
}) | |
else: | |
self.observation_space = spaces.Dict({ | |
'observation': spaces.MultiBinary(n_bits), | |
'achieved_goal': spaces.MultiBinary(n_bits), | |
'desired_goal': spaces.MultiBinary(n_bits) | |
}) | |
self.obs_space = spaces.MultiBinary(n_bits) | |
if continuous: | |
self.action_space = spaces.Box(-1, 1, shape=(n_bits,), dtype=np.float32) | |
else: | |
self.action_space = spaces.Discrete(n_bits) | |
self.continuous = continuous | |
self.discrete_obs_space = discrete_obs_space | |
self.state = None | |
self.desired_goal = np.ones((n_bits,)) | |
if max_steps is None: | |
max_steps = n_bits | |
self.max_steps = max_steps | |
self.current_step = 0 | |
self.reset() | |
def convert_if_needed(self, state): | |
""" | |
Convert to discrete space if needed. | |
:param state: (np.ndarray) | |
:return: (np.ndarray or int) | |
""" | |
if self.discrete_obs_space: | |
# The internal state is the binary representation of the | |
# observed one | |
return int(sum([state[i] * 2**i for i in range(len(state))])) | |
return state | |
def _get_obs(self): | |
""" | |
Helper to create the observation. | |
:return: (OrderedDict<int or ndarray>) | |
""" | |
return OrderedDict([ | |
('observation', self.convert_if_needed(self.state.copy())), | |
('achieved_goal', self.convert_if_needed(self.state.copy())), | |
('desired_goal', self.convert_if_needed(self.desired_goal.copy())) | |
]) | |
def reset(self): | |
self.current_step = 0 | |
self.state = self.obs_space.sample() | |
return self._get_obs() | |
def step(self, action): | |
if self.continuous: | |
self.state[action > 0] = 1 - self.state[action > 0] | |
else: | |
self.state[action] = 1 - self.state[action] | |
obs = self._get_obs() | |
reward = self.compute_reward(obs['achieved_goal'], obs['desired_goal'], None) | |
done = reward == 0 | |
self.current_step += 1 | |
# Episode terminate when we reached the goal or the max number of steps | |
info = {'is_success': done} | |
done = done or self.current_step >= self.max_steps | |
return obs, reward, True, info | |
def compute_reward(self, | |
achieved_goal: np.ndarray, | |
desired_goal: np.ndarray, | |
_info) -> float: | |
# Deceptive reward: it is positive only when the goal is achieved | |
print(f"desired_goal {desired_goal}") | |
if not np.all(desired_goal): | |
exit() | |
if self.discrete_obs_space: | |
return 0.0 if achieved_goal == desired_goal else -1.0 | |
return 0.0 if (achieved_goal == desired_goal).all() else -1.0 | |
def render(self, mode='human'): | |
if mode == 'rgb_array': | |
return self.state.copy() | |
print(self.state) | |
def close(self): | |
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from stable_baselines import HER, DQN, SAC, DDPG, TD3 | |
from stable_baselines.her import GoalSelectionStrategy, HERGoalEnvWrapper | |
from bit_flipping_env import BitFlippingEnv | |
model_class = DQN # works also with SAC, DDPG and TD3 | |
N_BITS = 16 | |
env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC, TD3], max_steps=N_BITS) | |
# Available strategies (cf paper): future, final, episode, random | |
goal_selection_strategy = 'future' # equivalent to GoalSelectionStrategy.FUTURE | |
# Wrap the model | |
model = HER('MlpPolicy', env, model_class, n_sampled_goal=4, goal_selection_strategy=goal_selection_strategy, | |
verbose=1) | |
# Train the model | |
model.learn(1000) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment