Skip to content

Instantly share code, notes, and snippets.

@piotrMocz
Created June 8, 2014 20:04
Show Gist options
  • Save piotrMocz/3827a1c0d9c53fe803a0 to your computer and use it in GitHub Desktop.
Save piotrMocz/3827a1c0d9c53fe803a0 to your computer and use it in GitHub Desktop.
Finally working -- only square matrices though
def LU(A: CuMatrix[Float])(implicit handle: cublasHandle): CuMatrix[Float]= {
if (A.rows != A.cols) {
println("Matrix has to be square.")
return A
}
val d_A = CuMatrix.create[Float](A.rows, A.cols)
d_A := A
val P = CuMatrix.create[Int](d_A.rows, 1)
val info = CuMatrix.create[Int](1, 1)
val A_ptr = jcuda.Pointer.to(d_A.offsetPointer)
val d_Aptr = new jcuda.Pointer()
JCuda.cudaMalloc(d_Aptr, jcuda.Sizeof.POINTER)
JCuda.cudaMemcpy(d_Aptr, A_ptr, 1 * jcuda.Sizeof.POINTER, cudaMemcpyKind.cudaMemcpyHostToDevice)
JCublas2.cublasSgetrfBatched(handle, d_A.rows, d_Aptr, d_A.majorSize,
P.offsetPointer, info.offsetPointer, 1)
d_A
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment