Skip to content

Instantly share code, notes, and snippets.

@rmrao
Last active October 26, 2021 14:12
Show Gist options
  • Save rmrao/c783eb31f2b679ce2f286ff4af1c4080 to your computer and use it in GitHub Desktop.
Save rmrao/c783eb31f2b679ce2f286ff4af1c4080 to your computer and use it in GitHub Desktop.
Initial implementation of vecorized smithwaterman in torch
from typing import Tuple
import torch
import torch.nn.functional as F
import itertools
from timeit import default_timer as timer
class SoftmaxWeightedMean(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs):
ctx.save_for_backward(inputs)
return (inputs * inputs.softmax(-1)).sum(-1)
@staticmethod
def backward(ctx, grad_output):
inputs, = ctx.saved_tensors
softmax = inputs.softmax(-1)
output = inputs * softmax
grad = (output + softmax * (1 - output.sum(-1, keepdims=True))) \
* grad_output.unsqueeze(-1)
return grad
softmax_mean = SoftmaxWeightedMean.apply # type: ignore
@torch.enable_grad()
def smithwaterman_nogap(similarity_matrix: torch.Tensor) -> torch.Tensor:
similarity_matrix.requires_grad = True
transposed_similarity = similarity_matrix.permute(1, 2, 0)
transposed_similarity = F.pad(transposed_similarity, [0, 0, 1, 0, 1, 0])
filled_similarity = torch.zeros_like(transposed_similarity)
iter_row = range(1, filled_similarity.size(0))
iter_col = range(1, filled_similarity.size(1))
iterable = itertools.product(iter_row, iter_col)
for row, col in iterable:
diag = filled_similarity[row - 1, col - 1] + transposed_similarity[row, col]
down = filled_similarity[row - 1, col]
right = filled_similarity[row, col - 1]
hij = torch.stack([diag, down, right]).max(0)[0]
filled_similarity[row, col] = hij
values = filled_similarity.view(-1, filled_similarity.size(2)).max(0)[0] # type: ignore
return torch.autograd.grad(values.sum(), similarity_matrix)[0]
def diag_indices(tensor: torch.Tensor, diag: int) -> Tuple[torch.Tensor, torch.Tensor]:
if diag not in range(tensor.size(1) + tensor.size(2) - 1):
raise IndexError(f"diag {diag} is out of bounds for tensor of size {tensor.size()}")
dim1 = torch.arange(
max(diag - tensor.size(2) + 1, 0),
min(diag + 1, tensor.size(1)))
dim2 = torch.arange(
min(diag, tensor.size(2) - 1),
max(diag - tensor.size(1), -1),
-1)
return dim1, dim2
@torch.enable_grad()
def smithwaterman_nogap_vectorized(similarity_matrix: torch.Tensor) -> torch.Tensor:
similarity_matrix.requires_grad_(True)
padded_similarity = F.pad(similarity_matrix, [1, 0, 1, 0])
batch_size, dim1_size, dim2_size = padded_similarity.size()
filled_similarity = torch.zeros(
batch_size, dim1_size + 1, dim2_size + 1,
dtype=padded_similarity.dtype, device=padded_similarity.device)
num_diags = dim1_size + dim2_size - 1
for diag_offset in range(2, num_diags):
row, col = diag_indices(padded_similarity, diag_offset)
row_offset = row + 1
col_offset = col + 1
diag = filled_similarity[:, row_offset - 1, col_offset - 1] + \
padded_similarity[:, row, col]
down = filled_similarity[:, row_offset - 1, col_offset]
right = filled_similarity[:, row_offset, col_offset - 1]
# hij = torch.stack([diag, down, right]).max(0)[0] # hard SW uses max
hij = torch.stack([diag, down, right], -1)
hij = softmax_mean(hij)
filled_similarity[:, row_offset, col_offset] = hij
values = filled_similarity.reshape(batch_size, -1).max(1)[0] # type: ignore
return torch.autograd.grad(values.sum(), similarity_matrix, create_graph=True)[0]
def time_func(func, num_iter: int = 7):
total_time = 0.
for _ in range(num_iter):
inputs = torch.randn(32, 200, 300, device='cuda', requires_grad=True)
start = timer()
result = func(inputs)
result.sum().backward()
end = timer()
assert inputs.grad is not None
total_time += (end - start)
torch.set_grad_enabled(False)
torch.cuda.empty_cache()
torch.set_grad_enabled(True)
print(total_time / num_iter)
if __name__ == '__main__':
inputs = torch.randn(32, 10, 5, device='cuda', requires_grad=True)
time_func(smithwaterman_nogap_vectorized)
time_func(smithwaterman_nogap_vectorized_swm)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment