|
import math |
|
from dataclasses import dataclass |
|
from typing import Optional, List |
|
|
|
import lxml.etree as etree |
|
import numpy as np |
|
import scipy.spatial |
|
import itertools |
|
|
|
from pymatgen import Structure, Element |
|
from phonopy.structure.atoms import symbol_map, atom_data |
|
|
|
from ruamel.yaml import YAML |
|
|
|
yaml = YAML(typ='safe') |
|
NSMAP = {'xlink': 'http://www.w3.org/1999/xlink'} |
|
|
|
|
|
def mass_from_symbol(symbol: str): |
|
index = symbol_map[symbol.title()] |
|
data = atom_data[index] |
|
return data[3] |
|
|
|
|
|
def plane_dist(a, b, c): |
|
normal = np.cross(a, b) |
|
normal /= np.linalg.norm(normal) |
|
return abs(np.dot(normal, c)) |
|
|
|
|
|
def structure_bonds(structure: Structure, max_dist=1.6): |
|
lattice = structure.lattice.matrix |
|
vecs = [lattice[1], lattice[2], lattice[0]] |
|
|
|
image_ranges = [] |
|
for idx in range(3): |
|
dist = plane_dist(*np.roll(vecs, idx, axis=0)) |
|
image_ranges.append( |
|
np.arange(max(int(np.ceil(max_dist / dist)), 2))) |
|
|
|
# build cartesian offsets from the multiples of each lattice vector |
|
images = np.array(list(itertools.product(*image_ranges))) |
|
|
|
# just make a huge system, whatever |
|
ref_coords = structure.cart_coords |
|
|
|
bonds = None |
|
for a, b, c in images: |
|
offset = np.dot(lattice.T, [a, b, c]) |
|
image_coords = ref_coords + offset |
|
|
|
distances = scipy.spatial.distance.cdist( |
|
ref_coords, image_coords) |
|
|
|
new_bonds = np.transpose((distances < max_dist).nonzero()) |
|
image_offsets = np.repeat([[a, b, c]], new_bonds.shape[0], axis=0) |
|
new_bonds = np.hstack((image_offsets, new_bonds)) |
|
|
|
if bonds is None: |
|
bonds = new_bonds[new_bonds[:, 3] != new_bonds[:, 4]] |
|
else: |
|
bonds = np.append(bonds, new_bonds, axis=0) |
|
|
|
return bonds |
|
|
|
|
|
def bond_to_positions(bond: np.array, structure: Structure): |
|
offset = np.dot(structure.lattice.matrix.T, bond[:3]) |
|
pos_a = structure.cart_coords[bond[3]] |
|
pos_b = structure.cart_coords[bond[4]] + offset |
|
return pos_a, pos_b |
|
|
|
|
|
def load_eigs_band_yaml(filename: str): |
|
with open(filename, 'r') as f: |
|
data = yaml.load(f) |
|
band = data['phonon'][0]['band'] |
|
|
|
eigs = [np.array(mode['eigenvector'])[:, :, 0] for mode in band] |
|
# THZ to cm^-1 |
|
frequencies = 33.35641 * np.array([mode['frequency'] for mode in band]) |
|
|
|
return frequencies, eigs |
|
|
|
|
|
@dataclass(unsafe_hash=True) |
|
class RenderSettings: |
|
# image size multiplier |
|
scaling: float = 30.0 |
|
# how much to translate the structure |
|
translation: np.array = np.zeros(3) |
|
# structure rotation (applied after translation) |
|
rotation: np.array = np.eye(3) |
|
|
|
# extra space to add to fit structure in the image |
|
padding: float = 2.0 |
|
|
|
# text settings for mode info |
|
draw_info_text: bool = True |
|
info_text_size: float = padding / 4 |
|
|
|
# can be none for transparent |
|
background_color: Optional[str] = '#FFFFFF' |
|
|
|
bonds = { |
|
'stroke': '#7F8285', |
|
'stroke-width': 0.1251, |
|
# whether to draw bonds across periodic bounds |
|
'draw-periodic': False, |
|
} |
|
|
|
displacements = { |
|
'color': '#EB1923', |
|
'stroke-width': bonds['stroke-width'], |
|
'max-length': 1.44 / 2, |
|
'arrow-width': 3 * bonds['stroke-width'], |
|
} |
|
|
|
# per-atom render settings |
|
atom_types = { |
|
Element.C: { |
|
'radius': 0.2607, |
|
'fill': '#111417', |
|
'stroke': '#111417', |
|
'stroke-width': 0.0834, |
|
}, |
|
Element.H: { |
|
'radius': 0.1669, |
|
'fill': '#FFFFFF', |
|
'stroke': '#111417', |
|
'stroke-width': 0.0834, |
|
}, |
|
} |
|
|
|
|
|
class ModeRenderer: |
|
def __init__( |
|
self, |
|
structure_filename: str, |
|
eigs_filename: str, |
|
# defaults to [1, 1, 1] |
|
supercell: Optional[List[int]] = None, |
|
settings: RenderSettings = RenderSettings() |
|
): |
|
# read structure |
|
structure = Structure.from_file(structure_filename) |
|
# initial translation |
|
structure.translate_sites( |
|
np.arange(structure.num_sites), settings.translation, to_unit_cell=True) |
|
|
|
# make supercell if specified |
|
if supercell is not None: |
|
supercell_images = np.prod(supercell) |
|
structure.make_supercell(supercell) |
|
else: |
|
supercell_images = 1 |
|
|
|
self.structure = structure |
|
self.bonds = structure_bonds(structure) |
|
|
|
self.frequencies, self.eigenvectors = load_eigs_band_yaml(eigs_filename) |
|
self.eigenvectors = np.array([ |
|
np.repeat(e, supercell_images, axis=0) for e in self.eigenvectors |
|
]) |
|
|
|
self.settings = settings |
|
|
|
def render(self, mode_id: Optional[int]): |
|
rset = self.settings |
|
|
|
# our coordinate transform then consists of a rotation then scaling |
|
def coord_transform(coords): |
|
return rset.scaling * np.dot(rset.rotation, coords) |
|
|
|
# find structure bounds |
|
min_pos = coord_transform(self.structure.cart_coords.min(0) - rset.padding) |
|
max_pos = coord_transform(self.structure.cart_coords.max(0) + rset.padding) |
|
range_pos = max_pos - min_pos |
|
|
|
svg = etree.Element("svg", nsmap=NSMAP, attrib={ |
|
'width': '{:.1f}'.format(range_pos[0]), |
|
'height': '{:.1f}'.format(range_pos[1]), |
|
'version': '1.1', |
|
'viewBox': '{:.1f} {:.1f} {:.1f} {:.1f}'.format( |
|
min_pos[0], min_pos[1], |
|
range_pos[0], range_pos[1], |
|
), |
|
'xmlns': "http://www.w3.org/2000/svg", |
|
}) |
|
|
|
# definitions for use later |
|
defs = etree.SubElement(svg, 'defs') |
|
|
|
if rset.background_color is not None: |
|
etree.SubElement(svg, 'rect', attrib={ |
|
'id': 'background', |
|
'x': '{:.1f}'.format(min_pos[0]), |
|
'y': '{:.1f}'.format(min_pos[1]), |
|
'width': '{:.1f}'.format(range_pos[0]), |
|
'height': '{:.1f}'.format(range_pos[1]), |
|
'fill': rset.background_color, |
|
}) |
|
|
|
if rset.draw_info_text and mode_id is not None: |
|
text_size = rset.info_text_size * rset.scaling |
|
mode_info = etree.SubElement(svg, 'text', attrib={ |
|
'id': 'mode-info', |
|
'x': '{:.1f}'.format(min_pos[0] + text_size / 3), |
|
'y': '{:.1f}'.format(max_pos[1] - text_size / 3), |
|
'font-family': 'monospace', |
|
'font-size': '{}'.format(text_size), |
|
'font-weight': 'bold', |
|
}) |
|
mode_info.text = "id: {} | frequency: {:.4f} cm^-1".format(mode_id, self.frequencies[mode_id]) |
|
|
|
self.svg_bonds(svg, coord_transform) |
|
|
|
if mode_id is not None: |
|
self.svg_displacements(svg, defs, mode_id, coord_transform) |
|
|
|
self.svg_atoms(svg, coord_transform) |
|
|
|
return etree.ElementTree(svg) |
|
|
|
def svg_atoms(self, svg, coord_transform): |
|
rset = self.settings |
|
|
|
atom_group = etree.SubElement(svg, 'g', attrib={ |
|
'id': 'atoms' |
|
}) |
|
|
|
for key, value in rset.atom_types.items(): |
|
value['group'] = etree.SubElement(atom_group, 'g', attrib={ |
|
'id': '{}-atoms'.format(key.symbol.lower()), |
|
'fill': value['fill'], |
|
'stroke': value['stroke'], |
|
'stroke-width': str(value['stroke-width'] * rset.scaling), |
|
}) |
|
for site in self.structure.sites: |
|
x, y, z = coord_transform(site.coords) |
|
info = rset.atom_types[site.specie] |
|
|
|
etree.SubElement(info['group'], 'circle', attrib={ |
|
'r': str(info['radius'] * rset.scaling), |
|
'cx': '{:.1f}'.format(x), |
|
'cy': '{:.1f}'.format(y), |
|
}) |
|
|
|
def svg_displacements(self, svg: etree.Element, defs: etree.Element, mode_id: int, coord_transform): |
|
rset = self.settings |
|
|
|
# displacements (phonon mode) |
|
etree.SubElement(defs, 'path', attrib={ |
|
'id': 'arrowhead', |
|
'd': 'M -1 0 h 2 l -1 1.732 z', |
|
'fill': rset.displacements['color'], |
|
'stroke-width': '0', |
|
'transform': 'scale({})'.format(rset.scaling * rset.displacements['arrow-width'] / 2) |
|
}) |
|
|
|
displacement_group = etree.SubElement(svg, 'g', nsmap=NSMAP, attrib={ |
|
'id': 'displacements', |
|
'stroke': '#EB1923', |
|
'stroke-width': str(rset.displacements['stroke-width'] * rset.scaling), |
|
}) |
|
|
|
eigs = self.eigenvectors[mode_id] |
|
# scale by 1 / sqrt(mass) to get displacements |
|
masses = np.array([mass_from_symbol(s.name) for s in self.structure.species]) |
|
disps = (eigs.T / np.sqrt(masses)).T |
|
|
|
# normalize, then scale to arrow length |
|
disps = rset.displacements['max-length'] * disps / np.max(np.linalg.norm(disps, axis=1)) |
|
|
|
for pos, disp in zip(self.structure.cart_coords, disps): |
|
pos = coord_transform(pos) |
|
disp = coord_transform(disp) |
|
end = pos + disp |
|
|
|
etree.SubElement(displacement_group, 'path', attrib={ |
|
'd': 'M {:.1f} {:.1f} {:.1f} {:.1f}'.format( |
|
pos[0], pos[1], end[0], end[1] |
|
) |
|
}) |
|
|
|
# add arrow head |
|
etree.SubElement(displacement_group, 'use', nsmap=NSMAP, attrib={ |
|
'{http://www.w3.org/1999/xlink}href': '#arrowhead', |
|
'x': '{:.1f}'.format(end[0]), |
|
'y': '{:.1f}'.format(end[1]), |
|
'transform': 'rotate({:.1f} {:.1f} {:.1f})'.format( |
|
math.degrees(math.atan2(disp[1], disp[0]) - np.pi / 2), end[0], end[1] |
|
), |
|
}) |
|
|
|
def svg_bonds(self, svg, coord_transform): |
|
rset = self.settings |
|
|
|
# bonds |
|
bond_group = etree.SubElement(svg, 'g', attrib={ |
|
'id': 'bonds', |
|
'stroke': rset.bonds['stroke'], |
|
'stroke-width': str(rset.bonds['stroke-width'] * rset.scaling), |
|
}) |
|
|
|
for bond in self.bonds: |
|
if not rset.bonds['draw-periodic'] and np.any(bond[:3] != 0): |
|
continue |
|
|
|
pos_a, pos_b = bond_to_positions(bond, self.structure) |
|
pos_a, pos_b = coord_transform(pos_a), coord_transform(pos_b) |
|
|
|
etree.SubElement(bond_group, 'path', attrib={ |
|
'd': 'M {:.1f} {:.1f} {:.1f} {:.1f}'.format( |
|
pos_a[0], pos_a[1], pos_b[0], pos_b[1] |
|
) |
|
}) |
|
|
|
|
|
def main(): |
|
renderer = ModeRenderer("POSCAR", "band.yaml", [2, 1, 1]) |
|
# see the RenderSettings class for available settings (e.g. colors/sizes/etc) |
|
renderer.settings = RenderSettings( |
|
translation=np.array([-0.5 / 12.944, 0, 0]), |
|
# need to swap y/z since we render the x, y plane |
|
rotation=np.array([ |
|
[1.0, 0.0, 0.0], |
|
[0.0, 0.0, 1.0], |
|
[0.0, 1.0, 0.0], |
|
]), |
|
) |
|
|
|
for i in range(len(renderer.frequencies)): |
|
filename = 'mode_{}.svg'.format(i + 1) |
|
|
|
print(filename) |
|
renderer.render(i).write( |
|
filename, |
|
encoding='utf-8', |
|
pretty_print=True, |
|
xml_declaration=True |
|
) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |