Skip to content

Instantly share code, notes, and snippets.

@kngwyu
Last active July 17, 2021 11:05
Show Gist options
  • Save kngwyu/58c2aedcc7d72d866cb7c8e9c6388f32 to your computer and use it in GitHub Desktop.
Save kngwyu/58c2aedcc7d72d866cb7c8e9c6388f32 to your computer and use it in GitHub Desktop.
PPO with Gaussian Policy implemented by JAX
import dataclasses
import functools
import typing as t
import chex
import distrax
import gym
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax
import rlax
import typer
from chex import Array as JaxArray
Observation = np.ndarray
AgentOutput = t.Any
Action = np.ndarray
Actor = t.Callable[[Observation], t.Tuple[Action, AgentOutput]]
@dataclasses.dataclass
class ReturnReporter:
reward_sum: float = 0.0
episode_returns: t.List[float] = dataclasses.field(default_factory=list)
def experience(self, reward: float, done: bool) -> None:
self.reward_sum += reward
if done:
print(f"Episodic return: {self.reward_sum}")
self.episode_returns.append(self.reward_sum)
self.reward_sum = 0
@dataclasses.dataclass
class RolloutResult:
observations: t.List[Observation]
actions: t.List[Action] = dataclasses.field(default_factory=list)
rewards: t.List[float] = dataclasses.field(default_factory=list)
terminals: t.List[bool] = dataclasses.field(default_factory=list)
outputs: t.List[AgentOutput] = dataclasses.field(default_factory=list)
def rollout(
env: gym.Env,
initial_obs: Observation,
n_steps: int,
actor: Actor,
reporter: t.Callable[[float, bool], None],
) -> t.Tuple[Observation, RolloutResult]:
last_obs = initial_obs
result = RolloutResult(observations=[last_obs])
for i in range(n_steps):
jnp_obs = jnp.array(np.expand_dims(last_obs, axis=0))
action, output = jax.lax.stop_gradient(actor(jnp_obs))
obs, reward, terminal, _ = env.step(action)
obs = obs.flatten()
result.observations.append(obs)
result.actions.append(action)
result.rewards.append(reward)
result.terminals.append(terminal)
result.outputs.append(output)
reporter(reward, terminal)
if terminal:
last_obs = env.reset()
else:
last_obs = obs
return last_obs, result
class GaussianPiAndVNet(hk.Module):
"""A simple network."""
def __init__(self, action_dim: int) -> None:
super().__init__()
self._action_dim = action_dim
def __call__(
self,
observation: np.ndarray,
) -> t.Tuple[JaxArray, JaxArray, JaxArray]:
"""Process a batch of observations."""
torso = hk.Sequential(
[hk.Flatten(), hk.Linear(128), jax.nn.relu, hk.Linear(64), jax.nn.relu]
)
hidden = torso(observation)
pi_mu = hk.Linear(self._action_dim)(hidden)
pi_log_std = hk.get_parameter(
"pi_log_std",
(1, self._action_dim),
init=jnp.zeros,
)
baseline = hk.Linear(1)(hidden)
baseline = jnp.squeeze(baseline, axis=-1)
return pi_mu, pi_log_std, baseline
@chex.dataclass(frozen=True, mappable_dataclass=False)
class PPOBatch:
observation: JaxArray
action: JaxArray
reward: JaxArray
mask: JaxArray
advantage: JaxArray
value_target: JaxArray
log_prob: JaxArray
def __getitem__(self, idx: JaxArray) -> t.Any:
return self.__class__(
observation=self.observation[idx],
action=self.action[idx],
reward=self.reward[idx],
mask=self.mask[idx],
advantage=self.advantage[idx],
value_target=self.value_target[idx],
log_prob=self.log_prob[idx],
)
def make_ppo_batch(
rollout_result: RolloutResult,
next_value: JaxArray,
gamma: float,
gae_lambda: float,
) -> PPOBatch:
observation, action, reward, terminal = map(
jnp.array, dataclasses.astuple(rollout_result)[:-1]
)
mu, logstd, value = map(jnp.concatenate, zip(*rollout_result.outputs))
value = jnp.concatenate([value, next_value])
mask = 1.0 - terminal
advantage = rlax.truncated_generalized_advantage_estimation(
reward, mask * gamma, gae_lambda, value
)
value_target = advantage + value[:-1]
policy = distrax.LogStddevNormal(mu, logstd)
return PPOBatch(
observation=observation,
action=action,
reward=reward,
mask=mask,
advantage=advantage,
value_target=value_target,
log_prob=policy.log_prob(action),
)
class Agent:
def __init__(
self,
network: hk.Transformed,
clip_epsilon: float,
entropy_coeff: float,
) -> None:
self._network = network
self._clip_epsilon = clip_epsilon
self._entropy_coef = entropy_coeff
@functools.partial(jax.jit, static_argnums=0)
def act(
self,
observation: JaxArray,
*,
rng_key: JaxArray,
params: hk.Params,
) -> t.Tuple[JaxArray, t.Tuple[JaxArray, JaxArray]]:
mu, logstd, value = self._network.apply(params, observation)
_, step_key = jax.random.split(rng_key)
distrib = distrax.LogStddevNormal(mu.flatten(), logstd.flatten())
return distrib.sample(seed=step_key), (mu, logstd, value)
def _loss(self, params: hk.Params, batch: PPOBatch) -> JaxArray:
net_vmap = jax.vmap(self._network.apply, (None, 0))
mu, logstd, value = net_vmap(params, batch.observation)
policy = distrax.LogStddevNormal(mu, logstd)
log_prob = policy.log_prob(batch.action)
prob_ratio = jnp.exp(jnp.sum(log_prob - batch.log_prob, axis=-1))
clipped_prob_ratio = jnp.clip(
prob_ratio,
1.0 - self._clip_epsilon,
1.0 + self._clip_epsilon,
)
clipped_objective = jnp.fmin(
prob_ratio * batch.advantage, clipped_prob_ratio * batch.advantage
)
policy_loss = -jnp.mean(clipped_objective)
entropy_loss = -jnp.mean(policy.entropy())
value_loss = jnp.mean(rlax.l2_loss(value - batch.value_target))
return policy_loss + value_loss + self._entropy_coef * entropy_loss
def get_updater(
loss_function: t.Callable[..., t.Any],
updater: optax.TransformUpdateFn,
) -> t.Callable[..., t.Tuple[hk.Params, optax.OptState]]:
@jax.jit
def update(
params: hk.Params,
opt_state: optax.OptState,
ppo_batch: PPOBatch,
) -> t.Tuple[hk.Params, optax.OptState]:
g = jax.grad(loss_function)(params, ppo_batch)
updates, new_opt_state = updater(g, opt_state)
return optax.apply_updates(params, updates), new_opt_state
return update
def batch_sample_indices(
n_instances: int,
n_minibatches: int,
rng_key: JaxArray,
) -> t.Iterable[JaxArray]:
indices = jax.random.permutation(rng_key, n_instances)
minibatch_size = n_instances // n_minibatches
for start in range(0, n_instances, minibatch_size):
yield indices[start : start + minibatch_size]
def main(
total_steps: int = 100000,
n_rollout_steps: int = 128,
n_minibatches: int = 1,
n_epochs: int = 10,
gamma: float = 0.99,
gae_lambda: float = 0.95,
clip_epsilon: float = 0.1,
entropy_coeff: float = 0.01,
seed: int = 1,
env: str = "Hopper-v3",
render: bool = False,
) -> None:
env = gym.make(env)
current_obs = env.reset()
action_dim = env.action_space.shape[0]
network = hk.without_apply_rng(
hk.transform(lambda ts: GaussianPiAndVNet(action_dim)(ts))
)
# Construct the agent
agent = Agent(network, clip_epsilon, entropy_coeff)
opt = optax.adam(3e-4, eps=1e-4)
updater = get_updater(agent._loss, opt.update)
rng_seq = hk.PRNGSequence(1)
# Initialize the optimizer state.
params = jax.jit(network.init)(
next(rng_seq),
np.expand_dims(current_obs, axis=0),
)
opt_state = opt.init(params)
reporter = ReturnReporter()
for _ in range(total_steps // n_rollout_steps):
current_obs, rollout_result = rollout(
env,
current_obs,
n_rollout_steps,
functools.partial(
agent.act,
rng_key=next(rng_seq),
params=params,
),
reporter.experience,
)
_, _, next_value = network.apply(params, np.expand_dims(current_obs, 0))
ppo_batch = make_ppo_batch(rollout_result, next_value, gamma, gae_lambda)
for _ in range(n_epochs):
indices_iter = batch_sample_indices(
n_rollout_steps,
n_minibatches,
next(rng_seq),
)
for indices in indices_iter:
minibatch = ppo_batch[indices]
params, opt_state = updater(params, opt_state, minibatch)
if render:
env.render()
np.save("result.npy", np.array(reporter.episode_returns))
if __name__ == "__main__":
typer.run(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment