Created
April 20, 2016 18:13
-
-
Save poweic/95901268fca1dc35cc5de9c4ddd18b02 to your computer and use it in GitHub Desktop.
A simple script that can convert MATLAB symbolic expression to C implementation.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function [] = sym2cpp( expr, filename, output_size, n_digits ) | |
%SYM2CPP Summary of this function goes here | |
% Detailed explanation goes here | |
if ~exist('n_digits', 'var') | |
n_digits = 64; | |
end | |
[dim1, dim2] = size(expr); | |
% This only works for simple expression. | |
% Need a REAL parser when it get too complicated. | |
pattern_hat2power = '([^()*/+-]*)\^([^()*/+-]*)'; | |
tmpl = get_template; | |
expr_64 = vpa(expr, n_digits); | |
expr_str = ''; | |
for i = 1:dim1 | |
for j = 1:dim2 | |
s = char(expr_64(i, j)); | |
s = regexprep(s, pattern_hat2power, 'pow($1, $2)'); | |
% MATLAB use Fortran, and Fortran is column-major. | |
idx = (j-1) * dim1 + (i-1); | |
s = sprintf(' output[%2d] = %s;\n', idx, s); | |
expr_str = [expr_str, s]; | |
end | |
end | |
codes = sprintf(tmpl, expr_str, output_size(1), output_size(2)); | |
c_fn = [filename, '.c']; | |
fid = fopen(c_fn, 'w'); | |
fprintf(fid, codes); | |
fclose(fid); | |
mex(c_fn); | |
end | |
function [tmpl] = get_template() | |
tmpl = ['#include "mex.h"\n', ... | |
'\n', ... | |
'#define theta_1 theta[0]\n', ... % This is for my KDC project | |
'#define theta_2 theta[1]\n', ... | |
'#define theta_3 theta[2]\n', ... | |
'#define theta_4 theta[3]\n', ... | |
'#define theta_5 theta[4]\n', ... | |
'#define theta_6 theta[5]\n', ... | |
'#define omega_1 omega[0]\n', ... | |
'#define omega_2 omega[1]\n', ... | |
'#define omega_3 omega[2]\n', ... | |
'#define omega_4 omega[3]\n', ... | |
'#define omega_5 omega[4]\n', ... | |
'#define omega_6 omega[5]\n', ... | |
'\n', ... | |
'void compute(double output[], double theta[], double omega[]) {\n%s}', ... | |
'\n\n', ... | |
'void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[] ) {\n', ... | |
' size_t rows = %d, cols = %d;\n', ... | |
' plhs[0] = mxCreateDoubleMatrix((mwSize) rows, (mwSize) cols, mxREAL);\n', ... | |
' double* output = mxGetPr(plhs[0]);\n', ... | |
' double* theta = mxGetPr(prhs[0]);\n', ... | |
' double* omega = mxGetPr(prhs[1]);\n', ... | |
' compute(output, theta, omega);\n', ... | |
'}']; | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment