Skip to content

Instantly share code, notes, and snippets.

@jswhit
Last active April 24, 2023 15:05
Show Gist options
  • Save jswhit/99b1c7b12a8b3eabcf521987424c98c3 to your computer and use it in GitHub Desktop.
Save jswhit/99b1c7b12a8b3eabcf521987424c98c3 to your computer and use it in GitHub Desktop.
test multiscale LETKF solver
"""non-cycled 1d test of LETKF with multiscale localization"""
import numpy as np
from scipy.linalg import eigh, lapack, solve
from scipy.fft import rfft, irfft, rfftfreq
from argparse import ArgumentParser
# set random seed for reproducibility
#np.random.seed(42)
def syminv(C):
# inverse of a square symmetric positive definite matrix
# using eigenanalysis
#evals, eigs = eigh(C)
#evals = evals.clip(min=np.finfo(evals.dtype).eps)
#C_inv = (eigs * (1.0 / evals)).dot(eigs.T)
# using Cholesky decomp
zz, info = lapack.dpotrf(C)
C_inv, info = lapack.dpotri(zz)
# lapack only returns the upper or lower triangular part
C_inv = np.triu(C_inv) + np.triu(C_inv, k=1).T
return C_inv
# using linear solver
#return solve(C, np.identity(C.shape[0]), sym_pos=True)
def gausscov(r,l,w):
return w*np.exp(-(r/l)**2)
def expcov(r,l):
return np.exp(-2.*r/l)
def getdist(i,j):
"""find distances between point i and other points j in 1d periodic domain"""
ndim = len(j)
return np.abs(np.remainder(i-j + ndim/2.,ndim)-ndim/2.)
def gasp_cohn(r):
"""Gaspari-Cohn localization function (goes to zero at r=1)"""
eps = np.finfo(r.dtype).eps
r = (np.abs(2*r)).clip(min=eps)
loc = np.zeros(r.shape, r.dtype)
loc = np.where(r<=1, -0.25*r**5+0.5*r**4+0.625*r**3-5./3.*r**2+1, loc)
loc = np.where(np.logical_and(r > 1.,r <= 2.),
1./12.*r**5-0.5*r**4+0.625*r**3+5./3.*r**2-5.*r+4.-2./3./r, loc)
return loc
# CL args.
parser = ArgumentParser(description='test multiscale LETKF localization')
parser.add_argument('--lscales', type=float, nargs='+',required=True, help='localization scales in grid points')
parser.add_argument('--band_cutoffs', type=float, nargs='+',required=True, help='wavenumber cutoff for each lscale')
parser.add_argument('--cov_param', type=float, default=70, help='true covariance parameter')
parser.add_argument('--verbose', action='store_true', help='verbose output')
parser.add_argument('--l1norm', action='store_true', help='L1 instead of L2 error norm')
parser.add_argument('--nsamples', type=int, default=8, help='ensemble members')
parser.add_argument('--ntrials', type=int, default=100, help='number of trials')
parser.add_argument('--ndim', type=int, default=500, help='number of grid points')
parser.add_argument('--random_seed', type=int, default=0, help='random seed (default is to not set)')
parser.add_argument('--oberrvar', type=float, default=1.0, help='observation error variance')
args = parser.parse_args()
# update local namespace with CL args and values
locals().update(args.__dict__)
if verbose:
print(args)
if random_seed > 0: # set random seed for reproducibility
np.random.seed(random_seed)
# specify covariance and localization matrices
nlscales = len(lscales)
nband_cutoffs = len(band_cutoffs)
if nlscales > 1 and nband_cutoffs != nlscales-1:
print('number of lscales not the same as len(band_cutoffs)+1')
raise SystemExit
local = np.zeros((nlscales,ndim,ndim),np.float64)
cov = np.zeros((ndim,ndim),np.float64) # B variance = 1
# fourier wavenumbers
wavenums = ndim*rfftfreq(ndim)[0 : (ndim // 2) + 1]
# define true cov as sum of gaussians with gaussian weighting
clscales = np.arange(1,ndim//2)
wts = np.exp(-(clscales/cov_param)**2)
wts = wts/wts.sum()
for wt,clscale in zip(wts,clscales):
for i in range(ndim):
dist = getdist(i,np.arange(ndim))
cov[:,i] += gausscov(dist,clscale,wt)
# use exponential approximation
#for i in range(ndim):
# dist = getdist(i,np.arange(ndim))
# cov[:,i] = expcov(dist,cov_param)
#if verbose:
# import matplotlib.pyplot as plt
# plt.plot(np.arange(ndim), cov[ndim//2],color='k')
# plt.title('cov')
# plt.show()
# raise SystemExit
for n,lscale in enumerate(lscales):
for i in range(ndim):
dist = getdist(i,np.arange(ndim))
local[n,:,i] = gasp_cohn(dist/lscale) # gaspari-cohn polynomial (compact support)
local = local.clip(min=np.finfo(local.dtype).eps)
# eigenanalysis of true cov, compute optimal gain matrix
evals, evecs = eigh(cov)
evals = evals.clip(min=np.finfo(evals.dtype).eps)
scaled_evecs = np.dot(evecs, np.diag(np.sqrt(evals)))/np.sqrt(nsamples-1)
kfopt = np.dot(cov, syminv(cov + oberrvar*np.eye(ndim)))
paopt = np.dot((np.eye(ndim) - kfopt), cov)
#import matplotlib.pyplot as plt
#plt.plot(np.arange(ndim), kfopt[ndim//2],color='b')
#plt.plot(np.arange(ndim), pa[ndim//2],color='r')
#plt.show()
#raise SystemExit
if verbose:
print('tr(paopt)/tr(pb) = ',np.trace(paopt)/np.trace(cov))
l1norm = False
if l1norm:
kfopt_frobnorm = np.abs(kfopt).sum()
else:
kfopt_frobnorm = (kfopt**2).sum()
# create square root of localization matrices for each local volume
dist = np.zeros((ndim,ndim),np.float64)
indx = np.zeros((ndim,ndim),bool) # based on longest length scale
for i in range(ndim):
dist[i] = getdist(i,np.arange(ndim))
indx[i] = dist[i] < np.abs(lscales[0])
sqrtlocalloc_lst=[]; neig_lst=[]
for n,lscale in enumerate(lscales):
nlocal = np.zeros(ndim,int)
for i in range(ndim):
localloc = local[n][np.ix_(indx[i],indx[i])]
nlocal = localloc.shape[0]
# symmetric square root of localization (truncated eigenvector expansion)
evalsl, evecsl = eigh(localloc)
for ne in range(1,nlocal):
percentvar = evalsl[-ne:].sum()/evalsl.sum()
if percentvar > 0.99:
neig = ne
break
evecs_norml = (evecsl*np.sqrt(evalsl/percentvar)).T
if not i:
neig_lst.append(neig)
sqrtlocalloc = np.zeros((ndim,neig,nlocal),np.float64)
sqrtlocalloc[i,...] = evecs_norml[nlocal-neig:nlocal,:]
sqrtlocalloc_lst.append(sqrtlocalloc)
nsamples_tot=0
for n in range(nlscales):
nsamples_tot += neig_lst[n]*nsamples
# run trials with different ensembles
meankferr_rloc = 0; meankferr_bloc = 0; meankferr_blocg = 0
kfmean_rloc = np.zeros((ndim,ndim),np.float64)
kfmean_bloc = np.zeros((ndim,ndim),np.float64)
kfmean_blocg = np.zeros((ndim,ndim),np.float64)
bandvar_mean = np.zeros(nlscales, np.float64)
totvar1=0; totvar2=0; totvar3=0
for ntrial in range(ntrials):
# generate ensemble (x is an array of unit normal random numbers)
x = np.random.normal(size=(ndim,nsamples))
x = x - x.mean(axis=1)[:,np.newaxis] # zero mean
# full ensemble
y = np.dot(scaled_evecs,x)
# spectral bandpass filtering (boxcar window).
if nlscales == 1:
yyl=[y]
else:
yyl=[]
yfilt_save = np.zeros_like(y)
yspec = rfft(y,axis=0)
for n,sigma in enumerate(band_cutoffs):
yfiltspec = np.where(wavenums[:,np.newaxis] < sigma, yspec, 0.+0.j)
yfilt = irfft(yfiltspec,axis=0)
yyl.append(yfilt-yfilt_save)
yfilt_save=yfilt
ysum = np.zeros_like(y)
for n in range(nband_cutoffs):
ysum += yyl[n]
yyl.append(y-ysum)
yy = np.asarray(yyl)
bandvar = np.zeros(nlscales, np.float64)
for n in range(nlscales):
bandvar[n] = ((yy[n]**2).sum(axis=-1)/(nsamples-1)).mean()
bandvar_mean += bandvar/nsamples
yyall = yy.sum(axis=0)
totvar1 += ((y**2).sum(axis=-1)/(nsamples-1)).mean()/nsamples
totvar2 += ((yyall**2).sum(axis=-1)/(nsamples-1)).mean()/nsamples
totvar3 += bandvar.sum()/nsamples
#diff = y-yyall
#print(diff.min(), diff.max(),totvar1,totvar2,bandvar.sum())
#continue
kfens_rloc = np.zeros((ndim,ndim),np.float64)
kfens_bloc = np.zeros((ndim,ndim),np.float64)
kfens_blocg = np.zeros((ndim,ndim),np.float64)
# global solve
cov_local = np.zeros((nlscales,ndim,ndim),np.float64)
# no cross covariance (by construction)
for n in range(nlscales):
cov_local[n] = local[n]*np.dot(yy[n], yy[n].T)
hpbhtinv = syminv(cov_local.sum(axis=0) + oberrvar*np.eye(ndim))
for n in range(nlscales):
kfens_blocg += np.dot(cov_local[n], hpbhtinv)
# local solve
for i in range(ndim):
# find local grid points (obs since H=I)
if not ntrial:
# use largest (first) localization scale to define local volume
indx[i] = dist[i] < np.abs(lscales[0])
ylocal_full = y[np.ix_(indx[i],np.ones(nsamples,bool))].T
ylocal = np.zeros((nlscales,)+ylocal_full.shape,ylocal_full.dtype)
for n in range(nlscales):
ylocal[n] = yy[n][np.ix_(indx[i],np.ones(nsamples,bool))].T
# R localization
Yb_sqrtRinv_lst=[]; Yb_Rinv_lst=[]; ylocal_lst=[]
for n in range(nlscales):
taper = local[n,indx[i],i]
Yb_sqrtRinv_lst.append(np.sqrt(taper/oberrvar)*ylocal[n])
Yb_Rinv_lst.append((taper/oberrvar)*ylocal[n])
ylocal_lst.append(yy[n,i,:])
Yb_sqrtRinv = np.vstack(Yb_sqrtRinv_lst)
Yb_Rinv = np.vstack(Yb_Rinv_lst)
ytmp = np.concatenate(ylocal_lst)
pa = np.eye(nsamples*nlscales) + np.dot(Yb_sqrtRinv, Yb_sqrtRinv.T)
painv = syminv(pa); painv_YbRinv = np.dot(painv, Yb_Rinv)
kfens_rloc[indx[i],i] = np.dot(ytmp, painv_YbRinv)
# B loc with modulate ensemble with eigenvectors of 'local' localization matrix.
Yb_sqrtRinv_lst=[]; ylocal_lst=[]
for n,lscale in enumerate(lscales):
neig = neig_lst[n]
sqrtlocalloc = sqrtlocalloc_lst[n]
nlocal = sqrtlocalloc.shape[-1]
nsamples2 = neig*nsamples; nsamp2 = 0
ylocal2 = np.zeros((nsamples2,nlocal),ylocal.dtype)
ylocal = yy[n][np.ix_(indx[i],np.ones(nsamples,bool))].T
for j in range(neig):
for nsamp in range(nsamples):
ylocal2[nsamp2,:] = ylocal[nsamp,:]*sqrtlocalloc[i,neig-j-1,:]
nsamp2 += 1
Yb_sqrtRinv = ylocal2/np.sqrt(oberrvar)
Yb_sqrtRinv_lst.append(Yb_sqrtRinv)
ylocal_lst.append(ylocal2[:,np.argmin(dist[i][indx[i]])])
Yb_sqrtRinv = np.vstack(Yb_sqrtRinv_lst)
ytmp = np.concatenate(ylocal_lst)
painv = syminv(np.eye(nsamples_tot) + np.dot(Yb_sqrtRinv, Yb_sqrtRinv.T))
kfens_bloc[indx[i],i] = np.dot(ytmp,np.dot(painv,Yb_sqrtRinv/np.sqrt(oberrvar)))
# normalized Frobenius norm
diff_rloc = kfens_rloc-kfopt
diff_bloc = kfens_bloc-kfopt
diff_blocg = kfens_blocg-kfopt
if l1norm:
kferr_rloc = np.abs(diff_rloc).sum()
meankferr_rloc += (kferr_rloc/kfopt_frobnorm)/ntrials
kferr_bloc = np.abs(diff_bloc).sum()
meankferr_bloc += (kferr_bloc/kfopt_frobnorm)/ntrials
kferr_blocg = np.abs(diff_blocg).sum()
meankferr_blocg += (kferr_blocg/kfopt_frobnorm)/ntrials
else:
kferr_rloc = (diff_rloc**2).sum()
meankferr_rloc += np.sqrt(kferr_rloc/kfopt_frobnorm)/ntrials
kferr_bloc = (diff_bloc**2).sum()
meankferr_bloc += np.sqrt(kferr_bloc/kfopt_frobnorm)/ntrials
kferr_blocg = (diff_blocg**2).sum()
meankferr_blocg += np.sqrt(kferr_blocg/kfopt_frobnorm)/ntrials
kfmean_rloc += kfens_rloc/ntrials
kfmean_bloc += kfens_bloc/ntrials
kfmean_blocg += kfens_blocg/ntrials
#print(totvar1, totvar2, totvar3)
#print(bandvar_mean)
if verbose:
import matplotlib.pyplot as plt
x = np.linspace(-ndim//2,ndim//2-1,ndim)
plt.plot(x,kfmean_rloc[ndim//2],'r',label='mean est K (Rloc)')
plt.plot(x,kfmean_bloc[ndim//2],'b',label='mean est K (Bloc)')
plt.plot(x,kfopt[ndim//2],'k',label='K true')
plt.xlim(-lscale,lscale)
plt.legend()
plt.savefig('meangain.png')
plt.show()
# print out mean error
print("lscale = %s Kerr_localRloc = %s Kerr_localBloc = %s Kerr_globalBloc = %s" %\
(lscales[0],meankferr_rloc,meankferr_bloc,meankferr_blocg))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment