Created
May 22, 2020 07:28
-
-
Save addy1997/9b28a83251bd471c51d58d5b06838cd9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#creating Grid Environment | |
COLORS = {0:[0.0,0.0,0.0], 1:[0.5,0.5,0.5], | |
2:[0.0,0.0,1.0], 3:[0.0,1.0,0.0], | |
4:[1.0,0.0,0.0], 6:[1.0,0.0,1.0], | |
7:[1.0,1.0,0.0]} | |
class GridEnv(gym.Env): | |
num_env = 0 | |
metadata = {'render.modes': ['human']} | |
def __init__(self,start, obs_shape, obstacle_mask, terminal_state): | |
#action space | |
self.actions = ['up', 'down', 'right', 'left', 'begin'] | |
self.inv_actions = ['begin', 'down','up','left','right'] | |
self.actions_pos_dict = {up:[-1,0], down:[1,0], right:[0,-1], left:[0,1], begin:[0,0]} | |
self.action_space = spaces.Discrete(5) | |
#observation space | |
self.obs_shape = [128, 128, 3] | |
self.observation_space = spaces.Box(low=0, high=1, shape=self.obs_shape, dtype=np.float32) | |
#construct the grid | |
file_path = os.path.dirname(os.path.realpath(__file__)) | |
self.insert_grid_map = os.path.join(file_path, 'map2.txt') | |
self.initial_map = self.read_grid_map(self.insert_grid_map) | |
self.current_map = copy.deepcopy(self.initial_map) | |
self.observation = self.grid_map_observation(self.initial_map) | |
self.grid_shape = self.initial_map.shape | |
#agent actions | |
self.start_state, self.target_state = self.get_agent_states(self.initial_grid) | |
self.agent_state = copy.deepcopy(self.start_state) | |
#other params | |
self.verbose=False | |
#env parameters | |
self.reset() | |
self.viewer() | |
self.seed() | |
#function makes the obstacles | |
def make_obstacles(self, obstacles): | |
""" | |
add obstacles using matplotlib | |
""" | |
self.obstacle_mask = obstacle_mask | |
return obstacle_mask | |
#seeding | |
def seed(self, seed=None): | |
self.np_random, seed = seeding.np.random(seed) | |
return [seed] | |
def render(self): | |
GridEnv.num_env +=1 | |
self.fig_num = GridEnv.num_env | |
if self.verbose == True: | |
self.fig = plt.figure(self.fig_num) | |
plt.show(block=False) | |
plt.axis('off') | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
grid