Skip to content

Instantly share code, notes, and snippets.

@JIElite
Created April 13, 2018 18:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save JIElite/a2352517b985b1d9102091ac843aff9d to your computer and use it in GitHub Desktop.
Save JIElite/a2352517b985b1d9102091ac843aff9d to your computer and use it in GitHub Desktop.
'''
The error occurred around line: 28 to 33
The upper code is ignored.
'''
# The upper code is ignored
while True:
# Sync the parameters with shared model
local_master_model.load_state_dict(shared_master_model.state_dict())
# Reset n-step experience buffer
entropies = []
critic_values = []
spatial_policy_log_probs = []
rewards = []
# step forward n steps
for step in range(args['n_steps']):
screen_observation = Variable(torch.from_numpy(game_inferface.get_screen_obs(
timesteps=state,
indexes=[4, 5, 6, 7, 8, 9, 14, 15],
))).cuda(args['gpu'])
select_spatial_action_prob, value = local_master_model(screen_observation)
log_select_spatial_action_prob = torch.log(select_spatial_action_prob)
print("before:", select_spatial_action_prob)
# mask spatial action
selection_mask = torch.from_numpy((state.observation['screen'][_SCREEN_PLAYER_RELATIVE] == 1).astype('float32'))
selection_mask = Variable(selection_mask.view(1, -1), requires_grad=False).cuda(args['gpu'])
select_spatial_action_prob = select_spatial_action_prob * selection_mask
print("after:", select_spatial_action_prob)
select_action = select_spatial_action_prob.multinomial()
print("select action:", select_action)
select_entropy = - (log_select_spatial_action_prob * select_spatial_action_prob).sum(1)
master_log_action_prob = log_select_spatial_action_prob.gather(1, select_action)
# record n-step experience
entropies.append(select_entropy)
spatial_policy_log_probs.append(master_log_action_prob)
critic_values.append(value)
# Step
action = game_inferface.build_action(_SELECT_POINT, select_action[0].cpu())
state = env.step([action])[0]
temp_reward = np.asscalar(state.reward)
# sub agent step
if _MOVE_SCREEN in state.observation['available_actions']:
# sub model decision
screen_observation = Variable(torch.from_numpy(game_inferface.get_screen_obs(
timesteps=state,
indexes=[4, 5, 6, 7, 8, 9, 14, 15],
)), volatile=True).cuda(args['gpu'])
spatial_action_prob, value = local_sub_model(screen_observation)
spatial_action = spatial_action_prob.multinomial()
action = game_inferface.build_action(_MOVE_SCREEN, spatial_action[0].cpu())
state = env.step([action])[0]
else:
action = actions.FunctionCall(_NO_OP, [])
state = env.step([action])[0]
temp_reward += np.asscalar(state.reward)
# update episodic information
rewards.append(temp_reward)
episode_reward += temp_reward
episode_length += 1
global_counter.value += 1
episode_done = (episode_length >= args['max_eps_length']) or state.last()
if episode_done:
episode_length = 0
env.reset()
state = env.step([actions.FunctionCall(_NO_OP, [])])[0]
break
R_t = torch.zeros(1)
if not episode_done:
screen_observation = Variable(torch.from_numpy(game_inferface.get_screen_obs(
timesteps=state,
indexes=[4, 5, 6, 7, 8, 9, 14, 15]
))).cuda(args['gpu'])
_, value = local_master_model(screen_observation)
R_t = value.data
R_var = Variable(R_t).cuda(args['gpu'])
critic_values.append(R_var)
policy_loss = 0.
value_loss = 0.
gae_ts = torch.zeros(1).cuda(args['gpu'])
for i in reversed(range(len(rewards))):
R_var = rewards[i] + args['gamma'] * R_var
advantage_var = R_var - critic_values[i]
value_loss += 0.5 * advantage_var.pow(2)
td_error = rewards[i] + args['gamma'] * critic_values[i+1].data - critic_values[i].data
gae_ts = gae_ts * args['gamma'] * args['tau'] + td_error
policy_loss += -(spatial_policy_log_probs[i] * Variable(gae_ts, requires_grad=False) + 0.05 *entropies[i])
optimizer.zero_grad()
total_loss = policy_loss + 0.5 * value_loss
total_loss.backward()
torch.nn.utils.clip_grad_norm(local_master_model.parameters(), 40)
ensure_shared_grad(local_master_model, shared_master_model)
optimizer.step()
if episode_done:
summary_queue.put((global_counter.value, episode_reward))
episode_reward = 0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment