Last active
September 13, 2022 08:39
-
-
Save HiddenBeginner/d127063e5691805d2f51bc80304a9e6b to your computer and use it in GitHub Desktop.
A Giotto-tda compatible implementation of persistence vectors introduced in [C. Bresten & J.-H Jung, 2019]
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from joblib import Parallel, delayed, effective_n_jobs | |
from gtda.utils.intervals import Interval | |
from gtda.utils.validation import validate_params, check_diagrams | |
from gtda.diagrams._utils import _subdiagrams, _bin, _make_homology_dimensions_mapping, _homology_dimensions_to_sorted_ints | |
from sklearn.base import BaseEstimator, TransformerMixin | |
from sklearn.utils import gen_even_slices | |
from sklearn.utils.validation import check_is_fitted | |
class PersistenceVector(BaseEstimator, TransformerMixin): | |
""" | |
A Giotto-tda compatible implementation of persistence vectors introduced in [1] | |
Parameters | |
---------- | |
topk: int, default=20 | |
The number of top-K persistences taken as a vector. If a persistence diagram has less than topk points, | |
the gap is filled with zeros. | |
References | |
---------- | |
.. [1] C. Bresten & J.-H Jung, "Detection of gravitational waves using topological data analysis and | |
convolutional neural network: An improved approach", arXiv:1910.08245 [astro-ph.IM] | |
""" | |
_hyperparameters = { | |
"topk": {"type": int, "in": Interval(1, np.inf, closed="left")} | |
} | |
def __init__(self, topk=20, n_jobs=None): | |
self.topk = topk | |
self.n_jobs = n_jobs | |
def fit(self, X, y=None): | |
"""Store all observed homology dimensions in | |
:attr:`homology_dimensions_`. Then, return the estimator. | |
This method is here to implement the usual scikit-learn API and hence | |
work in pipelines. | |
Parameters | |
---------- | |
X : ndarray of shape (n_samples, n_features, 3) | |
Input data. Array of persistence diagrams, each a collection of | |
triples [b, d, q] representing persistent topological features | |
through their birth (b), death (d) and homology dimension (q). | |
It is important that, for each possible homology dimension, the | |
number of triples for which q equals that homology dimension is | |
constants across the entries of X. | |
y : None | |
There is no need for a target in a transformer, yet the pipeline | |
API requires this parameter. | |
Returns | |
------- | |
self : object | |
""" | |
X = check_diagrams(X) | |
validate_params( | |
self.get_params(), self._hyperparameters, exclude=["n_jobs"]) | |
# Find the unique homology dimensions in the 3D array X passed to `fit` | |
# assuming that they can all be found in its zero-th entry | |
homology_dimensions_fit = np.unique(X[0, :, 2]) | |
self.homology_dimensions_ = \ | |
_homology_dimensions_to_sorted_ints(homology_dimensions_fit) | |
self._n_dimensions = len(self.homology_dimensions_) | |
return self | |
def transform(self, X, y=None): | |
check_is_fitted(self) | |
X = check_diagrams(X) | |
# Following the conventions of other classes in gtda.diagrams.representations | |
Xt = Parallel(n_jobs=self.n_jobs)(delayed(self.compute_persistence_vector)( | |
_subdiagrams(X[s], [dim], remove_dim=True)) | |
for s in gen_even_slices(len(X), effective_n_jobs(self.n_jobs)) | |
for dim in self.homology_dimensions_) | |
Xt = np.concatenate(Xt).reshape(self._n_dimensions, len(X), -1).transpose((1, 0, 2)) | |
return Xt | |
def compute_persistence_vector(self, diagrams): | |
# If a diagram has less than self.topk points, insert dummpy points (0, 0). | |
if len(diagrams[0]) < self.topk: | |
pad = self.topk - len(diagrams[0]) | |
diagrams = np.pad(diagrams, pad_width=np.array(((0, 0), (0, pad), (0, 0)))) | |
# Computing the intervals | |
persistences = diagrams[:, :, 1] - diagrams[:, :, 0] | |
# Sorting persistences in descending order (np.sort is an ascending order) | |
persistences = -np.sort(-persistences) | |
return persistences[:, :self.topk] | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment