Skip to content

Instantly share code, notes, and snippets.

@genzj
Created December 31, 2020 08:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save genzj/34b2d9813813c145239ec1d16b651e99 to your computer and use it in GitHub Desktop.
Save genzj/34b2d9813813c145239ec1d16b651e99 to your computer and use it in GitHub Desktop.
Migrate DAIN https://github.com/baowenbo/DAIN to PyTorch 1.7.1 with CUDA 10.2
diff --git a/PWCNet/correlation_package_pytorch1_0/correlation.py b/PWCNet/correlation_package_pytorch1_0/correlation.py
index 80a8b09..fe8ab06 100644
--- a/PWCNet/correlation_package_pytorch1_0/correlation.py
+++ b/PWCNet/correlation_package_pytorch1_0/correlation.py
@@ -4,19 +4,16 @@ from torch.autograd import Function
import correlation_cuda
class CorrelationFunction(Function):
-
- def __init__(self, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1):
- super(CorrelationFunction, self).__init__()
- self.pad_size = pad_size
- self.kernel_size = kernel_size
- self.max_displacement = max_displacement
- self.stride1 = stride1
- self.stride2 = stride2
- self.corr_multiply = corr_multiply
- # self.out_channel = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1)
-
- def forward(self, input1, input2):
- self.save_for_backward(input1, input2)
+ @staticmethod
+ def forward(ctx, input1, input2, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1):
+ ctx.pad_size = pad_size
+ ctx.kernel_size = kernel_size
+ ctx.max_displacement = max_displacement
+ ctx.stride1 = stride1
+ ctx.stride2 = stride2
+ ctx.corr_multiply = corr_multiply
+ # ctx.out_channel = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1)
+ ctx.save_for_backward(input1, input2)
with torch.cuda.device_of(input1):
rbot1 = input1.new()
@@ -24,12 +21,13 @@ class CorrelationFunction(Function):
output = input1.new()
correlation_cuda.forward(input1, input2, rbot1, rbot2, output,
- self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)
+ ctx.pad_size, ctx.kernel_size, ctx.max_displacement,ctx.stride1, ctx.stride2, ctx.corr_multiply)
return output
- def backward(self, grad_output):
- input1, input2 = self.saved_tensors
+ @staticmethod
+ def backward(ctx, grad_output):
+ input1, input2 = ctx.saved_tensors
with torch.cuda.device_of(input1):
rbot1 = input1.new()
@@ -39,7 +37,7 @@ class CorrelationFunction(Function):
grad_input2 = input2.new()
correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2,
- self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)
+ ctx.pad_size, ctx.kernel_size, ctx.max_displacement,ctx.stride1, ctx.stride2, ctx.corr_multiply)
return grad_input1, grad_input2
@@ -56,7 +54,7 @@ class Correlation(Module):
def forward(self, input1, input2):
- result = CorrelationFunction(self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)(input1, input2)
+ result = CorrelationFunction.apply(input1, input2, self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)
return result
diff --git a/demo_MiddleBury.py b/demo_MiddleBury.py
index 82a6b71..21f7d63 100644
--- a/demo_MiddleBury.py
+++ b/demo_MiddleBury.py
@@ -10,7 +10,7 @@ import numpy
import networks
from my_args import args
-from scipy.misc import imread, imsave
+from skimage.io import imread, imsave
from AverageMeter import *
torch.backends.cudnn.benchmark = True # to speed up the
diff --git a/demo_MiddleBury_slowmotion.py b/demo_MiddleBury_slowmotion.py
index 2bb4293..1b01bac 100644
--- a/demo_MiddleBury_slowmotion.py
+++ b/demo_MiddleBury_slowmotion.py
@@ -7,7 +7,7 @@ import numpy as np
import numpy
import networks
from my_args import args
-from scipy.misc import imread, imsave
+from skimage.io import imread, imsave
from AverageMeter import *
import shutil
@@ -183,4 +183,4 @@ if DO_MiddleBurryOther:
count = count + 1
-
\ No newline at end of file
+
diff --git a/my_package/compiler_args.py b/my_package/compiler_args.py
index 7451c0c..de8c256 100644
--- a/my_package/compiler_args.py
+++ b/my_package/compiler_args.py
@@ -4,35 +4,35 @@ nvcc_args = [
# Quadro: (None)
# NVIDIA NVS: (None)
# Jetson: (None)
- '-gencode', 'arch=compute_37,code=sm_37',
+ # '-gencode', 'arch=compute_37,code=sm_37',
# Tesla: (None)
# Quadro: K1200, K620, M1200, M520, M5000M, M4000M, M3000M, M2000M, M1000M, K620M, M600M, M500M
# NVIDIA NVS: 810
# GeForce / Titan: GTX 750 Ti, GTX 750, GTX 960M, GTX 950M, 940M, 930M, GTX 860M, GTX 850M, 840M, 830M
# Jetson: (None)
- '-gencode', 'arch=compute_50,code=sm_50',
+ # '-gencode', 'arch=compute_50,code=sm_50',
# Tesla: M60, M40
# Quadro: M6000 24GB, M6000, M5000, M4000, M2000, M5500M, M2200, M620
# NVIDIA NVS: (None)
# GeForce / Titan: GTX TITAN X, GTX 980 Ti, GTX 980, GTX 970, GTX 960, GTX 950, GTX 980, GTX 980M, GTX 970M, GTX 965M, 910M
# Jetson: (None)
- '-gencode', 'arch=compute_52,code=sm_52',
+ # '-gencode', 'arch=compute_52,code=sm_52',
# Tesla: P100
# Quadro: GP100
# NVIDIA: NVS: (None)
# GeForce / Titan: (None)
# Jetson: (None)
- '-gencode', 'arch=compute_60,code=sm_60',
+ # '-gencode', 'arch=compute_60,code=sm_60',
# Tesla: P40, P4
# Quadro: P6000, P5000, P4000, P2200, P2000, P1000, P620, P600, P400, P620, P520, P5200, P4200, P3200, P5000, P4000, P3000, P2000, P1000, P600, P500
# NVIDIA NVS: (None)
# GeForce / Titan: TITAN Xp, TITAN X, GTX 1080 Ti, GTX 1080, GTX 1070, GTX 1060, GTX 1050, GTX 1080, GTX 1070, GTX 1060
# Jetson: (None)
- '-gencode', 'arch=compute_61,code=sm_61',
+ # '-gencode', 'arch=compute_61,code=sm_61',
# Tesla: T4
# Quadro: RTX 8000, RTX 6000, RTX 5000, RTX 4000, RTX 5000, RTX 4000, RTX 3000, T2000, T1000
@@ -47,4 +47,4 @@ nvcc_args = [
'-w' # Ignore compiler warnings.
]
-cxx_args = ['-std=c++11', '-w']
\ No newline at end of file
+cxx_args = ['-std=c++14', '-w']
diff --git a/my_package/test_module.py b/my_package/test_module.py
index c1797ec..ca9ccda 100755
--- a/my_package/test_module.py
+++ b/my_package/test_module.py
@@ -7,7 +7,7 @@ from torch.autograd import gradcheck
#from modules.InterpolationModule import InterpolationModule
#from modules.FilterInterpolationModule import FilterInterpolationModule
#from modules.FlowProjectionModule import FlowProjectionModule
-from my_package.DepthFlowProjection import DepthFlowProjectionModule
+from DepthFlowProjection import DepthFlowProjectionModule
#from modules.FilterInterpolationModule import AdaptiveWeightInterpolationModule
#from modules.SeparableConvModule import SeparableConvModule
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment