Skip to content

Instantly share code, notes, and snippets.

@leetschau
Last active July 7, 2016 03:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save leetschau/12f342f13df4998fea314f78671e8c2e to your computer and use it in GitHub Desktop.
Save leetschau/12f342f13df4998fea314f78671e8c2e to your computer and use it in GitHub Desktop.
在Anaconda环境下运行,详细分析见"Data Science from Scratch"第10章读书笔记
import math
import random
import numpy as np
import matplotlib.pyplot as plt
def normal_cdf(x, mu=0, sigma=1):
return (1 + math.erf((x - mu) / math.sqrt(2) / sigma)) / 2
def inverse_normal_cdf(p, mu=0, sigma=1, tolerance=0.00001):
"""find approximate inverse using binary search"""
# if not standard, compute standard and rescale
if mu != 0 or sigma != 1:
return mu + sigma * inverse_normal_cdf(p, tolerance=tolerance)
low_z, low_p = -10.0, 0 # normal_cdf(-10) is (very close to) 0
hi_z, hi_p = 10.0, 1 # normal_cdf(10) is (very close to) 1
while hi_z - low_z > tolerance:
mid_z = (low_z + hi_z) / 2 # consider the midpoint
mid_p = normal_cdf(mid_z) # and the cdf's value there
if mid_p < p:
# midpoint is still too low, search above it
low_z, low_p = mid_z, mid_p
elif mid_p > p:
# midpoint is still too high, search below it
hi_z, hi_p = mid_z, mid_p
else:
break
return mid_z
def random_normal():
"""returns a random draw from a standard normal distribution"""
return inverse_normal_cdf(random.random())
def make_scatterplot_matrix():
# first, generate some random data
num_points = 100
def random_row():
row = [None, None, None, None]
row[0] = random_normal()
row[1] = -5 * row[0] + random_normal()
row[2] = row[0] + row[1] + 5 * random_normal()
row[3] = 6 if row[2] > -2 else 0
return row
random.seed(0)
data = np.array([random_row() for _ in range(num_points)])
# then plot it
num_columns = data.shape[1]
fig, ax = plt.subplots(num_columns, num_columns)
for i in range(num_columns):
for j in range(num_columns):
# scatter column_i on the x-axis vs column_j on the y-axis
if i != j:
print('column %s on the x-axis:' % i)
print(data[:, i])
print('column %s on the y-axis:' % j)
print(data[:, j])
print('subplot ax[%s][%s]:' % (i, j))
ax[i][j].scatter(data[:, i], data[:, j])
# unless i == j, in which case show the series name
else:
ax[i][j].annotate("series " + str(i),
(0.5, 0.5),
xycoords='axes fraction',
ha="center",
va="center")
# then hide axis labels except left and bottom charts
if i < num_columns - 1:
ax[i][j].xaxis.set_visible(False)
if j > 0:
ax[i][j].yaxis.set_visible(False)
# fix the bottom right and top left axis labels, which are wrong because
# their charts only have text in them
ax[-1][-1].set_xlim(ax[0][-1].get_xlim())
ax[0][0].set_ylim(ax[0][1].get_ylim())
plt.show()
plt.scatter(data[:, 3], data[:, 2], label='row2 vs row3')
plt.xlabel('row2')
plt.ylabel('row3')
xs = range(-1, 8)
plt.plot(xs, [-2 for _ in xs], ':', label='y=-2')
plt.legend(loc=9)
plt.show()
make_scatterplot_matrix()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment