Skip to content

Instantly share code, notes, and snippets.

@ctralie
Created July 26, 2022 17:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ctralie/979b420570dc6ea65dee2ab9f8a49705 to your computer and use it in GitHub Desktop.
Save ctralie/979b420570dc6ea65dee2ab9f8a49705 to your computer and use it in GitHub Desktop.
Merge trees on time series with simplification
"""
Programmer: Chris Tralie
Purpose: To provide a basic ordered merge tree class for interval and circular domains,
along with methods to construct the merge tree from a time series, to plot it and
its associated persistence diagram, and to simplify the merge trees by persistence
"""
import numpy as np
import matplotlib.pyplot as plt
def plot_diagrams(
diagrams,
plot_only=None,
title=None,
xy_range=None,
labels=None,
markers=None,
sizes=None,
colors=None,
colormap="default",
ax_color=np.array([0.0, 0.0, 0.0]),
diagonal=True,
lifetime=False,
equal=True,
legend=True,
show=False,
ax=None
):
"""A helper function to plot persistence diagrams.
Parameters
----------
diagrams: ndarray (n_pairs, 2) or list of diagrams
A diagram or list of diagrams. If diagram is a list of diagrams,
then plot all on the same plot using different colors.
plot_only: list of numeric
If specified, an array of only the diagrams that should be plotted.
title: string, default is None
If title is defined, add it as title of the plot.
xy_range: list of numeric [xmin, xmax, ymin, ymax]
User provided range of axes. This is useful for comparing
multiple persistence diagrams.
labels: string or list of strings
Legend labels for each diagram.
If none are specified, we use H_0, H_1, H_2,... by default.
markers: string or list of strings
Markers for each diagram
If none are specified, we use dots by default.
sizes: int or list of ints
Sizes of each marker
If none are specified, use 20 by default
colors: string or list of strings
Colors for each diagram
If none are specified, use the default sequence from matplotlib
colormap: string, default is 'default'
Any of matplotlib color palettes.
Some options are 'default', 'seaborn', 'sequential'.
See all available styles with
.. code:: python
import matplotlib as mpl
print(mpl.styles.available)
ax_color: any valid matplotlib color type.
See [https://matplotlib.org/api/colors_api.html](https://matplotlib.org/api/colors_api.html) for complete API.
diagonal: bool, default is True
Plot the diagonal x=y line.
lifetime: bool, default is False. If True, diagonal is turned to False.
Plot life time of each point instead of birth and death.
Essentially, visualize (x, y-x).
equal: bool, default is True. If True, plot axes equal
legend: bool, default is True
If true, show the legend.
show: bool, default is False
Call plt.show() after plotting. If you are using self.plot() as part
of a subplot, set show=False and call plt.show() only once at the end.
"""
ax = ax or plt.gca()
plt.style.use(colormap)
xlabel, ylabel = "Birth", "Death"
if not isinstance(diagrams, list):
# Must have diagrams as a list for processing downstream
diagrams = [diagrams]
if labels is None:
# Provide default labels for diagrams if using self.dgm_
labels = ["$H_{{{}}}$".format(i) for i , _ in enumerate(diagrams)]
if markers is None:
markers = ["o"]*len(diagrams)
if sizes is None:
sizes = [20]*len(diagrams)
if colors is None:
colors = ["C{}".format(i) for i in range(len(diagrams))]
if plot_only:
diagrams = [diagrams[i] for i in plot_only]
labels = [labels[i] for i in plot_only]
if not isinstance(labels, list):
labels = [labels] * len(diagrams)
if not isinstance(markers, list):
markers = [markers]*len(diagrams)
if not isinstance(sizes, list):
sizes = [sizes]*len(diagrams)
if not isinstance(colors, list):
colors = [colors]*len(diagrams)
# Construct copy with proper type of each diagram
# so we can freely edit them.
diagrams = [dgm.astype(np.float32, copy=True) for dgm in diagrams]
# find min and max of all visible diagrams
concat_dgms = np.concatenate(diagrams).flatten()
has_inf = np.any(np.isinf(concat_dgms))
finite_dgms = concat_dgms[np.isfinite(concat_dgms)]
# clever bounding boxes of the diagram
if not xy_range:
# define bounds of diagram
ax_min, ax_max = np.min(finite_dgms), np.max(finite_dgms)
x_r = ax_max - ax_min
# Give plot a nice buffer on all sides.
# ax_range=0 when only one point,
buffer = 1 if xy_range == 0 else x_r / 5
x_down = ax_min - buffer / 2
x_up = ax_max + buffer
y_down, y_up = x_down, x_up
else:
x_down, x_up, y_down, y_up = xy_range
yr = y_up - y_down
if lifetime:
# Don't plot landscape and diagonal at the same time.
diagonal = False
# reset y axis so it doesn't go much below zero
y_down = -yr * 0.05
y_up = y_down + yr
# set custom ylabel
ylabel = "Lifetime"
# set diagrams to be (x, y-x)
for dgm in diagrams:
dgm[:, 1] -= dgm[:, 0]
# plot horizon line
ax.plot([x_down, x_up], [0, 0], c=ax_color)
# Plot diagonal
if diagonal:
ax.plot([x_down, x_up], [x_down, x_up], "--", c=ax_color)
# Plot inf line
if has_inf:
# put inf line slightly below top
b_inf = y_down + yr * 0.95
ax.plot([x_down, x_up], [b_inf, b_inf], "--", c="k", label=r"$\infty$")
# convert each inf in each diagram with b_inf
for dgm in diagrams:
dgm[np.isinf(dgm)] = b_inf
# Plot each diagram
for dgm, label, marker, size, color in zip(diagrams, labels, markers, sizes, colors):
# plot persistence pairs
ax.scatter(dgm[:, 0], dgm[:, 1], size, c=color, label=label, marker=marker)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_xlim([x_down, x_up])
ax.set_ylim([y_down, y_up])
if equal:
ax.set_aspect('equal', 'box')
if title is not None:
ax.set_title(title)
if legend is True:
ax.legend(loc="lower right")
if show is True:
plt.show()
def poly_fit(X, xs, do_plot = False):
"""
Given a Nx2 array X of 2D coordinates, fit an N^th order polynomial
and evaluate it at the coordinates in xs.
This function assumes that all of the points have a unique X position
"""
x = X[:, 0]
y = X[:, 1]
N = X.shape[0]
A = np.zeros((N, N))
for i in range(N):
A[:, i] = x**i
AInv = np.linalg.inv(A)
b = AInv.dot(y[:, None])
M = xs.size
Y = np.zeros((M, 2))
Y[:, 0] = xs
for i in range(N):
Y[:, 1] += b[i]*(xs**i)
if do_plot:
plt.plot(Y[:, 0], Y[:, 1], 'b')
plt.hold(True)
plt.scatter(X[:, 0], X[:, 1], 20, 'r')
plt.show()
return Y
def draw_curve(X, Y, linewidth):
"""
Draw a parabolic curve between two 2D points
Parameters
----------
X: list of [x, y]
First point
Y: list of [x, y]
Second point
linewidth: int
Thickness of line
"""
if Y[1] < X[1]:
X, Y = Y, X
[x1, y1, x3, y3] = [X[0], X[1], Y[0], Y[1]]
x2 = 0.5*x1 + 0.5*x3
y2 = 0.25*y1 + 0.75*y3
xs = np.linspace(x1, x3, 50)
X = np.array([[x1, y1], [x2, y2], [x3, y3]])
Y = poly_fit(X, xs, do_plot=False)
plt.plot(Y[:, 0], Y[:, 1], 'k', linewidth=linewidth)
class MergeNode(object):
def __init__(self, y, x=None):
"""
Parameters
----------
y: float
Height of node
x: float
x position of node (optional)
"""
self.children = []
self.x = x
self.y = y
self.idx = -1 # Inorder index
self.birth_death = []
self.is_globalmin = False
def get_coords(self, use_inorder):
"""
Return a list of the [x, y] coordinates of this node
Parameters
----------
use_inorder: boolean
If True, use the inorder coordinate for x. If false,
use a prespecified x coordinate if it exists
"""
coords = np.array([self.idx, self.y])
if not use_inorder:
if self.x or self.x == 0:
coords[0] = self.x
return coords
def inorder(self, idx):
"""
Perform a generalized inorder traversal
NOTE: This will sort child nodes arbitrarily if
their x coordinates have not been specified
Parameters
idx: list[1]
A count, by reference
"""
for child in sorted(self.children+[self], key=lambda c: c.x):
if self == child:
self.idx = idx[0]
idx[0] += 1
else:
child.inorder(idx)
def get_rep_timeseries(self, xs, ys, signs):
"""
Create a piecewise linear function that is
obtained from an inorder traversal of the y
coordinates of the nodes in this tree
Parameters
----------
xs: list of float
X coordinates of time series that I'm building
ys: list of float
Time series that I'm building
signs: list of [-1, 1]
A parallel list indicating local min (-1) or local max (+1)
"""
if len(self.children) == 0:
xs.append(self.x)
ys.append(self.y)
signs.append(-1)
for i, child in enumerate(sorted(self.children, key=lambda c: c.x)):
child.get_rep_timeseries(xs, ys, signs)
if i < len(self.children)-1:
# Put the max in between every adjacent pair of children
xs.append(self.x)
ys.append(self.y)
signs.append(1)
def persistence_simplify(self, eps):
"""
Remove all leaves that are under a certain persistence threshold
Parameters
----------
eps: Persistence threshold
"""
survived = True
if not self.is_globalmin and len(self.birth_death) == 2: # Leaf node
if self.birth_death[1] - self.birth_death[0] < eps:
survived = False
elif len(self.children) > 0:
self.children = [c for c in self.children if c.persistence_simplify(eps)]
if len(self.children) == 0:
survived = False
return survived
def delete_singletons(self):
"""
Delete nodes with a single child
"""
ret = self
if len(self.children) == 1:
ret = self.children[0].delete_singletons()
else:
for i, c in enumerate(self.children):
self.children[i] = c.delete_singletons()
return ret
def plot(self, use_inorder, params):
"""
Recursive helper method for plotting
Parameters
----------
use_inorder: boolean
If True, use the inorder coordinate for x. If false,
use a prespecified x coordinate if it exists
params: dict {
offset: [x, y]: Offset by which to plot this
draw_curved: boolean: If true, draw parabolic curved lines between nodes
linewidth: int: How thick to draw the edges
pointsize: int: How big to draw the nodes
plot_birthdeath: boolean: Whether to plot (birth, death) at leaf nodes
}
"""
offset = np.array([0, 0]) if not 'offset' in params else params['offset']
draw_curved = True if not 'draw_curved' in params else params['draw_curved']
linewidth = 3 if not 'linewidth' in params else params['linewidth']
pointsize = 200 if not 'pointsize' in params else params['pointsize']
plot_birthdeath = False if not 'plot_birthdeath' in params else params['plot_birthdeath']
X = np.array([self.x, self.y])
X = self.get_coords(use_inorder) + offset
plt.scatter(X[0], X[1], pointsize, 'k')
if len(self.birth_death) > 0 and plot_birthdeath:
plt.text(X[0], X[1], "{:.2f}, {:.2f}".format(*self.birth_death), c='r')
for child in self.children:
Y = child.get_coords(use_inorder) + offset
if draw_curved:
draw_curve(X, Y, linewidth)
else:
plt.plot([X[0], Y[0]], [X[1], Y[1]], 'k', lineWidth=linewidth)
child.plot(use_inorder, params)
def unionfind_root(pointers, u):
"""
Union find root operation with path-compression
Parameters
----------
pointers: list
A list of pointers to representative nodes
u: int
Index of the node to find
Returns
-------
Index of the representative of the component of u
"""
if not (pointers[u] == u):
pointers[u] = unionfind_root(pointers, pointers[u])
return pointers[u]
def unionfind_union(pointers, u, v, idxorder):
"""
Union find "union" with early birth-based merging
(similar to rank-based merging...not sure if exactly the
same theoretical running time)
Parameters
----------
pointers: list
A list of pointers to representative nodes
u: int
Index of first node
v: int
Index of the second node
idxorder: list
List of order in which each point shows up
"""
u = unionfind_root(pointers, u)
v = unionfind_root(pointers, v)
if u != v:
[ufirst, usecond] = [u, v]
if idxorder[v] < idxorder[u]:
[ufirst, usecond] = [v, u]
pointers[usecond] = ufirst
class MergeTree(object):
def __init__(self, x=np.array([])):
"""
Construct a new merge tree
Parameters
----------
x: ndarray(N)
Time series with which to initialize a merge tree.
If left blank, initialize an empty merge tree.
"""
if x.size > 0:
self.init_from_timeseries(x)
else:
self.root = None
self.PD = np.array([[]])
self.PDIdx = np.array([[]], dtype=int)
def get_rep_timeseries(self):
"""
Return a piecewise linear function that is
obtained from an inorder traversal of the y
coordinates of the nodes in this tree, as well as
a parallel array that indicates whether the points
are mins or maxes
Returns
-------
{
xs: ndarray(N): Coordinates of time series
ys: ndarray(N): Time series representing piecewise linear function,
with as many samples as there are nodes in the tree,
signs: ndarray(N): A parallel array of signs
}
"""
ys = []
xs = []
signs = []
if self.root:
self.root.get_rep_timeseries(xs, ys, signs)
return dict(xs=np.array(xs), ys=np.array(ys), signs=np.array(signs))
def persistence_simplify(self, eps):
"""
Remove all leaves that are under a certain persistence threshold
Parameters
----------
eps: Persistence threshold
"""
if self.root:
self.root.persistence_simplify(eps)
self.root.delete_singletons()
def plot(self, use_inorder, params={}):
"""
Draw this tree
Parameters
----------
use_inorder: boolean
If True, use the inorder coordinate for x. If false,
use a prespecified x coordinate if it exists
params: dict {
offset: [x, y]: Offset by which to plot this
draw_curved: boolean: If true, draw parabolic curved lines between nodes
linewidth: int: How thick to draw the edges
pointsize: int: How big to draw the nodes
plot_birthdeath: boolean: Whether to plot (birth, death) at leaf nodes
}
"""
if self.root:
if use_inorder:
idx = [0]
self.root.inorder(idx)
self.root.plot(use_inorder, params)
def plot_with_pd(self, use_inorder, params={}):
"""
Draw this tree alongslide its persistence diagram
Parameters
----------
use_inorder: boolean
If True, use the inorder coordinate for x. If false,
use a prespecified x coordinate if it exists
params: dict {
offset: [x, y]: Offset by which to plot this
draw_curved: boolean: If true, draw parabolic curved lines between nodes
linewidth: int: How thick to draw the edges
pointsize: int: How big to draw the nodes
plot_birthdeath: boolean: Whether to plot (birth, death) at leaf nodes
use_grid: boolean: Whether to draw grid lines
show_merge_xticks: Whether to show the x ticks for the merge tree
}
"""
if self.root:
use_grid = False if not 'use_grid' in params else params['use_grid']
show_merge_xticks = False if not 'show_merge_xticks' in params else params['show_merge_xticks']
yvals = np.sort(np.unique(self.get_rep_timeseries()['ys']))
dy = yvals[-1] - yvals[0]
plt.subplot(121)
self.plot(use_inorder, params)
plt.gca().set_yticks(yvals)
plt.ylim(yvals[0]-0.1*dy, yvals[-1]+0.1*dy)
if not show_merge_xticks:
plt.gca().set_xticks([])
if use_grid:
plt.grid()
plt.subplot(122)
plot_diagrams([self.PD])
plt.gca().set_yticks(np.unique(self.PD[:, 1]))
plt.ylim(yvals[0]-0.1*dy, yvals[-1]+0.1*dy)
plt.gca().set_xticks(np.unique(self.PD[:, 0]))
plt.xlim(yvals[0]-0.1*dy, yvals[-1]+0.1*dy)
if use_grid:
plt.grid()
def init_from_timeseries(self, y, include_essential=False, circular=False):
"""
Uses union find to make a merge tree object from the time series x
(NOTE: This code is pretty general and could work to create merge trees
on any domain if the neighbor set was updated)
Parameters
----------
y: ndarray(N)
1D array representing the time series
include_essential: bool
Whether to include the essential class
circular: boolean
Whether to assume that the domain wraps around circularly
Returns
-------
I: ndarray(N, 2)
H0 persistence diagram for this merge tree (also store locally
as a side effect)
"""
#Add points from the bottom up
N = len(y)
idx = np.argsort(y)
idxorder = np.zeros(N)
idxorder[idx] = np.arange(N)
pointers = np.arange(N) #Pointer to oldest indices
representatives = {} # Nodes that represent a connected component
leaves = {} # Leaf nodes
I = [] #Persistence diagram
IIdx = [] # Paired indices
for i in idx: # Go through each point in the time series in height order
neighbs = []
#Find the oldest representatives of the neighbors that
#are already alive
for di in [-1, 1]: #Neighbor set is simply left/right
if circular or (i+di >= 0 and i+di < N):
idx = i + di
if circular:
idx = idx % N
if idxorder[idx] < idxorder[i]:
neighbs.append(unionfind_root(pointers, idx))
if len(neighbs) == 0:
#If none of this point's neighbors are alive yet, this
#point will become alive with its own class
leaves[i] = MergeNode(y[i], i)
representatives[i] = leaves[i]
else:
#Find the oldest class, merge earlier classes with this class,
#and record the merge events and birth/death times
oldest_neighb = neighbs[np.argmin([idxorder[n] for n in neighbs])]
#No matter, what, the current node becomes part of the
#oldest class to which it is connected
unionfind_union(pointers, oldest_neighb, i, idxorder)
if len(neighbs) == 2: #A nontrivial merge
for n in neighbs:
if not (n == oldest_neighb):
#Create node and record persistence event if it's nontrivial
if y[i] > y[n]:
# Record persistence information
I.append([y[n], y[i]])
IIdx.append([n, i])
leaves[n].birth_death = (y[n], y[i])
# Create new node
node = MergeNode(y[i], i)
self.root = node
left_right = [representatives[n] for n in neighbs]
if left_right[0].x > left_right[1].x:
left_right = left_right[::-1]
node.children = left_right
#Change the representative for this class to be the new node
representatives[oldest_neighb] = node
unionfind_union(pointers, oldest_neighb, n, idxorder)
#Add the essential class
leaves[np.argmin(y)].is_globalmin = True
if include_essential:
idx1 = np.argmin(y)
idx2 = np.argmax(y)
[b, d] = [y[idx1], y[idx2]]
I.append([b, d])
IIdx.append([idx1, idx2])
leaves[idx1].birth_death = (b, d)
self.PD = np.array(I)
self.PDIdx = np.array(IIdx, dtype=int)
return self.PD, self.PDIdx
if __name__ == '__main__':
circular=False
np.random.seed(0)
N = 200
t = np.linspace(0.01, 0.98, N)
x = np.cos(2*np.pi*t*10) + t*10
x += 0.3*np.random.randn(N)
MT = MergeTree()
MT.init_from_timeseries(x)
rg = [np.min(x), np.max(x)]
pad = 0.1*(rg[1]-rg[0])
rg[0] -= pad
rg[1] += pad
fac = 0.6
plt.figure(figsize=(fac*20, fac*6))
for i, eps in enumerate(np.linspace(0, 1, 200)):
plt.clf()
plt.subplot(131)
MT.persistence_simplify(eps)
MT.plot(False, {'pointsize':10, 'linewidth':1})
plt.title("Simplified $\\epsilon = {:.3f}$".format(eps))
plt.ylim(rg)
plt.xlim([-1, x.size+1])
plt.subplot(132)
PD = MT.PD
plot_diagrams(MT.PD, sizes=4)
plt.plot([rg[0], rg[1]], [rg[0]+eps, rg[1]+eps], c='C3', linestyle='--')
PDEps = PD[PD[:, 1]-PD[:, 0] < eps, :]
plt.scatter(PDEps[:, 0], PDEps[:, 1], marker='x', c='C3')
plt.xlim(rg)
plt.ylim(rg)
plt.title("Persistence diagram")
plt.subplot(133)
res = MT.get_rep_timeseries()
plt.plot(res['xs'], res['ys'])
plt.title("Representative Time Series")
plt.ylim(rg)
plt.xlim([-1, x.size+1])
plt.savefig("MT{}.png".format(i))
@ctralie
Copy link
Author

ctralie commented Jul 26, 2022

MTSimplified

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment