Skip to content

Instantly share code, notes, and snippets.

@kentsommer
Created August 21, 2017 05:55
Show Gist options
  • Save kentsommer/4e2ea8e70237330e595f45d2bcecba70 to your computer and use it in GitHub Desktop.
Save kentsommer/4e2ea8e70237330e595f45d2bcecba70 to your computer and use it in GitHub Desktop.
I2A Pacman Model
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.nn.parameter import Parameter
# Pool and Inject Module
class PAI(nn.Module):
def __init__(self, config):
super(I2A, self).__init__()
self.config = config
self.mp = nn.MaxPool2d(config.size)
def forward(self, X, config):
mp = self.mp(X)
tiled = mp.repeat(self.config.size)
concat = torch.cat([mp, tiled])
return concat
# PBasic Block Module
class BB(nn.Module):
def __init__(self, n1, n2, n3, config):
super(BB, self).__init__()
self.config = config
self.PAI = PAI(config)
self.n1 = n1
self.n2 = n2
self.n3 = n3
self.l1 = nn.Conv2d(in_channels=n1,
out_channels=?,
kernel_size=(1, 1),
stride=1, padding=0,
bias=True)
self.l2 = nn.Conv2d(in_channels=n1,
out_channels=?,
kernel_size=(10, 10),
stride=1, padding=0,
bias=True)
self.r1 = nn.Conv2d(in_channels=n2,
out_channels=?,
kernel_size=(1, 1),
stride=1, padding=0,
bias=True)
self.r2 = nn.Conv2d(in_channels=n2,
out_channels=?,
kernel_size=(3, 3),
stride=1, padding=0,
bias=True)
self.m = nn.Conv2d(in_channels=n3,
out_channels=?,
kernel_size=(1, 1),
stride=1, padding=0,
bias=True)
def forward(self, X, config):
l1 = self.l1(X)
l2 = self.l2(l1)
r1 = self.r1(X)
r2 = self.r2(r1)
c1 = torch.cat([l2, r2])
m = self.m(c1)
c2 = torch.cat([X, m])
return c2
# Imagination-Augmented Agents (PACMAN Model)
class I2A(nn.Module):
def __init__(self, config):
super(I2A, self).__init__()
self.config = config
self.b1 = BB(16, 32, 64, config)
self.b2 = BB(16, 32, 64, config)
self.lconv1 = nn.Conv2d(in_channels=64,
out_channels=?,
kernel_size=(1, 1),
stride=1, padding=0,
bias=True)
self.lconv2 = nn.Conv2d(in_channels=3,
out_channels=?,
kernel_size=(1, 1),
stride=1, padding=0,
bias=True)
self.rconv1 = nn.Conv2d(in_channels=64,
out_channels=?,
kernel_size=(1, 1),
stride=1, padding=0,
bias=True)
self.rconv2 = nn.Conv2d(in_channels=64,
out_channels=?,
kernel_size=(1, 1),
stride=1, padding=0,
bias=True)
self.fc = nn.Linear(in_features=64,
out_features=5,
bias=False)
self.sm = nn.Softmax()
def forward(self, X, A, config):
tiled = A.repeat(self.config.size)
c1 = torch.cat([X, tiled])
p_conv1 = self.lconv1(c1)
p_bb1 = self.b1(p_conv1)
p_bb2 = self.b2(p_bb1)
p_conv2 = self.lconv2(p_bb2)
r_conv1 = self.rconv1(p_bb2)
r_conv2 = self.rconv2(r_conv1)
r_fc = self.fc(r_conv2)
r_sm = self.sm(r_fc)
return p_conv2, r_sm
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment