Skip to content

Instantly share code, notes, and snippets.

@muuki88
Created December 19, 2012 21:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save muuki88/4340776 to your computer and use it in GitHub Desktop.
Save muuki88/4340776 to your computer and use it in GitHub Desktop.
private static FloatMatrix multiply(FloatMatrix A, FloatMatrix B, boolean local) throws IOException {
if (A.getColumnDimension() != B.getRowDimension()) {
throw new IllegalArgumentException("Matrix inner dimensions must agree.");
}
CLContext context = JavaCL.createBestContext();
CLQueue queue = context.createDefaultQueue();
int resultLength = A.getRowDimension() * B.getColumnDimension();
Pointer<Float> aPtr = matrixToPointer(A);
Pointer<Float> bPtr = matrixToPointer(B);
Pointer<Float> resultPtr = allocateFloats(resultLength);
Pointer<Integer> q = allocateInt();
q.set(A.getColumnDimension()); // q is inner dimension
// Create OpenCL input buffers (using the native memory pointers aPtr
// and bPtr) :
CLBuffer<Float> aInputBuffer = context.createBuffer(Usage.Input, aPtr);
CLBuffer<Float> bInputBuffer = context.createBuffer(Usage.Input, bPtr);
CLBuffer<Integer> qInputBuffer = context.createIntBuffer(Usage.Input, q);
// Create an OpenCL output buffer :
CLBuffer<Float> resultBuffer = context.createBuffer(Usage.Output, resultPtr);
// Get and call the kernel :
MultiplicationKernel kernel = new MultiplicationKernel(context);
int[] localWorkSizes = new int[] { 8, 8 };
int[] globalWorkSizes = new int[] { A.getRowDimension(), B.getColumnDimension() };
CLEvent clEvent = null;
Pointer<Float> outPtr = null;
FloatMatrix matrix = null;
try {
if (local) {
clEvent = kernel.floatMatrixMultLocals(queue, //
resultBuffer, //
aInputBuffer, //
bInputBuffer, //
qInputBuffer, //
globalWorkSizes, //
localWorkSizes);
} else {
clEvent = kernel.floatMatrixMult(queue, //
resultBuffer, //
aInputBuffer, //
bInputBuffer, //
qInputBuffer, //
globalWorkSizes, //
localWorkSizes);
}
// blocks until
outPtr = resultBuffer.read(queue, clEvent);
// mulitiplication finished
matrix = pointerToFloatMatrix(outPtr, A.getRowDimension(), B.getColumnDimension());
} catch (CLException e) {
e.printStackTrace();
throw e;
} finally {
Pointer.release(aPtr, bPtr, outPtr, resultPtr, q);
aInputBuffer.release();
bInputBuffer.release();
qInputBuffer.release();
resultBuffer.release();
queue.release();
context.release();
clEvent.release();
aPtr = null;
bPtr = null;
outPtr = null;
resultPtr = null;
q = null;
aInputBuffer = null;
bInputBuffer = null;
qInputBuffer = null;
resultBuffer = null;
queue = null;
context = null;
clEvent = null;
}
return matrix;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment