Skip to content

Instantly share code, notes, and snippets.

@e96031413
Created December 18, 2023 06:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save e96031413/ecb09bb31a86d7d1a82b976fd5e349d9 to your computer and use it in GitHub Desktop.
Save e96031413/ecb09bb31a86d7d1a82b976fd5e349d9 to your computer and use it in GitHub Desktop.
"""
https://github.com/z-bingo/FastDVDNet/tree/master/arch
Reimplementation of 4 channel FastDVDNet in PyTorch
"""
import torch
import torch.nn as nn
import numpy as np
from thop import profile
class M_U_Net(nn.Module):
"""
The Block Module in paper, a modified U-Net
"""
def __init__(self, in_channel=12, out_channel=3):
"""
:param in_channel:
:param out_channel:
"""
super(M_U_Net, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channel, 90, 3, 1, 1),
nn.BatchNorm2d(90),
nn.ReLU(),
nn.Conv2d(90, 32, 3, 1, 1),
nn.BatchNorm2d(32),
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(32, 64, 3, 2, 1),
nn.BatchNorm2d(64),
nn.ReLU(),
*(
(
nn.Conv2d(64, 64, 3, 1, 1),
nn.BatchNorm2d(64),
nn.ReLU()
)*2
)
)
self.conv3 = nn.Sequential(
nn.Conv2d(64, 128, 3, 2, 1),
nn.BatchNorm2d(128),
nn.ReLU(),
*(
(
nn.Conv2d(128, 128, 3, 1, 1),
nn.BatchNorm2d(128),
nn.ReLU()
) * 4
),
nn.Conv2d(128, 256, 3, 1, 1),
nn.PixelShuffle(2)
)
self.conv4 = nn.Sequential(
*(
(
nn.Conv2d(64, 64, 3, 1, 1),
nn.BatchNorm2d(64),
nn.ReLU()
)*2
),
nn.Conv2d(64, 128, 3, 1, 1),
nn.PixelShuffle(2)
)
self.conv5 = nn.Sequential(
nn.Conv2d(32, 32, 3, 1, 1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, out_channel, 3, 1, 1)
)
def forward(self, data, ref):
"""
:param data: noisy frames
:param ref: reference frame that is the middle frame of noisy frames
:return:
"""
conv1 = self.conv1(data)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
conv4 = self.conv4(conv3 + conv2)
conv5 = self.conv5(conv4 + conv1)
return conv5+ref
class FastDVDNet(nn.Module):
def __init__(self, in_frames=5, color=True, sigma_map=True):
"""
class initial
:param in_frames: T-2, T-1, T, T+1, T+2, generally 5 frames
:param color: now only color images are supported
:param sigma_map: noise map, whose value is the estimation of noise standard variation
"""
super(FastDVDNet, self).__init__()
self.in_frames = in_frames
channel = 4 if color else 1
in1 = (3 + (1 if sigma_map else 0)) * channel
self.block1 = M_U_Net(in1, channel)
self.block2 = M_U_Net(in1, channel)
def forward(self, input):
"""
forward function
:param input: [b, N, c, h, w], the concatenation of noisy frames and noise map
:return: the noised frame corresponding to reference frame
"""
# split the noisy frames and noise map
frames, map = torch.split(input, self.in_frames, dim=1)
b, N, c, h, w = frames.size()
data_temp = []
# first stage
for i in range(self.in_frames-2):
data_temp.append(self.block1(
torch.cat([frames[:, i:i+3, ...].view(b, -1, h, w), map.squeeze(1)], dim=1),
frames[:, i+1, ...]
))
# second stage
data_temp = torch.cat(data_temp, dim=1)
return self.block2(torch.cat([data_temp, map.squeeze(1)], dim=1), frames[:, N//2, ...])
class SingleStageFastDVDNet(nn.Module):
def __init__(self, in_frames=5, color=True, sigma_map=True):
super(SingleStageFastDVDNet, self).__init__()
self.in_frames = in_frames
channel = 4 if color else 1
in1 = 24
self.block1 = M_U_Net(in1, channel)
def forward(self, data):
frames, map = torch.split(data, self.in_frames, dim=1)
b, N, c, h, w = frames.size()
return self.block1(
torch.cat([frames.view(b, -1, h, w), map.squeeze(1)], dim=1),
frames[:, N//2+N%2, ...]
)
if __name__ == '__main__':
"""
FastDVDNet()
MACs = 44.604325888G
Params = 1.466852M
SingleStageFastDVDNet()
MACs = 11.575754752G
Params = 0.739906M
"""
in_frames=5
input = torch.randn(1, in_frames+1, 4, 256, 256).cuda()
model = FastDVDNet(in_frames=in_frames).cuda()
# model = Single_Stage(in_frames=in_frames).cuda()
output = model(input)
print("Input size: ", input.shape)
print("Output size: ", output.shape)
macs, params = profile(model, inputs=(input, ))
print('MACs = ' + str(macs/1000**3) + 'G')
print('Params = ' + str(params/1000**2) + 'M')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment