Skip to content

Instantly share code, notes, and snippets.

@Guillemdb
Last active April 4, 2020 18:29
Show Gist options
  • Save Guillemdb/405b03f4dd885b0f46018f0f190df6f9 to your computer and use it in GitHub Desktop.
Save Guillemdb/405b03f4dd885b0f46018f0f190df6f9 to your computer and use it in GitHub Desktop.
from typing import Dict, Optional, Union
from fragile.core import DiscreteEnv
import gym
from gym import spaces
import numpy as np
from mathy import MathTypeKeysMax, MathyEnvState
from mathy.envs.gym import MathyGymEnv
# No need for plangym. As long as it has a valid step_batch method, and the observation/action
# spaces it is ok. The `make_transitions` method only uses step_batch.
# I added the minimum stuff that it needs to work.
class FragileEnvironment:
"""Fragile Environment for solving Mathy problems."""
problem: Optional[str]
def __init__(
self,
name: str,
environment: str = "poly",
difficulty: str = "normal",
problem: str = None,
max_steps: int = 64,
**kwargs,
):
super(FragileEnvironment, self).__init__(name=name)
self._env: MathyGymEnv = gym.make(
f"mathy-{environment}-{difficulty}-v0",
np_observation=True,
error_invalid=False,
env_problem=problem,
**kwargs,
)
self.observation_space = spaces.Box(
low=0, high=MathTypeKeysMax, shape=(256, 256, 1), dtype=np.uint8,
)
self.action_space = spaces.Discrete(self._env.action_size)
self.problem = problem
self.max_steps = max_steps
self._env.reset()
def get_state(self) -> np.ndarray:
assert self._env.state is not None, "env required to get_state"
return self._env.state.to_np()
def set_state(self, state: np.ndarray):
assert self._env is not None, "env required to set_state"
self._env.state = MathyEnvState.from_np(state)
return state
def step(
self,
action: int,
state: np.ndarray = None,
) -> tuple:
assert self._env is not None, "env required to step"
assert state is not None, "only works with state stepping"
self.set_state(state)
obs, reward, _, info = self._env.step(action)
oob = not info.get("valid", False)
new_state = self.get_state()
return new_state, obs, reward, oob, info
def step_batch(
self, actions, states=None
) -> tuple:
data = [self.step(action, state) for action, state in zip(actions, states)]
new_states, observs, rewards, oobs, infos = [], [], [], [], []
for d in data:
new_state, obs, _reward, end, info = d
new_states.append(new_state)
observs.append(obs)
rewards.append(_reward)
oobs.append(end)
infos.append(info)
return new_states, observs, rewards, oobs, infos
def reset(self):
assert self._env is not None, "env required to reset"
obs = self._env.reset()
return self.get_state(), obs
# I added the make_transition function so you don't have to be messing with the states directly.
# That is already handled by the superclass.
class MathyFragileEnv(DiscreteEnv):
"""The DiscreteEnv acts as an interface with `plangym` discrete actions.
It can interact with any environment that accepts discrete actions and \
follows the interface of `plangym`.
"""
def __init__(self, name: str,
environment: str = "poly",
difficulty: str = "normal",
problem: str = None,
max_steps: int = 64,
**kwargs,
):
"""Initialize a :class:`MathyFragileEnv`."""
self._env = FragileEnvironment(name=name, environment=environment,
difficulty=difficulty, problem=problem,
max_steps=max_steps, **kwargs)
self._n_actions = self._env.action_space.n
super(DiscreteEnv, self).__init__(
states_shape=self._env.get_state().shape,
observs_shape=self._env.observation_space.shape,
)
def __getattr__(self, item):
return getattr(self._env, item)
def make_transitions(
self, states: np.ndarray, actions: np.ndarray, dt: Union[np.ndarray, int]
) -> Dict[str, np.ndarray]:
"""
Step the underlying :class:`plangym.Environment` using the ``step_batch`` \
method of the ``plangym`` interface.
"""
new_states, observs, rewards, oobs, infos = self._env.step_batch(
actions=actions, states=states
)
# I changed the order of how you get terminals and oobs, because the envs in plangym return
# oobs by default. this way you won't need to modify anything in case I make changes in the
# future (unlikely but better be safe than sorry)
terminals = [inf.get("done", False) for inf in infos]
data = {
"states": np.array(new_states),
"observs": np.array(observs),
"rewards": np.array(rewards),
"oobs": np.array(oobs),
"terminals": np.array(terminals),
}
return data
# Creating the swarm
# When choosing reward scales set the lower one to 1 and the most important scale
# to a bigger integer to avoid unnecessary exponentiation
# If you don't need to recover the full path you don't need the Tree. I case you need it,
# now it takes a callable like
# tree=lambda: HistoryTree(prune=True, names=["states", "observs". "actions", "whatever"])
# I also removed default values that don't need to be changed.
# creating env callables:
env_callable = lambda: FragileMathyEnv(**your_env_parameters) # No parallelization
# With parallelization
from fragile.distributed import ParallelEnv
env_callable = ParallelEnv(env_callable=env_callable)
swarm = Swarm(
model=lambda env: DiscreteMasked(env=env),
env=env_callable,
reward_limit=EnvRewards.WIN,
n_walkers=config.n_walkers,
max_epochs=config.max_iters,
reward_scale=1,
distance_scale=3,
distance_function=mathy_dist,
show_pbar=False,
)
# Also, now the Swarm has `swarm.best_state`, `swarm.best_obs`, etc, so you don't have to be
# making instrospection to get the best value found.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment