Skip to content

Instantly share code, notes, and snippets.

@graipher
Last active May 25, 2019 08:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save graipher/bc94156fe6e740a55b49dcc1e631c027 to your computer and use it in GitHub Desktop.
Save graipher/bc94156fe6e740a55b49dcc1e631c027 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
from functools import partial
import timeit
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from itertools import starmap, count
from string import digits
from random import choices
def get_time(func, x):
timer = timeit.Timer(partial(func, x))
t = timer.repeat(repeat=5, number=1)
return np.min(t), np.std(t) / np.sqrt(len(t))
def get_times(func, inputs, star=False):
if star:
return np.array(list(starmap(partial(get_time, func), inputs)))
return np.array(list(map(partial(get_time, func), inputs)))
def get_df(funcs, inputs, key, star=False):
df = pd.DataFrame(list(map(key, inputs)), columns=["x"])
for i, func in enumerate(funcs):
label = str(i) if func.__name__ == "<lambda>" else func.__name__
df[label], df[label + "_err"] = get_times(func, inputs, star=star).T
return df
def counter():
c = count()
def wrapper(*args):
return next(c)
return wrapper
def identity(x):
return x
def plot_times(funcs, inputs, key=identity, xlabel="x", ylabel="Time [s]", logx=False, logy=False, ratio=False, star=False):
df = get_df(funcs, inputs, key, star)
for label in df.columns[1::2]:
x, y, yerr = df["x"], df[label], df[label + "_err"]
if ratio:
y, yerr = y / df.T.iloc[1], yerr / df.T.iloc[1]
plt.errorbar(x, y, yerr, fmt='o-', label=label)
plt.xlabel(xlabel)
if ratio:
ylabel = ylabel + " / " + df.columns[0]
plt.ylabel(ylabel)
if logx:
plt.xscale("log")
if logy:
plt.yscale("log")
plt.legend()
plt.show()
def count_even_digits_spyr03_for(n):
count = 0
for c in str(n):
if c in "02468":
count += 1
return count
def count_even_digits_spyr03_sum(n):
return sum(c in "02468" for c in str(n))
def count_even_digits_spyr03_sum2(n):
return sum(1 for c in str(n) if c in "02468")
def count_even_digits_spyr03_count_unrolled(n):
s = str(n)
return s.count("0") + s.count("2") + s.count("4") + s.count("6") + s.count("8")
if __name__ == "__main__":
x = [int("".join(choices(digits, k=n))) for n in np.logspace(1, 5, dtype=int)]
funcs = [count_even_digits_spyr03_for, count_even_digits_spyr03_sum,
count_even_digits_spyr03_sum2, count_even_digits_spyr03_count_unrolled]
plot_times(funcs, x, xlabel="$\log_{10} n$", logx=True, ratio=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment