Skip to content

Instantly share code, notes, and snippets.

@ttchengab
Created October 21, 2020 02:48
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ttchengab/b81ea8bb1c21121237845d65d15aa3a0 to your computer and use it in GitHub Desktop.
Save ttchengab/b81ea8bb1c21121237845d65d15aa3a0 to your computer and use it in GitHub Desktop.
"""
The following LSTM is the model to use after the OCR extraction, where it predicts the key-value pairs after the texts
are extracted using the previous get_info() method.
"""
class ExtractLSTM(nn.Module):
def __init__(self, vocab_size, embed_size, hidden_size):
super().__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers=2, bidirectional=True)
self.linear = nn.Linear(hidden_size * 2, 5)
def forward(self, inpt):
embedded = self.embed(inpt)
feature, _ = self.lstm(embedded)
oupt = self.linear(feature)
return oupt
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment