Created
April 12, 2018 00:59
-
-
Save unixpickle/73df655656a13d2d5be64d5d10b88c5e to your computer and use it in GitHub Desktop.
Fractal AI on Sonic
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fractalai.policy import GreedyPolicy | |
from fractalai.model import RandomDiscreteModel | |
from fractalai.fractalmc import FractalAI | |
from fractalai.environment import AtariEnvironment | |
from fractalai.monitor import AtariMonitorPolicy | |
from alex_sandbox.dqn.wrappers import SonicDiscretizer, AllowBacktracking | |
import numpy as np | |
import retro | |
from environment import RetroEnvironment | |
# TODO: replace observations with x,y. | |
# TODO: "death" when going backwards, just for debugging. | |
render=True # It is funnier if the game is displayed on the screen | |
clone_seeds = True # This will speed things up a bit | |
max_steps = 1e6 # Play until the game is finished. | |
skip_frames = 80 # The Agent cannot do anything anyway, so its faster if we skip some frames at the begining | |
n_fixed_steps = 10 # Atari games run at 20 fps, so taking 4 actions per seconds is more | |
# than enough to finish the first level | |
max_samples = 100 # Let see how well it can perform using at most 300 samples per step | |
max_states = 10 # Let's set a really small number to make everthing faster | |
time_horizon = 50 # 50 frames should be enough to realise you have been eaten by a ghost | |
raw_env = retro.make(game='SonicTheHedgehog-Genesis', | |
state='GreenHillZone.Act1', | |
scenario='contest') | |
raw_env = SonicDiscretizer(raw_env) | |
env = RetroEnvironment(raw_env) | |
model = RandomDiscreteModel(env.env.action_space.n) # The Agent will take discrete actions at random | |
greedy = GreedyPolicy(env=env, model=model) # Our prior will be a random uniform policy | |
fractal = FractalAI(policy=greedy, max_samples=max_samples, max_states=max_states, | |
time_horizon=time_horizon, n_fixed_steps=n_fixed_steps) | |
fractal.evaluate(render=render, max_steps=max_steps, skip_frames=skip_frames) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Environments for Fractal AI. | |
""" | |
from fractalai.environment import Environment | |
from fractalai.state import State | |
import gym | |
import numpy as np | |
class RetroEnvironment(Environment): | |
"""Simulator for retro environments.""" | |
def __init__(self, env): | |
super(RetroEnvironment, self).__init__(name='retro') | |
self._env = Savable(env) | |
self._state = self.reset() | |
@property | |
def env(self): | |
"""Access to the Gym environment.""" | |
return self._env | |
@property | |
def num_actions(self): | |
"""Number of actions.""" | |
return self.env.action_space.n | |
def set_seed(self, seed): | |
np.random.seed(seed) | |
def render(self, *args, **kwargs): # pylint: disable=W0221 | |
self.env.render(*args, **kwargs) | |
def reset(self) -> State: | |
self._cum_reward = 100 | |
if self._state is None: | |
self._state = State() | |
obs = self.env.reset() | |
self.state.reset_state() | |
self.state.update_state(observed=obs, microstate=self.env.get_state(), end=False, | |
reward=self._cum_reward) | |
return self.state.create_clone() | |
def set_simulation_state(self, state: State): | |
self._state = state | |
self._cum_reward = state.reward | |
self.env.set_state(state.microstate) | |
def step_simulation(self, action: np.array, fixed_steps=1) -> State: | |
end = False | |
for _ in range(fixed_steps): | |
observed, reward, _end, info = self.env.step(action.argmax()) | |
self._cum_reward += reward | |
end = end or _end | |
if end: | |
break | |
self.state.reset_state() | |
self.state.update_state(observed=observed, | |
microstate=self.env.get_state(), | |
reward=self._cum_reward, | |
end=end, | |
model_action=action, | |
policy_action=action, | |
model_data=[info]) | |
if end: | |
self.env.reset() | |
return self.state | |
class Savable(gym.Wrapper): | |
def __init__(self, env): | |
super(Savable, self).__init__(env) | |
self._x = None | |
self._last_state = None | |
self._last_action = None | |
def get_state(self): | |
return {'x': self._x, 'state': self._last_state, 'action': self._last_action} | |
def set_state(self, state): | |
if state['state'] is None: | |
self.env.reset() | |
return | |
self._x = state['x'] | |
self._last_state = state['state'] | |
self._last_action = state['action'] | |
self.env.unwrapped.em.set_state(self._last_state) | |
self.env.unwrapped.data.reset() | |
self.env.unwrapped.data.update_ram() | |
self.env.step(self._last_action) | |
def reset(self, **kwargs): | |
self._x = None | |
self._last_state = None | |
self._last_action = None | |
return self.env.reset(**kwargs) | |
def step(self, action): | |
self._last_action = action | |
self._last_state = self.env.unwrapped.em.get_state() | |
last_x = self._x | |
obs, _, done, info = self.env.step(action) | |
self._x = info['x'] | |
last_x = last_x if last_x is not None else self._x | |
return obs, self._x - last_x, done, info |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment