-
-
Save sschaetz/f37e15ec2f059e13777b to your computer and use it in GitHub Desktop.
clFFT Bug on Surface Pro 3 i5
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#get_ipython().magic(u'matplotlib inline') | |
#get_ipython().magic(u"config InlineBackend.figure_format = 'retina'") | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import os | |
import pyopencl as cl | |
import pyopencl.array | |
import pyopencl.array as cla | |
import scipy.io | |
import matplotlib.pylab as pylab | |
kstr = """ | |
#define FIX | |
//#define FIX +1 | |
__attribute__((always_inline)) void | |
FwdPass1(uint rw, uint b, uint me, uint inOffset, uint outOffset, | |
__local float *bufInRe, __local float *bufInIm, __local float *bufOutRe, | |
__local float *bufOutIm, float2 *R0, float2 *R1, float2 *R2, float2 *R3) | |
{ | |
if(rw) | |
{ | |
bufOutRe[outOffset + ( ((1*me + 0)/4)*16 + (1*me + 0)%4 + 0 )*1 FIX] = (*R0).x; | |
bufOutRe[outOffset + ( ((1*me + 0)/4)*16 + (1*me + 0)%4 + 4 )*1 FIX] = (*R1).x; | |
bufOutRe[outOffset + ( ((1*me + 0)/4)*16 + (1*me + 0)%4 + 8 )*1 FIX] = (*R2).x; | |
bufOutRe[outOffset + ( ((1*me + 0)/4)*16 + (1*me + 0)%4 + 12 )*1 FIX] = (*R3).x; | |
} | |
barrier(CLK_LOCAL_MEM_FENCE); | |
if(rw) | |
{ | |
(*R0).x = bufOutRe[outOffset + ( 0 + me*1 + 0 + 0 )*1 FIX]; | |
(*R1).x = bufOutRe[outOffset + ( 0 + me*1 + 0 + 16 )*1 FIX]; | |
(*R2).x = bufOutRe[outOffset + ( 0 + me*1 + 0 + 32 )*1 FIX]; | |
(*R3).x = bufOutRe[outOffset + ( 0 + me*1 + 0 + 48 )*1 FIX]; | |
} | |
barrier(CLK_LOCAL_MEM_FENCE); | |
if(rw) | |
{ | |
bufOutIm[outOffset + ( ((1*me + 0)/4)*16 + (1*me + 0)%4 + 0 )*1 FIX] = (*R0).y; | |
bufOutIm[outOffset + ( ((1*me + 0)/4)*16 + (1*me + 0)%4 + 4 )*1 FIX] = (*R1).y; | |
bufOutIm[outOffset + ( ((1*me + 0)/4)*16 + (1*me + 0)%4 + 8 )*1 FIX] = (*R2).y; | |
bufOutIm[outOffset + ( ((1*me + 0)/4)*16 + (1*me + 0)%4 + 12 )*1 FIX] = (*R3).y; | |
} | |
barrier(CLK_LOCAL_MEM_FENCE); | |
if(rw) | |
{ | |
(*R0).y = bufOutIm[outOffset + ( 0 + me*1 + 0 + 0 )*1 FIX]; | |
(*R1).y = bufOutIm[outOffset + ( 0 + me*1 + 0 + 16 )*1 FIX]; | |
(*R2).y = bufOutIm[outOffset + ( 0 + me*1 + 0 + 32 )*1 FIX]; | |
(*R3).y = bufOutIm[outOffset + ( 0 + me*1 + 0 + 48 )*1 FIX]; | |
} | |
barrier(CLK_LOCAL_MEM_FENCE); | |
} | |
__attribute__((always_inline)) void | |
FwdPass2(uint rw, uint b, uint me, uint inOffset, uint outOffset, | |
__local float *bufInRe, __local float *bufInIm, __global float2 *bufOut, | |
float2 *R0, float2 *R1, float2 *R2, float2 *R3) | |
{ | |
if(rw) | |
{ | |
bufOut[outOffset + ( 1*me + 0 + 0 )*1] = (*R0); | |
bufOut[outOffset + ( 1*me + 0 + 16 )*1] = (*R1); | |
bufOut[outOffset + ( 1*me + 0 + 32 )*1] = (*R2); | |
bufOut[outOffset + ( 1*me + 0 + 48 )*1] = (*R3); | |
} | |
} | |
typedef union { uint u; int i; } cb_t; | |
__kernel __attribute__((reqd_work_group_size (64,1,1))) | |
void fft_fwd(__constant cb_t *cb __attribute__((max_constant_size(32))), | |
__global const float2 * restrict gbIn, __global float2 * restrict gbOut) | |
{ | |
uint me = get_local_id(0); | |
uint batch = get_group_id(0); | |
__local float lds[256 FIX]; | |
if (me == 0) | |
{ | |
for (uint i=0; i<256 FIX; i++) | |
{ | |
lds[i] = 1.0f; | |
} | |
} | |
barrier(CLK_LOCAL_MEM_FENCE); | |
uint iOffset; | |
uint oOffset; | |
__global float2 *lwbIn; | |
__global float2 *lwbOut; | |
float2 R0 = (float2)(1.0, 1.0); | |
float2 R1 = (float2)(2.0, 2.0); | |
float2 R2 = (float2)(3.0, 3.0); | |
float2 R3 = (float2)(4.0, 4.0); | |
uint rw = (me < ((cb[0].u) - batch*4)*16) ? 1 : 0; | |
uint b = 0; | |
iOffset = (batch*4 + (me/16))*64; | |
oOffset = (batch*4 + (me/16))*64; | |
lwbIn = gbIn + iOffset; | |
lwbOut = gbOut + oOffset; | |
FwdPass1(rw, b, me%16, (me/16)*64, (me/16)*64, lds, lds, lds, lds, &R0, &R1, &R2, &R3); | |
FwdPass2(rw, b, me%16, (me/16)*64, 0, lds, lds, lwbOut, &R0, &R1, &R2, &R3); | |
} | |
""" | |
# In[17]: | |
def testFFT(kernelstring, deviceId, inData): | |
platform = cl.get_platforms()[0] | |
device = platform.get_devices()[deviceId] | |
print device | |
# Dimensions. | |
d1 = 64 | |
d2 = 64 | |
mesh = (1024, 1, 1) | |
bundle = (64, 1, 1) | |
# data | |
bounds = (d1, d2) | |
size = d1 * d2 | |
ctx = cl.Context([device]) | |
queue = cl.CommandQueue(ctx, | |
properties=cl.command_queue_properties.PROFILING_ENABLE) | |
# Host data. | |
constData = (np.ones(2) * 64).astype(np.uint32) | |
#inData = (np.arange(1, size+1) + 1.j*np.arange(1, size+1)).astype( | |
# np.complex64).reshape(d1,d2) | |
outData = (np.ones(size) + 1.j* np.ones(size)).astype( | |
np.complex64).reshape(d1,d2) | |
# Device data. | |
mf = cl.mem_flags | |
constBuffer = cl.Buffer(ctx, mf.READ_WRITE, constData.nbytes) | |
inBuffer = cl.Buffer(ctx, mf.READ_WRITE, inData.nbytes) | |
outBuffer = cl.Buffer(ctx, mf.READ_WRITE, outData.nbytes) | |
localBuffer = cl.Buffer(ctx, mf.READ_WRITE, 4*256*64) | |
# Copy. | |
cl.enqueue_copy(queue, constBuffer, constData) | |
cl.enqueue_copy(queue, inBuffer, inData) | |
cl.enqueue_copy(queue, outBuffer, outData) | |
prg = cl.Program(ctx, kernelstring).build() | |
#binary = prg.get_info(cl.program_info.BINARIES)[0] | |
#print binary | |
prg.fft_fwd(queue, mesh, bundle, | |
constBuffer, inBuffer, outBuffer) | |
cl.enqueue_copy(queue, outData, outBuffer) | |
queue.finish() | |
return outData | |
def plot(data): | |
fig, ax = plt.subplots(figsize=(10,10)) | |
cax = ax.imshow(data, cmap='gray') | |
cax.set_interpolation('nearest') | |
cbar = fig.colorbar(cax, orientation=u'horizontal') | |
fig.show() | |
# In[18]: | |
np.random.seed(10) | |
size = 64*64 | |
inData = (np.random.randn(size) + 1.j*np.random.randn(size)).astype( | |
np.complex64).reshape(64,64) | |
out0 = testFFT(kstr, 0, inData) | |
out1 = testFFT(kstr, 1, inData) | |
print np.sum(np.abs(np.real(out0)-np.real(out1))), np.max(out0), np.min(out0) | |
print np.sum(np.abs(np.imag(out0)-np.imag(out1))), np.max(out1), np.min(out1) | |
plot(np.abs(out0)) | |
plot(np.abs(out1)) | |
plot(np.log10(np.abs(out1/out0))) | |
# In[ ]: | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment