Last active
December 4, 2022 13:16
-
-
Save j-faria/74516bf634f400651b2d5c75c70eb39f to your computer and use it in GitHub Desktop.
C++ and pybind11 code for matrix exponential
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <mkl.h> | |
#include <math.h> | |
#include <pybind11/pybind11.h> | |
#include <pybind11/numpy.h> | |
namespace py = pybind11; | |
using namespace pybind11::literals; | |
using namespace std; | |
// Matrix exponential | |
// Implementation by Po Liu using Intel's Math Kernel Library 11.0.4 in C++ | |
// Algorithm based Arsigny's PhD thesis (PDF: github.com/poliu2s/MKL/blob/master/matrix_exponential_reference.pdf) | |
py::array_t<double> matrix_exponential(py::array_t<double> A, int accuracy=10, int scaling=4) | |
{ | |
auto buf1 = A.request(); | |
if (buf1.ndim != 2) | |
throw std::runtime_error("Number of dimensions must be 2"); | |
if (buf1.shape[0] != buf1.shape[1]) | |
throw std::runtime_error("Matrix must be square"); | |
int n = buf1.shape[0]; | |
auto result = py::array_t<double>(buf1.size); | |
auto buf2 = result.request(); | |
double *matrix = (double *) buf1.ptr, | |
*ptr_result = (double *) buf2.ptr; | |
//M_small = M/(2^N); | |
double* M_small = (double*)mkl_malloc(n * n * sizeof(double), 64); | |
for(int i = 0; i < buf1.size; i++) M_small[i] = matrix[i] / pow(2.0, (double)scaling); | |
// Exp part | |
double* m_exp1 = (double*)mkl_malloc(n * n * sizeof(double), 64); | |
double* m_exp2 = (double*)mkl_malloc(n * n * sizeof(double), 64); | |
for(int i = 0; i < buf1.size; i++) ptr_result[i] = 0.0; | |
for(int i = 0; i < buf1.size; i+=n+1) ptr_result[i] = 1.0; | |
double* M_power = (double*)mkl_malloc(n * n * sizeof(double), 64); | |
double* M_power1 = (double*)mkl_malloc(n * n * sizeof(double), 64); | |
cblas_dcopy(buf1.size, M_small, 1, M_power, 1); | |
double* tmpM1 = (double*)mkl_malloc(n * n * sizeof(double), 64); | |
double factorial_i = 1.0; | |
for(int i = 1; i < accuracy; i++) { | |
factorial_i = factorial_i * i; | |
//m_exp = m_exp + M_power/factorial(i); | |
for(int x = 0; x < buf1.size; x++) tmpM1[x] = M_power[x] / factorial_i; | |
vdAdd(buf1.size, ptr_result, tmpM1, ptr_result); | |
//M_power = M_power * M_small; | |
cblas_dcopy(buf1.size, M_power, 1, M_power1, 1); | |
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, | |
n, n, n, 1.0, M_power1, n, M_small, n, 0.0, M_power, n); | |
} | |
// Squaring step | |
const MKL_INT oneb = 1; | |
for(int i = 0; i < scaling; i++) { | |
// m_exp = m_exp*m_exp; | |
cblas_dcopy(buf1.size, ptr_result, oneb, m_exp1, 1); | |
cblas_dcopy(buf1.size, ptr_result, 1, m_exp2, 1); | |
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, | |
n, n, n, 1.0, m_exp1, n, m_exp2, n, 0.0, ptr_result, n); | |
} | |
mkl_free(M_small); | |
mkl_free(m_exp1); | |
mkl_free(M_power); | |
mkl_free(M_power1); | |
mkl_free(tmpM1); | |
mkl_free(m_exp2); | |
result.resize({n, n}); | |
return result; | |
} | |
PYBIND11_MODULE(example, m) { | |
m.def("expm", &matrix_exponential, "Matrix exponential", | |
"A"_a, "accuracy"_a=10, "scaling"_a=4); | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
compile with a variation of