Created
May 4, 2025 23:14
-
-
Save anshkush92/694bc635c0056b09547fdb4916ea7215 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class Attention(nn.Module): | |
def __init__(self, nHidden): | |
super(Attention, self).__init__() | |
self.attention = nn.Linear(nHidden * 2, nHidden * 2) | |
self.v = nn.Linear(nHidden * 2, 1, bias=False) | |
def forward(self, rnn_output): | |
T, b, h = rnn_output.size() | |
# [T, b, h] -> apply tanh then project down to a scalar | |
energy = torch.tanh(self.attention(rnn_output)) # [T, b, h] | |
attention_weights = F.softmax(self.v(energy), dim=0) # [T, b, 1] | |
attended_output = torch.sum(attention_weights * rnn_output, dim=0) # [b, h] | |
return attended_output.unsqueeze(0) # [1, b, h] | |
class BidirectionalLSTM(nn.Module): | |
def __init__(self, nIn, nHidden, nOut, use_attention=False, dropout=0.0): | |
super(BidirectionalLSTM, self).__init__() | |
# Added dropout to LSTM | |
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True, dropout=dropout, num_layers=1) | |
self.embedding = nn.Linear(nHidden * 2, nOut) | |
self.use_attention = use_attention | |
if use_attention: | |
self.attention = Attention(nHidden) | |
def forward(self, input): | |
recurrent, _ = self.rnn(input) # [T, b, 2*nHidden] | |
if self.use_attention: | |
# Use attention only if self.use_attention is True | |
recurrent = self.attention(recurrent) # [1, b, 2*nHidden] | |
output = self.embedding(recurrent) # [T or 1, b, nOut] | |
return output | |
class CRNN(nn.Module): | |
def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False, | |
use_attention=True, lstm_dropout=0.0): | |
super(CRNN, self).__init__() | |
assert imgH % 16 == 0, 'imgH has to be a multiple of 16' | |
ks = [3, 3, 3, 3, 3, 3, 2] | |
ps = [1, 1, 1, 1, 1, 1, 0] | |
ss = [1, 1, 1, 1, 1, 1, 1] | |
nm = [64, 128, 256, 256, 512, 512, 512] | |
cnn = nn.Sequential() | |
def convRelu(i, batchNormalization=False): | |
nIn = nc if i == 0 else nm[i - 1] | |
nOut = nm[i] | |
cnn.add_module('conv{0}'.format(i), | |
nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i])) | |
if batchNormalization: | |
cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut)) | |
if leakyRelu: | |
cnn.add_module('relu{0}'.format(i), nn.LeakyReLU(0.2, inplace=True)) | |
else: | |
cnn.add_module('relu{0}'.format(i), nn.ReLU(True)) | |
# Enable BN in all conv layers (you can selectively enable it on certain layers too) | |
convRelu(0, batchNormalization=True) | |
cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64 | |
convRelu(1, batchNormalization=True) | |
cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32 | |
convRelu(2, batchNormalization=True) | |
convRelu(3, batchNormalization=True) | |
cnn.add_module('pooling{0}'.format(2), nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16 | |
convRelu(4, batchNormalization=True) | |
convRelu(5, batchNormalization=True) | |
cnn.add_module('pooling{0}'.format(3), nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16 | |
convRelu(6, batchNormalization=True) # 512x1x16 | |
self.cnn = cnn | |
# We create 2 BiLSTM layers in sequence. | |
self.rnn = nn.Sequential( | |
BidirectionalLSTM(512, nh, nh, use_attention=use_attention, dropout=lstm_dropout), | |
BidirectionalLSTM(nh, nh, nclass, use_attention=False, dropout=lstm_dropout) | |
) | |
def forward(self, input): | |
conv = self.cnn(input) | |
b, c, h, w = conv.size() | |
assert h == 1, "the height of conv must be 1" | |
conv = conv.squeeze(2) # [b, c, w] | |
conv = conv.permute(2, 0, 1) # [w, b, c] | |
output = self.rnn(conv) # [w or 1, b, nclass] | |
output = F.log_softmax(output, dim=2) | |
return output |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment