Skip to content

Instantly share code, notes, and snippets.

@TeaPoly
Last active July 15, 2022 15:45
Show Gist options
  • Save TeaPoly/5b07ead5efe039cfd73f75fe6474b121 to your computer and use it in GitHub Desktop.
Save TeaPoly/5b07ead5efe039cfd73f75fe6474b121 to your computer and use it in GitHub Desktop.
Reduced Embedding Decoders, ref: Tied & Reduced RNN-T Decoder
#!/usr/bin/env python3
# Copyright 2022 Lucky Wong
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import torch
import torch.nn as nn
import torch.nn.functional as F
class ReducedEmbeddingDecoder(torch.nn.Module):
"""This class implements the stateless decoder from the following paper:
Tied & Reduced RNN-T Decoder
https://arxiv.org/pdf/2109.07513.pdf
"""
def __init__(
self,
logit_weight: torch.Tensor,
blank_id: int,
n_head: int,
context_size: int
):
"""
Args:
logit_weight:
The logit weight for shared embedding.
blank_id:
The ID of the blank symbol.
n_head:
The number of position vectors.
context_size:
Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
"""
super().__init__()
self.n_head = n_head
self.context_size = context_size
self.embedding_dim = logit_weight.size()[1]
# Shared embeding
self.blank_id = blank_id
self.embedding = torch.zeros_like(
logit_weight, requires_grad=False)
self.embedding[:, :] = logit_weight
self.embedding[blank_id, :] = torch.zeros(
self.embedding_dim,
dtype=logit_weight.dtype,
device=logit_weight.device,
requires_grad=False)
# Multi-headed position vectors
self.context_size = context_size
self.pos = nn.Parameter(torch.Tensor(
n_head, context_size, self.embedding_dim))
torch.nn.init.xavier_uniform_(self.pos)
self.proj = torch.nn.Linear(self.embedding_dim, self.embedding_dim)
self.norm = torch.nn.LayerNorm(self.embedding_dim, eps=1e-5)
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U) with blank prepended.
Returns:
Return a tensor of shape (N, U, embedding_dim).
"""
# (N, U, D)
with torch.no_grad():
embeding_out = F.embedding(y, self.embedding.to(y.device))
if need_pad is True:
# Padding zeros for same output.
embeding_out = F.pad(
embeding_out, ((0, 0, self.context_size-1, 0)))
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embeding_out.size(1) == self.context_size
embeding_out = self.multihead_reduced(embeding_out)
# add a cheap projection layer
embeding_out = self.proj(embeding_out)
# stabilized with LayerNorm
embeding_out = self.norm(embeding_out)
# followed by a Swish non-linearity
embeding_out = embeding_out * torch.sigmoid(embeding_out)
return embeding_out
def multihead_reduced(self, xs: torch.Tensor) -> torch.Tensor:
"""multi-headed reduced decoder.
Args:
xs:
A 3-D tensor of shape (N, U, embedding_dim) with blank prepended embedding.
Returns:
Return a tensor of shape (N, U, embedding_dim).
"""
# (N, U+(context_size-1), embedding_dim) -> (N, U, embedding_dim, context_size)
xs_expand = xs.unfold(1, self.context_size, 1)
# (N, U, embedding_dim, context_size) -> (N, U, context_size, embedding_dim)
xs_expand = xs_expand.permute(0, 1, 3, 2)
ys = None
for i in range(self.n_head):
# (N, U, context_size, embedding_dim) -> (N, U, context_size)
weight = torch.sum(xs_expand*self.pos[i], -1)
# (N, U, context_size) -> (N, U, context_size, embedding_dim)
weight = torch.tile(torch.unsqueeze(weight, 3),
(1, 1, 1, self.embedding_dim))
# (N, U, context_size, embedding_dim) -> (N, U, 1, embedding_dim) -> (N, U, embedding_dim)
ys_i = torch.squeeze(torch.sum(xs_expand * weight, axis=2))
if ys is None:
ys = ys_i
else:
ys += ys_i
# (N, U, embedding_dim)
return ys/(self.n_head*self.context_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment