Skip to content

Instantly share code, notes, and snippets.

@arthurmensch
Created April 15, 2019 21:00
Show Gist options
  • Save arthurmensch/379a69ce2971ff2dc88a172e9e8ec249 to your computer and use it in GitHub Desktop.
Save arthurmensch/379a69ce2971ff2dc88a172e9e8ec249 to your computer and use it in GitHub Desktop.
import json
import math
import time
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import os
from os.path import join, expanduser
from joblib import Parallel, delayed
def make_hessian(n_players, n_actions, cond=.1,
asym=.5,
seed=None):
np.random.seed(seed)
size = n_players * n_actions
A = np.random.randn(size, size)
A = .5 * (A + A.T)
vs, _ = np.linalg.eigh(A)
max_v = np.max(vs)
A -= np.eye(size) * (np.min(vs) + max_v * cond)
vs, _ = np.linalg.eigh(A)
B = np.random.randn(size, size)
B = .5 * (B - B.T)
H = A * (1 - asym) + B * asym
return H.reshape((n_players, n_actions, n_players, n_actions))
def solve_nash(H, n_iter=100, step_size=1., inner_step_size=1.,
subsampling=1., history_file=None):
torch.set_grad_enabled(False)
# Generate random game
H = torch.from_numpy(H)
n_players, n_actions = H.shape[:2]
H_value = H.clone()
H_value[range(n_players), :, range(n_players), :] /= 2
log_policies = torch.randn(n_players, n_actions).type(H.type())
avg_policies = torch.zeros(n_players, n_actions).type(H.type())
timing = 0
gradient_computations = 0
values_r = []
policies_r = []
gradient_computations_r = []
timings_r = []
gap_r = []
policies = torch.softmax(log_policies, dim=1)
for t in range(n_iter):
t0 = time.perf_counter()
mask = torch.empty(n_players).uniform_(0, 1.) < subsampling
if torch.any(mask):
gradient_computations += torch.sum(mask.float()).item()
grad = torch.einsum('ijkl, kl -> ij', H[mask], policies)
extra_log_policies = log_policies.clone()
extra_log_policies[mask] -= step_size * grad / (t + 1)
extra_policies = policies.clone()
extra_policies[mask] = torch.softmax(extra_log_policies[mask],
dim=1)
else:
extra_policies = policies.clone()
mask = torch.empty(n_players).uniform_(0, 1.) < subsampling
if torch.any(mask):
gradient_computations += torch.sum(mask.float()).item()
extra_grad = torch.einsum('ijkl, kl -> ij', H[mask],
extra_policies)
log_policies[mask] -= step_size * extra_grad
log_policies[mask] -= torch.logsumexp(log_policies[mask], dim=1)[:, None]
policies = torch.softmax(log_policies, dim=1)
avg_policies *= (1 - 1 / (t + 1))
avg_policies += policies / (t + 1)
timing += time.perf_counter() - t0
# Value computation
values = torch.sum(
torch.einsum('ijkl, kl -> ij', H_value, avg_policies)
* policies, dim=1)
adv_log_policies = torch.log(avg_policies)
for tt in range(100):
adv_policies = torch.softmax(adv_log_policies, dim=1)
adv_grad = torch.einsum('ijkl, kl -> ij', H,
avg_policies - 2 * adv_policies)
gap = torch.sum(torch.einsum('ijkl, kl -> ij', H,
adv_policies - avg_policies) * adv_policies)
adv_log_policies += adv_grad * inner_step_size / (tt + 1)
adv_log_policies -= torch.logsumexp(adv_log_policies, dim=1)[:, None]
if t % 100 == 0:
print(f'Iter {t}, values {values}, gap {gap}')
values_r.append(values.tolist())
gap_r.append(gap.item())
gradient_computations_r.append(gradient_computations)
timings_r.append(timing)
policies_r.append(policies.tolist())
history = {'values': values_r,
'policies': policies_r,
'gap': gap_r,
'gradient_computations': gradient_computations_r,
'timings': timings_r,
'iterations': list(range(n_iter),),
'subsampling': subsampling,
'n_players': n_players,
}
if history_file is not None:
with open(history_file, 'w+') as f:
json.dump(history, f)
return avg_policies.numpy(), history
def plot_compare(output_dir):
fig, axes = plt.subplots(1, 4, figsize=(14, 4), constrained_layout=True)
handles = []
player_handles = []
labels = []
player_labels = []
for index in range(4):
with open(join(output_dir, f'history_{index}.json'), 'r') as f:
res = json.load(f)
timings = res['timings']
iterations = res['iterations']
values = res['values']
p = res['subsampling']
n_players = res['n_players']
gap = res['gap']
values = np.array(values)
for player in range(n_players):
cmap = sns.light_palette((23 * player, 90, 60), input="husl",
n_colors=10, reverse=True)
h, = axes[0].plot(timings, values[:, player], color=cmap[index],
)
axes[1].plot(iterations, values[:, player], color=cmap[index])
if index == 0:
player_handles.append(h)
player_labels.append(f'Player {player}')
cmap = sns.light_palette((0, 90, 60), input="husl",
n_colors=5, reverse=True)
h, = axes[2].plot(timings, gap, color=cmap[index])
handles.append(h)
labels.append(f'p = {p:.1f}')
axes[3].plot(iterations, gap, color=cmap[index])
fig.legend(handles, labels, ncol=2,
bbox_to_anchor=[0.6, 0.9],
loc='upper left', frameon=False)
fig.legend(player_handles, player_labels, ncol=2,
bbox_to_anchor=[0.1, 0.9],
loc='upper left', frameon=False)
axes[0].set_xlabel('CPU time')
axes[1].set_xlabel('Iteration')
axes[2].set_xlabel('CPU time')
axes[2].set_ylabel('VI Gap')
axes[3].set_xlabel('Iteration')
axes[0].set_ylabel('Reward')
sns.despine(fig)
plt.savefig(join(output_dir, 'compare.pdf'))
def run():
n_players = 2
n_actions = 2
n_iter = 1000
step_size = 1
subsampling = .5
H = make_hessian(n_players, n_actions, asym=1., seed=1)
policies, history = solve_nash(H, n_iter=n_iter, step_size=step_size,
subsampling=subsampling)
def run_many(output_dir):
n_players = 4
n_actions = 4
n_iter = 1000
step_size = 1
H = make_hessian(n_players, n_actions, asym=.9, seed=1)
Parallel(n_jobs=4)(
delayed(solve_nash)(H, n_iter=n_iter, step_size=step_size,
subsampling=subsampling,
history_file=join(output_dir, f'history_{i}.json'))
for i, subsampling in enumerate(np.linspace(.25, 1, 4)))
output_dir = expanduser('~/output/games_rl/subsampling_simple')
if not os.path.exists(output_dir):
os.makedirs(output_dir)
run_many(output_dir)
plot_compare(output_dir)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment