Skip to content

Instantly share code, notes, and snippets.

@grey-area
Last active December 19, 2022 22:07
Show Gist options
  • Save grey-area/100c8716defffecf151646953718f4c3 to your computer and use it in GitHub Desktop.
Save grey-area/100c8716defffecf151646953718f4c3 to your computer and use it in GitHub Desktop.
import torch
from torch import nn
from tqdm import tqdm
from torch import Tensor
from typing import Optional, List
import time
def subsequent_mask(size):
return torch.triu(torch.full((size, size), float('-inf')), diagonal=1)
class AutoregressiveTransformerDecoderLayer(nn.TransformerDecoderLayer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _autoregressive_sa_block(self, x: Tensor, memory: Tensor,
key_padding_mask: Optional[Tensor]) -> Tensor:
x = self.self_attn(x, memory, memory,
attn_mask=None,
key_padding_mask=key_padding_mask,
need_weights=False)[0]
return self.dropout1(x)
def autoregressive_forward(self, tgt: Tensor, memory: Tensor, memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
x = tgt[-1:]
if self.norm_first:
x = x + self._autoregressive_sa_block(self.norm1(x), self.norm1(tgt), tgt_key_padding_mask)
x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask)
x = x + self._ff_block(self.norm3(x))
else:
x = self.norm1(x + self._autoregressive_sa_block(x, tgt, tgt_key_padding_mask))
x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask))
x = self.norm3(x + self._ff_block(x))
return x
class AutoregressiveTransformerDecoder(nn.TransformerDecoder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def get_initial_inputs_list(self, x: Tensor) -> List[Tensor]:
_, B, F = x.shape
remaining_inputs = [torch.zeros(0, B, F, device=x.device, dtype=x.dtype) for _ in range(len(self.layers))]
return [x] + remaining_inputs
def autoregressive_forward(self, inputs: List[Tensor], memory: Tensor, memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
with torch.no_grad():
for i, mod in enumerate(self.layers):
x = inputs[i]
x = mod.autoregressive_forward(x, memory,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask)
inputs[i + 1] = torch.cat([inputs[i + 1], x], dim=0)
if self.norm is not None:
x = self.norm(x)
return x, inputs
if __name__ == "__main__":
d_model = 4
num_layers = 3
nhead = 4
L = 400
normal_transformer_decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead)
normal_transformer_decoder = nn.TransformerDecoder(normal_transformer_decoder_layer, num_layers=num_layers).eval()
autoregressive_transformer_decoder_layer = AutoregressiveTransformerDecoderLayer(d_model=d_model, nhead=nhead)
autoregressive_transformer_decoder = AutoregressiveTransformerDecoder(autoregressive_transformer_decoder_layer, num_layers=num_layers).eval()
# load state dict from one model to the other
autoregressive_transformer_decoder.load_state_dict(normal_transformer_decoder.state_dict())
# Initial state
initial_x = torch.randn(1, 1, d_model)
memory = torch.randn(L, 1, d_model)
x = initial_x.clone()
start = time.time()
# Loop, concatenate
for i in tqdm(range(L)):
sequence_length = x.size(0)
tgt_mask = subsequent_mask(sequence_length)
# in iteration i - 1, sequence of length i attends to sequence of length i
# but we only want the last element to attend to the sequence
output = normal_transformer_decoder(x, memory, tgt_mask=tgt_mask)
last_output = output[-1:]
x = torch.cat([x, last_output], dim=0)
final_x1 = x.clone()
print(f"Time for nn.TransformerDecoder: {time.time() - start}")
# Initial state
x = initial_x.clone()
inputs = autoregressive_transformer_decoder.get_initial_inputs_list(x)
start = time.time()
# Loop, concatenate
for i in tqdm(range(L)):
output, inputs = autoregressive_transformer_decoder.autoregressive_forward(inputs, memory)
inputs[0] = torch.cat([inputs[0], output], dim=0)
final_x2 = inputs[0].clone()
print(f"Time for AutoregressiveTransformerDecoder: {time.time() - start}")
print(f"Max abs error: {torch.max(torch.abs(final_x1 - final_x2))}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment