Skip to content

Instantly share code, notes, and snippets.

@anshkush92
Created May 4, 2025 23:14
Show Gist options
  • Save anshkush92/694bc635c0056b09547fdb4916ea7215 to your computer and use it in GitHub Desktop.
Save anshkush92/694bc635c0056b09547fdb4916ea7215 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Attention(nn.Module):
def __init__(self, nHidden):
super(Attention, self).__init__()
self.attention = nn.Linear(nHidden * 2, nHidden * 2)
self.v = nn.Linear(nHidden * 2, 1, bias=False)
def forward(self, rnn_output):
T, b, h = rnn_output.size()
# [T, b, h] -> apply tanh then project down to a scalar
energy = torch.tanh(self.attention(rnn_output)) # [T, b, h]
attention_weights = F.softmax(self.v(energy), dim=0) # [T, b, 1]
attended_output = torch.sum(attention_weights * rnn_output, dim=0) # [b, h]
return attended_output.unsqueeze(0) # [1, b, h]
class BidirectionalLSTM(nn.Module):
def __init__(self, nIn, nHidden, nOut, use_attention=False, dropout=0.0):
super(BidirectionalLSTM, self).__init__()
# Added dropout to LSTM
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True, dropout=dropout, num_layers=1)
self.embedding = nn.Linear(nHidden * 2, nOut)
self.use_attention = use_attention
if use_attention:
self.attention = Attention(nHidden)
def forward(self, input):
recurrent, _ = self.rnn(input) # [T, b, 2*nHidden]
if self.use_attention:
# Use attention only if self.use_attention is True
recurrent = self.attention(recurrent) # [1, b, 2*nHidden]
output = self.embedding(recurrent) # [T or 1, b, nOut]
return output
class CRNN(nn.Module):
def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False,
use_attention=True, lstm_dropout=0.0):
super(CRNN, self).__init__()
assert imgH % 16 == 0, 'imgH has to be a multiple of 16'
ks = [3, 3, 3, 3, 3, 3, 2]
ps = [1, 1, 1, 1, 1, 1, 0]
ss = [1, 1, 1, 1, 1, 1, 1]
nm = [64, 128, 256, 256, 512, 512, 512]
cnn = nn.Sequential()
def convRelu(i, batchNormalization=False):
nIn = nc if i == 0 else nm[i - 1]
nOut = nm[i]
cnn.add_module('conv{0}'.format(i),
nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
if batchNormalization:
cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
if leakyRelu:
cnn.add_module('relu{0}'.format(i), nn.LeakyReLU(0.2, inplace=True))
else:
cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
# Enable BN in all conv layers (you can selectively enable it on certain layers too)
convRelu(0, batchNormalization=True)
cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64
convRelu(1, batchNormalization=True)
cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
convRelu(2, batchNormalization=True)
convRelu(3, batchNormalization=True)
cnn.add_module('pooling{0}'.format(2), nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
convRelu(4, batchNormalization=True)
convRelu(5, batchNormalization=True)
cnn.add_module('pooling{0}'.format(3), nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
convRelu(6, batchNormalization=True) # 512x1x16
self.cnn = cnn
# We create 2 BiLSTM layers in sequence.
self.rnn = nn.Sequential(
BidirectionalLSTM(512, nh, nh, use_attention=use_attention, dropout=lstm_dropout),
BidirectionalLSTM(nh, nh, nclass, use_attention=False, dropout=lstm_dropout)
)
def forward(self, input):
conv = self.cnn(input)
b, c, h, w = conv.size()
assert h == 1, "the height of conv must be 1"
conv = conv.squeeze(2) # [b, c, w]
conv = conv.permute(2, 0, 1) # [w, b, c]
output = self.rnn(conv) # [w or 1, b, nclass]
output = F.log_softmax(output, dim=2)
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment