Skip to content

Instantly share code, notes, and snippets.

@hushell
Created September 30, 2021 00:09
Show Gist options
  • Save hushell/a28052e6000d4d279b941d9829b5dc7f to your computer and use it in GitHub Desktop.
Save hushell/a28052e6000d4d279b941d9829b5dc7f to your computer and use it in GitHub Desktop.
AvaTr V2
import copy
from typing import Optional, List
from dotmap import DotMap
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from asteroid_filterbanks import make_enc_dec
class AvaTr(nn.Module):
def __init__(
self,
# Avatars
n_spk,
embed_dim=128,
# Transformer
d_model=128,
nhead=4,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=1024,
dropout=0.1,
activation="relu",
normalize_before=False,
return_intermediate_dec=False,
# waveform encoder/decoder params
kernel_size=16,
n_filters=128,
stride=8,
enc_activation="relu",
fb_name="free",
sample_rate=8000,
**fb_kwargs):
super().__init__()
self.model_args = DotMap()
# Avatars
self.model_args.n_spk = n_spk
self.model_args.embed_dim = embed_dim
# Transformer
self.model_args.d_model = d_model
self.model_args.nhead = nhead
self.model_args.num_encoder_layers = num_encoder_layers
self.model_args.num_decoder_layers = num_decoder_layers
self.model_args.dim_feedforward = dim_feedforward
self.model_args.dropout = dropout
self.model_args.activation = activation
self.model_args.normalize_before = normalize_before
self.model_args.return_intermediate_dec = return_intermediate_dec
# waveform encoder/decoder params
self.model_args.kernel_size = kernel_size
self.model_args.n_filters = n_filters
self.model_args.stride = stride
self.model_args.fb_name = fb_name
self.model_args.sample_rate = sample_rate
# conv & deconv
self.conv, self.deconv = make_enc_dec(
fb_name, kernel_size=kernel_size, n_filters=n_filters, stride=stride,
sample_rate=sample_rate, **fb_kwargs
)
n_feats = self.conv.n_feats_out
assert d_model == n_feats, (
"Number of filterbank output channels"
" and number of model channels should "
"be the same. Received "
f"{n_feats} and {d_model}"
)
self.enc_activation = nn.ReLU() if enc_activation == 'relu' else nn.Identity()
self.enc_norm = GlobLN(n_feats)
# Avatars
self.avatar = nn.Embedding(n_spk, embed_dim)
# Transformer
self.tsfm = Transformer(
d_model=d_model,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=activation,
normalize_before=normalize_before,
return_intermediate_dec=return_intermediate_dec,
)
self.mask_act = nn.Sigmoid()
self.pos_conv = nn.Conv1d(
d_model, d_model,
kernel_size=129,
padding=128 // 2,
groups=16,
)
def forward(self, inputs, mask=None):
wav, spk_id = inputs
# Handle 1D, 2D or n-D inputs
if wav.ndim == 1:
wav = wav.unsqueeze(0).unsqueeze(1)
if wav.ndim == 2:
wav = wav.unsqueeze(1)
# mix feature
mix_rep_0 = self.enc_norm(self.enc_activation(self.conv(wav))) # B x C x T
# positional encoding
pos_embed = self.pos_conv(mix_rep_0) # B x C x T
# avatar embedding
query_embed = self.avatar(spk_id) # B x C
# mask generation
hs, mix_rep_t = self.tsfm(mix_rep_0, query_embed, pos_embed)
mix_mask = self.mask_act(hs)
# masking
masked_rep = mix_rep_t * mix_mask
# source prediction
out_wavs = pad_x_to_y(self.deconv(masked_rep), wav)
if out_wavs.shape[1] == 1: # task == ehn_single
out_wavs = out_wavs.squeeze(1)
return out_wavs
def serialize(self, scheduler_dict):
"""Serialize model and args
Returns:
dict, serialized model with keys `model_args` and `state_dict`.
"""
model_conf = dict(
model_name=self.__class__.__name__,
state_dict=self.get_state_dict(),
model_args=self.get_model_args(),
scheduler_dict=scheduler_dict
)
return model_conf
def get_state_dict(self):
""" In case the state dict needs to be modified before sharing the model."""
return self.state_dict()
def get_model_args(self):
"""return args to re-instantiate the class."""
return self.model_args.toDict()
class Transformer(nn.Module):
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False,
return_intermediate_dec=False):
super().__init__()
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before)
decoder_norm = nn.LayerNorm(d_model)
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
return_intermediate=return_intermediate_dec)
self._reset_parameters()
self.d_model = d_model
self.nhead = nhead
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, src, query_embed, pos_embed, mask=None):
# B x C x T -> T x B x C
B, C, T = src.shape
src = src.permute(2, 0, 1)
pos_embed = pos_embed.permute(2, 0, 1)
query_embed = query_embed.unsqueeze(0).repeat(T, 1, 1) # B x C -> T x B x C
tgt = torch.zeros_like(query_embed)
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) # T, B, C
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
pos=pos_embed, query_pos=query_embed) # T, B, C
return hs.permute(1, 2, 0), memory.permute(1, 2, 0)
class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(self, src,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
output = src
for layer in self.layers:
output = layer(output, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, pos=pos)
if self.norm is not None:
output = self.norm(output)
return output
class TransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
def forward(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
output = tgt
intermediate = []
for layer in self.layers:
output = layer(output, memory, tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos, query_pos=query_pos)
if self.return_intermediate:
intermediate.append(self.norm(output))
if self.norm is not None:
output = self.norm(output)
if self.return_intermediate:
intermediate.pop()
intermediate.append(output)
if self.return_intermediate:
return torch.stack(intermediate)
return output
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
q = k = self.with_pos_embed(src, pos)
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
def forward_pre(self, src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
src2 = self.norm1(src)
q = k = self.with_pos_embed(src2, pos)
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src2 = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
src = src + self.dropout2(src2)
return src
def forward(self, src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
q = k = self.with_pos_embed(tgt, query_pos)
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward_pre(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
def forward(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
class GlobLN(nn.Module):
"""Global Layer Normalization (globLN)."""
def __init__(self, channel_size):
super(GlobLN, self).__init__()
self.channel_size = channel_size
self.gamma = nn.Parameter(torch.ones(channel_size), requires_grad=True)
self.beta = nn.Parameter(torch.zeros(channel_size), requires_grad=True)
def apply_gain_and_bias(self, normed_x):
""" Assumes input of size `[batch, chanel, *]`. """
return (self.gamma * normed_x.transpose(1, -1) + self.beta).transpose(1, -1)
def forward(self, x, eps=1e-8):
"""Applies forward pass.
Works for any input size > 2D.
Args:
x (:class:`torch.Tensor`): Shape `[batch, chan, *]`
Returns:
:class:`torch.Tensor`: gLN_x `[batch, chan, *]`
"""
def _z_norm(x, dims):
mean = x.mean(dim=dims, keepdim=True)
var2 = torch.var(x, dim=dims, keepdim=True, unbiased=False)
value = (x - mean) / torch.sqrt((var2 + eps))
return value
def _glob_norm(x):
dims = torch.arange(1, len(x.shape)).tolist()
return _z_norm(x, dims)
value = _glob_norm(x)
return self.apply_gain_and_bias(value)
def pad_x_to_y(x: torch.Tensor, y: torch.Tensor, axis: int = -1) -> torch.Tensor:
"""Right-pad or right-trim first argument to have same size as second argument
Args:
x (torch.Tensor): Tensor to be padded.
y (torch.Tensor): Tensor to pad `x` to.
axis (int): Axis to pad on.
Returns:
torch.Tensor, `x` padded to match `y`'s shape.
"""
if axis != -1:
raise NotImplementedError
inp_len = y.shape[axis]
output_len = x.shape[axis]
return nn.functional.pad(x, [0, inp_len - output_len])
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment