Skip to content

Instantly share code, notes, and snippets.

View thomasahle's full-sized avatar
♟️

Thomas Dybdahl Ahle thomasahle

♟️
View GitHub Profile
@thomasahle
thomasahle / tetraminos.py
Created November 27, 2020 13:43
Packing tetraminos
import collections
from ortools.linear_solver import pywraplp
# xy format
tetras = [
# Up downs
((0,0),(0,1),(1,1),(1,2)),
((0,0),(0,1),(-1,1),(-1,2)),
# left rights
((0,0),(1,0),(1,-1),(2,-1)),
class ParameterDeque(nn.Module):
def __init__(self) -> None:
super(ParameterDeque, self).__init__()
self.left = 0
self.right = 0 # Points at the first non-existing element
def _convert_idx(self, idx):
"""Get the absolute index for the list of modules"""
idx = operator.index(idx)
@thomasahle
thomasahle / train.py
Created December 19, 2021 21:43
Selfplay
def play(r1, r2, replay_buffer):
privs = [game.make_priv(r1, 0).to(device),
game.make_priv(r2, 1).to(device)]
def play_inner(state):
cur = game.get_cur(state)
calls = game.get_calls(state)
assert cur == len(calls) % 2
if calls and calls[-1] == game.LIE_ACTION:
def update(state, t):
pi = compute_policy(state, t)
score = 0
for i in actions(state):
score_i = update(state + action)
score += pi[i] * score_i
state.mean_score = (state.mean_score * t + score)/(t + 1)
return score
def compute_policy(state):
def update(state, t):
pi = compute_policy(state, t)
score = 0
for i in actions(state):
score_i = update(state + i)
score += pi[i] * score_i
state.mean_score = (state.mean_score * t + score)/(t + 1)
return score
def compute_policy(state):
def play(r1, r2, replay_buffer):
privs = [game.make_priv(r1, 0).to(device),
game.make_priv(r2, 1).to(device)]
def play_inner(state):
cur = game.get_cur(state)
calls = game.get_calls(state)
assert cur == len(calls) % 2
if calls and calls[-1] == game.LIE_ACTION:
def play(r1, r2, replay_buffer):
privs = [game.make_priv(r1, 0), game.make_priv(r2, 1)]
def play_inner(state):
cur = game.get_cur(state)
calls = game.get_calls(state)
assert cur == len(calls) % 2
if calls and calls[-1] == game.LIE_ACTION:
prev_call = calls[-2] if len(calls) >= 2 else -1
def play(r1, r2, replay_buffer):
privs = [game.make_priv(r1, 0), game.make_priv(r2, 1)]
def play_inner(state):
cur = game.get_cur(state) # Current player id
calls = game.get_calls(state) # Bets made by player so far
if calls and calls[-1] == game.LIE_ACTION:
prev_call = calls[-2] if len(calls) >= 2 else -1
# If prev_call is good it mean we won (because our opponent called lie)
@thomasahle
thomasahle / scatter.py
Created June 7, 2022 22:41
How to use pytorch scatter_add
[nav] In [48]: table = torch.arange(12, dtype=torch.float32).reshape(4,3)
[ins] In [49]: new_table = torch.zeros(4, 3)
[ins] In [50]: index = torch.tensor([1,1,0,3])
[ins] In [51]: index2 = index.unsqueeze(1).expand(4,3)
[ins] In [52]: table
Out[52]:
import torch.nn as nn
class TestModule(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Linear(10, 10)
self.b = SubTestModule()
class SubTestModule(nn.Module):
def __init__(self):