Created
October 10, 2017 14:54
-
-
Save sonesuke/a17dfb5fb456c21ef72f8fbf8d5cf070 to your computer and use it in GitHub Desktop.
fdgmres
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cblas.h> | |
#include <math.h> | |
#include <memory.h> | |
#include <stdio.h> | |
typedef enum {GMRES_SUCCESS, GMRES_NOT_CONVERGENCE} GMRES_RESULT; | |
typedef void (*axfunc)(double*, double*); | |
/* Compute the Givens rotation matrix parameters for a and b. */ | |
void rotmat(double *c, double *s, double a, double b) | |
{ | |
if(b == 0.0) { | |
*c = 1.0; | |
*s = 0.0; | |
} else if(fabs(b) > fabs(a)) { | |
double temp = a / b; | |
*s = 1.0 / sqrt(1.0 + pow(temp, 2)); | |
*c = temp * (*s); | |
} else { | |
double temp = b / a; | |
*c = 1.0 / sqrt(1.0 + pow(temp, 2)); | |
*s = temp * (*c); | |
} | |
} | |
/* | |
gmres solves the linear system Ax=b | |
using the Generalized Minimal residual ( GMRESm ) method with restarts . | |
input axfunc a function to calculate A*x | |
x initial guess vector | |
b right hand side vector | |
tol error tolerance | |
output x solution vector | |
error error norm | |
flag GMRES_RESULT: | |
GMRES_SUCCESS = solution found to tolerance | |
GMRES_NOT_CONVERGENCE = no convergence | |
*/ | |
GMRES_RESULT gmres(const int N, axfunc f, double *x, double *b, double tol ,double* err) | |
{ | |
/* working memory */ | |
static double work[1000000]; | |
static int next = 0; | |
memset((void*)work, 0, 100000 * sizeof(double)); | |
/* initialize */ | |
GMRES_RESULT flag = GMRES_SUCCESS; | |
double bnrm2 = cblas_dnrm2(N, b, 1); | |
if(bnrm2 == 0.0) { | |
bnrm2 = 1.0; | |
} | |
/* r = b - Ax */ | |
double *r = work; next += N; /* N */ | |
f(x, r); | |
cblas_daxpy(N, -1, b, 1, r, 1); | |
cblas_dscal(N, -1, r, 1); | |
*err = cblas_dnrm2(N, r, 1) / bnrm2; | |
if ( *err < tol ) { | |
return 0; | |
} | |
/* initial workspace */ | |
double *V = work + next; next += N * (N+1); /* N x N+1 */ | |
double *H = work + next; next += (N+1) * N; /* N+1 x N */ | |
double *cs = work + next; next += N; /* N */ | |
double *sn = work + next; next += N; /* N */ | |
double *e1 = work + next; next += N; /* N */ | |
e1[0] = 1.0; | |
double *s = work + next; next += N; /* N */ | |
double *w = work + next; next += N; /* N */ | |
double *y = work + next; next += N; /* N */ | |
double *vi = work + next; next += N; | |
/* V(:,1) = r / norm( r ) */; | |
double rnrm = cblas_dnrm2(N, r, 1); | |
cblas_daxpy(N, 1.0 /rnrm , r, 1, V, N+1); | |
/* s = norm(r) * e1 */ | |
cblas_daxpy(N, rnrm , e1, 1, s, 1); | |
/* construct orthonormal */ | |
int i = 0; | |
for(i = 0; i < N; ++i) { | |
/* basis using Gram-Schmidt */ | |
/* w = A*V(:,i) */ | |
cblas_dcopy(N, &(V[i]), N+1, vi, 1); | |
f(vi, w); | |
for(int k = 0; k < i + 1; ++k) { | |
H[k*N + i] = cblas_ddot(N, w, 1, &(V[k]), N+1); | |
cblas_daxpy(N, -H[k*N + i], &(V[k]), N+1, w, 1); | |
} | |
/* H(i+1,i) = norm( w ) */ | |
H[(i+1)*N + i] = cblas_dnrm2(N, w, 1); | |
/* V(:,i+1) = w / H(i+1,i) */ | |
cblas_daxpy(N, 1 / H[(i+1)*N + i], w, 1, &(V[i+1]), N+1); | |
/* apply Givens rotation */ | |
for (int k = 0; k < i; ++k) { | |
double temp1= cs[k]*H[k*N + i] + sn[k]*H[(k+1)*N + i]; | |
H[(k+1)*N + i] = -sn[k]*H[k*N+i] + cs[k]*H[(k+1)*N + i]; | |
H[k*N + i] = temp1; | |
} | |
/* form i-th rotation matrix */ | |
rotmat(&(cs[i]), &(sn[i]), H[i*N+i], H[(i+1)*N+i]); | |
/* approximate residual norm */ | |
double temp4 = cs[i] * s[i]; | |
s[i+1] = -sn[i] * s[i]; | |
s[i] = temp4; | |
H[i*N + i] = cs[i]*H[i*N + i] + sn[i]*H[(i+1)*N + i]; | |
H[(i+1)*N + i] = 0; | |
*err = fabs(s[i+1]) / bnrm2; | |
if (*err <= tol) { | |
/* update approximation */ | |
/* and exit */ | |
cblas_dcopy(i+1, s, 1, y, 1); | |
cblas_dtrsv(CblasRowMajor, CblasUpper, CblasNoTrans, CblasNonUnit, i+1, H, N, y, 1); | |
cblas_dgemv(CblasRowMajor, CblasNoTrans, N, i+1, 1, V, N+1, y, 1, 1, x, 1); | |
break; | |
} | |
} | |
if( *err <= tol) { | |
return flag; | |
} | |
/* update approximation */ | |
/* y = H(1:m,1:m) \ s(1:m); */ | |
/* x = x + V(:,1:m)*y; */ | |
cblas_dcopy(N, s, 1, y, 1); | |
cblas_dtrsv(CblasRowMajor, CblasUpper, CblasNoTrans, CblasNonUnit, N, H, N, y, 1); | |
cblas_dgemv(CblasRowMajor, CblasNoTrans, N, N, 1, V, N+1, y, 1, 1, x, 1); | |
/* check convergence */ | |
f(x, r); | |
cblas_daxpy(N, -1, b, 1, r, 1); | |
cblas_dscal(N, -1, r, 1); | |
/* check convergence */ | |
s[i+1] = cblas_dnrm2(N, r, 1); | |
*err = s[i+1] / bnrm2; | |
if (*err > tol) { | |
flag = GMRES_NOT_CONVERGENCE; | |
} | |
return flag; | |
} | |
typedef void (*ffunc)(double*, double*); | |
void fd(const int N, ffunc f, double* f0, double* x0, double* x, double* dfx) | |
{ | |
/* initialize work */ | |
static double work[100]; | |
memset((void*)work, 0, 100 * sizeof(double)); | |
int next = 0; | |
const double h = 1e-6; | |
double* x0_plus_h_x = work; next += N; | |
double* farg = work; next += N; | |
double* ff = work; next += N; | |
cblas_dcopy(N, x0, 1, farg, 1); | |
cblas_daxpy(N, h, x, 1, farg, 1); | |
f(farg, ff); | |
cblas_daxpy(N, -1, f0, 1, ff, 1); | |
cblas_dscal(N, 1/h, ff, 1); | |
cblas_dcopy(N, ff, 1, dfx, 1); | |
} | |
/* | |
gmres solves the linear system Ax=b | |
using the Generalized Minimal residual ( GMRESm ) method with restarts . | |
input axfunc a function to calculate A*x | |
x initial guess vector | |
b right hand side vector | |
tol error tolerance | |
output x solution vector | |
error error norm | |
flag GMRES_RESULT: | |
GMRES_SUCCESS = solution found to tolerance | |
GMRES_NOT_CONVERGENCE = no convergence | |
*/ | |
GMRES_RESULT fdgmres(const int N, ffunc f, double* f0, double* x0, double* x, double tol ,double* err) | |
{ | |
/* working memory */ | |
static double work[1000000]; | |
static int next = 0; | |
memset((void*)work, 0, 100000 * sizeof(double)); | |
/* initialize */ | |
GMRES_RESULT flag = GMRES_SUCCESS; | |
double *b = work; next += N; /* N */ | |
cblas_daxpy(N, -1, f0, 1, b, 1); | |
double bnrm2 = cblas_dnrm2(N, b, 1); | |
if(bnrm2 == 0.0) { | |
bnrm2 = 1.0; | |
} | |
/* r = b - dfx */ | |
double *r = work; next += N; /* N */ | |
cblas_dscal(N, 0, x, 1); | |
cblas_dcopy(N, b, 1, r, 1); | |
for (int i = 0; i < N; ++i) { | |
printf("%f\t", r[i]); | |
} | |
*err = cblas_dnrm2(N, r, 1) / bnrm2; | |
if ( *err < tol ) { | |
return 0; | |
} | |
/* initial workspace */ | |
double *V = work + next; next += N * (N+1); /* N x N+1 */ | |
double *H = work + next; next += (N+1) * N; /* N+1 x N */ | |
double *cs = work + next; next += N; /* N */ | |
double *sn = work + next; next += N; /* N */ | |
double *e1 = work + next; next += N; /* N */ | |
e1[0] = 1.0; | |
double *s = work + next; next += N; /* N */ | |
double *w = work + next; next += N; /* N */ | |
double *y = work + next; next += N; /* N */ | |
double *vi = work + next; next += N; | |
/* V(:,1) = r / norm( r ) */; | |
double rnrm = cblas_dnrm2(N, r, 1); | |
cblas_daxpy(N, 1.0 /rnrm , r, 1, V, N+1); | |
/* s = norm(r) * e1 */ | |
cblas_daxpy(N, rnrm , e1, 1, s, 1); | |
/* construct orthonormal */ | |
int i = 0; | |
for(i = 0; i < N; ++i) { | |
/* basis using Gram-Schmidt */ | |
/* w = A*V(:,i) */ | |
cblas_dcopy(N, &(V[i]), N+1, vi, 1); | |
fd(N, f, f0, x0, vi, w); | |
for(int k = 0; k < i + 1; ++k) { | |
H[k*N + i] = cblas_ddot(N, w, 1, &(V[k]), N+1); | |
cblas_daxpy(N, -H[k*N + i], &(V[k]), N+1, w, 1); | |
} | |
/* H(i+1,i) = norm( w ) */ | |
H[(i+1)*N + i] = cblas_dnrm2(N, w, 1); | |
/* V(:,i+1) = w / H(i+1,i) */ | |
cblas_daxpy(N, 1 / H[(i+1)*N + i], w, 1, &(V[i+1]), N+1); | |
/* apply Givens rotation */ | |
for (int k = 0; k < i; ++k) { | |
double temp1= cs[k]*H[k*N + i] + sn[k]*H[(k+1)*N + i]; | |
H[(k+1)*N + i] = -sn[k]*H[k*N+i] + cs[k]*H[(k+1)*N + i]; | |
H[k*N + i] = temp1; | |
} | |
/* form i-th rotation matrix */ | |
rotmat(&(cs[i]), &(sn[i]), H[i*N+i], H[(i+1)*N+i]); | |
/* approximate residual norm */ | |
double temp4 = cs[i] * s[i]; | |
s[i+1] = -sn[i] * s[i]; | |
s[i] = temp4; | |
H[i*N + i] = cs[i]*H[i*N + i] + sn[i]*H[(i+1)*N + i]; | |
H[(i+1)*N + i] = 0; | |
*err = fabs(s[i+1]) / bnrm2; | |
if (*err <= tol) { | |
/* update approximation */ | |
/* and exit */ | |
cblas_dcopy(i+1, s, 1, y, 1); | |
cblas_dtrsv(CblasRowMajor, CblasUpper, CblasNoTrans, CblasNonUnit, i+1, H, N, y, 1); | |
cblas_dgemv(CblasRowMajor, CblasNoTrans, N, i+1, 1, V, N+1, y, 1, 1, x, 1); | |
break; | |
} | |
} | |
if( *err <= tol) { | |
return flag; | |
} | |
/* update approximation */ | |
/* y = H(1:m,1:m) \ s(1:m); */ | |
/* x = x + V(:,1:m)*y; */ | |
cblas_dcopy(N, s, 1, y, 1); | |
cblas_dtrsv(CblasRowMajor, CblasUpper, CblasNoTrans, CblasNonUnit, N, H, N, y, 1); | |
cblas_dgemv(CblasRowMajor, CblasNoTrans, N, N, 1, V, N+1, y, 1, 1, x, 1); | |
/* check convergence */ | |
f(x, r); | |
cblas_daxpy(N, -1, b, 1, r, 1); | |
cblas_dscal(N, -1, r, 1); | |
/* check convergence */ | |
s[i+1] = cblas_dnrm2(N, r, 1); | |
*err = s[i+1] / bnrm2; | |
if (*err > tol) { | |
flag = GMRES_NOT_CONVERGENCE; | |
} | |
return flag; | |
} | |
#define N 3 | |
#define N2 2 | |
#define tol 1e-1 | |
void ax(double *x, double *ret) | |
{ | |
double A[] = { | |
1.0, 2, 1, | |
1, 1.0, 1, | |
2, 0, 1.0, | |
}; | |
cblas_dgemv(CblasRowMajor, CblasNoTrans, N, N, 1, A, N, x, 1, 0, ret, 1); | |
} | |
void df(double* x, double* ret) | |
{ | |
/* (x1 - 1)^2 + (x2 - 2)^2 */ | |
ret[0] = 2*x[0] - 2; | |
ret[1] = 2*x[1] - 4; | |
} | |
int main(void) { | |
double b[] = { | |
1, | |
2, | |
3, | |
}; | |
double x[] = { | |
3, | |
1, | |
2, | |
}; | |
double err = 0; | |
GMRES_RESULT flag = gmres(N, ax, x, b, tol, &err); | |
printf("flag %d\n", flag); | |
printf("err %e\n", err); | |
for (int i = 0; i < N; ++i) { | |
printf("%f\n", x[i]); | |
} | |
double f0[] = { | |
-2, | |
-2, | |
}; | |
double x0[] = { | |
0, | |
1, | |
}; | |
double dx[] = { | |
0, | |
0, | |
}; | |
flag = fdgmres(N2, df, f0, x0, dx, tol, &err); | |
printf("flag %d\n", flag); | |
printf("err %e\n", err); | |
for (int i = 0; i < N2; ++i) { | |
printf("%f\n", dx[i]); | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment