Last active
July 7, 2016 03:26
-
-
Save leetschau/12f342f13df4998fea314f78671e8c2e to your computer and use it in GitHub Desktop.
在Anaconda环境下运行,详细分析见"Data Science from Scratch"第10章读书笔记
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import random | |
import numpy as np | |
import matplotlib.pyplot as plt | |
def normal_cdf(x, mu=0, sigma=1): | |
return (1 + math.erf((x - mu) / math.sqrt(2) / sigma)) / 2 | |
def inverse_normal_cdf(p, mu=0, sigma=1, tolerance=0.00001): | |
"""find approximate inverse using binary search""" | |
# if not standard, compute standard and rescale | |
if mu != 0 or sigma != 1: | |
return mu + sigma * inverse_normal_cdf(p, tolerance=tolerance) | |
low_z, low_p = -10.0, 0 # normal_cdf(-10) is (very close to) 0 | |
hi_z, hi_p = 10.0, 1 # normal_cdf(10) is (very close to) 1 | |
while hi_z - low_z > tolerance: | |
mid_z = (low_z + hi_z) / 2 # consider the midpoint | |
mid_p = normal_cdf(mid_z) # and the cdf's value there | |
if mid_p < p: | |
# midpoint is still too low, search above it | |
low_z, low_p = mid_z, mid_p | |
elif mid_p > p: | |
# midpoint is still too high, search below it | |
hi_z, hi_p = mid_z, mid_p | |
else: | |
break | |
return mid_z | |
def random_normal(): | |
"""returns a random draw from a standard normal distribution""" | |
return inverse_normal_cdf(random.random()) | |
def make_scatterplot_matrix(): | |
# first, generate some random data | |
num_points = 100 | |
def random_row(): | |
row = [None, None, None, None] | |
row[0] = random_normal() | |
row[1] = -5 * row[0] + random_normal() | |
row[2] = row[0] + row[1] + 5 * random_normal() | |
row[3] = 6 if row[2] > -2 else 0 | |
return row | |
random.seed(0) | |
data = np.array([random_row() for _ in range(num_points)]) | |
# then plot it | |
num_columns = data.shape[1] | |
fig, ax = plt.subplots(num_columns, num_columns) | |
for i in range(num_columns): | |
for j in range(num_columns): | |
# scatter column_i on the x-axis vs column_j on the y-axis | |
if i != j: | |
print('column %s on the x-axis:' % i) | |
print(data[:, i]) | |
print('column %s on the y-axis:' % j) | |
print(data[:, j]) | |
print('subplot ax[%s][%s]:' % (i, j)) | |
ax[i][j].scatter(data[:, i], data[:, j]) | |
# unless i == j, in which case show the series name | |
else: | |
ax[i][j].annotate("series " + str(i), | |
(0.5, 0.5), | |
xycoords='axes fraction', | |
ha="center", | |
va="center") | |
# then hide axis labels except left and bottom charts | |
if i < num_columns - 1: | |
ax[i][j].xaxis.set_visible(False) | |
if j > 0: | |
ax[i][j].yaxis.set_visible(False) | |
# fix the bottom right and top left axis labels, which are wrong because | |
# their charts only have text in them | |
ax[-1][-1].set_xlim(ax[0][-1].get_xlim()) | |
ax[0][0].set_ylim(ax[0][1].get_ylim()) | |
plt.show() | |
plt.scatter(data[:, 3], data[:, 2], label='row2 vs row3') | |
plt.xlabel('row2') | |
plt.ylabel('row3') | |
xs = range(-1, 8) | |
plt.plot(xs, [-2 for _ in xs], ':', label='y=-2') | |
plt.legend(loc=9) | |
plt.show() | |
make_scatterplot_matrix() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment