Skip to content

Instantly share code, notes, and snippets.

@jbeezley
Created April 19, 2011 02:42
Show Gist options
  • Save jbeezley/926712 to your computer and use it in GitHub Desktop.
Save jbeezley/926712 to your computer and use it in GitHub Desktop.
A python script that imports WRF-Fire output files into a mayavi2 session
#!/usr/bin/env python
'''
WRF2mayavi.py
Jonathan Beezley
April 18, 2011
This is a python script to import WRF-Fire output files into a mayavi2
session. To run it, you will need to have mayavi2 installed, which
requires a number of dependencies. See your package manager or the
Enthought Python Distribution. Failing that 'easy_install mayavi' may
work. The code here is very rough, but seems to work to automate the
most difficult part of using mayavi with WRF netcdf files.
Common usage:
python WRF2mayavi.py -t 0 -v FGRNHFX,GRNHFX,NFUEL_CAT -w surface_wind:UF:VF,atm_wind:U:V:W wrfout_d01*
Try `python WRF2mayavi.py -h` for a complete list of options.
For future work:
1. Make surface vector fields parallel to the surface rather than horizontal
2. Allow headless output to a (or series of) VTK files that can be loaded onto
a mayavi session on a different computer.
3. Do whatever is necessary to create animations from multiple frames.
'''
from netCDF4 import Dataset
from scipy.ndimage.interpolation import map_coordinates
import numpy as np
import optparse
scale_degree=1.
surface_point_grid=None
fire_point_grid=None
atm_point_grid=None
reduced_point_grid=[]
seperate_data=True
class ncFileDef(object):
def __init__(self,filename,tdim=0):
self.filename=filename
self.tdim=tdim
self.nc=Dataset(filename,'r')
def get_var(self,varname):
if isinstance(varname,ncVectorDef):
varname=str(varname.vars[0])
v=self.nc.variables[str(varname)]
return v
def read(self,varname):
v=self.get_var(varname)
d=self.nc.dimensions[v.dimensions[0]]
if d.isunlimited():
a=v[self.tdim,:].squeeze()
else:
a=v[:]
return a
def setTime(self,time):
self.tdim=time
class ncWRFFile(ncFileDef):
def srxsry(self):
srx=1
sry=1
try:
x=self.nc.dimensions['west_east']
y=self.nc.dimensions['south_north']
sx=self.nc.dimensions['west_east_subgrid']
sy=self.nc.dimensions['south_north_subgrid']
srx=(len(sx))/(len(x)+1)
sry=(len(sy))/(len(y)+1)
except:
pass
return srx,sry
def is_subgrid(self,v):
v=self.get_var(v)
if v.dimensions[-1] == 'west_east_subgrid':
return True
else:
return False
def stag(self,v):
v=self.get_var(v)
s=''
if v.dimensions[-1][-5:] == '_stag':
s=s+'X'
elif v.dimensions[-2][-5:] == '_stag':
s=s+'Y'
elif len(v.dimensions) > 3 and v.dimensions[-3][-5:] == '_stag':
s=s+'Z'
return s
def is_surf(self,v):
v=self.get_var(v)
if len(v.dimensions) == 3:
return True
else:
return False
def read(self,varname):
v=self.get_var(varname)
a=ncFileDef.read(self,varname)
if self.is_subgrid(varname):
srx,sry=self.srxsry()
a=a[...,:-srx,:-sry]
stag=self.stag(varname)
if 'X' in stag:
a=(a[...,1:]+a[...,:-1])/2.
if 'Y' in stag:
a=(a[...,1:,:]+a[...,:-1,:])/2.
if 'Z' in stag:
a=(a[...,1:,:,:]+a[...,:-1,:,:])/2.
return a
class ncVariableDef(object):
def __init__(self,varname,coords=None):
self.coords=coords
self.varname=varname
def __str__(self):
return str(self.varname)
def read(self,file):
return file.read(str(self))
def getCoords(self,file):
return self.coords.read(file)
class WRFVariable(ncVariableDef):
def __init__(self,varname,coords=None):
ncVariableDef.__init__(self,varname)
def getCoords(self,file):
sub=file.is_subgrid(self.varname)
surf=file.is_surf(self.varname)
if not self.coords:
self.coords=WRFCoordinateDef(surf,sub)
c=ncVariableDef.getCoords(self,file)
if surf and len(c[2].shape) == 3:
c[2]=c[2][0,:,:].squeeze()
if surf:
c[0]=np.reshape(c[0],(1,)+c[0].shape)
c[1]=np.reshape(c[1],(1,)+c[1].shape)
c[2]=np.reshape(c[2],(1,)+c[2].shape)
else:
c[0]=np.repeat(c[0].reshape((1,)+c[0].shape),c[2].shape[0],0)
c[1]=np.repeat(c[1].reshape((1,)+c[1].shape),c[2].shape[0],0)
return c
class ncCoordinateDef(object):
def __init__(self,vars):
if len(vars) != 2 or len(vars) != 3:
raise Exception('Coordinates must be 2 or 3 dimensional')
self.vars=[]
for v in vars:
self.vars.append(ncVariableDef(v,None))
def read(self,file):
a=[]
shp=[]
for v in self.vars:
a.append(v.read(file))
shp.append(len(a[-1].shape))
return a
class WRFVerticalCoord(object):
def read(self,file):
ph=ncVariableDef('PH',None)
phb=ncVariableDef('PHB',None)
aph=ph.read(file)
aphb=phb.read(file)
return (aph+aphb)/(9.81)
class WRFHorizontalCoord(object):
def __init__(self,dim):
self.dim=dim
def dxdy(self,file):
dx=file.nc.DX*scale_degree
dy=file.nc.DY*scale_degree
return dx,dy
def nxny(self,file):
nx=len(file.nc.dimensions['west_east'])
ny=len(file.nc.dimensions['south_north'])
return nx,ny
def read(self,file):
nx,ny=self.nxny(file)
dx,dy=self.dxdy(file)
x=np.linspace(dx/2,nx*dx-dx/2,nx)
y=np.linspace(dy/2,ny*dy-dy/2,ny)
X,Y=np.meshgrid(x,y)
if self.dim == 'x':
return X
elif self.dim == 'y':
return Y
else:
return None
class WRFFireCoord(WRFHorizontalCoord):
def dxdy(self,file):
dx,dy=WRFHorizontalCoord.dxdy(self,file)
srx,sry=file.srxsry()
return dx/srx,dy/sry
def nxny(self,file):
nx,ny=WRFHorizontalCoord.nxny(self,file)
srx,sry=file.srxsry()
return nx*srx,ny*sry
class ncVectorDef(object):
def __init__(self,vecname,vars):
if len(vars) != 2 and len(vars) != 3:
raise Exception('Only 2 or 3 dimensional vectors are supported')
self.vecname=vecname
self.vars=vars
def __str__(self):
return str(self.vecname)
def read(self,file):
a=[]
for v in self.vars:
a.append(v.read(file))
return a
def getCoords(self,file):
return self.vars[0].getCoords(file)
class WRFVector(ncVectorDef):
pass
class WRFCoordinateDef(ncCoordinateDef):
def __init__(self,surf=False,fire=False,vars=[]):
if vars:
ncCoordinateDef.__init__(self,vars)
else:
if not surf and fire:
raise Exception('non surface fire grid variables not supported')
if surf and fire:
self.vars=[WRFFireCoord('x'),
WRFFireCoord('y'),
ncVariableDef('ZSF')]
elif surf and not fire:
self.vars=[WRFHorizontalCoord('x'),
WRFHorizontalCoord('y'),
ncVariableDef('HGT')]
if not surf:
self.vars=[WRFHorizontalCoord('x'),
WRFHorizontalCoord('y'),
WRFVerticalCoord()]
def reduceArray(alist,n,**kwds):
try:
a=alist[0]
nolist=False
except TypeError:
a=alist
alist=[a]
nolist=True
if len(a.shape) == 2:
try:
len(n)
except TypeError:
n=[n]*2
c=np.mgrid[0:a.shape[0]-1:n[0]*1j,0:a.shape[1]-1:n[1]*1j]
elif len(a.shape) == 3:
try:
len(n)
except TypeError:
n=[n]*3
c=np.mgrid[0:a.shape[0]-1:n[0]*1j,0:a.shape[1]-1:n[1]*1j,0:a.shape[2]-1:n[2]*1j]
else:
raise Exception('%id arrays not supported'%len(a.shape))
o=[]
for a in alist:
o.append(map_coordinates(a,c,**kwds))
if nolist:
o=o[0]
return o
def getStructuredGrid(file,var,reduce=None):
global surface_point_grid,atm_point_grid,fire_point_grid,reduced_point_grid
from enthought.tvtk.api import tvtk
a=var.read(file)
c=var.getCoords(file)
if isinstance(var,WRFVector):
aa=a+c
else:
aa=[a]+c
if reduce:
aa=reduceArray(aa,reduce)
c=aa[-len(c):]
for i in range(len(aa)):
if len(aa[i].shape) == 2:
aa[i]=aa[i].reshape((1,)+aa[i].shape)
cc=np.ndarray(aa[0].shape+(3,),dtype=float)
#print cc.shape
cc[...,0]=c[0]*scale_degree
cc[...,1]=c[1]*scale_degree
cc[...,2]=c[2]
cc=cc.transpose(2,1,0,3).copy()
cc.shape=cc.size/3,3
dims=[c[0].shape[i] for i in range(len(c[0].shape))]
#print dims,cc.shape,c[2].transpose(2,1,0).shape,aa[0].shape
sgrid=tvtk.StructuredGrid(dimensions=dims,points=cc)
if not seperate_data:
if reduce:
reduced_point_grid.append(sgrid)
elif file.is_subgrid(var):
if fire_point_grid:
sgrid=fire_point_grid
else:
fire_point_grid=sgrid
elif file.is_surf(var):
if surface_point_grid:
sgrid=surface_point_grid
else:
surface_point_grid=sgrid
else:
if atm_point_grid:
sgrid=atm_point_grid
else:
atm_point_grid=sgrid
if isinstance(var,WRFVariable):
id=sgrid.point_data.add_array(aa[0].transpose(2,1,0).copy().ravel())
sgrid.point_data.get_array(id).name=str(var)
elif isinstance(var,WRFVector):
v=np.ndarray(aa[0].shape+(3,))
for i in range(len(a)):
v[...,i]=aa[i]
if len(a) == 2:
v[...,2]=0
v=v.transpose(2,1,0,3).copy()
v.shape=v.size/3,3
id=sgrid.point_data.add_array(v)
sgrid.point_data.get_array(id).name=str(var)
if seperate_data or reduce:
reduced_point_grid.append((sgrid,str(var)))
return sgrid
def get_all_point_grids():
g=[]
if atm_point_grid:
g.append((atm_point_grid,"atmospheric"))
if surface_point_grid:
g.append((surface_point_grid,"surface"))
if fire_point_grid:
g.append((fire_point_grid,"fire"))
g.extend(reduced_point_grid)
return g
def parse_time(time):
time=[int(t) for t in time.split(',')]
return time
if __name__ == '__main__':
parser=optparse.OptionParser(usage='usage: %prog [options] filename')
parser.add_option('-v','--variables',action="store",type="string",dest='variables',
default='',help='A comma separated list of variables')
parser.add_option('-w','--vectors',action='store',type='string',dest='vectors',
default='',help='A comma separated list of vectors (i.e. -w wind1:U:V:W,wind2:UF:VF)')
parser.add_option('-t','--time',action="store",type="string",dest="time",
default="0",help='The time slice to convert [default 0]')
parser.add_option('-o','--output',action="store_true",dest="output",
default=False,help='Output to vts files only, don''t open a mayavi window')
parser.add_option('-n','--nocompress',action='store_true',dest='compress',
default=False,help='Do not save as a compressed file')
#parser.add_option('-s','--scale',action="store",type="float",dest="scale",
# default=scale_degree,help="vertical to horizontal scaling factor")
(options,args)= parser.parse_args()
if len(args) < 1:
parser.print_help()
raise Exception('NetCDF file name required.')
outfile=None
if options.output:
outfile=True
filename=args[0]
file=ncWRFFile(filename,parse_time(options.time))
#scale_degree=options.scale
vars=[]
if options.variables:
vars.extend([WRFVariable(v.strip()) for v in options.variables.split(',')])
if options.vectors:
for vec in options.vectors.split(','):
vec=vec.split(':')
vecname=vec[0]
vvars=[WRFVariable(v.strip()) for v in vec[1:]]
vars.append(WRFVector(vecname.strip(),vvars))
if not outfile:
from enthought.mayavi import mlab
mlab.options.backend = 'envisage'
g=[]
for v in vars:
g.append(getStructuredGrid(file,v))
if outfile:
from enthought.tvtk.api import tvtk
writers=[]
for grid,name in get_all_point_grids():
w=tvtk.XMLStructuredGridWriter(file_name=name+'.vts',input=grid)
if not options.compress:
w.compressor=tvtk.ZLibDataCompressor()
else:
w.compressor=None
w.write()
writers.append(w)
else:
for a in get_all_point_grids():
mlab.pipeline.add_dataset(a[0],name=a[1])
mlab.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment