Last active
April 21, 2024 09:20
-
-
Save tenomoto/90e3a156a9629926fd8765dd2a1dc843 to your computer and use it in GitHub Desktop.
Autograd L63
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cstdlib> | |
#include <cstdio> | |
#include <random> | |
#include <chrono> | |
int enzyme_dup; | |
int enzyme_dupnoneed; | |
int enzyme_out; | |
int enzyme_const; | |
//template < typename return_type, typename ... T > | |
//return_type __enzyme_fwddiff(void*, T ... ); | |
template < typename return_type, typename ... T > | |
return_type __enzyme_autodiff(void*, T ... ); | |
struct double3{ double x, y, z; }; | |
double3 florenz(double3 w, const double p, const double r, const double b) { | |
double3 dw; | |
dw.x = -p * w.x + p * w.y; | |
dw.y = (r - w.z) * w.x - w.y; | |
dw.z = w.x * w.y - b * w.z; | |
return dw; | |
} | |
void fom(double3 w, double3 fw[], size_t nstop) { | |
const double p = 10.0; | |
const double r = 32.0; | |
const double b = 8.0 / 3.0; | |
const double dt = 0.01; | |
double3 dw; | |
fw[0] = w; | |
for (int i = 1; i < nstop; i++) { | |
dw = florenz(w, p, r, b); | |
w.x = w.x + dt * dw.x; | |
w.y = w.y + dt * dw.y; | |
w.z = w.z + dt * dw.z; | |
fw[i] = w; | |
} | |
} | |
double calc_cost(double x, double y, double z, double3 wo[], int nobs, int iobs) { | |
const int nstop = 201; | |
double3 w = {x, y, z}; | |
double3 wb[nstop]; | |
fom(w, wb, nstop); | |
double cost = 0.0; | |
int k = 0; | |
for (int i = iobs; i < nstop; i = i + iobs) { | |
double dx = wb[i].x - wo[k].x; | |
double dy = wb[i].y - wo[k].y; | |
double dz = wb[i].z - wo[k].z; | |
cost = cost + 0.5 * (dx * dx + dy * dy + dz * dz); | |
k++; | |
} | |
return cost; | |
} | |
void gen_obs(double3 e, double3 wt[], size_t n, int iobs, double3 wo[]) { | |
std::mt19937 rng(514); | |
std::uniform_real_distribution<double> dist(-1.0, 1.0); | |
int k; | |
for (int i = iobs; i < n; i = i + iobs) { | |
wo[k].x = wt[i].x + e.x * dist(rng); | |
wo[k].y = wt[i].y + e.y * dist(rng); | |
wo[k].z = wt[i].z + e.z * dist(rng); | |
k++; | |
} | |
} | |
double3 update(double3 w, double3 wo[], int nobs, int iobs) { | |
double alpha = 5.0e-4; | |
double3 dwo[nobs]; | |
double3 mu = __enzyme_autodiff<double3>((void*)calc_cost, | |
enzyme_out, w.x, enzyme_out, w.y, enzyme_out, w.z, | |
enzyme_dupnoneed, wo, dwo, | |
enzyme_const, nobs, iobs); | |
w.x = w.x - alpha * mu.x; | |
w.y = w.y - alpha * mu.y; | |
w.z = w.z - alpha * mu.z; | |
return w; | |
} | |
int main() { | |
auto start = std::chrono::high_resolution_clock::now(); | |
const int nstop = 201; | |
double3 w = {1.0, 3.0, 5.0}; | |
double3 e = {0.1, 0.3, 0.5}; | |
double3 wt[nstop]; | |
int iobs = 60; | |
int nobs = (nstop - 1) / iobs; | |
double3 wo[nobs]; | |
int itermax = 100; | |
fom(w, wt, nstop); | |
gen_obs(e, wt, nstop, iobs, wo); | |
w = {1.1, 3.3, 5.5}; | |
for (int i = 0; i < itermax; i++) { | |
w = update(w, wo, nobs, iobs); | |
printf("i=%d w=%f %f %f\n", i, w.x, w.y, w.z); | |
} | |
auto stop = std::chrono::high_resolution_clock::now(); | |
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(stop - start); | |
printf("elapsed time=%lld\n", duration.count()); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
% uname -a FreeBSD atago 13.2-RELEASE-p2 FreeBSD 13.2-RELEASE-p2 GENERIC amd64 % sysctl hw.model hw.ncpu hw.clockrate hw.physmem hw.model: Intel(R) Core(TM) i5-6500 CPU @ 3.20GHz hw.ncpu: 4 hw.clockrate: 3200 hw.physmem: 17093910528 % /usr/local/llvm12/bin/clang++ -O2 l63.cc -fplugin=${HOME}/src/Enzyme/enzyme/build/Enzyme/ClangEnzyme-12.so -o l63.exe % ./l63.exe i=0 w=1.056503 3.282663 5.327455 ... elapsed time=873