Created
February 27, 2022 02:07
-
-
Save Micky774/2d4260db54e45d5904f655c09a546e82 to your computer and use it in GitHub Desktop.
FastICA Whiten Benchmarks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import numpy as np | |
import pandas as pd | |
import argparse | |
from scipy import linalg | |
# import streamlit as st | |
# import altair as alt | |
parser = argparse.ArgumentParser( | |
description="Determine whether to save/load the dataframe." | |
) | |
parser.add_argument( | |
"--save", | |
type=str, | |
default="", | |
metavar="s", | |
help="Saves the dataframe to the path if provided", | |
) | |
parser.add_argument( | |
"--load", | |
type=str, | |
default="", | |
metavar="l", | |
help="Loads the dataframe at the path if provided", | |
) | |
args = parser.parse_args() | |
def main() -> None: | |
df = None | |
if args.load: | |
df = on_load(args.load) | |
else: | |
X_shapes = [] | |
for i in range(3): | |
X_shapes.extend((int(10 ** (2 + i)), int(10 ** (1 + j))) for j in range(3)) | |
solvers = { | |
"svd": linalg.svd, | |
"eigh": linalg.eigh, | |
} | |
total_reps = len(solvers) * len(X_shapes) | |
count = 0 | |
data = [] | |
for shape in X_shapes: | |
XT = np.random.rand(*shape).T | |
count += 1 | |
start = time.time() | |
solvers["svd"](XT) | |
print(f"Progress: {count}/{total_reps}") | |
svd_time = time.time() - start | |
count += 1 | |
start = time.time() | |
solvers["eigh"](XT.dot(XT.T)) | |
print(f"Progress: {count}/{total_reps}") | |
eigh_time = time.time() - start | |
data.append({"shape": str(shape), "svd": svd_time, "eigh": eigh_time, "eigh/svd": eigh_time/svd_time}) | |
df = pd.DataFrame(data) | |
if args.save and not args.load: | |
df.to_csv(args.save, index=False) | |
print(f"Dataframe saved to {args.save}") | |
""" | |
chart = ( | |
alt.Chart(df, width=300) | |
.mark_bar() | |
.encode(x="shape", y=["svd","eigh"], column="shape") | |
.properties(title="time by shape") | |
).resolve_scale(y="independent") | |
st.altair_chart(chart) | |
""" | |
def on_load(pth: str) -> pd.DataFrame: | |
print(f"Dataframe loaded from {args.load}") | |
df = pd.read_csv(pth) | |
df["shape"] = df["shape"].astype("string") | |
return df | |
if __name__ == "__main__": | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
shape | svd | eigh | eigh/svd | |
---|---|---|---|---|
(100, 10) | 0.0090076923370361 | 0.0010011196136474 | 0.11114052036737407 | |
(100, 100) | 0.0030028820037841 | 0.0020010471343994 | 0.6663755458515415 | |
(100, 1000) | 0.017014741897583 | 0.1161057949066162 | 6.8238352133398745 | |
(1000, 10) | 0.0100126266479492 | 0.0 | 0.0 | |
(1000, 100) | 0.0160133838653564 | 0.002002477645874 | 0.12505024938583972 | |
(1000, 1000) | 0.2785623073577881 | 0.1808667182922363 | 0.6492864020541349 | |
(10000, 10) | 1.422199249267578 | 0.0 | 0.0 | |
(10000, 100) | 12.595245122909546 | 0.0060043334960937 | 0.0004767143026992298 | |
(10000, 1000) | 9.553653001785278 | 0.2872602939605713 | 0.03006811048160231 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment