Last active
July 8, 2020 22:26
-
-
Save lzkelley/716d728a489e62973d0a9426a9ca1aa3 to your computer and use it in GitHub Desktop.
Draw an Arepo mesh and (optional) scalar variable
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def voronoi_mesh(path, snap_num, ndim=2, int_type=np.int32, double_type=np.float32): | |
"""Load an Arepo voronoi mesh from the 2D `voronoi_mesh_###` files. | |
""" | |
# fname = _snap_file_from_path(fname, "voronoi_mesh_{:03d}", snap_num=snap_num) | |
fname = os.path.join(path, "voronoi_mesh_{:03d}".format(snap_num)) | |
mesh = dict() | |
with open(fname, 'rb') as data: | |
# Read header | |
nums = np.fromfile(data, int_type, 3) | |
ngas_tot, nel_tot, ndt_tot = nums | |
mesh['ngas_tot'] = ngas_tot # total number of gas cells | |
mesh['nel_tot'] = nel_tot # total number of edges | |
mesh['ndt_tot'] = ndt_tot # total number of delaunay tetrahedra | |
# Read "edges" (really: vertices) data | |
mesh['nedges'] = np.fromfile(data, int_type, ngas_tot) | |
mesh['nedge_offset'] = np.fromfile(data, int_type, ngas_tot) | |
mesh['edge_list'] = np.fromfile(data, int_type, nel_tot) | |
mesh['xyz_edges'] = np.fromfile(data, double_type, ndt_tot*ndim) | |
# Make sure we've hit the end of the file | |
done = np.fromfile(data, int_type, 1) | |
if len(done) != 0: | |
print("done = '{}'".format(done), done.size, len(done)) | |
raise RuntimeError("Did not reach expected end of file! Wrong sizes for int/float?") | |
return mesh | |
def image_slice(fname, int_type=np.int32, double_type=np.float32, verbose=True): | |
"""Read in the image-slices/projections files produced directly by arepo | |
These are the "<parameter>_slice_###" and "proj_<parameter>_field_###" files. | |
""" | |
with open(fname, 'rb') as data: | |
# Read header | |
nums = np.fromfile(data, int_type, 2) # read two 32 bit integers | |
if verbose: | |
print("num = ", nums) | |
# Read density | |
vals = np.fromfile(data, double_type, np.product(nums)).reshape(nums) | |
return vals | |
def snapshot(path, snap_num=None, part_type=0, pars=None, verbose=False, header=False): | |
if header and snap_num is None: | |
snap_num = 0 | |
elif snap_num is None: | |
raise ValueError("`snap_num` must be provided unless `header` is True!") | |
fname = os.path.join(path, "snap_{:03d}.hdf5".format(snap_num)) | |
# fname = _snap_file_from_path(fname, "snap_{:03d}.hdf5", snap_num=snap_num) | |
single_flag = isinstance(pars, str) | |
part = "PartType{:1d}".format(part_type) | |
with h5py.File(fname) as h5in: | |
keys = list(h5in[part].keys()) | |
if header: | |
head = {kk: vv for kk, vv in h5in['Header'].attrs.items()} | |
params = {kk: vv for kk, vv in h5in['Parameters'].attrs.items()} | |
return keys, head, params | |
if pars is None: | |
pars = keys | |
if verbose: | |
top_keys = h5in.keys() | |
print("File keys: ", list(top_keys)) | |
for kk in top_keys: | |
try: | |
print("\t{} keys:".format(kk), list(h5in[kk].keys())) | |
print("\t{} attrs:".format(kk), list(h5in[kk].attrs.keys())) | |
except: | |
print(kk, "failed") | |
continue | |
print("Particle '{}' keys: ".format(part), keys) | |
data = h5in[part][pars][:] if single_flag else {kk: h5in[part][kk][:] for kk in pars} | |
return data | |
def draw_mesh(ax, mesh, vals=None, fix_poly=True, | |
smap=None, region=None, periodic=None, lines_flag=True, **kwargs): | |
"""Plot Arepo Voronoi Cell's : edges and filled-colors for cell parameters. | |
Each gas cell's edges are drawn if `lines_flag` is True. | |
Each gas cell gets a color-filled polygon if `vals` are provided for each cell. | |
Arguments | |
--------- | |
ax : matplotlib.axes.Axes instance | |
mesh : dict, storing arepo voronoi mesh data | |
vals : values to be plotted for each polygon (gas-cell; e.g. density, internal-energy, etc) | |
For example, this can be provided from the Arepo files: | |
"<parameter>_slice_###" and, | |
"proj_<parameter>_field_###" | |
Or from snapshot data. | |
fix_poly : bool | |
Sometimes cells don't seem to close properly... try to fix that | |
smap : `matplotlib.cm.ScalarMappable` instance specifying colormap | |
region : ndarray (2,2) specifying region to be plotted (for speed) | |
First axis is dimension, second axis is left/right edge | |
e.g. [[0.0, 2.0], [-1.0, 1.0]] means a region spanning x: [0.0, 2.0] and y: [-1.0, 1.0] | |
periodic : None or list | |
Specification of perodic boundary locations for each dimension | |
lines_flag : bool | |
Whether or not lines should be plotted for cell edges | |
(if `False`, then only the fill color is used for `vals` being plotted) | |
**kwargs : keyword arguments | |
Additional arguments passed to `ax.plot` | |
""" | |
if (not lines_flag) and (vals is None): | |
raise ValueError("Nothing is being plotted!") | |
ndt = mesh['ndt_tot'] # total number of delaunay tetrahedra | |
xyz = mesh['xyz_edges'] # called "edges" in arepo, but really vertices of edges | |
edge_list = mesh['edge_list'] | |
nedge_offset = mesh['nedge_offset'] # | |
nedges = mesh['nedges'] # number of edges for each gas cell | |
NDIM = 2 | |
# Maximum number of sides per polygon; determines allocated array size | |
# can be oddly large for some reason | |
NSIDE_MAX = 20 | |
xyz = xyz.reshape(ndt, NDIM) | |
poly = np.zeros((NSIDE_MAX, NDIM)) | |
if lines_flag: | |
lines = np.full((2*len(edge_list), NDIM), np.nan) | |
tot_num = len(nedge_offset) | |
# if the box is periodic, we will need to reflect (duplicate) vertices, and thus need more space | |
mult = 1 if (periodic is None) else 2 | |
if periodic is not None: | |
periodic = [np.array(pp) if pp is not None else pp for pp in periodic] | |
if region is not None: | |
region = np.atleast_2d(region) | |
# If each cell is being colored by some parameter then allocate storage for patches and colors | |
if vals is not None: | |
patches = np.empty(mult*tot_num, dtype=object) | |
colors = np.zeros(mult*tot_num) | |
cnt = 0 | |
valid = np.zeros(mult*tot_num, dtype=bool) | |
def add_cell(ee, ne, poly, cnt, end=False): | |
"""Add an individual gas cell (set of edges) to the collection | |
Arguments | |
--------- | |
ee : int, the index of this cell | |
ne : int, number of "edges" (vertices) | |
poly : ndarray storing polygon vertices | |
cnt : int, total number of vertices stored | |
""" | |
# If we're reflecting points, store the reflections at the end of the array | |
if end: | |
ff = mult*tot_num - 1 - ee | |
# Store normal points in order | |
else: | |
ff = ee | |
# Store lines connecting each vertex together | |
if lines_flag: | |
lines[cnt:cnt+ne, :] = poly[:ne, :] | |
if vals is not None: | |
inc = 0 | |
if fix_poly and np.allclose(poly[0, :], poly[ne-1, :]): | |
ne = ne - 1 | |
inc = 1 | |
# Significantly faster to assemble as list of polygon patches and plot together | |
# instead of plotting one at a time | |
pat = mpl.patches.Polygon(poly[:ne]) | |
patches[ff] = pat | |
# Set the colors array to this cell's index of values | |
colors[ff] = vals[ee] | |
ne = ne + inc | |
valid[ff] = True | |
cnt = cnt + ne + 1 | |
return cnt | |
pers = 0 # count the number of reflected cells for periodic boundaries | |
# Iterate over each cell | |
for ee in tqdm.tqdm(range(tot_num), total=tot_num, leave=False): | |
oo = nedge_offset[ee] # The offset in the edge-list for this cell | |
ne = nedges[ee] # Number of "edges" (vertices) for this cell | |
ll = edge_list[oo] # Index of the first vertex for this cell | |
lo = xyz[ll] # Location of the first vertex | |
poly[0] = lo # Store the first vertex to the polygon array | |
if ne >= NSIDE_MAX: | |
err = "Number of edges for element {} = {}, exceeds max {}!".format(ee, ne, NSIDE_MAX) | |
raise ValueError(err) | |
# Store all of the vertices in the polygon array | |
for ff in range(1, ne): | |
hh = edge_list[oo+ff] # start at this cell's offset, and continue up | |
hi = xyz[hh] | |
poly[ff] = hi | |
# If `region` is given, check whether any vertex is within the specified region | |
if (region is not None) and (not any_within(poly[:ne], region)): | |
continue | |
# Store this cell | |
cnt = add_cell(ee, ne, poly, cnt, end=False) | |
if periodic is None: | |
continue | |
# If this is a periodic box, reflect relevant points | |
for dd in range(NDIM): | |
# Skip non-periodic dimensions | |
if periodic[dd] is None: | |
continue | |
if np.any((poly[:ne, dd] < periodic[dd][0])): | |
dup = np.copy(poly[:ne, :]) | |
dup[:, dd] += (periodic[dd][1] - periodic[dd][0]) | |
cnt = add_cell(ee, ne, dup, cnt, end=True) | |
pers += 1 | |
elif np.any(poly[:ne, dd] > periodic[dd][1]): | |
dup = np.copy(poly[:ne, :]) | |
dup[:, dd] -= (periodic[dd][1] - periodic[dd][0]) | |
cnt = add_cell(ee, ne, dup, cnt, end=True) | |
pers += 1 | |
# if cnt > 1000: | |
# break | |
# Plot filled-polygons for cell values | |
extr = zmath.minmax(colors[valid]) | |
if vals is not None: | |
if smap is None: | |
smap = zplot.smap(extr, cmap='viridis') | |
p = mpl.collections.PatchCollection(patches[valid], cmap=smap.cmap, norm=smap.norm) | |
p.set_array(colors[valid]) | |
ax.add_collection(p) | |
# Plot lines between vertices | |
if lines_flag: | |
lines = lines[:cnt, :].T | |
ax.plot(*lines, **kwargs) | |
return smap, extr | |
def plot_mesh_2d(path, snap_num=0, param='Density', parse_param=None, | |
region=None, periodic=None, smap=None): | |
# Load the arepo mesh data | |
mesh = voronoi_mesh(path, snap_num) | |
# Optionally load an additional scalar parameter from snapshots | |
if param is None: | |
vals = None | |
else: | |
snap = readio.snapshot(path, snap_num=snap_num, pars=[param]) | |
vals = snap[param] | |
# Process the given values by some user-provided function | |
if parse_param is not None: | |
vals = parse_param(vals) | |
# Create figure | |
fig, ax = plt.subplots(figsize=[12, 10]) | |
plt.subplots_adjust(left=0.05, right=0.95, bottom=0.0, top=0.98) | |
# Draw mesh and/or scalar variable | |
smap, extr = draw_mesh( | |
ax, mesh, vals=vals, fix_poly=True, lines_flag=True, | |
lw=0.25, color='k', alpha=0.25, region=region, periodic=periodic, smap=smap) | |
# Plot positions of cell centers | |
# pos = snap['Coordinates'] | |
# ax.scatter(*pos[:, :2].T, color='r', s=2, alpha=0.25) | |
if region is None and periodic is not None: | |
region = periodic | |
if region is not None: | |
ax.set(xlim=region[0], ylim=region[1]) | |
plt.colorbar(smap, ax=ax, orientation='horizontal', pad=0.05) | |
return fig | |
# decorator requires `numba`, e.g. `from numba import njit` | |
@njit() | |
def any_within(poly, bounds): | |
"""Determine if any of the vertices in `poly` are within the `bounds` region | |
""" | |
pnts, dims = np.shape(poly) | |
for ii in range(pnts): | |
test = 0 | |
for jj in range(dims): | |
if ((poly[ii, jj] > bounds[jj, 0]) and (poly[ii, jj] < bounds[jj, 1])): | |
test += 1 | |
if test == dims: | |
return True | |
return False | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment