Skip to content

Instantly share code, notes, and snippets.

@unixpickle
Created April 12, 2018 00:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save unixpickle/73df655656a13d2d5be64d5d10b88c5e to your computer and use it in GitHub Desktop.
Save unixpickle/73df655656a13d2d5be64d5d10b88c5e to your computer and use it in GitHub Desktop.
Fractal AI on Sonic
from fractalai.policy import GreedyPolicy
from fractalai.model import RandomDiscreteModel
from fractalai.fractalmc import FractalAI
from fractalai.environment import AtariEnvironment
from fractalai.monitor import AtariMonitorPolicy
from alex_sandbox.dqn.wrappers import SonicDiscretizer, AllowBacktracking
import numpy as np
import retro
from environment import RetroEnvironment
# TODO: replace observations with x,y.
# TODO: "death" when going backwards, just for debugging.
render=True # It is funnier if the game is displayed on the screen
clone_seeds = True # This will speed things up a bit
max_steps = 1e6 # Play until the game is finished.
skip_frames = 80 # The Agent cannot do anything anyway, so its faster if we skip some frames at the begining
n_fixed_steps = 10 # Atari games run at 20 fps, so taking 4 actions per seconds is more
# than enough to finish the first level
max_samples = 100 # Let see how well it can perform using at most 300 samples per step
max_states = 10 # Let's set a really small number to make everthing faster
time_horizon = 50 # 50 frames should be enough to realise you have been eaten by a ghost
raw_env = retro.make(game='SonicTheHedgehog-Genesis',
state='GreenHillZone.Act1',
scenario='contest')
raw_env = SonicDiscretizer(raw_env)
env = RetroEnvironment(raw_env)
model = RandomDiscreteModel(env.env.action_space.n) # The Agent will take discrete actions at random
greedy = GreedyPolicy(env=env, model=model) # Our prior will be a random uniform policy
fractal = FractalAI(policy=greedy, max_samples=max_samples, max_states=max_states,
time_horizon=time_horizon, n_fixed_steps=n_fixed_steps)
fractal.evaluate(render=render, max_steps=max_steps, skip_frames=skip_frames)
"""
Environments for Fractal AI.
"""
from fractalai.environment import Environment
from fractalai.state import State
import gym
import numpy as np
class RetroEnvironment(Environment):
"""Simulator for retro environments."""
def __init__(self, env):
super(RetroEnvironment, self).__init__(name='retro')
self._env = Savable(env)
self._state = self.reset()
@property
def env(self):
"""Access to the Gym environment."""
return self._env
@property
def num_actions(self):
"""Number of actions."""
return self.env.action_space.n
def set_seed(self, seed):
np.random.seed(seed)
def render(self, *args, **kwargs): # pylint: disable=W0221
self.env.render(*args, **kwargs)
def reset(self) -> State:
self._cum_reward = 100
if self._state is None:
self._state = State()
obs = self.env.reset()
self.state.reset_state()
self.state.update_state(observed=obs, microstate=self.env.get_state(), end=False,
reward=self._cum_reward)
return self.state.create_clone()
def set_simulation_state(self, state: State):
self._state = state
self._cum_reward = state.reward
self.env.set_state(state.microstate)
def step_simulation(self, action: np.array, fixed_steps=1) -> State:
end = False
for _ in range(fixed_steps):
observed, reward, _end, info = self.env.step(action.argmax())
self._cum_reward += reward
end = end or _end
if end:
break
self.state.reset_state()
self.state.update_state(observed=observed,
microstate=self.env.get_state(),
reward=self._cum_reward,
end=end,
model_action=action,
policy_action=action,
model_data=[info])
if end:
self.env.reset()
return self.state
class Savable(gym.Wrapper):
def __init__(self, env):
super(Savable, self).__init__(env)
self._x = None
self._last_state = None
self._last_action = None
def get_state(self):
return {'x': self._x, 'state': self._last_state, 'action': self._last_action}
def set_state(self, state):
if state['state'] is None:
self.env.reset()
return
self._x = state['x']
self._last_state = state['state']
self._last_action = state['action']
self.env.unwrapped.em.set_state(self._last_state)
self.env.unwrapped.data.reset()
self.env.unwrapped.data.update_ram()
self.env.step(self._last_action)
def reset(self, **kwargs):
self._x = None
self._last_state = None
self._last_action = None
return self.env.reset(**kwargs)
def step(self, action):
self._last_action = action
self._last_state = self.env.unwrapped.em.get_state()
last_x = self._x
obs, _, done, info = self.env.step(action)
self._x = info['x']
last_x = last_x if last_x is not None else self._x
return obs, self._x - last_x, done, info
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment