Skip to content

Instantly share code, notes, and snippets.

@affjljoo3581
Created September 14, 2020 13:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save affjljoo3581/431bbf9fed3bab8907951bb1435558fe to your computer and use it in GitHub Desktop.
Save affjljoo3581/431bbf9fed3bab8907951bb1435558fe to your computer and use it in GitHub Desktop.
Test code for KoSpeech Transformer model.
import torch
from kospeech.models.acoustic.transformer.transformer import SpeechTransformer
from kospeech.criterion.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyLoss
transformer = SpeechTransformer(
num_classes=128, d_model=64, input_dim=80, pad_id=0, eos_id=3, d_ff=256,
num_heads=8, num_encoder_layers=2, num_decoder_layers=2, dropout_p=0.1)
criterion = LabelSmoothedCrossEntropyLoss(
num_classes=1024, ignore_index=0, smoothing=0.1, reduction='mean',
architecture='transformer')
inputs = torch.rand((32, 64, 80), dtype=torch.float)
lengths = torch.empty((32,), dtype=torch.long).fill_(64)
targets = torch.randint(0, 128, (32, 64), dtype=torch.long)
preds = transformer(inputs, lengths, targets)
loss = criterion(preds.contiguous().view(-1, preds.size(-1)), targets.contiguous().view(-1))
print(loss)
loss.backward()
preds = transformer(inputs, lengths, targets)
loss = criterion(preds.contiguous().view(-1, preds.size(-1)), targets.contiguous().view(-1))
print(loss)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment