Last active
October 26, 2021 14:12
-
-
Save rmrao/c783eb31f2b679ce2f286ff4af1c4080 to your computer and use it in GitHub Desktop.
Initial implementation of vecorized smithwaterman in torch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Tuple | |
import torch | |
import torch.nn.functional as F | |
import itertools | |
from timeit import default_timer as timer | |
class SoftmaxWeightedMean(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, inputs): | |
ctx.save_for_backward(inputs) | |
return (inputs * inputs.softmax(-1)).sum(-1) | |
@staticmethod | |
def backward(ctx, grad_output): | |
inputs, = ctx.saved_tensors | |
softmax = inputs.softmax(-1) | |
output = inputs * softmax | |
grad = (output + softmax * (1 - output.sum(-1, keepdims=True))) \ | |
* grad_output.unsqueeze(-1) | |
return grad | |
softmax_mean = SoftmaxWeightedMean.apply # type: ignore | |
@torch.enable_grad() | |
def smithwaterman_nogap(similarity_matrix: torch.Tensor) -> torch.Tensor: | |
similarity_matrix.requires_grad = True | |
transposed_similarity = similarity_matrix.permute(1, 2, 0) | |
transposed_similarity = F.pad(transposed_similarity, [0, 0, 1, 0, 1, 0]) | |
filled_similarity = torch.zeros_like(transposed_similarity) | |
iter_row = range(1, filled_similarity.size(0)) | |
iter_col = range(1, filled_similarity.size(1)) | |
iterable = itertools.product(iter_row, iter_col) | |
for row, col in iterable: | |
diag = filled_similarity[row - 1, col - 1] + transposed_similarity[row, col] | |
down = filled_similarity[row - 1, col] | |
right = filled_similarity[row, col - 1] | |
hij = torch.stack([diag, down, right]).max(0)[0] | |
filled_similarity[row, col] = hij | |
values = filled_similarity.view(-1, filled_similarity.size(2)).max(0)[0] # type: ignore | |
return torch.autograd.grad(values.sum(), similarity_matrix)[0] | |
def diag_indices(tensor: torch.Tensor, diag: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
if diag not in range(tensor.size(1) + tensor.size(2) - 1): | |
raise IndexError(f"diag {diag} is out of bounds for tensor of size {tensor.size()}") | |
dim1 = torch.arange( | |
max(diag - tensor.size(2) + 1, 0), | |
min(diag + 1, tensor.size(1))) | |
dim2 = torch.arange( | |
min(diag, tensor.size(2) - 1), | |
max(diag - tensor.size(1), -1), | |
-1) | |
return dim1, dim2 | |
@torch.enable_grad() | |
def smithwaterman_nogap_vectorized(similarity_matrix: torch.Tensor) -> torch.Tensor: | |
similarity_matrix.requires_grad_(True) | |
padded_similarity = F.pad(similarity_matrix, [1, 0, 1, 0]) | |
batch_size, dim1_size, dim2_size = padded_similarity.size() | |
filled_similarity = torch.zeros( | |
batch_size, dim1_size + 1, dim2_size + 1, | |
dtype=padded_similarity.dtype, device=padded_similarity.device) | |
num_diags = dim1_size + dim2_size - 1 | |
for diag_offset in range(2, num_diags): | |
row, col = diag_indices(padded_similarity, diag_offset) | |
row_offset = row + 1 | |
col_offset = col + 1 | |
diag = filled_similarity[:, row_offset - 1, col_offset - 1] + \ | |
padded_similarity[:, row, col] | |
down = filled_similarity[:, row_offset - 1, col_offset] | |
right = filled_similarity[:, row_offset, col_offset - 1] | |
# hij = torch.stack([diag, down, right]).max(0)[0] # hard SW uses max | |
hij = torch.stack([diag, down, right], -1) | |
hij = softmax_mean(hij) | |
filled_similarity[:, row_offset, col_offset] = hij | |
values = filled_similarity.reshape(batch_size, -1).max(1)[0] # type: ignore | |
return torch.autograd.grad(values.sum(), similarity_matrix, create_graph=True)[0] | |
def time_func(func, num_iter: int = 7): | |
total_time = 0. | |
for _ in range(num_iter): | |
inputs = torch.randn(32, 200, 300, device='cuda', requires_grad=True) | |
start = timer() | |
result = func(inputs) | |
result.sum().backward() | |
end = timer() | |
assert inputs.grad is not None | |
total_time += (end - start) | |
torch.set_grad_enabled(False) | |
torch.cuda.empty_cache() | |
torch.set_grad_enabled(True) | |
print(total_time / num_iter) | |
if __name__ == '__main__': | |
inputs = torch.randn(32, 10, 5, device='cuda', requires_grad=True) | |
time_func(smithwaterman_nogap_vectorized) | |
time_func(smithwaterman_nogap_vectorized_swm) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment