Created
May 22, 2020 07:38
-
-
Save addy1997/866205e53f904f6dd1e5baf5b5efd14a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_agent_states(self): | |
start_state = None | |
target_state = None | |
start_state = list(map(lambda x:x[0] if len(x) > 0 else None, np.where(insert_grid_map==4))) | |
target_state = list(map(lambda x:x[0] if len(x) > 0 else None, np.where(insert_grid_map==3))) | |
if (start_state == [None, None] or target_state == [None, None]): | |
sys.exit('Start or Target state not specified') | |
if (start_state == self.obstacle_mask): | |
sys.exit('Obstacle encountered GAME OVER !!!') | |
return start_state, target_state | |
def _render_(self, mode='human', close=False): | |
if self.verbose == False: | |
return | |
img = self.observation | |
fig = plt.figure(self.fig_num) | |
plt.clf() | |
plt.imshow(img) | |
fig.canvas.draw() | |
plt.pause(0.0002) | |
return | |
#configuring agent's states | |
def _get_agent_state(self): | |
return self.agent_state | |
def get_agent__start_target_state(self): | |
return self.start_state, self.target_state | |
def jump_to_a_state(self, to_state): | |
if self.current_map[to_state[0], to_state[1]] == self.current_map[self.obstacle_mask[0], self.obstacle_mask[1]]: | |
sys.exit("GAME OVER") | |
elif self.current_map[to_state[0], to_state[1]] == 0: | |
if self.current_map[agent_state[0], agent_state[1]] == 4: | |
self.current_map[agent_state[0], agent_state[1]] += self.actions_pos_dict[actions][2] | |
self.observation = self.grid_map_to_observation(self.current_map) | |
self.agent_state = [to_state[0], to_state[1]] | |
self._render_() | |
return (self.observation, agent_state) | |
if self.current_map[agent_state[0], agent_state[1]]== 6: | |
self.current_map[agent_state[0], agent_state[1]] += self.actions_pos_dict[actions][1] | |
self.observation = self.grid_map_to_observation(self.current_map) | |
self.agent_state = [to_state[0], to_state[1]] | |
self._render_() | |
return (self.observation, agent_state) | |
if self.current_map[agent_state[0], agent_state[1]]== 7: | |
self.current_map[agent_state[0], agent_state[1]] += self.actions_pos_dict[actions][3] | |
self.observation = self.grid_map_to_observation(self.current_map) | |
self.agent_state = [to_state[0], to_state[1]] | |
self._render_() | |
return (self.observation, agent_state) | |
elif self.current_map[to_state[0], to_state[1]] is None: | |
return ("Invalid state") | |
""" | |
please add the to_state if condition for the remaining states i.e 4, 1, 3, and 7. | |
""" | |
def close(self): | |
plt.close(1) | |
return | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment