Skip to content

Instantly share code, notes, and snippets.

@rezer0dai
Created December 15, 2019 22:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rezer0dai/1ad41c547f7b6a720964e3d9540aa6c7 to your computer and use it in GitHub Desktop.
Save rezer0dai/1ad41c547f7b6a720964e3d9540aa6c7 to your computer and use it in GitHub Desktop.
from utils.her import HER, CreditAssignment
from utils.gahil import GAHIL
GAN_REWARD = GAHIL(action_size=ACTION_SIZE)
import random
class ReacherHER(HER):
def update_goal(self, rewards, goals, states, states_1, n_goals, n_states, actions, her_step_inds, n_steps):
MAX_HER_STEP = 1
gid = 0
delta = 0
idx = []
h_goals = goals.clone()
h_n_goals = n_goals.clone()
for i, (g, s, n_g, n, u, step, n_so, a) in enumerate(zip(goals, states, n_goals, n_states, her_step_inds, n_steps, states_1, actions)):
her_active = bool(sum(her_step_inds[(i-MAX_HER_STEP) if MAX_HER_STEP < i else 0:i]))
if not her_active and u: # choose nearby state, however it have long term adversial effect policy will learn to just stay at place
if 0 == random.randint(0, 3): gid = random.randint(0, len(goals) - self.n_step - MAX_HER_STEP - 1)
elif step is not None and 0 == random.randint(0, 2): gid = i+step+random.randint(0, 2)
else: gid = i+random.randint(0, 2)
delta = 0
if her_active or u:
if gid+delta+self.n_step<len(goals) and i<len(goals)-self.n_step:
g, n_g = states_1[gid+delta][:len(g)], states_1[gid+delta][:len(g)]
delta += 1
def in_range(d):
dist = gid + delta
diff = dist - (i + d)
return (diff > -3 and diff < 3)
if in_range(0):
GAN_REWARD.register_target(
s.reshape(1, -1).numpy(), n_so.reshape(1, -1).numpy(), g.reshape(1, -1).numpy(), a.reshape(1, -1).numpy())
# h_goals[i] = g.clone()
# h_n_goals[i] = n_g.clone()
gahil = GAN_REWARD.register_other_with_reward(#old design to work over numpy
states.numpy(), states_1.numpy(), h_goals.numpy(), actions.numpy())
return ( gahil, h_goals, states, h_n_goals, n_states )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment