Skip to content

Instantly share code, notes, and snippets.

@WolframRhodium
Last active August 27, 2023 16:31
Show Gist options
  • Save WolframRhodium/d4b117ccc98081e40a70946b884dbe36 to your computer and use it in GitHub Desktop.
Save WolframRhodium/d4b117ccc98081e40a70946b884dbe36 to your computer and use it in GitHub Desktop.
WaDIQaM calculator for VapourSynth

Official implementation based on Chainer (requires CUDA)

from vapoursynth import core
import chainer
# chainer.global_config.cudnn_deterministic = False

from vs_wadiqam_chainer import wadiqam_fr, wadiqam_nr


model_folder_path = "deepIQA-master\models" # path to the folder that contains model's parameter files
# models can be downloaded from https://github.com/dmaniry/deepIQA/tree/master/models

src_rgb = ... # only RGBS / RGB24 / RGB48 are allowed
ref_rgb = ... # the same size and format as "src_rgb"


output_fr = wadiqam_fr(src_rgb, ref_rgb, model_folder_path, dataset='tid', top='patchwise', max_batch_size=2040)

# output_fr = core.text.FrameProps(output_fr, props=['Frame_WaDIQaM_FR'])


output_nr = wadiqam_nr(src_rgb, model_folder_path, dataset='tid', top='patchwise', max_batch_size=2040)

# output_nr = core.text.FrameProps(output_nr, props=['Frame_WaDIQaM_NR'])

Third-party implementation based on PyTorch

from vapoursynth import core
import torch

from vs_wadiqam_pytorch import wadiqam


model_path = "models/WaDIQaM-FR-TID2008" # path to the model parameter files
# models can be downloaded from https://github.com/lidq92/WaDIQaM/blob/master/models/WaDIQaM-FR-TID2008

use_cuda = torch.cuda.is_available()

src_rgbs = ... # only RGBS is allowed
ref_rgbs = ... # the same size and format as "src_rgbs"

output = wadiqam_fr(src_rgbs, ref_rgbs, model_path, seed=2019, use_cuda=use_cuda)

# output = core.text.FrameProps(output, props=['Frame_WaDIQaM'])
import vapoursynth as vs
from vapoursynth import core
is_api4 = hasattr(vs, "__api_version__") and vs.__api_version__.api_major == 4
def wadiqam_fr(clip1, clip2, model_folder_path, dataset="tid", top="patchwise", max_batch_size=2040):
"""Full-reference WaDIQaM calculator for VapourSynth
Please download the model from https://github.com/dmaniry/deepIQA/tree/master/models
A lower score indicates better visual image quality.
The score will be stored as frame property 'Frame_WaDIQaM_FR' in the output clip.
args:
clip1, clip2: RGB input clips with the same size and format.
The width and height of clips must be multiple of 32.
Integer clips with bit-depth other than 8, 16 are not allowed.
The first clip will be returned.
model_folder_path: Path to the folder that contains model's parameter files, e.g. "models".
dataset: (str, "live" or "tid") Dataset used for traininig.
Default is "tid".
top: (str, "patchwise" or "weighted") Top layer and loss definition of the model.
Default is "patchwise".
max_batch_size: (int) Maximum size of a batch.
The two input images are each divided into (width / 32) * (height * 32) patches.
The memory may overflow if too many patches are fed to the model.
Default is 2040 == (1920 / 32) * (1088 / 32).
ref:
[1] Bosse, S., Maniry, D., Müller, K. R., Wiegand, T., & Samek, W. (2018).
Deep neural networks for no-reference and full-reference image quality assessment.
IEEE Transactions on Image Processing, 27(1), 206-219.
[2] https://github.com/dmaniry/deepIQA/
"""
funcName = "wadiqam_fr"
import numpy as np
from numpy.lib.stride_tricks import as_strided
import chainer
from chainer import Variable
import chainer.functions as F
import chainer.links as L
from chainer import computational_graph
from chainer import cuda
from chainer import optimizers
from chainer import serializers
from functools import partial
import os
xp = cuda.cupy
cuda.cudnn_enabled = True
cuda.check_cuda_available()
chainer.global_config.train = False
chainer.global_config.enable_backprop = False
chainer.global_config.autotune = True
chainer.global_config.type_check = True
if not isinstance(clip1, vs.VideoNode) or clip1.format.color_family != vs.RGB:
raise TypeError(f'{funcName}: "clip1" must be a RGB clip!')
if not isinstance(clip2, vs.VideoNode) or clip2.format.color_family != vs.RGB:
raise TypeError(f'{funcName}: "clip2" must be a RGB clip!')
if clip1.width != clip2.width or clip1.height != clip2.height:
raise TypeError(f'{funcName}: "clip2" must be of the same size as "clip1"!')
if clip1.width % 32 != 0 or clip1.height % 32 != 0:
raise TypeError(f'{funcName}: The width and height of clips must be multiple of 32!')
if clip1.format.id != clip2.format.id:
raise TypeError(f'{funcName}: "clip2" must be of the same format as "clip1"!')
if clip1.format.sample_type == vs.INTEGER and clip1.format.bits_per_sample not in [8, 16]:
raise TypeError(f'{funcName}: Integer clips with bit-depth other than 8, 16 are not allowed!')
class FRModel(chainer.Chain):
def __init__(self, top="patchwise"):
super(FRModel, self).__init__(
conv1 = L.Convolution2D(3, 32, 3, pad=1),
conv2 = L.Convolution2D(32, 32, 3, pad=1),
conv3 = L.Convolution2D(32, 64, 3, pad=1),
conv4 = L.Convolution2D(64, 64, 3, pad=1),
conv5 = L.Convolution2D(64, 128, 3, pad=1),
conv6 = L.Convolution2D(128, 128, 3, pad=1),
conv7 = L.Convolution2D(128, 256, 3, pad=1),
conv8 = L.Convolution2D(256, 256, 3, pad=1),
conv9 = L.Convolution2D(256, 512, 3, pad=1),
conv10 = L.Convolution2D(512, 512, 3, pad=1),
fc1 = L.Linear(512 * 3, 512),
fc2 = L.Linear(512, 1)
)
self.top = top
if top == "weighted":
fc1_a = L.Linear(512 * 3, 512)
fc2_a = L.Linear(512, 1)
self.add_link("fc1_a", fc1_a)
self.add_link("fc2_a", fc2_a)
def extract_features(self, x, train=True):
h = F.relu(self.conv1(x))
h = F.relu(self.conv2(h))
self.h1 = h
h = F.max_pooling_2d(h, 2)
h = F.relu(self.conv3(h))
h = F.relu(self.conv4(h))
self.h2 = h
h = F.max_pooling_2d(h, 2)
h = F.relu(self.conv5(h))
h = F.relu(self.conv6(h))
self.h3 = h
h = F.max_pooling_2d(h, 2)
h = F.relu(self.conv7(h))
h = F.relu(self.conv8(h))
self.h4 = h
h = F.max_pooling_2d(h, 2)
h = F.relu(self.conv9(h))
h = F.relu(self.conv10(h))
self.h5 = h
h = F.max_pooling_2d(h, 2)
return h
def forward(self, x_data, x_ref_data, y_data, train=True,
n_patches_per_image=32):
if not isinstance(x_data, Variable):
x = Variable(x_data)
else:
x = x_data
x_data = x.data
self.n_images = y_data.shape[0]
self.n_patches = x_data.shape[0]
self.n_patches_per_image = n_patches_per_image
x_ref = Variable(x_ref_data)
h = self.extract_features(x)
self.h = h
h_ref = self.extract_features(x_ref)
h = F.concat((h-h_ref, h, h_ref))
h_ = h # save intermediate features
h = F.dropout(F.relu(self.fc1(h)), ratio=0.5)
h = self.fc2(h)
if self.top == "weighted":
a = F.dropout(F.relu(self.fc1_a(h_)), ratio=0.5)
a = F.relu(self.fc2_a(a)) + 0.000001
t = Variable(y_data)
self.weighted_loss(h, a, t)
elif self.top == "patchwise":
a = Variable(xp.ones_like(h.data))
t = Variable(xp.repeat(y_data, n_patches_per_image))
self.patchwise_loss(h, a, t)
if train:
return self.loss
else:
return self.loss, self.y
def patchwise_loss(self, h, a, t):
self.loss = F.sum(abs(h - F.reshape(t, (-1, 1))))
self.loss /= self.n_patches
if self.n_images > 1:
h = F.split_axis(h, self.n_images, 0)
a = F.split_axis(a, self.n_images, 0)
else:
h, a = [h], [a]
self.y = h
self.a = a
def weighted_loss(self, h, a, t):
self.loss = 0
if self.n_images > 1:
h = F.split_axis(h, self.n_images, 0)
a = F.split_axis(a, self.n_images, 0)
t = F.split_axis(t, self.n_images, 0)
else:
h, a, t = [h], [a], [t]
for i in range(self.n_images):
y = F.sum(h[i] * a[i], 0) / F.sum(a[i], 0)
self.loss += abs(y - F.reshape(t[i], (1, )))
self.loss /= self.n_images
self.y = h
self.a = a
def extract_patches(arr, patch_shape=(32, 32, 3), extraction_step=32):
extraction_step = [extraction_step] * 3
patch_strides = arr.strides
slices = tuple(slice(None, None, st) for st in extraction_step)
indexing_strides = arr[slices].strides
patch_indices_shape = ((np.array(arr.shape) - np.array(patch_shape)) //
np.array(extraction_step)) + 1
shape = tuple(list(patch_indices_shape) + list(patch_shape))
strides = tuple(list(indexing_strides) + list(patch_strides))
patches = as_strided(arr, shape=shape, strides=strides)
return patches
def benchmark(n, f, model, max_batch_size=2040):
fout = f[0].copy()
planes = f[0].format.num_planes
if is_api4:
img1 = np.stack(f[0], axis=2)
else:
img1 = np.stack([f[0].get_read_array(i) for i in range(planes)], axis=2)
img1_patches = np.transpose(extract_patches(img1).reshape((-1, 32, 32, 3)), (0, 3, 1, 2))
if is_api4:
img2 = np.stack(f[1], axis=2)
else:
img2 = np.stack([f[1].get_read_array(i) for i in range(planes)], axis=2)
img2_patches = np.transpose(extract_patches(img2).reshape((-1, 32, 32, 3)), (0, 3, 1, 2))
if img1.dtype == np.uint8:
img1_patches = xp.array(img1_patches.astype(np.float32))
img2_patches = xp.array(img2_patches.astype(np.float32))
elif img1.dtype == np.uint16:
img1_patches = xp.array(img1_patches.astype(np.float32) * np.float32(255 / 65535))
img2_patches = xp.array(img2_patches.astype(np.float32) * np.float32(255 / 65535))
elif img1.dtype in [np.float16, np.float32, np.float64, np.float_]:
img1_patches = xp.array(img1_patches.astype(np.float32) * np.float32(255))
img2_patches = xp.array(img2_patches.astype(np.float32) * np.float32(255))
else:
raise TypeError("benchmark: unknown dtype.")
t = xp.zeros((1, 1), dtype=np.float32)
y = []
weights = []
for i in range(0, img1_patches.shape[0], max_batch_size):
img1_batch = img1_patches[i:min(i + max_batch_size, img1_patches.shape[0])]
img2_batch = img2_patches[i:min(i + max_batch_size, img2_patches.shape[0])]
model.forward(img1_batch, img2_batch, t, False, n_patches_per_image=img1_batch.shape[0])
y.append(xp.asnumpy(model.y[0].data))
weights.append(xp.asnumpy(model.a[0].data))
y = np.concatenate(y)
weights = np.concatenate(weights)
score = np.sum(y * weights) / np.sum(weights)
fout.props['Frame_WaDIQaM_FR'] = np.float64(score)
return fout
model = FRModel(top=top)
model_path = os.path.join(model_folder_path, f"fr_{dataset}_{top}.model")
serializers.load_hdf5(model_path, model)
model.to_gpu()
return core.std.ModifyFrame(clip1, clips=[clip1, clip2],
selector=partial(benchmark, model=model, max_batch_size=max_batch_size))
def wadiqam_nr(clip, model_folder_path, dataset="tid", top="patchwise", max_batch_size=2040):
"""No-reference WaDIQaM calculator for VapourSynth
Please download the model from https://github.com/dmaniry/deepIQA/tree/master/models
A lower score indicates better visual image quality.
The score will be stored as frame property 'Frame_WaDIQaM_NR' in the output clip.
args:
clip: RGB input clip.
Integer clips with bit-depth other than 8, 16 are not allowed.
The width and height of clips must be multiple of 32.
model_folder_path: Path to the folder that contains model's parameter files, e.g. "models".
dataset: (str, "live" or "tid") Dataset used for traininig.
Default is "tid".
top: (str, "patchwise" or "weighted") Top layer and loss definition of the model.
Default is "patchwise".
max_batch_size: (int) Maximum size of a batch.
The input image is each divided into (width / 32) * (height * 32) patches.
The memory may overflow if too many patches are fed to the model.
Default is 2040 == (1920 / 32) * (1088 / 32).
ref:
[1] Bosse, S., Maniry, D., Müller, K. R., Wiegand, T., & Samek, W. (2018).
Deep neural networks for no-reference and full-reference image quality assessment.
IEEE Transactions on Image Processing, 27(1), 206-219.
[2] https://github.com/dmaniry/deepIQA/
"""
funcName = "wadiqam_nr"
import numpy as np
from numpy.lib.stride_tricks import as_strided
import chainer
from chainer import Variable
import chainer.functions as F
import chainer.links as L
from chainer import computational_graph
from chainer import cuda
from chainer import optimizers
from chainer import serializers
from functools import partial
import os
xp = cuda.cupy
cuda.cudnn_enabled = True
cuda.check_cuda_available()
chainer.global_config.train = False
chainer.global_config.enable_backprop = False
chainer.global_config.autotune = True
chainer.global_config.type_check = True
if not isinstance(clip, vs.VideoNode) or clip.format.color_family != vs.RGB:
raise TypeError(f'{funcName}: "clip" must be a RGB clip!')
if clip.width % 32 != 0 or clip.height % 32 != 0:
raise TypeError(f'{funcName}: The width and height of "clip" must be multiple of 32!')
if clip.format.sample_type == vs.INTEGER and clip.format.bits_per_sample not in [8, 16]:
raise TypeError(f'{funcName}: Integer clips with bit-depth other than 8, 16 are not allowed!')
class NRModel(chainer.Chain):
def __init__(self, top="patchwise"):
super(NRModel, self).__init__(
conv1 = L.Convolution2D(3, 32, 3, pad=1),
conv2 = L.Convolution2D(32, 32, 3, pad=1),
conv3 = L.Convolution2D(32, 64, 3, pad=1),
conv4 = L.Convolution2D(64, 64, 3, pad=1),
conv5 = L.Convolution2D(64, 128, 3, pad=1),
conv6 = L.Convolution2D(128, 128, 3, pad=1),
conv7 = L.Convolution2D(128, 256, 3, pad=1),
conv8 = L.Convolution2D(256, 256, 3, pad=1),
conv9 = L.Convolution2D(256, 512, 3, pad=1),
conv10 = L.Convolution2D(512, 512, 3, pad=1),
fc1 = L.Linear(512, 512),
fc2 = L.Linear(512, 1)
)
self.top = top
if top == "weighted":
fc1_a = L.Linear(512, 512)
fc2_a = L.Linear(512, 1)
self.add_link("fc1_a", fc1_a)
self.add_link("fc2_a", fc2_a)
def forward(self, x_data, y_data, train=True, n_patches=32):
if not isinstance(x_data, Variable):
x = Variable(x_data)
else:
x = x_data
x_data = x.data
self.n_images = y_data.shape[0]
self.n_patches = x_data.shape[0]
self.n_patches_per_image = self.n_patches / self.n_images
h = F.relu(self.conv1(x))
h = F.relu(self.conv2(h))
h = F.max_pooling_2d(h, 2)
h = F.relu(self.conv3(h))
h = F.relu(self.conv4(h))
h = F.max_pooling_2d(h, 2)
h = F.relu(self.conv5(h))
h = F.relu(self.conv6(h))
h = F.max_pooling_2d(h, 2)
h = F.relu(self.conv7(h))
h = F.relu(self.conv8(h))
h = F.max_pooling_2d(h, 2)
h = F.relu(self.conv9(h))
h = F.relu(self.conv10(h))
h = F.max_pooling_2d(h, 2)
h_ = h
self.h = h_
h = F.dropout(F.relu(self.fc1(h_)), ratio=0.5)
h = self.fc2(h)
if self.top == "weighted":
a = F.dropout(F.relu(self.fc1_a(h_)), ratio=0.5)
a = F.relu(self.fc2_a(a)) + 0.000001
t = Variable(y_data)
self.weighted_loss(h, a, t)
elif self.top == "patchwise":
a = Variable(xp.ones_like(h.data))
t = Variable(xp.repeat(y_data, n_patches))
self.patchwise_loss(h, a, t)
if train:
return self.loss
else:
return self.loss, self.y
def patchwise_loss(self, h, a, t):
self.loss = F.sum(abs(h - F.reshape(t, (-1, 1))))
self.loss /= self.n_patches
if self.n_images > 1:
h = F.split_axis(h, self.n_images, 0)
a = F.split_axis(a, self.n_images, 0)
else:
h, a = [h], [a]
self.y = h
self.a = a
def weighted_loss(self, h, a, t):
self.loss = 0
if self.n_images > 1:
h = F.split_axis(h, self.n_images, 0)
a = F.split_axis(a, self.n_images, 0)
t = F.split_axis(t, self.n_images, 0)
else:
h, a, t = [h], [a], [t]
for i in range(self.n_images):
y = F.sum(h[i] * a[i], 0) / F.sum(a[i], 0)
self.loss += abs(y - F.reshape(t[i], (1, )))
self.loss /= self.n_images
self.y = h
self.a = a
def extract_patches(arr, patch_shape=(32, 32, 3), extraction_step=32):
extraction_step = [extraction_step] * 3
patch_strides = arr.strides
slices = tuple(slice(None, None, st) for st in extraction_step)
indexing_strides = arr[slices].strides
patch_indices_shape = ((np.array(arr.shape) - np.array(patch_shape)) //
np.array(extraction_step)) + 1
shape = tuple(list(patch_indices_shape) + list(patch_shape))
strides = tuple(list(indexing_strides) + list(patch_strides))
patches = as_strided(arr, shape=shape, strides=strides)
return patches
def benchmark(n, f, model, max_batch_size=2040):
fout = f.copy()
if is_api4:
img1 = np.stack(f, axis=2)
else:
planes = f.format.num_planes
img1 = np.stack([f.get_read_array(i) for i in range(planes)], axis=2)
img1_patches = np.transpose(extract_patches(img1).reshape((-1, 32, 32, 3)), (0, 3, 1, 2))
if img1.dtype == np.uint8:
img1_patches = xp.array(img1_patches.astype(np.float32))
elif img1.dtype == np.uint16:
img1_patches = xp.array(img1_patches.astype(np.float32) * np.float32(255 / 65535))
elif img1.dtype in [np.float16, np.float32, np.float64, np.float_]:
img1_patches = xp.array(img1_patches.astype(np.float32) * np.float32(255))
else:
raise TypeError("benchmark: unknown dtype.")
t = xp.zeros((1, 1), dtype=np.float32)
y = []
weights = []
for i in range(0, img1_patches.shape[0], max_batch_size):
img1_batch = img1_patches[i:min(i + max_batch_size, img1_patches.shape[0])]
model.forward(img1_batch, t, False, n_patches=img1_batch.shape[0])
y.append(xp.asnumpy(model.y[0].data))
weights.append(xp.asnumpy(model.a[0].data))
y = np.concatenate(y)
weights = np.concatenate(weights)
score = np.sum(y * weights) / np.sum(weights)
fout.props['Frame_WaDIQaM_NR'] = np.float64(score)
return fout
model = NRModel(top=top)
model_path = os.path.join(model_folder_path, f"nr_{dataset}_{top}.model")
serializers.load_hdf5(model_path, model)
model.to_gpu()
return core.std.ModifyFrame(clip, clips=clip,
selector=partial(benchmark, model=model, max_batch_size=max_batch_size))
import vapoursynth as vs
from vapoursynth import core
is_api4 = hasattr(vs, "__api_version__") and vs.__api_version__.api_major == 4
def wadiqam_fr(clip1, clip2, model_path, n_patches=32, seed=2019, use_cuda=None):
"""WaDIQaM calculator for VapourSynth
Please download the model from https://github.com/lidq92/WaDIQaM/blob/master/models/WaDIQaM-FR-TID2008
The score will be stored as frame property 'Frame_WaDIQaM' in the output clip.
args:
clip1, clip2: RGBS input clips with the same size.
The first clip will be returned.
model_path: Path to model's parameter file, e.g. "models/WaDIQaM-FR-TID2008".
n_patches: (int) The algorithm randomly takes a number of patches from a frame to evaluate the score.
This sets the number of patches to be used.
Large values may lead to more accurate evaluations, but the algorithm may runs slower,
and the memory may overflow.
Default is 32.
seed: (int) Seed used to initialize the random number generator.
Default is 2019.
use_cuda: (bool) Whether to use CUDA to accelerate evaluation.
Default is True if CUDA is available.
ref:
[1] Bosse, S., Maniry, D., Müller, K. R., Wiegand, T., & Samek, W. (2018).
Deep neural networks for no-reference and full-reference image quality assessment.
IEEE Transactions on Image Processing, 27(1), 206-219.
[2] https://github.com/lidq92/WaDIQaM/
"""
funcName = "vs_wadiqam"
from functools import partial
import random
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
if not isinstance(clip1, vs.VideoNode) or clip1.format.id != vs.RGBS:
raise TypeError(f'{funcName}: "clip1" must be a clip in RGBS!')
if not isinstance(clip2, vs.VideoNode) or clip2.format.id != vs.RGBS:
raise TypeError(f'{funcName}: "clip2" must be a clip in RGBS!')
if clip1.width != clip2.width or clip1.height != clip2.height:
raise TypeError(f'{funcName}: "clip2" must be of the same size as "clip1"!')
if use_cuda is None:
use_cuda = torch.cuda.is_available()
class FRnet(nn.Module):
def __init__(self, top="patchwise", use_cuda=torch.cuda.is_available()):
super(FRnet, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
self.conv5 = nn.Conv2d(64, 128, 3, padding=1)
self.conv6 = nn.Conv2d(128, 128, 3, padding=1)
self.conv7 = nn.Conv2d(128, 256, 3, padding=1)
self.conv8 = nn.Conv2d(256, 256, 3, padding=1)
self.conv9 = nn.Conv2d(256, 512, 3, padding=1)
self.conv10 = nn.Conv2d(512, 512, 3, padding=1)
self.fc1 = nn.Linear(512*3, 512)
self.fc2 = nn.Linear(512, 1)
self.fc1_a = nn.Linear(512*3, 512)
self.fc2_a = nn.Linear(512, 1)
self.top = top
self.use_cuda = use_cuda
def extract_features(self, x):
h = F.relu(self.conv1(x))
h = F.relu(self.conv2(h))
h = F.max_pool2d(h, 2)
h = F.relu(self.conv3(h))
h = F.relu(self.conv4(h))
h = F.max_pool2d(h, 2)
h = F.relu(self.conv5(h))
h = F.relu(self.conv6(h))
h = F.max_pool2d(h, 2)
h = F.relu(self.conv7(h))
h = F.relu(self.conv8(h))
h = F.max_pool2d(h, 2)
h = F.relu(self.conv9(h))
h = F.relu(self.conv10(h))
h = F.max_pool2d(h, 2)
h = h.view(-1,512)
return h
def forward(self, data, train=True):
x, x_ref = data
if self.use_cuda:
x = x.cuda()
x_ref = x_ref.cuda()
x = Variable(x, volatile=not train)
x_ref = Variable(x_ref, volatile=not train)
h = self.extract_features(x)
h_ref = self.extract_features(x_ref)
h = torch.cat((h - h_ref, h, h_ref), 1)
h_ = h # save intermediate features
self.h = h_
h = F.dropout(F.relu(self.fc1(h_)), p=0.5, training=train)
h = self.fc2(h)
if self.top == "weighted":
a = F.dropout(F.relu(self.fc1_a(h_)), p=0.5, training=train)
a = F.relu(self.fc2_a(a)) + 0.000001 # small constant
elif self.top == "patchwise":
a = Variable(torch.ones_like(h.data), volatile=not train)
q = torch.sum(h * a) / torch.sum(a)
return q
def RandomCropPatches(img1, img2, n_patches=32, seed=None):
# img1: src, img2: ref, 3-D tensor in CHW format
random.seed(seed)
_, h, w = img1.shape
img1_crops = []
img2_crops = []
th = 32
tw = 32
for k in range(n_patches):
i = random.randint(0, h - th)
j = random.randint(0, w - tw)
img1_crops.append(torch.from_numpy(img1[:, i:i+th, j:j+tw]))
img2_crops.append(torch.from_numpy(img2[:, i:i+th, j:j+tw]))
return [torch.stack(img1_crops), torch.stack(img2_crops)]
def benchmark(n, f, model, n_patches=32, seed=None):
fout = f[0].copy()
if is_api4:
img1 = np.stack(f[0], axis=0)
img2 = np.stack(f[1], axis=0)
else:
planes = f[0].format.num_planes
img1 = np.stack([f[0].get_read_array(i) for i in range(planes)], axis=0)
img2 = np.stack([f[1].get_read_array(i) for i in range(planes)], axis=0)
data = RandomCropPatches(img1, img2, n_patches=n_patches, seed=seed)
score = model(data, train=False)
score = float(score.detach().cpu().numpy())
fout.props['Frame_WaDIQaM'] = score
return fout
model = FRnet(top="weighted", use_cuda=use_cuda)
model.load_state_dict(torch.load(model_path))
if use_cuda:
model = model.cuda()
return core.std.ModifyFrame(clip1, clips=[clip1, clip2],
selector=partial(benchmark, model=model, n_patches=32, seed=seed))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment