Skip to content

Instantly share code, notes, and snippets.

View Prajithp's full-sized avatar
🏠
Working from home

Prajith Prajithp

🏠
Working from home
View GitHub Profile
"""
The following LSTM is the model to use after the OCR extraction, where it predicts the key-value pairs after the texts
are extracted using the previous get_info() method.
"""
class ExtractLSTM(nn.Module):
def __init__(self, vocab_size, embed_size, hidden_size):
super().__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers=2, bidirectional=True)
self.linear = nn.Linear(hidden_size * 2, 5)