Last active
April 4, 2020 18:29
-
-
Save Guillemdb/405b03f4dd885b0f46018f0f190df6f9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Dict, Optional, Union | |
from fragile.core import DiscreteEnv | |
import gym | |
from gym import spaces | |
import numpy as np | |
from mathy import MathTypeKeysMax, MathyEnvState | |
from mathy.envs.gym import MathyGymEnv | |
# No need for plangym. As long as it has a valid step_batch method, and the observation/action | |
# spaces it is ok. The `make_transitions` method only uses step_batch. | |
# I added the minimum stuff that it needs to work. | |
class FragileEnvironment: | |
"""Fragile Environment for solving Mathy problems.""" | |
problem: Optional[str] | |
def __init__( | |
self, | |
name: str, | |
environment: str = "poly", | |
difficulty: str = "normal", | |
problem: str = None, | |
max_steps: int = 64, | |
**kwargs, | |
): | |
super(FragileEnvironment, self).__init__(name=name) | |
self._env: MathyGymEnv = gym.make( | |
f"mathy-{environment}-{difficulty}-v0", | |
np_observation=True, | |
error_invalid=False, | |
env_problem=problem, | |
**kwargs, | |
) | |
self.observation_space = spaces.Box( | |
low=0, high=MathTypeKeysMax, shape=(256, 256, 1), dtype=np.uint8, | |
) | |
self.action_space = spaces.Discrete(self._env.action_size) | |
self.problem = problem | |
self.max_steps = max_steps | |
self._env.reset() | |
def get_state(self) -> np.ndarray: | |
assert self._env.state is not None, "env required to get_state" | |
return self._env.state.to_np() | |
def set_state(self, state: np.ndarray): | |
assert self._env is not None, "env required to set_state" | |
self._env.state = MathyEnvState.from_np(state) | |
return state | |
def step( | |
self, | |
action: int, | |
state: np.ndarray = None, | |
) -> tuple: | |
assert self._env is not None, "env required to step" | |
assert state is not None, "only works with state stepping" | |
self.set_state(state) | |
obs, reward, _, info = self._env.step(action) | |
oob = not info.get("valid", False) | |
new_state = self.get_state() | |
return new_state, obs, reward, oob, info | |
def step_batch( | |
self, actions, states=None | |
) -> tuple: | |
data = [self.step(action, state) for action, state in zip(actions, states)] | |
new_states, observs, rewards, oobs, infos = [], [], [], [], [] | |
for d in data: | |
new_state, obs, _reward, end, info = d | |
new_states.append(new_state) | |
observs.append(obs) | |
rewards.append(_reward) | |
oobs.append(end) | |
infos.append(info) | |
return new_states, observs, rewards, oobs, infos | |
def reset(self): | |
assert self._env is not None, "env required to reset" | |
obs = self._env.reset() | |
return self.get_state(), obs | |
# I added the make_transition function so you don't have to be messing with the states directly. | |
# That is already handled by the superclass. | |
class MathyFragileEnv(DiscreteEnv): | |
"""The DiscreteEnv acts as an interface with `plangym` discrete actions. | |
It can interact with any environment that accepts discrete actions and \ | |
follows the interface of `plangym`. | |
""" | |
def __init__(self, name: str, | |
environment: str = "poly", | |
difficulty: str = "normal", | |
problem: str = None, | |
max_steps: int = 64, | |
**kwargs, | |
): | |
"""Initialize a :class:`MathyFragileEnv`.""" | |
self._env = FragileEnvironment(name=name, environment=environment, | |
difficulty=difficulty, problem=problem, | |
max_steps=max_steps, **kwargs) | |
self._n_actions = self._env.action_space.n | |
super(DiscreteEnv, self).__init__( | |
states_shape=self._env.get_state().shape, | |
observs_shape=self._env.observation_space.shape, | |
) | |
def __getattr__(self, item): | |
return getattr(self._env, item) | |
def make_transitions( | |
self, states: np.ndarray, actions: np.ndarray, dt: Union[np.ndarray, int] | |
) -> Dict[str, np.ndarray]: | |
""" | |
Step the underlying :class:`plangym.Environment` using the ``step_batch`` \ | |
method of the ``plangym`` interface. | |
""" | |
new_states, observs, rewards, oobs, infos = self._env.step_batch( | |
actions=actions, states=states | |
) | |
# I changed the order of how you get terminals and oobs, because the envs in plangym return | |
# oobs by default. this way you won't need to modify anything in case I make changes in the | |
# future (unlikely but better be safe than sorry) | |
terminals = [inf.get("done", False) for inf in infos] | |
data = { | |
"states": np.array(new_states), | |
"observs": np.array(observs), | |
"rewards": np.array(rewards), | |
"oobs": np.array(oobs), | |
"terminals": np.array(terminals), | |
} | |
return data | |
# Creating the swarm | |
# When choosing reward scales set the lower one to 1 and the most important scale | |
# to a bigger integer to avoid unnecessary exponentiation | |
# If you don't need to recover the full path you don't need the Tree. I case you need it, | |
# now it takes a callable like | |
# tree=lambda: HistoryTree(prune=True, names=["states", "observs". "actions", "whatever"]) | |
# I also removed default values that don't need to be changed. | |
# creating env callables: | |
env_callable = lambda: FragileMathyEnv(**your_env_parameters) # No parallelization | |
# With parallelization | |
from fragile.distributed import ParallelEnv | |
env_callable = ParallelEnv(env_callable=env_callable) | |
swarm = Swarm( | |
model=lambda env: DiscreteMasked(env=env), | |
env=env_callable, | |
reward_limit=EnvRewards.WIN, | |
n_walkers=config.n_walkers, | |
max_epochs=config.max_iters, | |
reward_scale=1, | |
distance_scale=3, | |
distance_function=mathy_dist, | |
show_pbar=False, | |
) | |
# Also, now the Swarm has `swarm.best_state`, `swarm.best_obs`, etc, so you don't have to be | |
# making instrospection to get the best value found. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment