Skip to content

Instantly share code, notes, and snippets.

@franroldans
Last active March 22, 2022 12:35
Show Gist options
  • Save franroldans/dd2c07c1017a1d7dd584fda7edef649b to your computer and use it in GitHub Desktop.
Save franroldans/dd2c07c1017a1d7dd584fda7edef649b to your computer and use it in GitHub Desktop.
sac_her_1_file
import time
import warnings
from os import listdir, makedirs
from typing import Any, Dict, Optional
import gym
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from stable_baselines3.common.callbacks import BaseCallback
from tensorflow.python.summary.summary_iterator import summary_iterator
from tianshou.data import (Batch, to_numpy)
from tqdm import tqdm
class LogStepsCallback(BaseCallback):
def __init__(self, log_dir, verbose=0):
self.log_dir = log_dir
super(LogStepsCallback, self).__init__(verbose)
def _on_training_start(self) -> None:
self.results = pd.DataFrame(columns=['Reward', 'Done'])
print("Τraining starts!")
def _on_step(self) -> bool:
if 'reward' in self.locals:
keys = ['reward', 'done']
else:
keys = ['rewards', 'dones']
self.results.loc[len(self.results)] = [self.locals[keys[0]][0], self.locals[keys[1]][0]]
return True
def _on_training_end(self) -> None:
self.results.to_csv(self.log_dir + 'training_data.csv', index=False)
print("Τraining ends!")
class TqdmCallback(BaseCallback):
def __init__(self):
super().__init__()
self.progress_bar = None
def _on_training_start(self):
self.progress_bar = tqdm(total=self.locals['total_timesteps'])
def _on_step(self):
self.progress_bar.update(1)
return True
def _on_training_end(self):
self.progress_bar.close()
self.progress_bar = None
def save_dict_to_file(dict, path, txt_name='hyperparameter_dict'):
f = open(path + '/' + txt_name + '.txt', 'w')
f.write(str(dict))
f.close()
def calc_episode_rewards(training_data):
# Calculate the rewards for each training episode
episode_rewards = []
temp_reward_sum = 0
for step in range(training_data.shape[0]):
reward, done = training_data.iloc[step, :]
temp_reward_sum += reward
if done:
episode_rewards.append(temp_reward_sum)
temp_reward_sum = 0
result = pd.DataFrame(columns=['Reward'])
result['Reward'] = episode_rewards
return result
def learning_curve(episode_rewards, log_dir, window=10):
# Calculate rolling window metrics
rolling_average = episode_rewards.rolling(window=window, min_periods=window).mean().dropna()
rolling_max = episode_rewards.rolling(window=window, min_periods=window).max().dropna()
rolling_min = episode_rewards.rolling(window=window, min_periods=window).min().dropna()
# Change column name
rolling_average.columns = ['Average Reward']
rolling_max.columns = ['Max Reward']
rolling_min.columns = ['Min Reward']
rolling_data = pd.concat([rolling_average, rolling_max, rolling_min], axis=1)
# Plot
sns.set()
plt.figure(0)
ax = sns.lineplot(data=rolling_data)
ax.fill_between(rolling_average.index, rolling_min.iloc[:, 0], rolling_max.iloc[:, 0], alpha=0.2)
ax.set_title('Learning Curve')
ax.set_ylabel('Reward')
ax.set_xlabel('Updates')
# Save figure
plt.savefig(log_dir + 'learning_curve' + str(window) + '.png')
def learning_curve_baselines(log_dir, window=10):
# Read data
training_data = pd.read_csv(log_dir + 'training_data.csv', index_col=None)
# Calculate episode rewards
episode_rewards = calc_episode_rewards(training_data)
learning_curve(episode_rewards=episode_rewards, log_dir=log_dir, window=window)
def learning_curve_tianshou(log_dir, window=10):
# Find event file
files = listdir(log_dir)
for f in files:
if 'events' in f:
event_file = f
break
# Read episode rewards
episode_rewards_list = []
episode_rewards = pd.DataFrame(columns=['Reward'])
try:
for e in summary_iterator(log_dir + event_file):
if len(e.summary.value) > 0:
if e.summary.value[0].tag == 'train/reward':
episode_rewards_list.append(e.summary.value[0].simple_value)
except Exception as e:
pass
episode_rewards['Reward'] = episode_rewards_list
# Learning curve
learning_curve(episode_rewards, log_dir, window=window)
def learning_curve_tianshou_multiple_runs(log_dirs, window=10):
episode_rewards_list = []
episode_rewards = pd.DataFrame(columns=['Reward'])
for log_dir in log_dirs:
# Find event file
files = listdir(log_dir)
for f in files:
if 'events' in f:
event_file = f
break
# Read episode rewards
try:
for e in summary_iterator(log_dir + event_file):
if len(e.summary.value) > 0:
if e.summary.value[0].tag == 'train/reward':
episode_rewards_list.append(e.summary.value[0].simple_value)
except Exception as e:
pass
episode_rewards['Reward'] = episode_rewards_list
# Learning curve
learning_curve(episode_rewards, log_dir, window=window)
def collect_and_record(self, video_dir, n_step: Optional[int] = None, n_episode: Optional[int] = None,
random: bool = False, render: Optional[float] = None, no_grad: bool = True,
) -> Dict[str, Any]:
"""Collect a specified number of step or episode.
To ensure unbiased sampling result with n_episode option, this function will
first collect ``n_episode - env_num`` episodes, then for the last ``env_num``
episodes, they will be collected evenly from each env.
:param int n_step: how many steps you want to collect.
:param int n_episode: how many episodes you want to collect.
:param bool random: whether to use random policy for collecting data. Default
to False.
:param float render: the sleep time between rendering consecutive frames.
Default to None (no rendering).
:param bool no_grad: whether to retain gradient in policy.forward(). Default to
True (no gradient retaining).
.. note::
One and only one collection number specification is permitted, either
``n_step`` or ``n_episode``.
:return: A dict including the following keys
* ``n/ep`` collected number of episodes.
* ``n/st`` collected number of steps.
* ``rews`` array of episode reward over collected episodes.
* ``lens`` array of episode length over collected episodes.
* ``idxs`` array of episode start index in buffer over collected episodes.
"""
assert not self.env.is_async, "Please use AsyncCollector if using async venv."
if n_step is not None:
assert n_episode is None, (
f"Only one of n_step or n_episode is allowed in Collector."
f"collect, got n_step={n_step}, n_episode={n_episode}."
)
assert n_step > 0
if not n_step % self.env_num == 0:
warnings.warn(
f"n_step={n_step} is not a multiple of #env ({self.env_num}), "
"which may cause extra transitions collected into the buffer."
)
ready_env_ids = np.arange(self.env_num)
elif n_episode is not None:
assert n_episode > 0
ready_env_ids = np.arange(min(self.env_num, n_episode))
self.data = self.data[:min(self.env_num, n_episode)]
else:
raise TypeError(
"Please specify at least one (either n_step or n_episode) "
"in AsyncCollector.collect()."
)
start_time = time.time()
step_count = 0
episode_count = 0
episode_rews = []
episode_lens = []
episode_start_indices = []
img_array_list = []
while True:
assert len(self.data) == len(ready_env_ids)
# restore the state: if the last state is None, it won't store
last_state = self.data.policy.pop("hidden_state", None)
# get the next action
if random:
self.data.update(
act=[self._action_space[i].sample() for i in ready_env_ids]
)
else:
if no_grad:
with torch.no_grad(): # faster than retain_grad version
# self.data.obs will be used by agent to get result
result = self.policy(self.data, last_state)
else:
result = self.policy(self.data, last_state)
# update state / act / policy into self.data
policy = result.get("policy", Batch())
assert isinstance(policy, Batch)
state = result.get("state", None)
if state is not None:
policy.hidden_state = state # save state into buffer
act = to_numpy(result.act)
if self.exploration_noise:
act = self.policy.exploration_noise(act, self.data)
self.data.update(policy=policy, act=act)
# get bounded and remapped actions first (not saved into buffer)
action_remap = self.policy.map_action(self.data.act)
# step in env
result = self.env.step(action_remap, ready_env_ids) # type: ignore
obs_next, rew, done, info = result
self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)
if self.preprocess_fn:
self.data.update(
self.preprocess_fn(
obs_next=self.data.obs_next,
rew=self.data.rew,
done=self.data.done,
info=self.data.info,
policy=self.data.policy,
env_id=ready_env_ids,
)
)
if render:
img_array = self.env.render(mode='rgb_array')
img_array = np.array(img_array)[0, :, :, :]
img_array = cv2.cvtColor(img_array, cv2.COLOR_BGR2RGB)
img_array_list.append(img_array)
if render > 0 and not np.isclose(render, 0):
time.sleep(render)
# add data into the buffer
ptr, ep_rew, ep_len, ep_idx = self.buffer.add(
self.data, buffer_ids=ready_env_ids
)
# collect statistics
step_count += len(ready_env_ids)
if np.any(done):
env_ind_local = np.where(done)[0]
env_ind_global = ready_env_ids[env_ind_local]
episode_count += len(env_ind_local)
episode_lens.append(ep_len[env_ind_local])
episode_rews.append(ep_rew[env_ind_local])
episode_start_indices.append(ep_idx[env_ind_local])
# now we copy obs_next to obs, but since there might be
# finished episodes, we have to reset finished envs first.
obs_reset = self.env.reset(env_ind_global)
if self.preprocess_fn:
obs_reset = self.preprocess_fn(
obs=obs_reset, env_id=env_ind_global
).get("obs", obs_reset)
self.data.obs_next[env_ind_local] = obs_reset
for i in env_ind_local:
self._reset_state(i)
# remove surplus env id from ready_env_ids
# to avoid bias in selecting environments
if n_episode:
surplus_env_num = len(ready_env_ids) - (n_episode - episode_count)
if surplus_env_num > 0:
mask = np.ones_like(ready_env_ids, dtype=bool)
mask[env_ind_local[:surplus_env_num]] = False
ready_env_ids = ready_env_ids[mask]
self.data = self.data[mask]
self.data.obs = self.data.obs_next
if (n_step and step_count >= n_step) or \
(n_episode and episode_count >= n_episode):
break
# generate statistics
self.collect_step += step_count
self.collect_episode += episode_count
self.collect_time += max(time.time() - start_time, 1e-9)
if n_episode:
self.data = Batch(
obs={}, act={}, rew={}, done={}, obs_next={}, info={}, policy={}
)
self.reset_env()
if episode_count > 0:
rews, lens, idxs = list(
map(
np.concatenate,
[episode_rews, episode_lens, episode_start_indices]
)
)
else:
rews, lens, idxs = np.array([]), np.array([], int), np.array([], int)
# Save video
width, height = img_array_list[0].shape[0], img_array_list[0].shape[1]
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
makedirs(video_dir)
video = cv2.VideoWriter(video_dir + 'video.mp4', fourcc, 60, (width, height))
for img in img_array_list:
video.write(img)
video.release()
save_dict_to_file({'reward': rews[0], 'length': lens[0]}, video_dir, txt_name='episode_stats')
return {
"n/ep": episode_count,
"n/st": step_count,
"rews": rews,
"lens": lens,
"idxs": idxs,
}
class Wrapper(gym.Wrapper):
"""Env wrapper for reward scale, action repeat and removing done penalty"""
def __init__(self, env, action_repeat=3, reward_scale=5, rm_done=True):
super().__init__(env)
self.action_repeat = action_repeat
self.reward_scale = reward_scale
self.rm_done = rm_done
def step(self, action):
r = 0.0
for _ in range(self.action_repeat):
obs, reward, done, info = self.env.step(action)
# remove done reward penalty
if not done or not self.rm_done:
r = r + reward
if done:
break
# scale reward
return obs, self.reward_scale * r, done, info
import time
import warnings
from typing import Any, Callable, Dict, Optional
import gym.spaces as space
import numpy as np
import torch
from tianshou.data import Batch, Collector, ReplayBuffer, to_numpy
from tianshou.env import BaseVectorEnv
from tianshou.policy import BasePolicy
class HERCollector(Collector):
"""Hindsight Experience Replay Collector.
The collector will construct hindsight trajectory from achieved goals
after one trajectory is fully collected.
HER Collector provides two methods for relabel: `online` and `offline`.
For details, please refer to https://arxiv.org/abs/1707.01495
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
:param env: a ``gym.Env`` environment or an instance of the
:class:`~tianshou.env.BaseVectorEnv` class.
:param dict_observation_space: a ``gym.spaces.Dict`` instance, which is
used to get goal and achieved goal in the flattened observation
:param function reward_fn: a function called to calculate reward.
Often defined as `env.compute_reward()`
:param str strategy: can be `online` or `offline`. `offline` strategy will add
relabeled data directly back to the buffer, while `online` strategy will store
the future achieved goal in `batch.info.achieved_goal`,
which can be used in `process_fn`to relabel data during the training process.
:param int replay_k: proportion of data to be relabeled.
For example, if `replay_k` is set to 4, then the collector will
generate 4 new trajectory with relabeled data.
:param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class.
If set to None, it will not store the data. Default to None.
:param function preprocess_fn: a function called before the data has been added to
the buffer, see issue #42 and :ref:`preprocess_fn`. Default to None.
:param bool exploration_noise: determine whether the action needs to be modified
with corresponding policy's exploration noise. If so, "policy.
exploration_noise(act, batch)" will be called automatically to add the
exploration noise into action. Default to False.
.. note::
1. According to the result reported in the paper, only future replay
is implemented in this collector.
2. Make use your environment's `info` has `achieved_goal` attribution
before use `online` replay strategy. it will be used for a Batch place holder.
3. Observation normalization in the environment is not recommended,
which bias the relabel.
4. Success rate is also provided in the return to monitor the training
progress.
"""
def __init__(
self,
policy: BasePolicy,
env: BaseVectorEnv,
dict_observation_space: space.Dict,
reward_fn: Callable[[np.ndarray, np.ndarray, Optional[dict]], np.ndarray],
replay_k: int = 4,
strategy: str = 'offline',
buffer: Optional[ReplayBuffer] = None,
preprocess_fn: Optional[Callable[..., Batch]] = None,
exploration_noise: bool = False,
) -> None:
# HER need dict observation space
self.dict_observation_space = dict_observation_space
self.reward_fn = reward_fn
assert replay_k > 0, f'Replay k = {replay_k}, it must be a positive integer'
self.replay_k = replay_k
assert strategy == 'offline' or strategy == 'online', \
f'Unsupported {strategy} replay strategy'
self.strategy = strategy
# Record the index of goal, achieved goal, and observation in obs,
# which save the 80% of time to get goal compared to
# use OpenAI gym's unflatten() function
current_idx = 0
self.obs_index_range = {}
for (key, s) in dict_observation_space.spaces.items():
self.obs_index_range[key] = np.arange(
current_idx, current_idx + s.shape[0]
)
current_idx += s.shape[0]
# assert type in base class
self.data: Batch
self.buffer: ReplayBuffer
super().__init__(policy, env, buffer, preprocess_fn, exploration_noise)
def collect(
self,
n_step: Optional[int] = None,
n_episode: Optional[int] = None,
random: bool = False,
render: Optional[float] = None,
no_grad: bool = True,
) -> Dict[str, Any]:
if n_step is not None:
assert n_episode is None, (
f"Only one of n_step or n_episode is allowed in Collector."
f"collect, got n_step={n_step}, n_episode={n_episode}."
)
assert n_step > 0
if not n_step % self.env_num == 0:
warnings.warn(
f"n_step={n_step} is not a multiple of #env ({self.env_num}), "
"which may cause extra transitions collected into the buffer."
)
ready_env_ids = np.arange(self.env_num)
elif n_episode is not None:
assert n_episode > 0
ready_env_ids = np.arange(min(self.env_num, n_episode))
self.data = self.data[:min(self.env_num, n_episode)]
else:
raise TypeError(
"Please specify at least one (either n_step or n_episode) "
"in AsyncCollector.collect()."
)
start_time = time.time()
step_count = 0
episode_count = 0
episode_rews = []
episode_success = []
episode_lens = []
episode_start_indices = []
while True:
assert len(self.data) == len(ready_env_ids)
# restore the state: if the last state is None, it won't store
last_state = self.data.policy.pop("hidden_state", None)
# get the next action
if random:
self.data.update(
act=[self._action_space[i].sample() for i in ready_env_ids]
)
else:
if no_grad:
with torch.no_grad(): # faster than retain_grad version
# self.data.obs will be used by agent to get result
result = self.policy(self.data, last_state)
else:
result = self.policy(self.data, last_state)
# update state / act / policy into self.data
policy = result.get("policy", Batch())
assert isinstance(policy, Batch)
state = result.get("state", None)
if state is not None:
policy.hidden_state = state # save state into buffer
act = to_numpy(result.act)
if self.exploration_noise:
act = self.policy.exploration_noise(act, self.data)
self.data.update(policy=policy, act=act)
# get bounded and remapped actions first (not saved into buffer)
action_remap = self.policy.map_action(self.data.act)
# step in env
result = self.env.step(action_remap, ready_env_ids) # type: ignore
obs_next, rew, done, info = result
self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)
if self.preprocess_fn:
self.data.update(
self.preprocess_fn(
obs_next=self.data.obs_next,
rew=self.data.rew,
done=self.data.done,
info=self.data.info,
policy=self.data.policy,
env_id=ready_env_ids,
)
)
if render:
self.env.render(mode='rgb_array')
if render > 0 and not np.isclose(render, 0):
time.sleep(render)
# add data into the buffer
ptr, ep_rew, ep_len, ep_idx = self.buffer.add(
self.data, buffer_ids=ready_env_ids
)
# collect statistics
step_count += len(ready_env_ids)
if np.any(done):
env_ind_local = np.where(done)[0]
env_ind_global = ready_env_ids[env_ind_local]
episode_count += len(env_ind_local)
episode_lens.append(ep_len[env_ind_local])
episode_rews.append(ep_rew[env_ind_local])
episode_success.append(self.data[env_ind_local].info.is_success)
episode_start_indices.append(ep_idx[env_ind_local])
# now we copy obs_next to obs, but since there might be
# finished episodes, we have to reset finished envs first.
obs_reset = self.env.reset(env_ind_global)
if self.preprocess_fn:
obs_reset = self.preprocess_fn(
obs=obs_reset, env_id=env_ind_global
).get("obs", obs_reset)
self.data.obs_next[env_ind_local] = obs_reset
for i in env_ind_local:
self._reset_state(i)
# remove surplus env id from ready_env_ids
# to avoid bias in selecting environments
if n_episode:
surplus_env_num = len(ready_env_ids) - (n_episode - episode_count)
if surplus_env_num > 0:
mask = np.ones_like(ready_env_ids, dtype=bool)
mask[env_ind_local[:surplus_env_num]] = False
ready_env_ids = ready_env_ids[mask]
self.data = self.data[mask]
# use HER to create more trajectory
for env_id in env_ind_global: # enumerate env
# get recently collected data from buffer
env_buffer = self.buffer.buffers[env_id]
env_buffer_len = env_buffer.last_index[0] + 1
traj_len = ep_len[env_id]
obs_index_range = np.arange(
env_buffer_len - traj_len, env_buffer_len
) % len(env_buffer)
original_trajectory = env_buffer[obs_index_range]
if self.strategy == 'offline':
new_trajactory_len = (
np.random.random(size=self.replay_k) * traj_len
).astype(int) + 1
# relabel data and add back
for length in new_trajactory_len:
trajectory = Batch(original_trajectory[:length], copy=True)
new_goal = trajectory.obs_next[
length - 1, self.obs_index_range['achieved_goal']]
new_goals = np.repeat([new_goal], length, axis=0)
trajectory.obs[:, self.
obs_index_range['desired_goal']] = new_goals
trajectory.obs_next[:, self.obs_index_range['desired_goal']
] = new_goals
trajectory.rew = self.reward_fn(
trajectory.obs_next[:, self.
obs_index_range['achieved_goal']],
new_goals, None
)
trajectory.done[-1] = True
for i in range(length):
env_buffer.add(trajectory[i])
elif self.strategy == 'online':
# record the achieved goal of future steps,
# to reduce the relabel time during the trainning
ag = original_trajectory.obs_next[:, self.obs_index_range[
'achieved_goal']]
for i, idx in enumerate(obs_index_range):
env_buffer.info.achieved_goal[idx] = ag[i:]
self.data.obs = self.data.obs_next
if (n_step and step_count >= n_step) or \
(n_episode and episode_count >= n_episode):
break
# generate statistics
self.collect_step += step_count
self.collect_episode += episode_count
self.collect_time += max(time.time() - start_time, 1e-9)
if n_episode:
self.data = Batch(
obs={}, act={}, rew={}, done={}, obs_next={}, info={}, policy={}
)
self.reset_env()
if episode_count > 0:
rews, success, lens, idxs = list(
map(
np.concatenate, [
episode_rews, episode_success, episode_lens,
episode_start_indices
]
)
)
rew_mean, rew_std = rews.mean(), rews.std()
len_mean, len_std = lens.mean(), lens.std()
else:
rews, success, lens, idxs = np.array([]), np.array(
[]
), np.array([], int), np.array([], int)
rew_mean = rew_std = len_mean = len_std = 0
return {
"n/ep": episode_count,
"n/st": step_count,
"rews": rews,
"success": success,
"lens": lens,
"idxs": idxs,
"rew": rew_mean,
"len": len_mean,
"rew_std": rew_std,
"len_std": len_std,
}
from typing import Any, Callable, Optional, Tuple, Union
import gym.spaces as space
import numpy as np
import torch
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
from tianshou.exploration import BaseNoise
from tianshou.policy import BasePolicy, SACPolicy
class SACHERPolicy(SACPolicy):
"""Implementation of Hindsight Experience Replay Based on SAC. arXiv:1707.01495.
The key difference is that we redesigned the process_fn to get relabel return,
if the replay strategy is `offline`, then it will behave the same as `SACPolicy`.
:param torch.nn.Module actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.optim.Optimizer actor_optim: the optimizer for actor network.
:param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, a))
:param torch.optim.Optimizer critic1_optim: the optimizer for the first
critic network.
:param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a))
:param torch.optim.Optimizer critic2_optim: the optimizer for the second
critic network.
:param float tau: param for soft update of the target network. Default to 0.005.
:param float gamma: discount factor, in [0, 1]. Default to 0.99.
:param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy
regularization coefficient. Default to 0.2.
If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then
alpha is automatically tuned.
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
:param BaseNoise exploration_noise: add a noise to action for exploration.
Default to None. This is useful when solving hard-exploration problem.
:param bool deterministic_eval: whether to use deterministic action (mean
of Gaussian policy) instead of stochastic action sampled by the policy.
Default to True.
:param bool action_scaling: whether to map actions from range [-1, 1] to range
[action_spaces.low, action_spaces.high]. Default to True.
:param str action_bound_method: method to bound action to range [-1, 1], can be
either "clip" (for simply clipping the action) or empty string for no bounding.
Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
.. seealso::
Please refer to :class:`~tianshou.policy.SACPolicy` for more detailed
explanation.
"""
def __init__(
self,
actor: torch.nn.Module,
actor_optim: torch.optim.Optimizer,
critic1: torch.nn.Module,
critic1_optim: torch.optim.Optimizer,
critic2: torch.nn.Module,
critic2_optim: torch.optim.Optimizer,
reward_fn: Callable[[np.ndarray, np.ndarray, Optional[dict]], np.ndarray],
tau: float = 0.005,
gamma: float = 0.99,
alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2,
reward_normalization: bool = False,
estimation_step: int = 1,
exploration_noise: Optional[BaseNoise] = None,
deterministic_eval: bool = True,
dict_observation_space: space.Dict = None,
future_k: float = 4,
strategy: str = 'offline',
**kwargs: Any,
) -> None:
super().__init__(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, tau,
gamma, alpha, reward_normalization, estimation_step, exploration_noise,
deterministic_eval, **kwargs
)
self.future_k = future_k
self.strategy = strategy
self.future_p = 1 - (1. / (1 + future_k))
self.reward_fn = reward_fn
# get index information of observation
self.dict_observation_space = dict_observation_space
current_idx = 0
self.index_range = {}
for (key, s) in dict_observation_space.spaces.items():
self.index_range[key] = np.arange(current_idx, current_idx + s.shape[0])
current_idx += s.shape[0]
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
) -> Batch:
# Step1: get all index needed
if self.strategy == 'offline':
return super(SACHERPolicy, self).process_fn(batch, buffer, indices)
assert not self._rew_norm, \
"Reward normalization in computing n-step returns is unsupported now."
end_flag = buffer.done.copy()
end_flag[buffer.unfinished_index()
] = True # consider unfinished case: remove it
bsz = len(indices) # get indice of sampled transitions
indices = [indices] # turn to list, prepare for expand next state e.g. [1,3]
for _ in range(self._n_step - 1):
indices.append(
buffer.next(indices[-1])
) # append next state index e.g. [[1,3][2,4]]
indices = np.stack(indices)
terminal = indices[-1] # next state
# Step2: sample new goal
batch = buffer[terminal] # batch.obs: s_{t+n}
new_goal = batch.obs_next[:, self.index_range['desired_goal']]
for i in range(bsz):
if np.random.random() < self.future_p:
goals = batch.info.achieved_goal[i]
if len(goals) != 0:
new_goal[i] = goals[int(np.random.random() * len(goals))]
# Step3: relabel batch's obs, obs_next, reward, calculate Q
batch.obs[:, self.index_range['desired_goal']] = new_goal
batch.obs_next[:, self.index_range['desired_goal']] = new_goal
batch.rew = self.reward_fn(
batch.obs_next[:, self.index_range['achieved_goal']], new_goal, None
)
with torch.no_grad():
obs_next_result = self(batch, input='obs_next')
a_ = obs_next_result.act
target_q_torch = torch.min(
self.critic1_old(batch.obs_next, a_),
self.critic2_old(batch.obs_next, a_),
) - self._alpha * obs_next_result.log_prob
target_q = to_numpy(target_q_torch.reshape(bsz, -1))
target_q = target_q * BasePolicy.value_mask(buffer, terminal).reshape(-1, 1)
# Step4: calculate N step return
gamma_buffer = np.ones(self._n_step + 1)
for i in range(1, self._n_step + 1):
gamma_buffer[i] = gamma_buffer[i - 1] * self._gamma
target_shape = target_q.shape
bsz = target_shape[0]
# change target_q to 2d array
target_q = target_q.reshape(bsz, -1)
returns = np.zeros(target_q.shape) # n_step returrn
gammas = np.full(indices[0].shape, self._n_step)
for n in range(self._n_step - 1, -1, -1):
now = indices[n]
gammas[end_flag[now] > 0] = n + 1
returns[end_flag[now] > 0] = 0.0
new_rew = []
old_obs_next = buffer.obs_next[now]
new_rew.append(
self.reward_fn(
old_obs_next[:, self.index_range['achieved_goal']], new_goal, None
)
)
returns = np.array(new_rew).reshape(bsz, 1) + self._gamma * returns
target_q = target_q * gamma_buffer[gammas].reshape(bsz, 1) + returns
target_q = target_q.reshape(target_shape)
# return values
batch.returns = to_torch_as(target_q, target_q_torch)
if hasattr(batch, "weight"): # prio buffer update
batch.weight = to_torch_as(batch.weight, target_q_torch)
return batch
import argparse
import os
import pprint
from functools import partial
import gym
import numpy as np
import torch
import yaml
from torch.utils.tensorboard import SummaryWriter
import tianshou as ts
from gym.wrappers import FilterObservation, FlattenObservation
from tianshou.data import (
Collector,
PrioritizedReplayBuffer,
PrioritizedVectorReplayBuffer,
ReplayBuffer,
VectorReplayBuffer,
)
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb, Critic
if __name__ == '__main__':
'''
load param
'''
with open('/content/config_sac_her_pnp.yaml', "r") as stream:
try:
config = yaml.safe_load(stream)
except yaml.YAMLError as exc:
print(exc)
'''
make env
'''
def make_env():
return gym.wrappers.FlattenObservation(gym.make(config['env']))
def make_test_env(i):
if config['record_test']:
return gym.wrappers.RecordVideo(
gym.wrappers.FlattenObservation(gym.make(config['env'])),
video_folder='log/' + config['env'] + '/video' + str(i),
episode_trigger=lambda x: True
)
else:
return gym.wrappers.FlattenObservation(gym.make(config['env']))
env = gym.make(config['env'])
dict_observation_space = env.observation_space
env = gym.wrappers.FlattenObservation(env)
obs = env.reset()
state_shape = len(obs)
action_shape = env.action_space.shape or env.action_space.n
train_envs = SubprocVectorEnv(
[make_env for _ in range(config['training_num'])], norm_obs=config['norm_obs']
)
if config['norm_obs']:
print('updating env norm...')
train_envs.reset()
for _ in range(1000):
_, _, done, _ = train_envs.step(
[env.action_space.sample() for _ in range(config['training_num'])]
)
if np.any(done):
env_ind = np.where(done)[0]
train_envs.reset(env_ind)
print('updating done!')
train_envs.update_obs_rms = False
test_envs = SubprocVectorEnv(
[partial(make_test_env, i) for i in range(config['test_num'])],
norm_obs=config['norm_obs'],
obs_rms=train_envs.obs_rms,
update_obs_rms=False
)
np.random.seed(config['seed'])
torch.manual_seed(config['seed'])
train_envs.seed(config['seed'])
test_envs.seed(config['seed'])
'''
build and init network
'''
if not (torch.cuda.is_available()):
config['device'] = 'cpu'
# actor
net_a = Net(
state_shape, hidden_sizes=config['hidden_sizes'], device=config['device']
)
actor = ActorProb(
net_a,
action_shape,
max_action=env.action_space.high[0],
device=config['device'],
unbounded=True,
conditioned_sigma=True
).to(config['device'])
actor_optim = torch.optim.Adam(actor.parameters(), lr=config['actor_lr'])
# critic
net_c1 = Net(
state_shape,
action_shape,
hidden_sizes=config['hidden_sizes'],
concat=True,
device=config['device']
)
net_c2 = Net(
state_shape,
action_shape,
hidden_sizes=config['hidden_sizes'],
concat=True,
device=config['device']
)
critic1 = Critic(net_c1, device=config['device']).to(config['device'])
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=config['critic_lr'])
critic2 = Critic(net_c2, device=config['device']).to(config['device'])
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=config['critic_lr'])
# auto alpha
if config['auto_alpha']:
target_entropy = -np.prod(env.action_space.shape)
log_alpha = torch.zeros(1, requires_grad=True, device=config['device'])
alpha_optim = torch.optim.Adam([log_alpha], lr=config['alpha_lr'])
config['alpha'] = (target_entropy, log_alpha, alpha_optim)
'''
set up policy
'''
policy = SACHERPolicy(
actor,
actor_optim,
critic1,
critic1_optim,
critic2,
critic2_optim,
tau=config['tau'],
gamma=config['gamma'],
alpha=config['alpha'],
estimation_step=config['estimation_step'],
action_space=env.action_space,
reward_normalization=False,
dict_observation_space=dict_observation_space,
reward_fn=env.compute_reward,
future_k=config['replay_k'],
strategy=config['strategy']
)
# load policy
if config['resume_path']:
policy.load_state_dict(
torch.load(config['resume_path'], map_location=config['device'])
)
print("Loaded agent from: ", config['resume_path'])
'''
set up collector
'''
if config['training_num'] > 1:
if config['use_PER']:
buffer = PrioritizedVectorReplayBuffer(
total_size=config['buffer_size'],
buffer_num=len(train_envs),
alpha=config['per_alpha'],
beta=config['per_beta']
)
else:
buffer = VectorReplayBuffer(config['buffer_size'], len(train_envs))
else:
if config['use_PER']:
buffer = PrioritizedReplayBuffer(
size=config['buffer_size'],
alpha=config['per_alpha'],
beta=config['per_beta']
)
else:
buffer = ReplayBuffer(config['buffer_size'])
train_collector = HERCollector(
policy=policy,
env=train_envs,
buffer=buffer,
exploration_noise=True,
dict_observation_space=dict_observation_space,
reward_fn=env.compute_reward,
replay_k=config['replay_k'],
strategy=config['strategy']
)
test_collector = Collector(policy, test_envs)
# warm up
train_collector.collect(n_step=config['start_timesteps'], random=True)
'''
logger
'''
log_file = config['info']
log_path = os.path.join(config['logdir'], config['env'], 'sac', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(config))
logger = TensorboardLogger(writer, update_interval=100, train_interval=100)
# save function
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def save_checkpoint_fn(epoch, env_step, gradient_step):
torch.save(policy.state_dict(), os.path.join(os.path.join(log_path, f'{epoch}'), 'policy.pth'))
# trainer
result = offpolicy_trainer(
policy,
train_collector,
test_collector,
config['epoch'],
config['step_per_epoch'],
config['step_per_collect'],
config['test_num'],
config['batch_size'],
save_fn=save_fn,
save_checkpoint_fn=save_checkpoint_fn,
logger=logger,
update_per_step=config['update_per_step'],
test_in_train=False
)
pprint.pprint(result)
# Learning Curve
#learning_curve_tianshou(log_dir=log_path + '/', window=25)
# Load model, optimisers and buffer
#checkpoint = torch.load('/content/policy.pth')
# Record Episode Video
num_episodes = 10
for episode in range(num_episodes):
env = ts.env.DummyVectorEnv([lambda: FlattenObservation(FilterObservation(gym.make("FetchPickAndPlace-v1"))) for _ in range(1)])
policy.eval()
collector = ts.data.Collector(policy, env, exploration_noise=False)
collector.collect_and_record = collect_and_record
collector.collect_and_record(self=collector, video_dir=log_path + f'/final_agent/video{episode}/', n_episode=1,
render=1 / 60)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment