Skip to content

Instantly share code, notes, and snippets.

@VincentRouvreau
Created February 2, 2022 09:49
Show Gist options
  • Save VincentRouvreau/db57e57ab12ba66dbc7a68b2242b6494 to your computer and use it in GitHub Desktop.
Save VincentRouvreau/db57e57ab12ba66dbc7a68b2242b6494 to your computer and use it in GitHub Desktop.
Not square cubical complexes
# Standard scientific Python imports
import numpy as np
# Standard scikit-learn imports
from sklearn.datasets import fetch_openml
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn import metrics
# Import TDA pipeline requirements
from gudhi.sklearn.cubical_persistence import CubicalPersistence
from gudhi.representations import PersistenceImage, DiagramSelector
X_, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)
X = np.zeros(shape=(70000, 840))
for idx in range(len(X)):
X[idx] = np.append(X_[idx],np.zeros(56))
# Target is: "is an eight ?"
y = (y == "8") * 1
print("There are", np.sum(y), "eights out of", len(y), "numbers.")
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)
pipe = Pipeline(
[
("cub_pers", CubicalPersistence(persistence_dimension=0, dimensions=[28, 30], n_jobs=-2)),
("finite_diags", DiagramSelector(use=True, point_type="finite")),
(
"pers_img",
PersistenceImage(bandwidth=50, weight=lambda x: x[1] ** 2, im_range=[0, 256, 0, 256], resolution=[20, 20]),
),
("svc", SVC()),
]
)
# Learn from the train subset
pipe.fit(X_train, y_train)
# Predict from the test subset
predicted = pipe.predict(X_test)
print(f"Classification report for TDA pipeline {pipe}:\n" f"{metrics.classification_report(y_test, predicted)}\n")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment