Skip to content

Instantly share code, notes, and snippets.

@dwinston
Created February 19, 2016 04:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dwinston/f3e131b1655d1f306bf9 to your computer and use it in GitHub Desktop.
Save dwinston/f3e131b1655d1f306bf9 to your computer and use it in GitHub Desktop.
generate pairs of aligned-vertical-axis BS and DOS images from MP structures and persist to gridfs
# coding: utf-8
#
# Electronic structure (es) image builder:
#
# Build static images (*.png) for plots of bandstructure (BS) and
# density-of-states (DOS) data for materials, storing them in a GridFS
# filesystem (e.g. 'es_plot.files' and 'es_plot.chunks' collections) on the
# same db as that of the source 'materials'/'electronic_structure' collections.
from __future__ import division, unicode_literals, print_function
import datetime
import json
import logging
import matplotlib
import os
import StringIO
import traceback
import warnings
import gridfs
import numpy as np
import pymongo
from matgendb.builders import core, util
from matgendb.query_engine import QueryEngine
from pymatgen.electronic_structure.bandstructure import BandStructureSymmLine
from pymatgen.electronic_structure.dos import CompleteDos
from pymatgen.electronic_structure.core import Spin
from pymatgen.electronic_structure.plotter import BSPlotter, DosPlotter
_log = util.get_builder_log("bs_dos_img")
_name = 'ESImagesBuilder'
_file_name = '_'.join([
_name, datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f") + '.log'])
hdlr = logging.FileHandler(os.path.join(os.environ['HOME'], 'logs', _file_name))
formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
hdlr.setFormatter(formatter)
_log.addHandler(hdlr)
_log.setLevel(logging.DEBUG)
_log.propagate = False # Only log to file, not to console
matplotlib.use('Agg')
def _no_timeout():
"""
Pass to collection.find so that cursor does not time out on server
out on server after 10 minutes of inactivity. This is needed because
the builder will batch-process 10,000 input documents, which takes
longer than 10 minutes, and there are >30,000 input documents to
fetch and process.
"""
if pymongo.version.startswith('3'):
return {'no_cursor_timeout': True}
else:
return {'timeout': False}
class Builder(core.Builder):
""" Generate png images of bandstructures and put in GridFS collection.
"""
def __init__(self, *args, **kwargs):
self._target_coll = None
core.Builder.__init__(self, *args, **kwargs)
warnings.filterwarnings("ignore",
"tight_layout : falling back to Agg renderer")
def get_items(self, source=None, fs_cname='es_plot', crit=None, skip=None):
"""Put plots in GridFS in same db as source collection.
:param source: Input materials collection
:type source: QueryEngine
:param fs_cname: GridFS collection name
:type fs_cname: string
:param crit: Filter criteria, e.g. "{ 'flag': True }".
:type crit: dict
:param skip: Name of file with JSON list of `task_id`s to skip.
:type skip: str
"""
self._skip = frozenset()
if skip:
with open(skip) as f:
self._skip = frozenset(json.load(f))
self._fs = gridfs.GridFS(source.db, collection=fs_cname)
fs_files = getattr(source.db, fs_cname + '.files')
fs_files.ensure_index("filename")
self._es_coll = source.db.electronic_structure
if not crit: # reduce any False-y crit value to an empty dict
crit = {}
incl_bs_and_dos = {'dos': {'$exists': True},
'band_structure': {'$exists': True}}
incl_bs_and_dos.update(crit)
self._cursor = source.collection.find(
incl_bs_and_dos, ['task_id','band_structure','dos'],
**_no_timeout())
_log.info("get_items: source={} crit={} count={:d}"
.format(source.collection, incl_bs_and_dos,
self._cursor.count()))
return self._cursor
def finalize(self, had_errors):
# Explicitly close cursor because we told it to not time out.
self._cursor.close()
return True
def process_item(self, item):
try:
if item['task_id'] not in self._skip:
_log.debug('processing {}'.format(item['task_id']))
self._process_item(item)
except Exception as e:
_log.error("Failed to process {} because {}. Traceback: {}".format(
item['task_id'], str(e), traceback.format_exc()))
def _process_item(self, item):
# Get bandstructure data, generate and save plot
bs_fk = item['band_structure'] # holds foreign keys (fk)
which = 'GGA+U' if bs_fk.get('GGA+U') else 'GGA'
doc = self._es_coll.find_one({'_id': bs_fk[which]['oid']})
mid = doc['material_id']
filename = 'bs_{}.png'.format(mid)
redo = self._fs.exists({'filename': filename})
if redo:
_log.info("redoing bs for {}".format(mid))
bs = BandStructureSymmLine.from_dict(doc)
bs.bz_symmetry = doc.get('bz_symmetry', {})
BSPlotter = WebBSPlotter#electronic_structure_plotter.BSPlotter
fig = BSPlotter(bs).get_plot()
ylim = fig.ylim() # Used by DOS plot
imgdata = StringIO.StringIO()
fig.savefig(imgdata, format='png', dpi=100)
_id = self._fs.put(imgdata.getvalue(), filename=filename,
content_type="image/png")
if redo:
# Write the new version first (above) before deleting older
# versions, in an attempt to avoid concurrent reads of a file while
# it is being deleted.
for grid_out in self._fs.find({"filename": filename}):
if grid_out._id != _id:
self._fs.delete(grid_out._id)
imgdata.close(); fig.close()
# Get DOS data, generate and save plot
dos_fk = item['dos']
which = 'GGA+U' if dos_fk.get('GGA+U') else 'GGA'
doc = self._es_coll.find_one({'_id': dos_fk[which]['oid']})
mid = doc['material_id']
filename = 'dos_{}.png'.format(mid)
redo = self._fs.exists({'filename': filename})
if redo:
_log.info("redoing dos for {}".format(mid))
dos = CompleteDos.from_dict(doc)
dos_plotter = WebDosPlotter()#electronic_structure_plotter.DosPlotter()
dos_plotter.add_dos_dict(dos.get_element_dos())
fig = dos_plotter.get_plot_vertical(ylim=ylim, handle_only=True)
try:
# ylim for DOS plot should match that of BS plot
fig = dos_plotter.get_plot_vertical(ylim=ylim, plt=fig)
except ValueError as e:
# I expect 'x and y must have the same first dimension' for at
# least 200 materials, and I expect 'truth value of an array with
# more than one element is ambiguous' for at least 1000 materials,
# until the database is repaired.
_log.error("ValueError for {}: {}".format(mid, str(e)))
fig.close()
return
imgdata = StringIO.StringIO()
fig.savefig(imgdata, format='png', dpi=100)
_id = self._fs.put(imgdata.getvalue(), filename=filename,
content_type="image/png")
if redo:
for grid_out in self._fs.find({"filename": filename}):
if grid_out._id != _id:
self._fs.delete(grid_out._id)
imgdata.close(); fig.close()
#
# Obtain web-friendly images by subclassing pymatgen plotters.
#
class WebBSPlotter(BSPlotter):
def get_plot(self, zero_to_efermi=True, ylim=None, smooth=False):
"""
get a matplotlib object for the bandstructure plot.
Blue lines are up spin, red lines are down
spin.
Args:
zero_to_efermi: Automatically subtract off the Fermi energy from
the eigenvalues and plot (E-Ef).
ylim: Specify the y-axis (energy) limits; by default None let
the code choose. It is vbm-4 and cbm+4 if insulator
efermi-10 and efermi+10 if metal
smooth: interpolates the bands by a spline cubic
"""
from pymatgen.util.plotting_utils import get_publication_quality_plot
plt = get_publication_quality_plot(6, 5.5) # Was 12, 8
from matplotlib import rc
import scipy.interpolate as scint
rc('text', usetex=True)
width = 4
ticksize = int(width * 2.5)
axes = plt.gca()
axes.set_title(axes.get_title(), size=width * 4)
labelsize = int(width * 3)
axes.set_xlabel(axes.get_xlabel(), size=labelsize)
axes.set_ylabel(axes.get_ylabel(), size=labelsize)
plt.xticks(fontsize=ticksize)
plt.yticks(fontsize=ticksize)
for axis in ['top','bottom','left','right']:
axes.spines[axis].set_linewidth(0.5)
#main internal config options
e_min = -4
e_max = 4
if self._bs.is_metal():
e_min = -10
e_max = 10
band_linewidth = 1 # Was 3 in pymatgen
data = self.bs_plot_data(zero_to_efermi)
if not smooth:
for d in range(len(data['distances'])):
for i in range(self._nb_bands):
plt.plot(data['distances'][d],
[data['energy'][d][str(Spin.up)][i][j]
for j in range(len(data['distances'][d]))], 'b-',
linewidth=band_linewidth)
if self._bs.is_spin_polarized:
plt.plot(data['distances'][d],
[data['energy'][d][str(Spin.down)][i][j]
for j in range(len(data['distances'][d]))],
'r--', linewidth=band_linewidth)
else:
for d in range(len(data['distances'])):
for i in range(self._nb_bands):
tck = scint.splrep(
data['distances'][d],
[data['energy'][d][str(Spin.up)][i][j]
for j in range(len(data['distances'][d]))])
step = (data['distances'][d][-1]
- data['distances'][d][0]) / 1000
plt.plot([x * step+data['distances'][d][0]
for x in range(1000)],
[scint.splev(x * step+data['distances'][d][0],
tck, der=0)
for x in range(1000)], 'b-',
linewidth=band_linewidth)
if self._bs.is_spin_polarized:
tck = scint.splrep(
data['distances'][d],
[data['energy'][d][str(Spin.down)][i][j]
for j in range(len(data['distances'][d]))])
step = (data['distances'][d][-1]
- data['distances'][d][0]) / 1000
plt.plot([x * step+data['distances'][d][0]
for x in range(1000)],
[scint.splev(x * step+data['distances'][d][0],
tck, der=0)
for x in range(1000)], 'r--',
linewidth=band_linewidth)
self._maketicks(plt)
#Main X and Y Labels
plt.xlabel(r'$\mathrm{Wave\ Vector}$')
ylabel = r'$\mathrm{E\ -\ E_f\ (eV)}$' if zero_to_efermi \
else r'$\mathrm{Energy\ (eV)}$'
plt.ylabel(ylabel)
# Draw Fermi energy, only if not the zero
if not zero_to_efermi:
ef = self._bs.efermi
plt.axhline(ef, linewidth=2, color='k')
# X range (K)
#last distance point
x_max = data['distances'][-1][-1]
plt.xlim(0, x_max)
if ylim is None:
if self._bs.is_metal():
# Plot A Metal
if zero_to_efermi:
plt.ylim(e_min, e_max)
else:
plt.ylim(self._bs.efermi + e_min, self._bs._efermi + e_max)
else:
for cbm in data['cbm']:
plt.scatter(cbm[0], cbm[1], color='r', marker='o', s=100)
for vbm in data['vbm']:
plt.scatter(vbm[0], vbm[1], color='g', marker='o', s=100)
plt.ylim(data['vbm'][0][1] + e_min, data['cbm'][0][1] + e_max)
else:
plt.ylim(ylim)
plt.tight_layout()
return plt
class WebDosPlotter(DosPlotter):
def get_plot_vertical(self, xlim=None, ylim=None,
plt=None, handle_only=False):
"""
Get a matplotlib plot showing the DOS.
Args:
xlim: Specifies the x-axis limits. Set to None for automatic
determination.
ylim: Specifies the y-axis limits.
plt: Handle on existing plot.
handle_only: Quickly return just a handle. Useful if this method
raises an exception so that one can close() the figure.
"""
from pymatgen.util.plotting_utils import get_publication_quality_plot
plt = plt or get_publication_quality_plot(2, 5.5)
if handle_only:
return plt
import prettyplotlib as ppl
from prettyplotlib import brewer2mpl
ncolors = max(3, len(self._doses))
ncolors = min(9, ncolors)
colors = brewer2mpl.get_map('Set1', 'qualitative', ncolors).mpl_colors
y = None
alldensities = []
allenergies = []
width = 4
ticksize = int(width * 2.5)
axes = plt.gca()
axes.set_title(axes.get_title(), size=width * 4)
labelsize = int(width * 3)
axes.set_xlabel(axes.get_xlabel(), size=labelsize)
axes.set_ylabel(axes.get_ylabel(), size=labelsize)
axes.xaxis.labelpad = 6
# Note that this complicated processing of energies is to allow for
# stacked plots in matplotlib.
for key, dos in self._doses.items():
energies = dos['energies']
densities = dos['densities']
if not y:
y = {Spin.up: np.zeros(energies.shape),
Spin.down: np.zeros(energies.shape)}
newdens = {}
for spin in [Spin.up, Spin.down]:
if spin in densities:
if self.stack:
y[spin] += densities[spin]
newdens[spin] = y[spin].copy()
else:
newdens[spin] = densities[spin]
allenergies.append(energies)
alldensities.append(newdens)
keys = list(self._doses.keys())
keys.reverse()
alldensities.reverse()
allenergies.reverse()
allpts = []
for i, key in enumerate(keys):
x = []
y = []
for spin in [Spin.up, Spin.down]:
if spin in alldensities[i]:
densities = list(int(spin) * alldensities[i][spin])
energies = list(allenergies[i])
if spin == Spin.down:
energies.reverse()
densities.reverse()
y.extend(energies)
x.extend(densities)
allpts.extend(list(zip(x, y)))
if self.stack:
plt.fill(x, y, color=colors[i % ncolors],
label=str(key))
else:
ppl.plot(x, y, color=colors[i % ncolors],
label=str(key),linewidth=1)
if not self.zero_at_efermi:
xlim = plt.xlim()
ppl.plot(xlim, [self._doses[key]['efermi'],
self._doses[key]['efermi']],
color=colors[i % ncolors],
linestyle='--', linewidth=1)
if ylim:
plt.ylim(ylim)
if xlim:
plt.xlim(xlim)
else:
ylim = plt.ylim()
relevantx = [p[0] for p in allpts
if ylim[0] < p[1] < ylim[1]]
plt.xlim(min(relevantx), max(relevantx))
if self.zero_at_efermi:
xlim = plt.xlim()
plt.plot(xlim, [0, 0], 'k--', linewidth=1)
plt.ylabel(r'$\mathrm{E\ -\ E_f\ (eV)}$')
plt.xlabel(r'$\mathrm{Density\ of\ states}$')
locs, _ = plt.xticks()
plt.xticks([0],fontsize=ticksize)
plt.yticks(fontsize=ticksize)
plt.grid(which='major',axis='y')
plt.legend(fontsize='x-small',
loc='upper right', bbox_to_anchor=(1.15, 1))
leg = plt.gca().get_legend()
leg.get_frame().set_alpha(0.25)
plt.tight_layout()
return plt
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment