Skip to content

Instantly share code, notes, and snippets.

@taoning
Last active June 20, 2023 17:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save taoning/ddb21cd03643ed6665597f4abd4a89c1 to your computer and use it in GitHub Desktop.
Save taoning/ddb21cd03643ed6665597f4abd4a89c1 to your computer and use it in GitHub Desktop.
from dataclasses import dataclass
import re
from typing import Generator
from lxml import objectify
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
import numpy as np
@dataclass
class ScatteringData:
front: str = ""
back: str = ""
@dataclass
class WavelengthData:
transmission: ScatteringData = ScatteringData()
reflection: ScatteringData = ScatteringData()
@dataclass
class SystemBSDF:
name: str
manufacturer: str
devicetype: str
geometry: str
visible: WavelengthData
solar: WavelengthData
def get_nested_list_levels(nested_list: list) -> int:
"""Calculate the number of levels given a nested list."""
return (
isinstance(nested_list, list)
and max(map(get_nested_list_levels, nested_list)) + 1
)
class TensorTree:
"""The tensor tree object.
Anisotropic tensor tree has should have 16 lists
Attributes:
parsed: parsed tensor tree object)
depth: number of tree levels
"""
def __init__(self, parsed) -> None:
self.parsed = parsed
self.depth = get_nested_list_levels(parsed)
def lookup(self, xp, yp) -> list:
"""Traverses a parsed tensor tree (a nexted list) given a input position."""
branch_idx = self.get_branch_index(xp, yp)
quads = [self.parsed[i] for i in branch_idx]
return [self.traverse(quad, xp, yp) for quad in quads]
def get_leaf_index(self, xp, yp) -> range:
if xp < 0:
if yp < 0:
return range(0, 4)
return range(4, 8)
if yp < 0:
return range(8, 12)
return range(12, 16)
def get_branch_index(self, xp, yp) -> range:
"""Gets a set of index."""
if xp < 0:
if yp < 0:
return range(0, 16, 4)
return range(2, 16, 4)
if yp < 0:
return range(1, 16, 4)
return range(3, 16, 4)
def traverse(self, quad, xp, yp, n: int = 1) -> list:
"""Traverse a quadrant."""
if len(quad) == 1: # single leaf
res = quad
else:
res = []
# get x, y signage in relation to branches
_x = xp + 1 / (2**n) if xp < 0 else xp - 1 / (2**n)
_y = yp + 1 / (2**n) if yp < 0 else yp - 1 / (2**n)
n += 1
# which four branches to get? get index for them
if n < self.depth:
ochild = self.get_branch_index(_x, _y)
else:
ochild = self.get_leaf_index(_x, _y)
sub_quad = [quad[i] for i in ochild]
if all(isinstance(i, float) for i in sub_quad):
res = sub_quad # single leaf for each branch
else: # branches within this quadrant
for sq in sub_quad:
if len(sq) > 4: # there is another branch
res.append(self.traverse(sq, _x, _y, n=n))
else: # just a leaf
res.append(sq)
return res
def parse_bsdf_xml(path: str):
with open(path, 'rb') as rdr:
obj = objectify.fromstring(rdr.read())
visible = WavelengthData()
solar = WavelengthData()
for i in obj.Optical.Layer.findall("{http://windows.lbl.gov}WavelengthData"):
spect = i.Wavelength.text
wd, side = i.WavelengthDataBlock.WavelengthDataDirection.text.split()
if spect.lower() == "visible":
if wd.lower() == "transmission":
if side.lower() == "back":
visible.transmission.back = i.WavelengthDataBlock.ScatteringData.text
elif side.lower() == "front":
visible.transmission.front = i.WavelengthDataBlock.ScatteringData.text
elif wd.lower() == "reflection":
if side.lower() == "back":
visible.reflection.back = i.WavelengthDataBlock.ScatteringData.text
elif side.lower() == "front":
visible.reflection.front = i.WavelengthDataBlock.ScatteringData.text
elif spect.lower() == "solar":
if wd.lower() == "transmission":
if side.lower() == "back":
solar.transmission.back = i.WavelengthDataBlock.ScatteringData.text
elif side.lower() == "front":
solar.transmission.front = i.WavelengthDataBlock.ScatteringData.text
elif wd.lower() == "reflection":
if side.lower() == "back":
solar.reflection.back = i.WavelengthDataBlock.ScatteringData.text
elif side.lower() == "front":
solar.reflection.front = i.WavelengthDataBlock.ScatteringData.text
return SystemBSDF(
name=obj.Optical.Layer.Material.Name,
manufacturer=obj.Optical.Layer.Material.Manufacturer,
devicetype=obj.Optical.Layer.Material.DeviceType,
geometry=obj.Optical.Layer.Geometry,
visible=visible,
solar=solar,
)
def tokenize(inp: str) -> Generator[str, None, None]:
"""Generator for tokenizing a string that
is seperated by a space or a comma.
Args:
inp: input string
Yields:
next token
"""
tokens = re.compile(
" +|[-+]?(\d+([.,]\d*)?|[.,]\d+)([eE][-+]?\d+)+|[\d*\.\d+]+|[{}]"
)
for match in tokens.finditer(inp):
if match.group(0)[0] in " ,":
continue
yield match.group(0)
def parse_branch(token: Generator[str, None, None]) -> list:
"""Prase tensor tree branches recursively by opening and closing curly braces.
Args:
token: token generator object.
Return:
children: parsed branches as nexted list
"""
children = []
while True:
value = next(token)
if value == "{":
children.append(parse_branch(token))
elif value == "}":
return children
else:
children.append(float(value))
def parse_ttree(data_str: str) -> list:
"""Parse a tensor tree.
Args:
data_str: input data string
Returns:
A nested list that is the tree
"""
tokenized = tokenize(data_str)
if next(tokenized) != "{":
raise ValueError("Tensor tree data not starting with {")
return parse_branch(tokenized)
def transform(vertices):
"""Transform input x y in square space into disk (Shirley-Chiu)"""
xp, yp = vertices.T
xoy = np.divide(xp, yp, out=np.zeros_like(xp), where=yp!=0)
yox = np.divide(yp, xp, out=np.zeros_like(yp), where=xp!=0)
alpha = np.pi / 4
r = np.where((xp + yp) > 0, np.where(xp > yp, xp, yp), np.where(xp < yp, -xp, -yp))
phi = alpha * np.where((xp + yp) > 0, np.where(xp > yp, yox, 2 - xoy), np.where(xp < yp, 4 + yox, np.where(yp * yp > 0, 6 - xoy, 0)))
return np.column_stack([r * np.cos(phi), r * np.sin(phi)])
def transform_path(sx, sy, side, depth):
"""Transform patch paths into disk space"""
step = max(2, int(12 / depth))
patch = mpl.patches.Polygon(((sx, sy), (sx+side, sy), (sx+side, sy+side), (sx, sy+side), (sx, sy)))
path = patch.get_path()
ipath = path.interpolated(step)
return mpl.path.Path(transform(ipath.vertices), ipath.codes)
def plot_ttree(loaded, tin, pin, vmin, vmax, title=None):
norm = mpl.colors.LogNorm(vmin=vmin, vmax=vmax)
patches = []
colors = []
def plot_quad(ax, data, depth, ss, sx, sy):
"""Plot one quadrant recursively.
Args:
ax: Matplotlib.pyplot.axes to plot
data: Quadrant data.
depth: How deep are we within the tree
ss: Patch side size.
sx: Patch x coordinate.
sy: Patch y coordinate.
"""
if isinstance(data, float):
path = transform_path(sx, sy, ss, depth)
patch = mpl.patches.PathPatch(path)
patches.append(patch)
colors.append(data)
elif len(data) == 1:
path = transform_path(sx, sy, ss, depth)
patch = mpl.patches.PathPatch(path)
patches.append(patch)
colors.append(data[0])
elif len(data) == 4:
depth += 1
ss /= 2
if depth < loaded.depth:
plot_quad(ax, data[0], depth, ss, sx+ss, sy+ss)
plot_quad(ax, data[1], depth, ss, sx, sy+ss)
plot_quad(ax, data[2], depth, ss, sx+ss, sy)
plot_quad(ax, data[3], depth, ss, sx, sy)
else:
plot_quad(ax, data[0], depth, ss, sx+ss, sy+ss)
plot_quad(ax, data[1], depth, ss, sx+ss, sy)
plot_quad(ax, data[2], depth, ss, sx, sy+ss)
plot_quad(ax, data[3], depth, ss, sx, sy)
r = max(min(tin / 90, 1), 0)
phi = np.deg2rad(pin)
quarter_pi = np.pi / 4
rgn = int(np.floor((phi + quarter_pi)/(np.pi/2)))
oa = [r, (np.pi / 2 - phi) * r / quarter_pi, -r, (phi - 3*np.pi/2) * r / quarter_pi, r][rgn]
ob = [phi * r / quarter_pi, r, (np.pi - phi) * r / quarter_pi, -r, (phi - 2 * np.pi) * r / quarter_pi][rgn]
test = loaded.lookup(oa, ob)
fig, ax = plt.subplots(figsize=(8, 8))
plot_quad(ax, test[0], 1, 1, 0, 0)
plot_quad(ax, test[1], 1, 1, -1, 0)
plot_quad(ax, test[2], 1, 1, 0, -1)
plot_quad(ax, test[3], 1, 1, -1, -1)
pc = PatchCollection(patches)
pc.set(array=np.array(colors), cmap='turbo', norm=norm)
im = ax.add_collection(pc)
ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)
ax.axis("off")
if title is not None:
ax.set_title(title)
box = ax.get_position()
cbarax = plt.axes([box.x0 * 1.1 + box.width * 1.08, box.y0, 0.03, box.height])
cbar = fig.colorbar(pc, cax=cbarax, label="BSDF")
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment