This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import collections | |
from ortools.linear_solver import pywraplp | |
# xy format | |
tetras = [ | |
# Up downs | |
((0,0),(0,1),(1,1),(1,2)), | |
((0,0),(0,1),(-1,1),(-1,2)), | |
# left rights | |
((0,0),(1,0),(1,-1),(2,-1)), |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ParameterDeque(nn.Module): | |
def __init__(self) -> None: | |
super(ParameterDeque, self).__init__() | |
self.left = 0 | |
self.right = 0 # Points at the first non-existing element | |
def _convert_idx(self, idx): | |
"""Get the absolute index for the list of modules""" | |
idx = operator.index(idx) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def play(r1, r2, replay_buffer): | |
privs = [game.make_priv(r1, 0).to(device), | |
game.make_priv(r2, 1).to(device)] | |
def play_inner(state): | |
cur = game.get_cur(state) | |
calls = game.get_calls(state) | |
assert cur == len(calls) % 2 | |
if calls and calls[-1] == game.LIE_ACTION: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def update(state, t): | |
pi = compute_policy(state, t) | |
score = 0 | |
for i in actions(state): | |
score_i = update(state + action) | |
score += pi[i] * score_i | |
state.mean_score = (state.mean_score * t + score)/(t + 1) | |
return score | |
def compute_policy(state): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def update(state, t): | |
pi = compute_policy(state, t) | |
score = 0 | |
for i in actions(state): | |
score_i = update(state + i) | |
score += pi[i] * score_i | |
state.mean_score = (state.mean_score * t + score)/(t + 1) | |
return score | |
def compute_policy(state): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def play(r1, r2, replay_buffer): | |
privs = [game.make_priv(r1, 0).to(device), | |
game.make_priv(r2, 1).to(device)] | |
def play_inner(state): | |
cur = game.get_cur(state) | |
calls = game.get_calls(state) | |
assert cur == len(calls) % 2 | |
if calls and calls[-1] == game.LIE_ACTION: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def play(r1, r2, replay_buffer): | |
privs = [game.make_priv(r1, 0), game.make_priv(r2, 1)] | |
def play_inner(state): | |
cur = game.get_cur(state) | |
calls = game.get_calls(state) | |
assert cur == len(calls) % 2 | |
if calls and calls[-1] == game.LIE_ACTION: | |
prev_call = calls[-2] if len(calls) >= 2 else -1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def play(r1, r2, replay_buffer): | |
privs = [game.make_priv(r1, 0), game.make_priv(r2, 1)] | |
def play_inner(state): | |
cur = game.get_cur(state) # Current player id | |
calls = game.get_calls(state) # Bets made by player so far | |
if calls and calls[-1] == game.LIE_ACTION: | |
prev_call = calls[-2] if len(calls) >= 2 else -1 | |
# If prev_call is good it mean we won (because our opponent called lie) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[nav] In [48]: table = torch.arange(12, dtype=torch.float32).reshape(4,3) | |
[ins] In [49]: new_table = torch.zeros(4, 3) | |
[ins] In [50]: index = torch.tensor([1,1,0,3]) | |
[ins] In [51]: index2 = index.unsqueeze(1).expand(4,3) | |
[ins] In [52]: table | |
Out[52]: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn as nn | |
class TestModule(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.a = nn.Linear(10, 10) | |
self.b = SubTestModule() | |
class SubTestModule(nn.Module): | |
def __init__(self): |
OlderNewer