Skip to content

Instantly share code, notes, and snippets.

@pclucas14
Created July 24, 2018 19:28
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save pclucas14/826c129aea315b02cbde0ec23ceb1068 to your computer and use it in GitHub Desktop.
Save pclucas14/826c129aea315b02cbde0ec23ceb1068 to your computer and use it in GitHub Desktop.
Wrapper for EMD in pytorch
from torch.autograd import Function, Variable
import numpy as np
import torch
import torch.nn as nn
import pdb
from torch.nn.modules.utils import _pair
from pyinn.utils import Dtype, Stream, load_kernel
import torch.nn.functional as F
from torch.autograd import gradcheck
approxmatch_kernel = '''
extern "C"
__global__ void approxmatch(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,float * __restrict__ match,float * temp){
float * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n;
float multiL,multiR;
if (n>=m){
multiL=1;
multiR=n/m;
}else{
multiL=m/n;
multiR=1;
}
const int Block=1024;
__shared__ float buf[Block*4];
for (int i=blockIdx.x;i<b;i+=gridDim.x){
for (int j=threadIdx.x;j<n*m;j+=blockDim.x)
match[i*n*m+j]=0;
for (int j=threadIdx.x;j<n;j+=blockDim.x)
remainL[j]=multiL;
for (int j=threadIdx.x;j<m;j+=blockDim.x)
remainR[j]=multiR;
__syncthreads();
for (int j=7;j>=-2;j--){
float level=-powf(4.0f,j);
if (j==-2){
level=0;
}
for (int k0=0;k0<n;k0+=blockDim.x){
int k=k0+threadIdx.x;
float x1=0,y1=0,z1=0;
if (k<n){
x1=xyz1[i*n*3+k*3+0];
y1=xyz1[i*n*3+k*3+1];
z1=xyz1[i*n*3+k*3+2];
}
float suml=1e-9f;
for (int l0=0;l0<m;l0+=Block){
int lend=min(m,l0+Block)-l0;
for (int l=threadIdx.x;l<lend;l+=blockDim.x){
float x2=xyz2[i*m*3+l0*3+l*3+0];
float y2=xyz2[i*m*3+l0*3+l*3+1];
float z2=xyz2[i*m*3+l0*3+l*3+2];
buf[l*4+0]=x2;
buf[l*4+1]=y2;
buf[l*4+2]=z2;
buf[l*4+3]=remainR[l0+l];
}
__syncthreads();
for (int l=0;l<lend;l++){
float x2=buf[l*4+0];
float y2=buf[l*4+1];
float z2=buf[l*4+2];
float d=level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1));
float w=__expf(d)*buf[l*4+3];
suml+=w;
}
__syncthreads();
}
if (k<n)
ratioL[k]=remainL[k]/suml;
}
/*for (int k=threadIdx.x;k<n;k+=gridDim.x){
float x1=xyz1[i*n*3+k*3+0];
float y1=xyz1[i*n*3+k*3+1];
float z1=xyz1[i*n*3+k*3+2];
float suml=1e-9f;
for (int l=0;l<m;l++){
float x2=xyz2[i*m*3+l*3+0];
float y2=xyz2[i*m*3+l*3+1];
float z2=xyz2[i*m*3+l*3+2];
float w=expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*remainR[l];
suml+=w;
}
ratioL[k]=remainL[k]/suml;
}*/
__syncthreads();
for (int l0=0;l0<m;l0+=blockDim.x){
int l=l0+threadIdx.x;
float x2=0,y2=0,z2=0;
if (l<m){
x2=xyz2[i*m*3+l*3+0];
y2=xyz2[i*m*3+l*3+1];
z2=xyz2[i*m*3+l*3+2];
}
float sumr=0;
for (int k0=0;k0<n;k0+=Block){
int kend=min(n,k0+Block)-k0;
for (int k=threadIdx.x;k<kend;k+=blockDim.x){
buf[k*4+0]=xyz1[i*n*3+k0*3+k*3+0];
buf[k*4+1]=xyz1[i*n*3+k0*3+k*3+1];
buf[k*4+2]=xyz1[i*n*3+k0*3+k*3+2];
buf[k*4+3]=ratioL[k0+k];
}
__syncthreads();
for (int k=0;k<kend;k++){
float x1=buf[k*4+0];
float y1=buf[k*4+1];
float z1=buf[k*4+2];
float w=__expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*buf[k*4+3];
sumr+=w;
}
__syncthreads();
}
if (l<m){
sumr*=remainR[l];
float consumption=fminf(remainR[l]/(sumr+1e-9f),1.0f);
ratioR[l]=consumption*remainR[l];
remainR[l]=fmaxf(0.0f,remainR[l]-sumr);
}
}
/*for (int l=threadIdx.x;l<m;l+=blockDim.x){
float x2=xyz2[i*m*3+l*3+0];
float y2=xyz2[i*m*3+l*3+1];
float z2=xyz2[i*m*3+l*3+2];
float sumr=0;
for (int k=0;k<n;k++){
float x1=xyz1[i*n*3+k*3+0];
float y1=xyz1[i*n*3+k*3+1];
float z1=xyz1[i*n*3+k*3+2];
float w=expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*ratioL[k];
sumr+=w;
}
sumr*=remainR[l];
float consumption=fminf(remainR[l]/(sumr+1e-9f),1.0f);
ratioR[l]=consumption*remainR[l];
remainR[l]=fmaxf(0.0f,remainR[l]-sumr);
}*/
__syncthreads();
for (int k0=0;k0<n;k0+=blockDim.x){
int k=k0+threadIdx.x;
float x1=0,y1=0,z1=0;
if (k<n){
x1=xyz1[i*n*3+k*3+0];
y1=xyz1[i*n*3+k*3+1];
z1=xyz1[i*n*3+k*3+2];
}
float suml=0;
for (int l0=0;l0<m;l0+=Block){
int lend=min(m,l0+Block)-l0;
for (int l=threadIdx.x;l<lend;l+=blockDim.x){
buf[l*4+0]=xyz2[i*m*3+l0*3+l*3+0];
buf[l*4+1]=xyz2[i*m*3+l0*3+l*3+1];
buf[l*4+2]=xyz2[i*m*3+l0*3+l*3+2];
buf[l*4+3]=ratioR[l0+l];
}
__syncthreads();
float rl=ratioL[k];
if (k<n){
for (int l=0;l<lend;l++){
float x2=buf[l*4+0];
float y2=buf[l*4+1];
float z2=buf[l*4+2];
float w=__expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*rl*buf[l*4+3];
match[i*n*m+(l0+l)*n+k]+=w;
suml+=w;
}
}
__syncthreads();
}
if (k<n)
remainL[k]=fmaxf(0.0f,remainL[k]-suml);
}
/*for (int k=threadIdx.x;k<n;k+=blockDim.x){
float x1=xyz1[i*n*3+k*3+0];
float y1=xyz1[i*n*3+k*3+1];
float z1=xyz1[i*n*3+k*3+2];
float suml=0;
for (int l=0;l<m;l++){
float x2=xyz2[i*m*3+l*3+0];
float y2=xyz2[i*m*3+l*3+1];
float z2=xyz2[i*m*3+l*3+2];
float w=expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*ratioL[k]*ratioR[l];
match[i*n*m+l*n+k]+=w;
suml+=w;
}
remainL[k]=fmaxf(0.0f,remainL[k]-suml);
}*/
__syncthreads();
}
}
}'''
matchcost_kernel = '''
extern "C"
__global__ void matchcost(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ out){
__shared__ float allsum[512];
const int Block=1024;
__shared__ float buf[Block*3];
for (int i=blockIdx.x;i<b;i+=gridDim.x){
float subsum=0;
for (int k0=0;k0<n;k0+=blockDim.x){
int k=k0+threadIdx.x;
float x1=0,y1=0,z1=0;
if (k<n){
x1=xyz1[i*n*3+k*3+0];
y1=xyz1[i*n*3+k*3+1];
z1=xyz1[i*n*3+k*3+2];
}
for (int l0=0;l0<m;l0+=Block){
int lend=min(m,l0+Block)-l0;
for (int l=threadIdx.x;l<lend*3;l+=blockDim.x)
buf[l]=xyz2[i*m*3+l0*3+l];
__syncthreads();
if (k<n){
for (int l=0;l<lend;l++){
float x2=buf[l*3+0];
float y2=buf[l*3+1];
float z2=buf[l*3+2];
float d=sqrtf((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1));
subsum+=d*match[i*n*m+(l0+l)*n+k];
}
}
__syncthreads();
}
}
allsum[threadIdx.x]=subsum;
for (int j=1;j<blockDim.x;j<<=1){
__syncthreads();
if ((threadIdx.x&j)==0 && threadIdx.x+j<blockDim.x){
allsum[threadIdx.x]+=allsum[threadIdx.x+j];
}
}
if (threadIdx.x==0)
out[i]=allsum[0];
__syncthreads();
}
}'''
matchcostgrad2_kernel = '''
extern "C"
__global__ void matchcostgrad2(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ grad2){
__shared__ float sum_grad[256*3];
for (int i=blockIdx.x;i<b;i+=gridDim.x){
int kbeg=m*blockIdx.y/gridDim.y;
int kend=m*(blockIdx.y+1)/gridDim.y;
for (int k=kbeg;k<kend;k++){
float x2=xyz2[(i*m+k)*3+0];
float y2=xyz2[(i*m+k)*3+1];
float z2=xyz2[(i*m+k)*3+2];
float subsumx=0,subsumy=0,subsumz=0;
for (int j=threadIdx.x;j<n;j+=blockDim.x){
float x1=x2-xyz1[(i*n+j)*3+0];
float y1=y2-xyz1[(i*n+j)*3+1];
float z1=z2-xyz1[(i*n+j)*3+2];
float d=match[i*n*m+k*n+j]*rsqrtf(fmaxf(x1*x1+y1*y1+z1*z1,1e-20f));
subsumx+=x1*d;
subsumy+=y1*d;
subsumz+=z1*d;
}
sum_grad[threadIdx.x*3+0]=subsumx;
sum_grad[threadIdx.x*3+1]=subsumy;
sum_grad[threadIdx.x*3+2]=subsumz;
for (int j=1;j<blockDim.x;j<<=1){
__syncthreads();
int j1=threadIdx.x;
int j2=threadIdx.x+j;
if ((j1&j)==0 && j2<blockDim.x){
sum_grad[j1*3+0]+=sum_grad[j2*3+0];
sum_grad[j1*3+1]+=sum_grad[j2*3+1];
sum_grad[j1*3+2]+=sum_grad[j2*3+2];
}
}
if (threadIdx.x==0){
grad2[(i*m+k)*3+0]=sum_grad[0];
grad2[(i*m+k)*3+1]=sum_grad[1];
grad2[(i*m+k)*3+2]=sum_grad[2];
}
__syncthreads();
}
}
}'''
matchcostgrad1_kernel = '''
extern "C"
__global__ void matchcostgrad1(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ grad1){
for (int i=blockIdx.x;i<b;i+=gridDim.x){
for (int l=threadIdx.x;l<n;l+=blockDim.x){
float x1=xyz1[i*n*3+l*3+0];
float y1=xyz1[i*n*3+l*3+1];
float z1=xyz1[i*n*3+l*3+2];
float dx=0,dy=0,dz=0;
for (int k=0;k<m;k++){
float x2=xyz2[i*m*3+k*3+0];
float y2=xyz2[i*m*3+k*3+1];
float z2=xyz2[i*m*3+k*3+2];
float d=match[i*n*m+k*n+l]*rsqrtf(fmaxf((x1-x2)*(x1-x2)+(y1-y2)*(y1-y2)+(z1-z2)*(z1-z2),1e-20f));
dx+=(x1-x2)*d;
dy+=(y1-y2)*d;
dz+=(z1-z2)*d;
}
grad1[i*n*3+l*3+0]=dx;
grad1[i*n*3+l*3+1]=dy;
grad1[i*n*3+l*3+2]=dz;
}
}
}
'''
#void approxmatchLauncher(int b,int n,int m,const float * xyz1,const float * xyz2,float * match,float * temp){
# approxmatch<<<32,512>>>(b,n,m,xyz1,xyz2,match,temp);
# }
class EMD(Function):
def forward(self, xyz1, xyz2):
assert xyz1.dim() == 3 and xyz1.is_cuda and xyz2.is_cuda
assert xyz1.shape[-1] == 3 # as done by Panos
batch_size, num_pts, pt_dim = xyz1.size()
_ , m , _ = xyz2.size()
match = torch.zeros(batch_size, m, num_pts).cuda()
cost = torch.zeros(batch_size, ).cuda()
temp = torch.zeros(batch_size, 2 * (m+num_pts)).cuda()
n = xyz1.numel()
with torch.cuda.device_of(xyz1):
# 1) get matching
f = load_kernel('approxmatch', approxmatch_kernel)
f(block=(512, 1, 1), # (CUDA_NUM_THREADS,1,1),
grid=(32, 1, 1), # GET_BLOCKS(n),1,1),
args=[batch_size, num_pts, m, xyz1.data_ptr(), xyz2.data_ptr(), match.data_ptr(), temp.data_ptr()],
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
# 2) calculate matching cost
g = load_kernel('matchcost', matchcost_kernel)
g(block=(512, 1, 1), # (CUDA_NUM_THREADS, 1, 1),
grid=(32, 1, 1), # (GET_BLOCKS(n), 1, 1),
args=[batch_size, num_pts, m, xyz1.data_ptr(), xyz2.data_ptr(), match.data_ptr(), cost.data_ptr()],
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
self.xyz1 = xyz1
self.xyz2 = xyz2
self.match = match
self.cost = cost
self.num_pts = num_pts
self.m = m
self.numel = n
self.batch_size = batch_size
del temp
return cost
def backward(self, grad_cost):
xyz1 = self.xyz1
xyz2 = self.xyz2
match = self.match
cost = self.cost
m = self.m
numel = self.numel
num_pts = self.num_pts
batch_size = self.batch_size
grad1 = torch.zeros_like(xyz1).cuda()
grad2 = torch.zeros_like(xyz2).cuda()
with torch.cuda.device_of(grad_cost):
if xyz1.requires_grad:
f = load_kernel('matchcostgrad1', matchcostgrad1_kernel)
f(block=(512, 1, 1), # (CUDA_NUM_THREADS, 1, 1),
grid=(32, 1, 1), # (GET_BLOCKS(xyz1.numel()), 1, 1),
args=[batch_size, num_pts, m, xyz1.data_ptr(), xyz2.data_ptr(), match.data_ptr(), grad1.data_ptr()],
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
if xyz2.requires_grad:
g = load_kernel('matchcostgrad2', matchcostgrad2_kernel)
g(block=(256, 1, 1), # (CUDA_NUM_THREADS, 1, 1),
grid=(32, 32, 1), # (GET_BLOCKS(xyz2.numel()), 1, 1),
args=[batch_size, num_pts, m, xyz1.data_ptr(), xyz2.data_ptr(), match.data_ptr(), grad2.data_ptr()],
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
return grad1 * grad_cost.view(-1, 1, 1), grad2 * grad_cost.view(-1, 1, 1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment