Skip to content

Instantly share code, notes, and snippets.

@jizongFox
Created November 4, 2022 20:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jizongFox/ab80f62bf10d8cb759f40a5e0ba3c924 to your computer and use it in GitHub Desktop.
Save jizongFox/ab80f62bf10d8cb759f40a5e0ba3c924 to your computer and use it in GitHub Desktop.
splitformer dynamic code
import time
import typing as t
from typing import Union
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from torch.nn.common_types import _size_2_t
from torch.nn.modules.utils import _pair
class DynamicLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None:
super().__init__(in_features, out_features, bias, device, dtype)
def forward(self, input: Tensor, mask: t.Optional[Tensor] = None) -> Tensor:
# todo: please specific the behavior with the mask.
"""
mask: a tensor to mask different channels.
"""
if mask is None:
return F.linear(input, self.weight, self.bias)
assert mask.shape[-1] == input.shape[-1], (input.shape, mask.shape)
b = input.shape[0]
output = (input * mask) @ self.weight.t() + self.bias[None, ...] * mask
return output.masked_select(mask.bool()).view(b, -1)
raise NotImplementedError(f"current frame has not been implememted.")
def inference(self, input: Tensor, mask: t.Optional[Tensor] = None):
if mask is None:
return F.linear(input, self.weight, self.bias)
assert mask.shape[-1] == input.shape[-1], (input.shape, mask.shape)
b = input.shape[0]
output = input.masked_select(mask.bool()).view(b, -1) @ self.weight.masked_select(
(mask.t() @ mask).bool()).view(int(mask.sum()), int(mask.sum()))
return output
def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None
)
class DynamicConv2D(nn.Conv2d):
def __init__(self, in_channels: int, out_channels: int, kernel_size: _size_2_t, stride: _size_2_t = 1,
padding: Union[str, _size_2_t] = 0, dilation: _size_2_t = 1, groups: int = 1, bias: bool = True,
padding_mode: str = 'zeros', device=None, dtype=None) -> None:
assert groups == 1 or groups == in_channels, f"groups only support {1} or {in_channels}, given {groups}"
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode,
device, dtype)
def forward(self, input: Tensor, mask: t.Optional[Tensor] = None) -> Tensor:
if mask is None:
return super(DynamicConv2D, self).forward(input)
assert mask.shape[-1] == input.shape[1]
if mask.dim() == 1:
mask = mask[None, ..., None, None]
if mask.dim() == 2:
mask = mask[..., None, None]
else:
raise ValueError(mask.shape)
assert mask.dim() == input.dim(), (mask.shape, input.shape)
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(input * mask, self._reversed_padding_repeated_twice, mode=self.padding_mode),
self.weight, self.bias, self.stride,
_pair(0), self.dilation, self.groups) * mask
return F.conv2d(input * mask, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups) * mask
def inference(self, input: Tensor, ):
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
self.weight, self.bias, self.stride,
_pair(0), self.dilation, self.groups)
return F.conv2d(input, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
def convert(self, mask):
assert mask.shape[-1] == input.shape[1]
mask = mask.squeeze()
selected_channel_length = int(mask.sum())
assert mask.dim() == 1, mask
*_, k1, k2 = self.weight.shape
nested_mask = (mask[..., None] @ mask[None, ...])[..., None, None].repeat(1, 1, k1, k2).bool()
weight = self.weight.masked_select(nested_mask) \
.view(selected_channel_length, selected_channel_length, k1, k2).contiguous()
self.weight = nn.Parameter(weight)
class timer:
def __enter__(self):
torch.cuda.synchronize()
self._cur_time = time.time()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
torch.cuda.synchronize()
cur_time = time.time()
print(f"used time: {cur_time - self._cur_time:.3e} s")
if __name__ == "__main__":
dim = 200
input = torch.randn(8, dim, 224, 224, requires_grad=True).cuda()
conv1x1 = DynamicConv2D(dim, dim, kernel_size=1, bias=False).cuda()
mask = torch.randint(0, 2, size=(1, dim), dtype=torch.float, requires_grad=True).cuda()
# mask = torch.ones_like(mask)
# mask = None
with timer():
for i in range(100):
output1 = conv1x1(input, mask)
with timer():
for i in range(100):
output1 = conv1x1(input, mask)
with timer():
for i in range(100):
output1 = conv1x1(input, mask)
with timer():
for i in range(100):
output1 = conv1x1(input, mask)
print("converting")
selected_channel_length = int(mask.sum())
b, c, h, w = input.shape
input_mask = mask.squeeze()[None, ..., None, None].repeat(b, 1, h, w).bool()
input2 = input.masked_select(input_mask).view(b, selected_channel_length, h, w).contiguous()
conv1x1.convert(mask)
with timer():
for i in range(100):
output2 = conv1x1.inference(input2, )
with timer():
for i in range(100):
output2 = conv1x1.inference(input2, )
with timer():
for i in range(100):
output2 = conv1x1.inference(input2, )
with timer():
for i in range(100):
output2 = conv1x1.inference(input2, )
print(output1.shape, output2.shape)
b, _, h, w = input.shape
assert torch.allclose(output2.sum(1), output1.sum(1), rtol=1e-2, atol=1e-2), (
output2.sum(1)[-1], output1.sum(1)[-1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment