Created
September 25, 2018 16:23
-
-
Save markroxor/c50a6bfc69da001180374a9e977ac21a to your computer and use it in GitHub Desktop.
testing becca gym
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import becca.brain as becca_brain | |
from becca.base_world import World as BaseWorld | |
import gym | |
class World(BaseWorld): | |
def __init__(self, env_name, seed=None): | |
BaseWorld.__init__(self) | |
self.name = 'gym_cartPole' | |
self.visualize_interval = 1e4 | |
self.env = gym.make(env_name) | |
if seed is not None: | |
self.env.seed(seed) | |
self.n_sensors, self.n_actions = self.env.observation_space.shape[0], self.env.action_space.n | |
print('states are', self.n_sensors, self.n_actions) | |
self.sensors = self.env.reset() | |
self.done = False | |
def step(self, actions): | |
if actions[0] == .5: | |
raise Exception("input actions are float -", actions[0]) | |
self.sensors, reward, self.done, _ = self.env.step(np.argmax(actions)) | |
if self.timestep % self.visualize_interval == 0: | |
report = "got reward of " + str(reward) | |
# Give an update | |
print(report) | |
if self.done is True: | |
print("done") | |
self.sensors = self.env.reset() | |
return self.sensors, reward | |
# def is_alive(self): | |
# return not self.done | |
if __name__ == "__main__": | |
becca_brain.run(World("CartPole-v0")) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment