Skip to content

Instantly share code, notes, and snippets.

@TeaPoly
Created March 29, 2023 03:11
Show Gist options
  • Save TeaPoly/f701526acdb8babb01fa5df6bcdbf0e6 to your computer and use it in GitHub Desktop.
Save TeaPoly/f701526acdb8babb01fa5df6bcdbf0e6 to your computer and use it in GitHub Desktop.
apply the stacking to inputs
#!/usr/bin/python
# -*- coding: utf-8 -*-
# Copyright 2022 Lucky Wong
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from typing import Optional
import torch
class LfrLayer(torch.nn.Module):
"""Low frame rate.
Args:
idim (int): Input dimension.
left (int): left context size
right (int): right context size
odim (int): Output dimension.
"""
def __init__(
self,
left: int = 3,
right: int = 3,
stride: int = 6,
pad_val: float = 0.0
):
"""Construct an LFR object."""
super().__init__()
self.stride = stride
self.right = right
self.left = left
self.win = left + 1 + right
self.pad_val = pad_val
def forward(self, x: torch.Tensor, masks: Optional[torch.Tensor] = None):
"""The core function to apply the stacking to inputs.
Args:
inputs: [batch, time, depth].
left: left stack size
right: right stack size
stride: stride size
masks: [batch, 1, time]
pad_value: the padding value for left/right context.
Returns:
[batch, ceil(time / stride), depth * stacking_window_length] and [batch,1,ceil(time / stride)] tensor.
"""
if self.left == 0 and self.right == 0:
return x
x_pad = torch.nn.functional.pad(
x, (0, 0, self.left, self.right, 0, 0), value=self.pad_val)
x = self._apply_stack(x_pad)
if masks is None:
return x, None
return x, masks[:, :, ::self.stride]
def _apply_stack(
self,
x: torch.Tensor,
) -> torch.Tensor:
"""The core function to apply the stacking to inputs.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
Returns:
[batch, ceil(time / stride), depth * stacking_window_length] and [batch,1,ceil(time / stride)] tensor.
"""
max_len = x.size(1)
length = max_len-self.right-self.left
# Make window_size() copies of the padded sequence with the original
# sequence length, where each copy is offset by 1 time ste
pieces = []
for i in range(self.win):
# Apply striding.
pieces.append(x[:, i: i + length:self.stride])
# Apply stacking.
return torch.concat(pieces, 2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment