Created
May 18, 2024 11:58
-
-
Save mohamad-amin/608634f0688ee0f004b28c0a04c9a00f to your computer and use it in GitHub Desktop.
Q-Learning iterated Prisoner's Dilemma with Transformers, implemented in JAX
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
from flax import linen as nn | |
import optax | |
from functools import partial | |
class TransformerBlock(nn.Module): | |
embed_dim: int | |
num_heads: int | |
mlp_dim: int | |
def setup(self): | |
self.self_attention = nn.SelfAttention(num_heads=self.num_heads, qkv_features=self.embed_dim) | |
self.layer_norm1 = nn.LayerNorm() | |
self.mlp = nn.Sequential([ | |
nn.Dense(self.mlp_dim), | |
nn.relu, | |
nn.Dense(self.embed_dim) | |
]) | |
self.layer_norm2 = nn.LayerNorm() | |
def __call__(self, x, train: bool): | |
# Self-attention block | |
attn_output = self.self_attention(x) | |
x = x + attn_output | |
x = self.layer_norm1(x) | |
# Feed-forward block | |
mlp_output = self.mlp(x) | |
x = x + mlp_output | |
x = self.layer_norm2(x) | |
return x | |
class TransformerModel(nn.Module): | |
num_layers: int | |
embed_dim: int | |
num_heads: int | |
mlp_dim: int | |
num_actions: int | |
block_size: int | |
def setup(self): | |
self.embedding = nn.Embed(num_embeddings=self.num_actions * 2 + 1, features=self.embed_dim) # Actions + pad token | |
self.position_embedding = nn.Embed(num_embeddings=self.block_size, features=self.embed_dim) | |
self.transformer_blocks = [TransformerBlock(self.embed_dim, self.num_heads, self.mlp_dim) for _ in range(self.num_layers)] | |
self.final_layer = nn.Dense(self.num_actions) | |
def __call__(self, x, train: bool = True): | |
if x.ndim == 1: | |
x = x.reshape(1, -1) | |
seq_len = x.shape[1] | |
x = self.embedding(x) | |
position_indices = jnp.arange(seq_len) | |
position_encodings = self.position_embedding(position_indices) | |
x += position_encodings | |
for block in self.transformer_blocks: | |
x = block(x, train) | |
# Only the output corresponding to the next game prediction | |
return self.final_layer(x[:, -1]) | |
def compute_loss(params, model, state, action, reward, next_state, next_action, gamma=0.99): | |
q_values = model.apply(params, state, train=False)[0] | |
next_q_values = model.apply(params, next_state, train=False)[0] | |
# q_values = jnp.clip(q_values, 0, 1) | |
# next_q_values = jnp.clip(next_q_values, 0, 1) | |
target = reward / MAX_PAYOFF + gamma * jnp.max(next_q_values) | |
loss = jnp.mean((q_values[action] - target) ** 2) | |
return loss | |
def compute_loss_2(params, model, state, action, reward, next_state, next_action, gamma=0.99): | |
q_values = model.apply(params, state, train=False)[0] | |
next_q_values = model.apply(params, next_state, train=False)[0] | |
target = reward / MAX_PAYOFF + gamma * jnp.max(next_q_values) | |
loss = jnp.mean((q_values[action] - target) ** 2) | |
return loss, q_values, next_q_values, target | |
@partial(jax.jit, static_argnums=(0,)) | |
def _select_action(model, params, state): | |
q_values = model.apply(params, state, train=False)[0] | |
return jnp.argmax(q_values).astype(int) | |
@partial(jax.jit, static_argnums=(1,)) | |
def _accumulate_gradients(params, model, opt_state, state, action, reward, next_state, gamma=0.99): | |
state = jnp.array(state, dtype=jnp.int32).reshape(1, -1) | |
next_state = jnp.array(next_state, dtype=jnp.int32).reshape(1, -1) | |
loss, grads = jax.value_and_grad(compute_loss)(params, model, state, action, reward, next_state, gamma) | |
return loss, grads | |
@partial(jax.jit, static_argnums=(3,)) | |
def _update_params(params, grads, opt_state, optimizer): | |
updates, opt_state = optimizer.update(grads, opt_state) | |
params = optax.apply_updates(params, updates) | |
return params, opt_state | |
class PrisonersDilemmaEnv: | |
def __init__(self): | |
self.num_actions = 2 | |
self.payoff_matrix = { | |
(0, 0): (3, 3), | |
(0, 1): (0, 5), | |
(1, 0): (5, 0), | |
(1, 1): (1, 1) | |
} | |
self.reset() | |
def reset(self): | |
return jnp.zeros((1,)) | |
def step(self, actions): # Todo: multiple actions | |
rewards = self.payoff_matrix[actions[0], actions[1]] | |
return rewards | |
class QLearningAgent: | |
def __init__(self, model, params, learning_rate=0.001, block_size=128, memory_size=128): | |
self.model = model | |
self.params = params | |
self.optimizer = optax.adam(learning_rate) | |
self.opt_state = self.optimizer.init(params) | |
self.block_size = block_size | |
self.memory_size = memory_size | |
self.history = [] | |
self.gradients = [] | |
self.losses = [] | |
def probs(self): | |
state = self.get_state() | |
probs = jax.nn.softmax(self.model.apply(self.params, state, train=False)[0]) | |
return probs | |
def select_action(self, epsilon): | |
state = self.get_state() | |
if np.random.rand() < epsilon: | |
result = np.random.randint(2), state | |
print('Choosing random!', result[0]) | |
return result | |
else: | |
action = _select_action(self.model, self.params, state) | |
return action, state | |
def calc_rewards(self, other_action, env): | |
q_values = self.model.apply(self.params, self.get_state(), train=False)[0] | |
action = jnp.argmax(q_values).astype(int) | |
rewards = env.step((action, other_action)) | |
return rewards | |
def get_state(self, next_action=None): | |
history = self.history[-self.memory_size:] | |
if next_action is not None: | |
history += [next_action] | |
state = jnp.array(history, dtype=jnp.int32) | |
padding_length = self.memory_size - len(state) | |
if padding_length > 0: | |
state = jnp.pad(state, (0, padding_length), constant_values=4) # Use 4 as the pad token. | |
return state.reshape(1, -1) | |
def get_action_pair(self, action1, action2): | |
return action1 * 2 + action2 | |
def add_to_history(self, action1, action2): | |
action_pair = action1 * 2 + action2 | |
self.history.append(action_pair) | |
def accumulate_gradients(self, state, action, reward, next_state, gamma=0.99): | |
loss, grads = _accumulate_gradients(self.params, self.model, self.opt_state, state, action, reward, next_state, gamma) | |
self.gradients.append(grads) | |
self.losses.append(loss) | |
def update(self): | |
mean_grads = jax.tree_util.tree_map(lambda *x: jnp.mean(jnp.stack(x), axis=0), *self.gradients) | |
self.params, self.opt_state = _update_params(self.params, mean_grads, self.opt_state, self.optimizer) | |
self.gradients = [] | |
env = PrisonersDilemmaEnv() | |
# Training loop | |
num_episodes = 500 | |
update_every = 10 | |
epsilon = 0.25 | |
for episode in range(num_episodes): | |
action1, state1 = p1.select_action(epsilon) | |
action2, state2 = p2.select_action(epsilon) | |
action1 = int(action1) | |
action2 = int(action2) | |
p1.add_to_history(action1, action2) | |
p2.add_to_history(action2, action1) | |
reward1, reward2 = env.step((action1, action2)) | |
next_state1 = p1.get_state() | |
next_state2 = p2.get_state() | |
p1.accumulate_gradients(state1, action1, reward1, next_state1, gamma=GAMMA) | |
p2.accumulate_gradients(state2, action2, reward2, next_state2, gamma=GAMMA) | |
if episode % 1 == 0: | |
p1.update() | |
p2.update() | |
if (episode + 1) % update_every == 0: | |
print(f"Episode {episode + 1}: Models updated") | |
epsilon = max(0.001, epsilon * 0.99) | |
print("Training completed.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment