Automatic Differentiation in C
/* Test for Automatic differentiation in C */ | |
/* MIT License - Ragni Matteo 2016 */ | |
#include <math.h> | |
typedef double number; | |
// Math constants | |
#define AD_PI 3.141592653589793 | |
#define AD_E 2.718281828459045 | |
// Basic leaves | |
#define AD_VARIABLE(x, k) number x[2] = {k, 1.0}; | |
#define AD_CONSTANT(x, k) number x[2] = {k, 0.0}; | |
#define AD_TEMP_VAR(x) number x[2] = {0.0, 0.0}; | |
#define AD_COPY(r, x) { \ | |
r[0] = x[0]; \ | |
r[1] = x[1]; \ | |
} | |
// Base functions | |
#define AD_SUM(r, x, y) { \ | |
r[0] = x[0] + y[0]; \ | |
r[1] = x[1] + y[1]; \ | |
} | |
#define AD_PROD(r, x, y) { \ | |
r[0] = x[0] * x[1]; \ | |
r[1] = x[0] * y[1] + y[0] * x[1]; \ | |
} | |
#define AD_DIFF(r, x, y) { \ | |
r[0] = x[0] - y[0]; \ | |
r[1] = x[1] - y[1]; \ | |
} | |
#define AD_DIVIDE(r, x, y) { \ | |
r[0] = x[0] / y[0]; \ | |
r[1] = (x[0] * y[1] - x[1] * y[0]) / (pow(y[0], 2.0)); \ | |
} | |
#define AD_NEGATIVE(r, x) { \ | |
r[0] = -x[0]; \ | |
r[1] = -x[1]; \ | |
} | |
#define AD_ABS(r, x) { \ | |
r[0] = (x[0] > 0 ? x[0] : -x[0]); \ | |
r[1] = x[1] * (x[0] > 0 ? 1.0 : -1.0); \ | |
} | |
// Trascendent functions | |
#define AD_POW_XK(r, x, k) { \ | |
r[0] = pow(x[0], k[0]); \ | |
r[1] = k[0] * x[1] * pow(x[0], (k[0] - 1.0)); \ | |
} | |
#define AD_POW_KX(r, k, x) { \ | |
r[0] = pow(x[0], k[0]); \ | |
r[1] = pow(k[0], x[0]) * log(k[0]) * x[1]; \ | |
} | |
#define AD_POW_XY(r, x, y) { \ | |
r[0] = pow(x[0], y[0]); \ | |
r[1] = pow(x[0], y[0] - 1.0) * (x[1] * y[0] + x[0] * log(x[0]) * y[1]); \ | |
} | |
#define AD_SQRT(r, x) { \ | |
r[0] = pow(x[0], 0.5); \ | |
r[1] = 0.5 * x[1] * pow(x[0], -0.5); \ | |
} | |
#define AD_EXP(r, x) { \ | |
r[0] = exp(x[0]); \ | |
r[1] = x[1] * exp(x[0]); \ | |
} | |
#define AD_LOG(r, x) { \ | |
r[0] = log(x[0]); \ | |
r[1] = x[1] / x[0]; \ | |
} | |
// Trigonometric functions | |
#define AD_SIN(r, x) { \ | |
r[0] = sin(x[0]); \ | |
r[1] = x[1] * cos(x[0]); \ | |
} | |
#define AD_COS(r, x) { \ | |
r[0] = cos(x[0]); \ | |
r[1] = -x[1] * sin(x[0]); \ | |
} | |
#define AD_TAN(r, x) { \ | |
r[0] = tan(x[0]); \ | |
r[1] = x[1] / pow(cos(x[0]), 2.0); \ | |
} | |
#define AD_ASIN(r, x){ \ | |
r[0] = asin(x[0]); \ | |
r[1] = x[1] / sqrt(1.0 - pow(x[0], 2.0)); \ | |
} | |
#define AD_ACOS(r, x){ \ | |
r[0] = acos(x[0]); \ | |
r[1] = -x[1] / sqrt(1.0 + pow(x[0], 2.0));; \ | |
} | |
#define AD_ATAN(r, x){ \ | |
r[0] = atan(x[0]); \ | |
r[1] = x[1] / (1.0 + pow(x[0], 2.0)); \ | |
} |
/* Automatic differentiation in C - Test file */ | |
/* | |
* We will test the following function: | |
* | |
* f(x) = x² + cos( tan ( x² )) | |
* | |
* df | |
* --(x) = 2x * ( 1 - sin( tan ( x² )) / cos²( x² ) ) | |
* dx | |
* | |
* To use the automatic differentiation we must define an intermediate | |
* variable for each binary operation. As for now, I don't have a way to implement | |
* it without this. | |
*/ | |
#include <stdio.h> | |
#include "ad.h" | |
// Atomaticaly differentiated function. | |
number * test_function(double *ret, double _x) { | |
AD_VARIABLE(x, _x); | |
AD_CONSTANT(c2, 2.0); | |
AD_TEMP_VAR(t_0); | |
AD_TEMP_VAR(t_1); | |
AD_TEMP_VAR(t_2); | |
AD_POW_XK(t_0, x, c2); | |
AD_TAN(t_1, t_0); | |
AD_COS(t_2, t_1); | |
AD_SUM(ret, t_0, t_2); | |
return ret; | |
} | |
// Exact function | |
double func(double x) { | |
return pow(x, 2) + cos(tan(pow(x, 2))); | |
} | |
// Exact derivative | |
double dfunc(double x) { | |
return (2 * x) * (1 - sin(tan(pow(x, 2))) / (pow(cos(pow(x, 2)),2))); | |
} | |
int main(int argc, char **argv) { | |
double ret_A[2]; | |
double x = 0; | |
while (x < 1) { | |
test_function(ret_A, x); | |
double t_1 = func(x); | |
double t_2 = dfunc(x); | |
printf("%3.3f\t%3.3f\t%3.3f\t%3.3f\t%3.3f\n", \ | |
x, t_1, ret_A[0], t_2, ret_A[1]); | |
x += 0.025; | |
} | |
return 0; | |
} |
0.000 1.000 1.000 0.000 0.000 | |
0.025 1.001 1.001 0.050 0.050 | |
0.050 1.002 1.002 0.100 0.100 | |
0.075 1.006 1.006 0.149 0.149 | |
0.100 1.010 1.010 0.198 0.198 | |
0.125 1.016 1.016 0.246 0.246 | |
0.150 1.022 1.022 0.293 0.293 | |
0.175 1.030 1.030 0.339 0.339 | |
0.200 1.039 1.039 0.384 0.384 | |
0.225 1.049 1.049 0.427 0.427 | |
0.250 1.061 1.061 0.469 0.469 | |
0.275 1.073 1.073 0.508 0.508 | |
0.300 1.086 1.086 0.545 0.545 | |
0.325 1.100 1.100 0.580 0.580 | |
0.350 1.115 1.115 0.613 0.613 | |
0.375 1.131 1.131 0.642 0.642 | |
0.400 1.147 1.147 0.668 0.668 | |
0.425 1.164 1.164 0.690 0.690 | |
0.450 1.181 1.181 0.709 0.709 | |
0.475 1.199 1.199 0.722 0.722 | |
0.500 1.218 1.218 0.731 0.731 | |
0.525 1.236 1.236 0.734 0.734 | |
0.550 1.254 1.254 0.729 0.729 | |
0.575 1.272 1.272 0.717 0.717 | |
0.600 1.290 1.290 0.696 0.696 | |
0.625 1.307 1.307 0.665 0.665 | |
0.650 1.323 1.323 0.621 0.621 | |
0.675 1.338 1.338 0.562 0.562 | |
0.700 1.351 1.351 0.486 0.486 | |
0.725 1.362 1.362 0.388 0.388 | |
0.750 1.370 1.370 0.264 0.264 | |
0.775 1.375 1.375 0.109 0.109 | |
0.800 1.375 1.375 -0.085 -0.085 | |
0.825 1.370 1.370 -0.328 -0.328 | |
0.850 1.358 1.358 -0.631 -0.631 | |
0.875 1.338 1.338 -1.010 -1.010 | |
0.900 1.307 1.307 -1.485 -1.485 | |
0.925 1.263 1.263 -2.079 -2.079 | |
0.950 1.202 1.202 -2.821 -2.821 | |
0.975 1.120 1.120 -3.739 -3.739 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment