Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
from timeit import default_timer as time
import numpy as np
from numba import cuda
import os
import numpy
import torch
import ctypes
import math
from torch.autograd import Variable
@cuda.jit('(float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:], int32, int32, int32)')
def cu_exp_matrix_mul(A, c, d, u, v, b, n, m):
tx = cuda.threadIdx.x
ty = cuda.threadIdx.y
bx = cuda.blockIdx.x
by = cuda.blockIdx.y
bw = cuda.blockDim.x
bh = cuda.blockDim.y
bi = tx + bx * bw
ni = ty + by * bh
if ni >= n or bi >= b:
r = 0
for mi in range(m):
r += math.exp(-A[ni, mi]+c[bi,mi]+d[bi,ni]) *u[bi, mi]
v[bi, ni] = r
@cuda.jit('(float32[:,:], float32[:,:], float32[:,:], float32[:], int32, int32, int32)')
def cu_exp_matrix_cost_sum(A, c, d, v, b, n, m):
tx = cuda.threadIdx.x
bx = cuda.blockIdx.x
bw = cuda.blockDim.x
bi = tx + bx * bw
if bi >= b:
r = 0
for mi in range(m):
for ni in range(n):
r += math.exp(-A[ni, mi]+c[bi,mi]+d[bi,ni])*A[ni, mi]
v[bi] = r
def get_devicendarray(t):
assert t.type() == 'torch.cuda.FloatTensor'
ctx = cuda.cudadrv.driver.driver.get_context()
mp = cuda.cudadrv.driver.MemoryPointer(ctx, ctypes.c_ulong(t.data_ptr()), t.numel()*4)
return cuda.cudadrv.devicearray.DeviceNDArray(t.size(), [i*4 for i in t.stride()], numpy.dtype('float32'),
gpu_data=mp, stream=torch.cuda.current_stream().cuda_stream)
def batch_expmat_product(A, c, d, u):
b = c.size(0)
n = A.size(0)
m = A.size(1)
assert A.dim()==2 and c.dim()==2 and d.dim()==2 and u.dim()==2, "dimension mismatch"
assert c.size(1)==m and d.size(0)==b and d.size(1)==n and u.size(0)==b and u.size(1)==m, "size mismatch"
v =
Ad,cd,dd,ud,vd = (get_devicendarray(x) for x in (A,c,d,u,v))
return v
def batch_expmat_mat_sum(A, c, d):
b = c.size(0)
n = A.size(0)
m = A.size(1)
assert A.dim()==2 and c.dim()==2 and d.dim()==2, "dimension mismatch"
assert c.size(1)==m and d.size(0)==b and d.size(1)==n, "size mismatch"
res =
Ad,cd,dd,resd = (get_devicendarray(x) for x in (A,c,d,res))
return res
b,n,m = 100,200,300
A = torch.randn(n,m).cuda()
c = torch.randn(b,m).cuda()
d = torch.randn(b,n).cuda()
u = torch.randn(b,m).cuda()
t = torch.randn(b,n).cuda()
w = batch_expmat_product(A,c,d,u)
Copy link

netw0rkf10w commented Mar 2, 2021

@t-vi Thank you for sharing your code. Could you please add some comments to tell what each function does? I am trying to write some custom PyTorch function in Numba, and I feel that your code is helpful, but unfortunately it is not easy to understand. Thank you so much!

Copy link

grinisrit commented Apr 27, 2021

To get the context for a tensor t I think it is better to use:

ctx = cuda.cudadrv.devices.get_context(t.device.index)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment