Created
April 27, 2012 14:36
-
-
Save beaucronin/2509755 to your computer and use it in GitHub Desktop.
Examples of dependence beyond correlation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from numpy.random import ( | |
uniform as runif, | |
multivariate_normal as rmvn, | |
normal as rnorm | |
) | |
from numpy import inner, linspace, array, vstack | |
from math import pi, cos, sin, pow, sqrt | |
import matplotlib.pyplot as plt | |
import json | |
import csv | |
""" | |
This is a python port of the R script at | |
http://en.wikipedia.org/wiki/File:Correlation_examples2.svg | |
""" | |
def generate(): | |
""" | |
Create datasets like those from the Wikipedia article | |
""" | |
datasets = [] | |
N = 1000 | |
for corr in [1., .8, .4, 0., -.4, -.8, -1.]: | |
x = rmvn([0., 0.], [[1., corr], [corr, 1.]], N) | |
datasets.append(x) | |
N = 1000 | |
for phi in [0., pi/12., pi/6., pi/4., pi/2. - pi/6., pi/2. - pi/12, pi/2]: | |
x = rmvn([0., 0.], [[1., 1.], [1., 1.]], N) | |
x = rotate(phi, x) | |
datasets.append(x) | |
N = 1000 | |
a = linspace(-1, 1, N) | |
x = array([(x0, 4. * pow(x0 * x0 - .5, 2.) + runif(-1./3., 1./3., 1)) | |
for x0 in a]) | |
datasets.append(x) | |
x = rotate(-pi/8., array([(x0, runif(-1., 1.)) for x0 in a])) | |
datasets.append(x) | |
x = rotate(-pi/8, x) | |
datasets.append(x) | |
x = array([(x0, x0 * x0 + runif(-.5, .5)) for x0 in a]) | |
datasets.append(x) | |
signs = [1. if runif() < .5 else -1. for _ in range(N)] | |
x = array([(x0, (x0 * x0 + runif(0., .5)) * sign) | |
for x0, sign in zip(a, signs)]) | |
datasets.append(x) | |
x = array([(sin(x0 * pi) + rnorm(0., .125), cos(x0 * pi) + rnorm(0., .125)) | |
for x0 in a]) | |
datasets.append(x) | |
x = vstack(( | |
rmvn([3., 3], [[1., 0.], [0., 1.]], N/4), | |
rmvn([-3., 3], [[1., 0.], [0., 1.]], N/4), | |
rmvn([-3., -3], [[1., 0.], [0., 1.]], N/4), | |
rmvn([3., -3], [[1., 0.], [0., 1.]], N/4) | |
)) | |
datasets.append(x) | |
return datasets | |
def plot_datasets(datasets): | |
""" | |
Plot the datasets, mimicking the original plot from Wikipedia. | |
""" | |
plt.figure() | |
for i in range(len(datasets)): | |
plt.subplot(3, 7, i+1) | |
x = [a[0] for a in datasets[i]] | |
y = [a[1] for a in datasets[i]] | |
plt.plot(x, y, '.', markersize=1.) | |
# plt.axis('scaled') | |
plt.xticks([]) | |
plt.yticks([]) | |
ax = plt.gca() | |
ax.set_axis_off() | |
if i == 14: | |
plt.xlim([-1, 1]) | |
plt.ylim([-1./3., 1.+1./3.]) | |
elif i == 15: | |
z = sqrt(2. + sqrt(2.)) / sqrt(2.) | |
plt.xlim([-z, z]) | |
plt.ylim([-z, z]) | |
elif i == 16: | |
plt.xlim([-sqrt(2.), sqrt(2.)]) | |
plt.ylim([-sqrt(2.), sqrt(2.)]) | |
elif i == 17: | |
plt.xlim([-1, 1]) | |
plt.ylim([-.5, 1.5]) | |
elif i == 18: | |
plt.xlim([-1.5, 1.5]) | |
plt.ylim([-1.5, 1.5]) | |
elif i == 19: | |
plt.xlim([-1.5, 1.5]) | |
plt.ylim([-1.5, 1.5]) | |
elif i == 20: | |
plt.xlim([-7, 7]) | |
plt.ylim([-7, 7]) | |
else: | |
plt.xlim([-4, 4]) | |
plt.ylim([-4, 4]) | |
ax.set_aspect('equal', adjustable='datalim') | |
plt.savefig('out.pdf') | |
def save_datasets_to_json(datasets): | |
""" | |
Write the datasets to json. | |
""" | |
obj = [] | |
for dataset in datasets: | |
obj.append([{'_id': '{0:04}'.format(i), 'x': a[0], 'y': a[1]} | |
for i, a in enumerate(dataset)]) | |
with open('correlation_datasets.json', 'wb') as fd: | |
fd.write(json.dumps(obj, indent=2)) | |
def save_datasets_to_csv(datasets): | |
""" | |
Write the datasets to csv. | |
""" | |
with open('correlation_datasets.csv', 'wb') as fd: | |
wr = csv.writer(fd, delimiter=',') | |
wr.writerow(['dataset','x','y']) | |
for i, dataset in enumerate(datasets): | |
for datapoint in dataset: | |
wr.writerow([i, datapoint[0], datapoint[1]]) | |
def rotate(phi, x): | |
""" | |
Rotate 2D points by angle phi | |
""" | |
z = [[cos(phi), sin(phi)], [-sin(phi), cos(phi)]] | |
return inner(x, z) | |
def main(): | |
datasets = generate() | |
#plot_datasets(datasets) | |
save_datasets_to_json(datasets) | |
save_datasets_to_csv(datasets) | |
if __name__ == '__main__': | |
main() | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#From http://en.wikipedia.org/wiki/File:Correlation_examples2.svg | |
#Title: An example of the correlation of x and y for various distributions of (x,y) pairs | |
#Tags: Mathematics; Statistics; Correlation | |
#Author: Denis Boigelot | |
#Packets needed : mvtnorm (rmvnorm), RSVGTipsDevice (devSVGTips) | |
#How to use: output() | |
# | |
#This is an translated version in R of an Matematica 6 code by Imagecreator. | |
library(mvtnorm) | |
library(RSVGTipsDevice) | |
MyPlot <- function(xy, xlim = c(-4, 4), ylim = c(-4, 4), eps = 1e-15) { | |
title = round(cor(xy[,1], xy[,2]), 1) | |
if (sd(xy[,2]) < eps) title = "" # corr. coeff. is undefined | |
plot(xy, main = title, xlab = "", ylab = "", | |
col = "darkblue", pch = 16, cex = 0.2, | |
xaxt = "n", yaxt = "n", bty = "n", | |
xlim = xlim, ylim = ylim) | |
} | |
MvNormal <- function(n = 1000, cor = 0.8) { | |
for (i in cor) { | |
sd = matrix(c(1, i, i, 1), ncol = 2) | |
x = rmvnorm(n, c(0, 0), sd) | |
MyPlot(x) | |
} | |
} | |
rotation <- function(t, X) return(X %*% matrix(c(cos(t), sin(t), -sin(t), cos(t)), ncol = 2)) | |
RotNormal <- function(n = 1000, t = pi/2) { | |
sd = matrix(c(1, 1, 1, 1), ncol = 2) | |
x = rmvnorm(n, c(0, 0), sd) | |
for (i in t) | |
MyPlot(rotation(i, x)) | |
} | |
Others <- function(n = 1000) { | |
x = runif(n, -1, 1) | |
y = 4 * (x^2 - 1/2)^2 + runif(n, -1, 1)/3 | |
MyPlot(cbind(x,y), xlim = c(-1, 1), ylim = c(-1/3, 1+1/3)) | |
y = runif(n, -1, 1) | |
xy = rotation(-pi/8, cbind(x,y)) | |
lim = sqrt(2+sqrt(2)) / sqrt(2) | |
MyPlot(xy, xlim = c(-lim, lim), ylim = c(-lim, lim)) | |
xy = rotation(-pi/8, xy) | |
MyPlot(xy, xlim = c(-sqrt(2), sqrt(2)), ylim = c(-sqrt(2), sqrt(2))) | |
y = 2*x^2 + runif(n, -1, 1) | |
MyPlot(cbind(x,y), xlim = c(-1, 1), ylim = c(-1, 3)) | |
y = (x^2 + runif(n, 0, 1/2)) * sample(seq(-1, 1, 2), n, replace = TRUE) | |
MyPlot(cbind(x,y), xlim = c(-1.5, 1.5), ylim = c(-1.5, 1.5)) | |
y = cos(x*pi) + rnorm(n, 0, 1/8) | |
x = sin(x*pi) + rnorm(n, 0, 1/8) | |
MyPlot(cbind(x,y), xlim = c(-1.5, 1.5), ylim = c(-1.5, 1.5)) | |
xy1 = rmvnorm(n/4, c( 3, 3)) | |
xy2 = rmvnorm(n/4, c(-3, 3)) | |
xy3 = rmvnorm(n/4, c(-3, -3)) | |
xy4 = rmvnorm(n/4, c( 3, -3)) | |
MyPlot(rbind(xy1, xy2, xy3, xy4), xlim = c(-3-4, 3+4), ylim = c(-3-4, 3+4)) | |
} | |
output <- function() { | |
devSVGTips(width = 7, height = 3.2) # remove first and last line for no svg exporting | |
par(mfrow = c(3, 7), oma = c(0,0,0,0), mar=c(2,2,2,0)) | |
MvNormal(800, c(1.0, 0.8, 0.4, 0.0, -0.4, -0.8, -1.0)); | |
RotNormal(200, c(0, pi/12, pi/6, pi/4, pi/2-pi/6, pi/2-pi/12, pi/2)); | |
Others(800) | |
dev.off() # remove first and last line for no svg exporting | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import veritable | |
import json | |
import os | |
import sys | |
from time import sleep | |
from random import random, shuffle | |
from math import floor | |
from veritable.utils import wait_for_analysis, split_rows | |
import matplotlib.pyplot as plt | |
DATA_FILE = 'correlation_datasets.json' | |
API = veritable.connect() | |
SCHEMA = { | |
'x': { 'type': 'real' }, | |
'y': { 'type': 'real' } | |
} | |
TRAIN_FRAC = .75 | |
PRED_COUNT = 100 | |
ylim = [ | |
[-4., 4.], | |
[-4., 4.], | |
[-4., 4.], | |
[-4., 4.], | |
[-4., 4.], | |
[-4., 4.], | |
[-4., 4.], | |
[-4., 4.], | |
[-4., 4.], | |
[-4., 4.], | |
[-4., 4.], | |
[-4., 4.], | |
[-4., 4.], | |
[-4., 4.], | |
[-1., 2.], | |
[-2., 2.], | |
[-2., 2.], | |
[-1., 2.], | |
[-2., 2.], | |
[-2., 2.], | |
[-7., 7.], | |
] | |
def main(): | |
with open(DATA_FILE, 'rb') as fd: | |
datasets = json.loads(fd.read()) | |
# Iterate over the 21 datasets | |
for i, dataset in enumerate(datasets): | |
print 'Dataset', i | |
# Divide the dataset into train and test sets | |
dataset.sort(key=lambda a: a['x']) | |
train_dataset = [a for j, a in enumerate(dataset) if j % 2 == 0] | |
test_dataset = [a for j, a in enumerate(dataset) if j % 2 == 1] | |
# Create a table and add the training rows | |
print ' Creating table and adding rows' | |
table = API.create_table() | |
table.batch_upload_rows(train_dataset) | |
# Create an analysis and wait for it to complete | |
print ' Creating analysis' | |
analysis = table.create_analysis(SCHEMA) | |
analysis.wait() | |
# Perform predictions on the test set | |
print ' Predictions' | |
x_test = [a['x'] for a in test_dataset] | |
y_actual = [a['y'] for a in test_dataset] | |
y_pred = [] | |
for x in x_test: | |
result = analysis.predict({ 'x': x, 'y': None }, count=PRED_COUNT) | |
y_pred.append([a['y'] for a in result]) | |
make_plots(i, x_test, y_actual, y_pred) | |
def make_plots(i, x_test, y_actual, y_pred): | |
plt.figure() | |
for j in range(len(x_test)): | |
plt.plot([x_test[j]] * PRED_COUNT, y_pred[j], 'or', alpha=.02, ms=8, mew=0) | |
plt.plot(x_test, y_actual, '.b') | |
plt.ylim(ylim[i]) | |
plt.xticks([]) | |
plt.yticks([]) | |
plt.savefig('out_{:02}.pdf'.format(i)) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment