Skip to content

Instantly share code, notes, and snippets.

@jasonsahl
Created April 29, 2016 20:53
Show Gist options
  • Save jasonsahl/6dd3939d3bb8c83f74f5ec5eac665280 to your computer and use it in GitHub Desktop.
Save jasonsahl/6dd3939d3bb8c83f74f5ec5eac665280 to your computer and use it in GitHub Desktop.
Calculate Kruskal-Wallis stat from BSR matrix
#!/usr/bin/env python
"""Compare groups with a Kruskal-Wallis
test to find associations
Requires Python >2.6 and < 3.0
Written by Jason Sahl
groups file looks like:
genome1 phenotype1
genome2 phenotype2
If there is no genotype, do not enter this information into the
groups file
"""
from __future__ import division
import sys
try:
from scipy.stats.mstats import kruskalwallis as kw
import numpy as np
from scipy import stats
from scipy.stats.stats import rankdata, square_of_sums, tiecorrect
except:
print "Numpy and Scipy must be installed...exiting"
sys.exit()
import optparse
import os
import subprocess
chisqprob = stats.chi2.sf
def test_file(option, opt_str, value, parser):
try:
with open(value): setattr(parser.values, option.dest, value)
except IOError:
print 'genes file cannot be opened'
sys.exit()
def parse_groups_file(groups):
meta_dict = {}
for line in open(groups, "U"):
newline = line.strip()
fields = newline.split()
"""0 indicates missing data and will be ignored"""
if fields[1] == "0":
pass
else:
try:
meta_dict[fields[1]].append(fields[0])
except KeyError:
meta_dict[fields[1]] = [fields[0]]
return meta_dict
def process_matrix(matrix,group_dict,permutations):
in_matrix = open(matrix, "U")
outfile = open("KW_results", "w")
group_indexes = {}
firstLine = in_matrix.readline()
fields = firstLine.split()
fields.insert(0, "cluster")
outfile.write("marker_name"+"\t"+"H-stat"+"\t"+"p-chi2"+"\t"+"p-perm"+"\t"+"FDR-perms"+"\t"+"FDR-chi"+"\n")
for x in fields:
for k,v in group_dict.iteritems():
for z in v:
if x == z:
try:
group_indexes[k].append(fields.index(x))
except KeyError:
group_indexes[k] = [fields.index(x)]
mean_dict = {}
for line in in_matrix:
group_values = {}
newline = line.strip()
fields = newline.split()
for k,v in group_indexes.iteritems():
"""in this case, z is each individual index"""
for z in v:
try:
group_values[k].append(fields[z])
except KeyError:
group_values[k] = [fields[z]]
"""make a list of list for each metadata category"""
tmp_list = []
"""for this list, it doesn't matter what order that they're in"""
value_list = []
"""Here, I am losing track of group information"""
for k,v in group_values.iteritems():
value_list.append(v)
float_results = map(float, v)
mean_values = str(sum(float_results)/len(float_results))
try:
mean_dict[k].append(mean_values)
except KeyError:
mean_dict[k] = [mean_values]
results = kruskal(value_list,permutations)
str_results = map(str, results)
str_results.insert(0,fields[0])
outfile.write("\t".join(str_results)+"\n")
return mean_dict
outfile.close()
def kruskal(data, permutations):
"""
Compute the Kruskal-Wallis H-test for independent samples
The Kruskal-Wallis H-test tests the null hypothesis that the population
median of all of the groups are equal. It is a non-parametric version of
ANOVA. The test works on 2 or more independent samples, which may have
different sizes. Note that rejecting the null hypothesis does not
indicate which of the groups differs. Post-hoc comparisons between
groups are required to determine which groups are different.
Parameters
----------
data : list of array_like
An iterable with two or more arrays with the sample measurements can
be given as arguments.
nrep : int
Number of random permutations to be used in the p-value calculations.
If nrep=0, then the permutation p-value is not calculated.
Returns
-------
H-statistic : float
The Kruskal-Wallis H statistic, corrected for ties
p-value-chi2 : float
The p-value for the test using the assumption that H has a chi
square distribution
p-value-perm : float
The p-value for the test based on random permutation
Notes
-----
Due to the assumption that H has a chi square distribution, the number
of samples in each group must not be too small. A typical rule is
that each sample must have at least 5 measurements.
In contrast to scipy.stats.kruskal, the samples are only a single
argument and need to be in an iterable (tuple, list or similar).
References
----------
.. [1] http://en.wikipedia.org/wiki/Kruskal-Wallis_one-way_analysis_of_variance
"""
args = map(np.asarray, data) # convert to a numpy array
na = len(args) # Kruskal-Wallis on 'na' groups, each in it's own array
if na < 2:
raise ValueError("Need at least two groups in stats.kruskal()")
n = map(len, args)
alldata = np.concatenate(args)
ranked = rankdata(alldata) # Rank the data
T = tiecorrect(ranked) # Correct for ties
if T == 0:
#raise ValueError('All numbers are identical in kruskal')
pass
j = np.insert(np.cumsum(n),0,0)
ssbn = 0
for i in range(na):
ssbn += square_of_sums(ranked[j[i]:j[i+1]])/float(n[i]) # Compute sum^2/n for each group
totaln = np.sum(n)
h = 12.0 / (totaln*(totaln+1)) * ssbn - 3*(totaln+1)
df = len(args) - 1
h = h / float(T)
#return h, chisqprob(h,df)
pval = np.nan
if permutations > 0:
#bootstrap by ckuster
# functional approximation (current implementation)
# pval = chisqprob(h,df)
# approximation via simulation (alternative, very slow implementation)
tot_count = 0
pas_count = 0
# There are a few ways of deciding when to end the following loop
# 1) convergence of p-value (this could take a long time)
# 2) maximum iteration count
# 3) maximum time spent
# I think SPSS (now PASW) has at least 1 and 3 from this list asoptions
while (tot_count < permutations): # this is the stupid way
np.random.shuffle(ranked)
ssbn = 0.0
for i in range(na):
ssbn = ssbn + square_of_sums(ranked[j[i]:j[i+1]])/float(n[i])
# Compute sum^2/n for each group
totaln = np.sum(n)
htest = 12.0 / (totaln*(totaln+1)) * ssbn - 3*(totaln+1)
htest = htest / float(T)
tot_count += 1
if htest >= h:
pas_count += 1
pval = float(pas_count)/float(tot_count)
return h, chisqprob(h,df), pval
def correct_pvalues_for_multiple_testing(pvalues, correction_type = "Benjamini-Hochberg"):
"""
consistent with R - print correct_pvalues_for_multiple_testing([0.0, 0.01, 0.029, 0.03, 0.031, 0.05, 0.069, 0.07, 0.071, 0.09, 0.1])
"""
from numpy import array, empty
pvalues = array(pvalues)
n = float(pvalues.shape[0])
new_pvalues = empty(n)
if correction_type == "Bonferroni":
new_pvalues = n * pvalues
elif correction_type == "Bonferroni-Holm":
values = [ (pvalue, i) for i, pvalue in enumerate(pvalues) ]
values.sort()
for rank, vals in enumerate(values):
pvalue, i = vals
new_pvalues[i] = (n-rank) * pvalue
elif correction_type == "Benjamini-Hochberg":
values = [ (pvalue, i) for i, pvalue in enumerate(pvalues) ]
values.sort()
values.reverse()
new_values = []
for i, vals in enumerate(values):
rank = n - i
pvalue, index = vals
new_values.append((n/rank) * pvalue)
for i in xrange(0, int(n)-1):
if new_values[i] < new_values[i+1]:
new_values[i+1] = new_values[i]
for i, vals in enumerate(values):
pvalue, index = vals
new_pvalues[index] = new_values[i]
return new_pvalues
def calculate_fdr(kw_output):
chi_ps = {}
perm_ps = {}
chi_out = open("chi.xyx.tmp", "w")
perms_out = open("perms.xyx.tmp", "w")
for line in open(kw_output, "rU"):
if line.startswith("marker_name"):
pass
else:
newline = line.strip()
fields = newline.split()
if fields[0] == "marker_name":
pass
elif fields[1] == "nan":
pass
else:
"""This is to keep track of potentially missing data"""
chi_ps.update({fields[0]:float(fields[2])})
perm_ps.update({fields[0]:float(fields[3])})
chi_ps_reduced = []
perm_ps_reduced = []
chi_tmp = []
perm_tmp = []
chi_combined_list = []
perm_combined_list = []
for k,v in chi_ps.iteritems():
chi_ps_reduced.append((v))
chi_tmp.append(k)
chi_combined_list.append(chi_tmp)
for k,v in perm_ps.iteritems():
perm_ps_reduced.append(v)
perm_tmp.append(k)
perm_combined_list.append(perm_tmp)
"""Here's where the corrections are performed"""
chi_fdr = correct_pvalues_for_multiple_testing(chi_ps_reduced)
perm_fdr = correct_pvalues_for_multiple_testing(perm_ps_reduced)
chi_tmp = []
perm_tmp = []
for result in chi_fdr:
chi_tmp.append(str(result))
for result in perm_fdr:
perm_tmp.append(str(result))
chi_combined_list.append(chi_tmp)
perm_combined_list.append(perm_tmp)
test=map(list, zip(*chi_combined_list))
test_1=map(list, zip(*perm_combined_list))
for x in test:
print >> chi_out, "\t".join(x)
for x in test_1:
print >> perms_out, "\t".join(x)
def merge_results(KW,perms,chi):
outfile = open("merged_results.txt", "w")
for line in open(KW, "rU"):
new_fields = []
newline = line.strip()
fields = newline.split()
new_fields.append(fields)
for line in open(perms, "U"):
perm_temp = []
perm_fields = line.split()
if fields[0] == perm_fields[0]:
perm_temp.append(perm_fields[1])
if len(perm_temp)>=1:
new_fields.append(perm_temp)
for line in open(chi, "U"):
chi_temp = []
chi_fields = line.split()
if fields[0] == chi_fields[0]:
chi_temp.append(chi_fields[1])
if len(chi_temp)>=1:
new_fields.append(chi_temp)
flat = [x for sublist in new_fields for x in sublist]
outfile.write("\t".join(flat)+"\n")
outfile.close()
def main(matrix,groups,permutations):
my_groups = parse_groups_file(groups)
mean_dict = process_matrix(matrix,my_groups,permutations)
"""dump out the mean_list into a usable format"""
tmp_length = []
for k,v in mean_dict.iteritems():
tmp_length.append(len(v))
length = tmp_length[0]
mean_file = open("mean.xyx.tmp", "w")
for k,v in mean_dict.iteritems():
mean_file.write("%s_mean\t" % str(k))
mean_file.write("\n")
for i in range(0,length):
line_list = []
for k,v in mean_dict.iteritems():
line_list.append(v[i])
mean_file.write("\t".join(line_list)+"\n")
mean_file.close()
calculate_fdr("KW_results")
merge_results("KW_results", "perms.xyx.tmp", "chi.xyx.tmp")
os.system("rm perms.xyx.tmp chi.xyx.tmp KW_results")
os.system("paste merged_results.txt mean.xyx.tmp > b.tmp")
os.system("head -1 b.tmp > a.tmp")
os.system("tail -n+2 b.tmp | sort -gr -k 2,2 >> a.tmp >> a.tmp")
os.system("mv a.tmp merged_results.txt")
os.system("rm mean.xyx.tmp b.tmp")
if __name__ == "__main__":
usage="usage: %prog [options]"
parser = optparse.OptionParser(usage=usage)
parser.add_option("-b", "--matrix", dest="matrix",
help="/path/to/BSR_matrix [REQUIRED]",
type="string", action="callback", callback=test_file)
parser.add_option("-g", "--groups", dest="groups",
help="/path/to/groups file [REQUIRED]",
type="string", action="callback", callback=test_file)
parser.add_option("-p", "--permutations", dest="permutations",
help="number of perms, defaults to 100",
type="int", action="store", default=100)
options, args = parser.parse_args()
mandatories = ["matrix","groups"]
for m in mandatories:
if not getattr(options, m, None):
print "\nMust provide %s.\n" %m
parser.print_help()
exit(-1)
main(options.matrix,options.groups,options.permutations)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment