Created
February 19, 2016 04:06
-
-
Save dwinston/f3e131b1655d1f306bf9 to your computer and use it in GitHub Desktop.
generate pairs of aligned-vertical-axis BS and DOS images from MP structures and persist to gridfs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
# | |
# Electronic structure (es) image builder: | |
# | |
# Build static images (*.png) for plots of bandstructure (BS) and | |
# density-of-states (DOS) data for materials, storing them in a GridFS | |
# filesystem (e.g. 'es_plot.files' and 'es_plot.chunks' collections) on the | |
# same db as that of the source 'materials'/'electronic_structure' collections. | |
from __future__ import division, unicode_literals, print_function | |
import datetime | |
import json | |
import logging | |
import matplotlib | |
import os | |
import StringIO | |
import traceback | |
import warnings | |
import gridfs | |
import numpy as np | |
import pymongo | |
from matgendb.builders import core, util | |
from matgendb.query_engine import QueryEngine | |
from pymatgen.electronic_structure.bandstructure import BandStructureSymmLine | |
from pymatgen.electronic_structure.dos import CompleteDos | |
from pymatgen.electronic_structure.core import Spin | |
from pymatgen.electronic_structure.plotter import BSPlotter, DosPlotter | |
_log = util.get_builder_log("bs_dos_img") | |
_name = 'ESImagesBuilder' | |
_file_name = '_'.join([ | |
_name, datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f") + '.log']) | |
hdlr = logging.FileHandler(os.path.join(os.environ['HOME'], 'logs', _file_name)) | |
formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s') | |
hdlr.setFormatter(formatter) | |
_log.addHandler(hdlr) | |
_log.setLevel(logging.DEBUG) | |
_log.propagate = False # Only log to file, not to console | |
matplotlib.use('Agg') | |
def _no_timeout(): | |
""" | |
Pass to collection.find so that cursor does not time out on server | |
out on server after 10 minutes of inactivity. This is needed because | |
the builder will batch-process 10,000 input documents, which takes | |
longer than 10 minutes, and there are >30,000 input documents to | |
fetch and process. | |
""" | |
if pymongo.version.startswith('3'): | |
return {'no_cursor_timeout': True} | |
else: | |
return {'timeout': False} | |
class Builder(core.Builder): | |
""" Generate png images of bandstructures and put in GridFS collection. | |
""" | |
def __init__(self, *args, **kwargs): | |
self._target_coll = None | |
core.Builder.__init__(self, *args, **kwargs) | |
warnings.filterwarnings("ignore", | |
"tight_layout : falling back to Agg renderer") | |
def get_items(self, source=None, fs_cname='es_plot', crit=None, skip=None): | |
"""Put plots in GridFS in same db as source collection. | |
:param source: Input materials collection | |
:type source: QueryEngine | |
:param fs_cname: GridFS collection name | |
:type fs_cname: string | |
:param crit: Filter criteria, e.g. "{ 'flag': True }". | |
:type crit: dict | |
:param skip: Name of file with JSON list of `task_id`s to skip. | |
:type skip: str | |
""" | |
self._skip = frozenset() | |
if skip: | |
with open(skip) as f: | |
self._skip = frozenset(json.load(f)) | |
self._fs = gridfs.GridFS(source.db, collection=fs_cname) | |
fs_files = getattr(source.db, fs_cname + '.files') | |
fs_files.ensure_index("filename") | |
self._es_coll = source.db.electronic_structure | |
if not crit: # reduce any False-y crit value to an empty dict | |
crit = {} | |
incl_bs_and_dos = {'dos': {'$exists': True}, | |
'band_structure': {'$exists': True}} | |
incl_bs_and_dos.update(crit) | |
self._cursor = source.collection.find( | |
incl_bs_and_dos, ['task_id','band_structure','dos'], | |
**_no_timeout()) | |
_log.info("get_items: source={} crit={} count={:d}" | |
.format(source.collection, incl_bs_and_dos, | |
self._cursor.count())) | |
return self._cursor | |
def finalize(self, had_errors): | |
# Explicitly close cursor because we told it to not time out. | |
self._cursor.close() | |
return True | |
def process_item(self, item): | |
try: | |
if item['task_id'] not in self._skip: | |
_log.debug('processing {}'.format(item['task_id'])) | |
self._process_item(item) | |
except Exception as e: | |
_log.error("Failed to process {} because {}. Traceback: {}".format( | |
item['task_id'], str(e), traceback.format_exc())) | |
def _process_item(self, item): | |
# Get bandstructure data, generate and save plot | |
bs_fk = item['band_structure'] # holds foreign keys (fk) | |
which = 'GGA+U' if bs_fk.get('GGA+U') else 'GGA' | |
doc = self._es_coll.find_one({'_id': bs_fk[which]['oid']}) | |
mid = doc['material_id'] | |
filename = 'bs_{}.png'.format(mid) | |
redo = self._fs.exists({'filename': filename}) | |
if redo: | |
_log.info("redoing bs for {}".format(mid)) | |
bs = BandStructureSymmLine.from_dict(doc) | |
bs.bz_symmetry = doc.get('bz_symmetry', {}) | |
BSPlotter = WebBSPlotter#electronic_structure_plotter.BSPlotter | |
fig = BSPlotter(bs).get_plot() | |
ylim = fig.ylim() # Used by DOS plot | |
imgdata = StringIO.StringIO() | |
fig.savefig(imgdata, format='png', dpi=100) | |
_id = self._fs.put(imgdata.getvalue(), filename=filename, | |
content_type="image/png") | |
if redo: | |
# Write the new version first (above) before deleting older | |
# versions, in an attempt to avoid concurrent reads of a file while | |
# it is being deleted. | |
for grid_out in self._fs.find({"filename": filename}): | |
if grid_out._id != _id: | |
self._fs.delete(grid_out._id) | |
imgdata.close(); fig.close() | |
# Get DOS data, generate and save plot | |
dos_fk = item['dos'] | |
which = 'GGA+U' if dos_fk.get('GGA+U') else 'GGA' | |
doc = self._es_coll.find_one({'_id': dos_fk[which]['oid']}) | |
mid = doc['material_id'] | |
filename = 'dos_{}.png'.format(mid) | |
redo = self._fs.exists({'filename': filename}) | |
if redo: | |
_log.info("redoing dos for {}".format(mid)) | |
dos = CompleteDos.from_dict(doc) | |
dos_plotter = WebDosPlotter()#electronic_structure_plotter.DosPlotter() | |
dos_plotter.add_dos_dict(dos.get_element_dos()) | |
fig = dos_plotter.get_plot_vertical(ylim=ylim, handle_only=True) | |
try: | |
# ylim for DOS plot should match that of BS plot | |
fig = dos_plotter.get_plot_vertical(ylim=ylim, plt=fig) | |
except ValueError as e: | |
# I expect 'x and y must have the same first dimension' for at | |
# least 200 materials, and I expect 'truth value of an array with | |
# more than one element is ambiguous' for at least 1000 materials, | |
# until the database is repaired. | |
_log.error("ValueError for {}: {}".format(mid, str(e))) | |
fig.close() | |
return | |
imgdata = StringIO.StringIO() | |
fig.savefig(imgdata, format='png', dpi=100) | |
_id = self._fs.put(imgdata.getvalue(), filename=filename, | |
content_type="image/png") | |
if redo: | |
for grid_out in self._fs.find({"filename": filename}): | |
if grid_out._id != _id: | |
self._fs.delete(grid_out._id) | |
imgdata.close(); fig.close() | |
# | |
# Obtain web-friendly images by subclassing pymatgen plotters. | |
# | |
class WebBSPlotter(BSPlotter): | |
def get_plot(self, zero_to_efermi=True, ylim=None, smooth=False): | |
""" | |
get a matplotlib object for the bandstructure plot. | |
Blue lines are up spin, red lines are down | |
spin. | |
Args: | |
zero_to_efermi: Automatically subtract off the Fermi energy from | |
the eigenvalues and plot (E-Ef). | |
ylim: Specify the y-axis (energy) limits; by default None let | |
the code choose. It is vbm-4 and cbm+4 if insulator | |
efermi-10 and efermi+10 if metal | |
smooth: interpolates the bands by a spline cubic | |
""" | |
from pymatgen.util.plotting_utils import get_publication_quality_plot | |
plt = get_publication_quality_plot(6, 5.5) # Was 12, 8 | |
from matplotlib import rc | |
import scipy.interpolate as scint | |
rc('text', usetex=True) | |
width = 4 | |
ticksize = int(width * 2.5) | |
axes = plt.gca() | |
axes.set_title(axes.get_title(), size=width * 4) | |
labelsize = int(width * 3) | |
axes.set_xlabel(axes.get_xlabel(), size=labelsize) | |
axes.set_ylabel(axes.get_ylabel(), size=labelsize) | |
plt.xticks(fontsize=ticksize) | |
plt.yticks(fontsize=ticksize) | |
for axis in ['top','bottom','left','right']: | |
axes.spines[axis].set_linewidth(0.5) | |
#main internal config options | |
e_min = -4 | |
e_max = 4 | |
if self._bs.is_metal(): | |
e_min = -10 | |
e_max = 10 | |
band_linewidth = 1 # Was 3 in pymatgen | |
data = self.bs_plot_data(zero_to_efermi) | |
if not smooth: | |
for d in range(len(data['distances'])): | |
for i in range(self._nb_bands): | |
plt.plot(data['distances'][d], | |
[data['energy'][d][str(Spin.up)][i][j] | |
for j in range(len(data['distances'][d]))], 'b-', | |
linewidth=band_linewidth) | |
if self._bs.is_spin_polarized: | |
plt.plot(data['distances'][d], | |
[data['energy'][d][str(Spin.down)][i][j] | |
for j in range(len(data['distances'][d]))], | |
'r--', linewidth=band_linewidth) | |
else: | |
for d in range(len(data['distances'])): | |
for i in range(self._nb_bands): | |
tck = scint.splrep( | |
data['distances'][d], | |
[data['energy'][d][str(Spin.up)][i][j] | |
for j in range(len(data['distances'][d]))]) | |
step = (data['distances'][d][-1] | |
- data['distances'][d][0]) / 1000 | |
plt.plot([x * step+data['distances'][d][0] | |
for x in range(1000)], | |
[scint.splev(x * step+data['distances'][d][0], | |
tck, der=0) | |
for x in range(1000)], 'b-', | |
linewidth=band_linewidth) | |
if self._bs.is_spin_polarized: | |
tck = scint.splrep( | |
data['distances'][d], | |
[data['energy'][d][str(Spin.down)][i][j] | |
for j in range(len(data['distances'][d]))]) | |
step = (data['distances'][d][-1] | |
- data['distances'][d][0]) / 1000 | |
plt.plot([x * step+data['distances'][d][0] | |
for x in range(1000)], | |
[scint.splev(x * step+data['distances'][d][0], | |
tck, der=0) | |
for x in range(1000)], 'r--', | |
linewidth=band_linewidth) | |
self._maketicks(plt) | |
#Main X and Y Labels | |
plt.xlabel(r'$\mathrm{Wave\ Vector}$') | |
ylabel = r'$\mathrm{E\ -\ E_f\ (eV)}$' if zero_to_efermi \ | |
else r'$\mathrm{Energy\ (eV)}$' | |
plt.ylabel(ylabel) | |
# Draw Fermi energy, only if not the zero | |
if not zero_to_efermi: | |
ef = self._bs.efermi | |
plt.axhline(ef, linewidth=2, color='k') | |
# X range (K) | |
#last distance point | |
x_max = data['distances'][-1][-1] | |
plt.xlim(0, x_max) | |
if ylim is None: | |
if self._bs.is_metal(): | |
# Plot A Metal | |
if zero_to_efermi: | |
plt.ylim(e_min, e_max) | |
else: | |
plt.ylim(self._bs.efermi + e_min, self._bs._efermi + e_max) | |
else: | |
for cbm in data['cbm']: | |
plt.scatter(cbm[0], cbm[1], color='r', marker='o', s=100) | |
for vbm in data['vbm']: | |
plt.scatter(vbm[0], vbm[1], color='g', marker='o', s=100) | |
plt.ylim(data['vbm'][0][1] + e_min, data['cbm'][0][1] + e_max) | |
else: | |
plt.ylim(ylim) | |
plt.tight_layout() | |
return plt | |
class WebDosPlotter(DosPlotter): | |
def get_plot_vertical(self, xlim=None, ylim=None, | |
plt=None, handle_only=False): | |
""" | |
Get a matplotlib plot showing the DOS. | |
Args: | |
xlim: Specifies the x-axis limits. Set to None for automatic | |
determination. | |
ylim: Specifies the y-axis limits. | |
plt: Handle on existing plot. | |
handle_only: Quickly return just a handle. Useful if this method | |
raises an exception so that one can close() the figure. | |
""" | |
from pymatgen.util.plotting_utils import get_publication_quality_plot | |
plt = plt or get_publication_quality_plot(2, 5.5) | |
if handle_only: | |
return plt | |
import prettyplotlib as ppl | |
from prettyplotlib import brewer2mpl | |
ncolors = max(3, len(self._doses)) | |
ncolors = min(9, ncolors) | |
colors = brewer2mpl.get_map('Set1', 'qualitative', ncolors).mpl_colors | |
y = None | |
alldensities = [] | |
allenergies = [] | |
width = 4 | |
ticksize = int(width * 2.5) | |
axes = plt.gca() | |
axes.set_title(axes.get_title(), size=width * 4) | |
labelsize = int(width * 3) | |
axes.set_xlabel(axes.get_xlabel(), size=labelsize) | |
axes.set_ylabel(axes.get_ylabel(), size=labelsize) | |
axes.xaxis.labelpad = 6 | |
# Note that this complicated processing of energies is to allow for | |
# stacked plots in matplotlib. | |
for key, dos in self._doses.items(): | |
energies = dos['energies'] | |
densities = dos['densities'] | |
if not y: | |
y = {Spin.up: np.zeros(energies.shape), | |
Spin.down: np.zeros(energies.shape)} | |
newdens = {} | |
for spin in [Spin.up, Spin.down]: | |
if spin in densities: | |
if self.stack: | |
y[spin] += densities[spin] | |
newdens[spin] = y[spin].copy() | |
else: | |
newdens[spin] = densities[spin] | |
allenergies.append(energies) | |
alldensities.append(newdens) | |
keys = list(self._doses.keys()) | |
keys.reverse() | |
alldensities.reverse() | |
allenergies.reverse() | |
allpts = [] | |
for i, key in enumerate(keys): | |
x = [] | |
y = [] | |
for spin in [Spin.up, Spin.down]: | |
if spin in alldensities[i]: | |
densities = list(int(spin) * alldensities[i][spin]) | |
energies = list(allenergies[i]) | |
if spin == Spin.down: | |
energies.reverse() | |
densities.reverse() | |
y.extend(energies) | |
x.extend(densities) | |
allpts.extend(list(zip(x, y))) | |
if self.stack: | |
plt.fill(x, y, color=colors[i % ncolors], | |
label=str(key)) | |
else: | |
ppl.plot(x, y, color=colors[i % ncolors], | |
label=str(key),linewidth=1) | |
if not self.zero_at_efermi: | |
xlim = plt.xlim() | |
ppl.plot(xlim, [self._doses[key]['efermi'], | |
self._doses[key]['efermi']], | |
color=colors[i % ncolors], | |
linestyle='--', linewidth=1) | |
if ylim: | |
plt.ylim(ylim) | |
if xlim: | |
plt.xlim(xlim) | |
else: | |
ylim = plt.ylim() | |
relevantx = [p[0] for p in allpts | |
if ylim[0] < p[1] < ylim[1]] | |
plt.xlim(min(relevantx), max(relevantx)) | |
if self.zero_at_efermi: | |
xlim = plt.xlim() | |
plt.plot(xlim, [0, 0], 'k--', linewidth=1) | |
plt.ylabel(r'$\mathrm{E\ -\ E_f\ (eV)}$') | |
plt.xlabel(r'$\mathrm{Density\ of\ states}$') | |
locs, _ = plt.xticks() | |
plt.xticks([0],fontsize=ticksize) | |
plt.yticks(fontsize=ticksize) | |
plt.grid(which='major',axis='y') | |
plt.legend(fontsize='x-small', | |
loc='upper right', bbox_to_anchor=(1.15, 1)) | |
leg = plt.gca().get_legend() | |
leg.get_frame().set_alpha(0.25) | |
plt.tight_layout() | |
return plt |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment