Skip to content

Instantly share code, notes, and snippets.

@tomgrek
Last active March 5, 2018 04:44
Show Gist options
  • Save tomgrek/8a6ae9262ffb1415e8332982e595b6ed to your computer and use it in GitHub Desktop.
Save tomgrek/8a6ae9262ffb1415e8332982e595b6ed to your computer and use it in GitHub Desktop.
PyTorch model for a simple AI chatbot
class BotBrain(nn.Module):
def __init__(self):
super().__init__()
self.embedding = nn.Embedding(len(words), 10)
self.rnn = nn.LSTM(10, 20, 2, dropout=0.5)
self.h = (Variable(torch.zeros(2, 1, 20)), Variable(torch.zeros(2, 1, 20)))
self.l_out = nn.Linear(20, len(words))
def forward(self, cs):
inp = self.embedding(cs)
outp,h = self.rnn(inp, self.h)
out = F.log_softmax(self.l_out(outp), dim=-1).view(-1, len(words))
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment