Created
December 15, 2019 22:16
-
-
Save rezer0dai/1ad41c547f7b6a720964e3d9540aa6c7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from utils.her import HER, CreditAssignment | |
from utils.gahil import GAHIL | |
GAN_REWARD = GAHIL(action_size=ACTION_SIZE) | |
import random | |
class ReacherHER(HER): | |
def update_goal(self, rewards, goals, states, states_1, n_goals, n_states, actions, her_step_inds, n_steps): | |
MAX_HER_STEP = 1 | |
gid = 0 | |
delta = 0 | |
idx = [] | |
h_goals = goals.clone() | |
h_n_goals = n_goals.clone() | |
for i, (g, s, n_g, n, u, step, n_so, a) in enumerate(zip(goals, states, n_goals, n_states, her_step_inds, n_steps, states_1, actions)): | |
her_active = bool(sum(her_step_inds[(i-MAX_HER_STEP) if MAX_HER_STEP < i else 0:i])) | |
if not her_active and u: # choose nearby state, however it have long term adversial effect policy will learn to just stay at place | |
if 0 == random.randint(0, 3): gid = random.randint(0, len(goals) - self.n_step - MAX_HER_STEP - 1) | |
elif step is not None and 0 == random.randint(0, 2): gid = i+step+random.randint(0, 2) | |
else: gid = i+random.randint(0, 2) | |
delta = 0 | |
if her_active or u: | |
if gid+delta+self.n_step<len(goals) and i<len(goals)-self.n_step: | |
g, n_g = states_1[gid+delta][:len(g)], states_1[gid+delta][:len(g)] | |
delta += 1 | |
def in_range(d): | |
dist = gid + delta | |
diff = dist - (i + d) | |
return (diff > -3 and diff < 3) | |
if in_range(0): | |
GAN_REWARD.register_target( | |
s.reshape(1, -1).numpy(), n_so.reshape(1, -1).numpy(), g.reshape(1, -1).numpy(), a.reshape(1, -1).numpy()) | |
# h_goals[i] = g.clone() | |
# h_n_goals[i] = n_g.clone() | |
gahil = GAN_REWARD.register_other_with_reward(#old design to work over numpy | |
states.numpy(), states_1.numpy(), h_goals.numpy(), actions.numpy()) | |
return ( gahil, h_goals, states, h_n_goals, n_states ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment