Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@calebdre
Last active March 2, 2020 14:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save calebdre/925bc689446059e0e7f466025bb1299d to your computer and use it in GitHub Desktop.
Save calebdre/925bc689446059e0e7f466025bb1299d to your computer and use it in GitHub Desktop.
Pokemon Model
import torch.nn as nn
class PokeAgent(nn.Module):
def __init__(self, vocab_size, num_choices):
super(PokeAgent, self).__init__()
self.embedding = nn.Embedding(vocab_size, 32)
self.linear = nn.Linear(32, 32)
self.activation = nn.ReLU()
self.out = nn.Linear(32, num_choices)
def forward(self, turn_input):
embedding = self.embedding(turn_input)
embedding = embedding.mean(dim=0)
x = self.linear(embedding)
x = self.activation(x)
x = self.out(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment