Skip to content

Instantly share code, notes, and snippets.

@sonesuke
Created October 9, 2017 12:49
Show Gist options
  • Save sonesuke/9a7756b4dc68ab5a27c26c43f4c22c49 to your computer and use it in GitHub Desktop.
Save sonesuke/9a7756b4dc68ab5a27c26c43f4c22c49 to your computer and use it in GitHub Desktop.
gmres
#include <cblas.h>
#include <math.h>
#include <memory.h>
typedef enum {GMRES_SUCCESS, GMRES_NOT_CONVERGENCE} GMRES_RESULT;
typedef void (*axfunc)(double*, double*);
/* Compute the Givens rotation matrix parameters for a and b. */
void rotmat(double *c, double *s, double a, double b)
{
if(b == 0.0) {
*c = 1.0;
*s = 0.0;
} else if(fabs(b) > fabs(a)) {
double temp = a / b;
*s = 1.0 / sqrt(1.0 + pow(temp, 2));
*c = temp * (*s);
} else {
double temp = b / a;
*c = 1.0 / sqrt(1.0 + pow(temp, 2));
*s = temp * (*c);
}
}
/*
gmres solves the linear system Ax=b
using the Generalized Minimal residual ( GMRESm ) method with restarts .
input axfunc a function to calculate A*x
x initial guess vector
b right hand side vector
tol error tolerance
output x solution vector
error error norm
flag GMRES_RESULT:
GMRES_SUCCESS = solution found to tolerance
GMRES_NOT_CONVERGENCE = no convergence
*/
GMRES_RESULT gmres(const int N, axfunc f, double *x, double *b, double tol ,double* err)
{
/* working memory */
static double work[1000000];
static int next = 0;
memset((void*)work, 0, 100000 * sizeof(double));
/* initialize */
GMRES_RESULT flag = GMRES_SUCCESS;
double bnrm2 = cblas_dnrm2(N, b, 1);
if(bnrm2 == 0.0) {
bnrm2 = 1.0;
}
/* r = b - Ax */
double *r = work; next += N; /* N */
f(x, r);
cblas_daxpy(N, -1, b, 1, r, 1);
cblas_dscal(N, -1, r, 1);
*err = cblas_dnrm2(N, r, 1) / bnrm2;
if ( *err < tol ) {
return 0;
}
/* initial workspace */
double *V = work + next; next += N * (N+1); /* N x N+1 */
double *H = work + next; next += (N+1) * N; /* N+1 x N */
double *cs = work + next; next += N; /* N */
double *sn = work + next; next += N; /* N */
double *e1 = work + next; next += N; /* N */
e1[0] = 1.0;
double *s = work + next; next += N; /* N */
double *w = work + next; next += N; /* N */
double *y = work + next; next += N; /* N */
double *vi = work + next; next += N;
/* V(:,1) = r / norm( r ) */;
double rnrm = cblas_dnrm2(N, r, 1);
cblas_daxpy(N, 1.0 /rnrm , r, 1, V, N+1);
/* s = norm(r) * e1 */
cblas_daxpy(N, rnrm , e1, 1, s, 1);
/* construct orthonormal */
int i = 0;
for(i = 0; i < N; ++i) {
/* basis using Gram-Schmidt */
/* w = A*V(:,i) */
cblas_dcopy(N, &(V[i]), N+1, vi, 1);
f(vi, w);
for(int k = 0; k < i + 1; ++k) {
H[k*N + i] = cblas_ddot(N, w, 1, &(V[k]), N+1);
cblas_daxpy(N, -H[k*N + i], &(V[k]), N+1, w, 1);
}
/* H(i+1,i) = norm( w ) */
H[(i+1)*N + i] = cblas_dnrm2(N, w, 1);
/* V(:,i+1) = w / H(i+1,i) */
cblas_daxpy(N, 1 / H[(i+1)*N + i], w, 1, &(V[i+1]), N+1);
/* apply Givens rotation */
for (int k = 0; k < i; ++k) {
double temp1= cs[k]*H[k*N + i] + sn[k]*H[(k+1)*N + i];
H[(k+1)*N + i] = -sn[k]*H[k*N+i] + cs[k]*H[(k+1)*N + i];
H[k*N + i] = temp1;
}
/* form i-th rotation matrix */
rotmat(&(cs[i]), &(sn[i]), H[i*N+i], H[(i+1)*N+i]);
/* approximate residual norm */
double temp4 = cs[i] * s[i];
s[i+1] = -sn[i] * s[i];
s[i] = temp4;
H[i*N + i] = cs[i]*H[i*N + i] + sn[i]*H[(i+1)*N + i];
H[(i+1)*N + i] = 0;
*err = fabs(s[i+1]) / bnrm2;
if (*err <= tol) {
/* update approximation */
/* and exit */
cblas_dcopy(i+1, s, 1, y, 1);
cblas_dtrsv(CblasRowMajor, CblasUpper, CblasNoTrans, CblasNonUnit, i+1, H, N, y, 1);
cblas_dgemv(CblasRowMajor, CblasNoTrans, N, i+1, 1, V, N+1, y, 1, 1, x, 1);
break;
}
}
if( *err <= tol) {
return flag;
}
/* update approximation */
/* y = H(1:m,1:m) \ s(1:m); */
/* x = x + V(:,1:m)*y; */
cblas_dcopy(N, s, 1, y, 1);
cblas_dtrsv(CblasRowMajor, CblasUpper, CblasNoTrans, CblasNonUnit, N, H, N, y, 1);
cblas_dgemv(CblasRowMajor, CblasNoTrans, N, N, 1, V, N+1, y, 1, 1, x, 1);
/* check convergence */
f(x, r);
cblas_daxpy(N, -1, b, 1, r, 1);
cblas_dscal(N, -1, r, 1);
/* check convergence */
s[i+1] = cblas_dnrm2(N, r, 1);
*err = s[i+1] / bnrm2;
if (*err > tol) {
flag = GMRES_NOT_CONVERGENCE;
}
return flag;
}
typedef double (*ffunc)(double*, double*);
void fowarddiff(const int N, ffunc f, double* f0, double* x0, double* x, double* dfx)
{
/* initialize work */
static double work[100];
memset((void*)work, 0, 100 * sizeof(double));
int next = 0;
const double h = 1e-6;
double* x0_plus_h_x = work; next += N;
double* farg = work; next += N;
double* ff = work; next += N;
cblas_dcopy(N, x0, 1, farg, 1);
cblas_daxpy(N, h, x, 1, farg, 1);
f(farg, ff);
cblas_daxpy(N, -1, f0, 1, ff, 1);
cblas_dscal(N, 1/h, ff, 1);
cblas_dcopy(N, ff, 1, dfx, 1);
}
/*
gmres solves the linear system Ax=b
using the Generalized Minimal residual ( GMRESm ) method with restarts .
input axfunc a function to calculate A*x
x initial guess vector
b right hand side vector
tol error tolerance
output x solution vector
error error norm
flag GMRES_RESULT:
GMRES_SUCCESS = solution found to tolerance
GMRES_NOT_CONVERGENCE = no convergence
*/
GMRES_RESULT fdgmres(const int N, axfunc f, double *x, double *b, double tol ,double* err)
{
/* working memory */
static double work[1000000];
static int next = 0;
memset((void*)work, 0, 100000 * sizeof(double));
/* initialize */
GMRES_RESULT flag = GMRES_SUCCESS;
double bnrm2 = cblas_dnrm2(N, b, 1);
if(bnrm2 == 0.0) {
bnrm2 = 1.0;
}
/* r = b - Ax */
double *r = work; next += N; /* N */
f(x, r);
cblas_daxpy(N, -1, b, 1, r, 1);
cblas_dscal(N, -1, r, 1);
*err = cblas_dnrm2(N, r, 1) / bnrm2;
if ( *err < tol ) {
return 0;
}
/* initial workspace */
double *V = work + next; next += N * (N+1); /* N x N+1 */
double *H = work + next; next += (N+1) * N; /* N+1 x N */
double *cs = work + next; next += N; /* N */
double *sn = work + next; next += N; /* N */
double *e1 = work + next; next += N; /* N */
e1[0] = 1.0;
double *s = work + next; next += N; /* N */
double *w = work + next; next += N; /* N */
double *y = work + next; next += N; /* N */
double *vi = work + next; next += N;
/* V(:,1) = r / norm( r ) */;
double rnrm = cblas_dnrm2(N, r, 1);
cblas_daxpy(N, 1.0 /rnrm , r, 1, V, N+1);
/* s = norm(r) * e1 */
cblas_daxpy(N, rnrm , e1, 1, s, 1);
/* construct orthonormal */
int i = 0;
for(i = 0; i < N; ++i) {
/* basis using Gram-Schmidt */
/* w = A*V(:,i) */
cblas_dcopy(N, &(V[i]), N+1, vi, 1);
f(vi, w);
for(int k = 0; k < i + 1; ++k) {
H[k*N + i] = cblas_ddot(N, w, 1, &(V[k]), N+1);
cblas_daxpy(N, -H[k*N + i], &(V[k]), N+1, w, 1);
}
/* H(i+1,i) = norm( w ) */
H[(i+1)*N + i] = cblas_dnrm2(N, w, 1);
/* V(:,i+1) = w / H(i+1,i) */
cblas_daxpy(N, 1 / H[(i+1)*N + i], w, 1, &(V[i+1]), N+1);
/* apply Givens rotation */
for (int k = 0; k < i; ++k) {
double temp1= cs[k]*H[k*N + i] + sn[k]*H[(k+1)*N + i];
H[(k+1)*N + i] = -sn[k]*H[k*N+i] + cs[k]*H[(k+1)*N + i];
H[k*N + i] = temp1;
}
/* form i-th rotation matrix */
rotmat(&(cs[i]), &(sn[i]), H[i*N+i], H[(i+1)*N+i]);
/* approximate residual norm */
double temp4 = cs[i] * s[i];
s[i+1] = -sn[i] * s[i];
s[i] = temp4;
H[i*N + i] = cs[i]*H[i*N + i] + sn[i]*H[(i+1)*N + i];
H[(i+1)*N + i] = 0;
*err = fabs(s[i+1]) / bnrm2;
if (*err <= tol) {
/* update approximation */
/* and exit */
cblas_dcopy(i+1, s, 1, y, 1);
cblas_dtrsv(CblasRowMajor, CblasUpper, CblasNoTrans, CblasNonUnit, i+1, H, N, y, 1);
cblas_dgemv(CblasRowMajor, CblasNoTrans, N, i+1, 1, V, N+1, y, 1, 1, x, 1);
break;
}
}
if( *err <= tol) {
return flag;
}
/* update approximation */
/* y = H(1:m,1:m) \ s(1:m); */
/* x = x + V(:,1:m)*y; */
cblas_dcopy(N, s, 1, y, 1);
cblas_dtrsv(CblasRowMajor, CblasUpper, CblasNoTrans, CblasNonUnit, N, H, N, y, 1);
cblas_dgemv(CblasRowMajor, CblasNoTrans, N, N, 1, V, N+1, y, 1, 1, x, 1);
/* check convergence */
f(x, r);
cblas_daxpy(N, -1, b, 1, r, 1);
cblas_dscal(N, -1, r, 1);
/* check convergence */
s[i+1] = cblas_dnrm2(N, r, 1);
*err = s[i+1] / bnrm2;
if (*err > tol) {
flag = GMRES_NOT_CONVERGENCE;
}
return flag;
}
#include <stdio.h>
#define N 3
#define tol 1e-1
void ax(double *x, double *ret)
{
double A[] = {
1.0, 2, 1,
1, 1.0, 1,
2, 0, 1.0,
};
cblas_dgemv(CblasRowMajor, CblasNoTrans, N, N, 1, A, N, x, 1, 0, ret, 1);
}
int main(void) {
double b[] = {
1,
2,
3,
};
double x[] = {
3,
1,
2,
};
double err = 0;
GMRES_RESULT flag = gmres(N, ax, x, b, tol, &err);
printf("flag %d\n", flag);
printf("err %e\n", err);
for (int i = 0; i < N; ++i) {
printf("%f\n", x[i]);
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment