Last active
March 30, 2022 18:36
-
-
Save d0choa/f7476b18f603fe9d4222af7b83c94b0b to your computer and use it in GitHub Desktop.
Experimenting with coloc in pyspark
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pyspark.sql.functions as F | |
from pyspark import SparkConf | |
from pyspark.sql import SparkSession | |
from functools import reduce | |
sparkConf = SparkConf() | |
sparkConf = sparkConf.set('spark.hadoop.fs.gs.requester.pays.mode', 'AUTO') | |
sparkConf = sparkConf.set('spark.hadoop.fs.gs.requester.pays.project.id', | |
'open-targets-eu-dev') | |
# establish spark connection | |
spark = ( | |
SparkSession.builder | |
.config(conf=sparkConf) | |
.master('local[*]') | |
.getOrCreate() | |
) | |
# credSetPath = "gs://genetics-portal-dev-staging/finemapping/220228_merged/credset/" | |
# credSet = ( | |
# spark.read.json(credSetPath) | |
# .distinct() | |
# .withColumn("studyKey", F.concat_ws('_', *['study_id', 'phenotype_id', 'bio_feature'])) | |
# ) | |
credSet22Path = "gs://genetics-portal-dev-analysis/dsuveges/test_credible_set_chr22.parquet" | |
credSet = ( | |
spark.read.parquet(credSet22Path) | |
.distinct() | |
.withColumn("studyKey", F.concat_ws('_', *['type', 'study_id', 'phenotype_id', 'bio_feature'])) | |
) | |
# Priors | |
# priorc1 Prior on variant being causal for trait 1 | |
# priorc2 Prior on variant being causal for trait 2 | |
# priorc12 Prior on variant being causal for traits 1 and 2 | |
priors = spark.createDataFrame([(1e-4, 1e-4, 1e-5)], ("priorc1", "priorc2", "priorc12")) | |
## TODO: calculate logABF from data, because Finngen studies (Sumstats) don't have logABF | |
## https://github.com/tobyjohnson/gtx/blob/9afa9597a51d0ff44536bc5c8eddd901ab3e867c/R/abf.R#L53 | |
columnsToJoin = [ | |
"studyKey","tag_variant_id", "lead_variant_id", "type", "logABF" | |
] | |
rename_columns = [ | |
"studyKey", "lead_variant_id", "type", "logABF" | |
] | |
## TODO: Resolve so that biofeatures are always right | |
# Overlapping signals (exploded at the tag variant level) | |
leftDf = reduce(lambda DF, col: | |
DF.withColumnRenamed(col, 'left_' + col), rename_columns, credSet.select(columnsToJoin).distinct()) | |
rightDf = reduce(lambda DF, col: | |
DF.withColumnRenamed(col, 'right_' + col), rename_columns, credSet.select(columnsToJoin).distinct()) | |
overlappingPeaks = ( | |
leftDf | |
# molecular traits always on the right-side | |
.filter(F.col("left_type") == "gwas") | |
# Get all study/peak pairs where at least one tagging variant overlap: | |
.join(rightDf, on='tag_variant_id', how='inner') | |
.filter( | |
# Remove rows with identical study: | |
(F.col('left_studyKey') != F.col('right_studyKey')) | |
) | |
# Keep only the upper triangle where both study is gwas | |
.filter( | |
(F.col('right_type') != 'gwas') | | |
(F.col('left_studyKey') > F.col('right_studyKey')) | |
) | |
# remove overlapping tag variant isnfo | |
.drop("left_logABF", "right_logABF", "tag_variant_id") | |
# distinct to get study-pair info | |
.distinct() | |
.persist() | |
) | |
overlappingLeft = ( | |
overlappingPeaks | |
.join( | |
leftDf | |
.select("left_studyKey", "left_lead_variant_id", "tag_variant_id", "left_logABF"), | |
on=["left_studyKey", "left_lead_variant_id"], | |
how='inner' | |
) | |
) | |
overlappingRight = ( | |
overlappingPeaks | |
.join( | |
rightDf | |
.select("right_studyKey", "right_lead_variant_id", "tag_variant_id", "right_logABF"), | |
on=['right_studyKey', 'right_lead_variant_id'], | |
how='inner' | |
) | |
) | |
overlappingSignals = ( | |
overlappingLeft.alias("a") | |
.join( | |
overlappingRight.alias("b"), | |
on = [ | |
"tag_variant_id", | |
"left_lead_variant_id", | |
"right_lead_variant_id", | |
"left_studyKey", | |
"right_studyKey", | |
"right_type", | |
"left_type" | |
], | |
how='outer' | |
) | |
) | |
signalPairsCols = ["studyKey", "lead_variant_id", "type"] | |
# Colocalisation analysis | |
coloc = ( | |
overlappingSignals | |
# Before summarizing logABF columns nulls need to be filled with 0: | |
.fillna(0, subset=['left_logABF', 'right_logABF']) | |
# Grouping data by peak and collect list of the sums: | |
.withColumn('sum_logABF', F.col('left_logABF') + F.col('right_logABF')) | |
# TODO: group by key column and keep rest of columns | |
.groupBy(*["left_" + col for col in signalPairsCols] + ["right_" + col for col in signalPairsCols]) | |
.agg( | |
F.count('*').alias('coloc_n_vars'), | |
F.collect_list(F.col('left_logABF')).alias('left_logABF_array'), | |
F.collect_list(F.col('right_logABF')).alias('right_logABF_array'), | |
F.collect_list(F.col('sum_logABF')).alias('sum_logABF_array') | |
) | |
# Calculating the logsum of the sum_logABF array: | |
.withColumn("max", F.array_max(F.col('left_logABF_array'))) | |
.withColumn("exp", F.transform(F.col('left_logABF_array'), lambda x: F.exp(x - F.col("max")))) | |
.withColumn("left_logsum", F.col("max") + F.log(F.expr('AGGREGATE(exp, DOUBLE(0), (acc, x) -> acc + x)'))) | |
# Drop unneeded column: | |
.drop('left_logABF_array', 'max', 'exp') | |
# Calculating the logsum of the sum_logABF array: | |
.withColumn("max", F.array_max(F.col('right_logABF_array'))) | |
.withColumn("exp", F.transform(F.col('right_logABF_array'), lambda x: F.exp(x - F.col("max")))) | |
.withColumn("right_logsum", F.col("max") + F.log(F.expr('AGGREGATE(exp, DOUBLE(0), (acc, x) -> acc + x)'))) | |
# Drop unneeded column: | |
.drop('right_logABF_array', 'max', 'exp') | |
# Calculating the logsum of the sum_logABF array: | |
.withColumn("max", F.array_max(F.col('sum_logABF_array'))) | |
.withColumn("exp", F.transform(F.col('sum_logABF_array'), lambda x: F.exp(x - F.col("max")))) | |
.withColumn("logsum_left_right", F.col("max") + F.log(F.expr('AGGREGATE(exp, DOUBLE(0), (acc, x) -> acc + x)'))) | |
# Drop unneeded column: | |
.drop('sum_logABF_array', 'max', 'exp') | |
# Add priors | |
.crossJoin(priors) | |
# h0-h2 | |
.withColumn("lH0abf", F.lit(0)) | |
.withColumn("lH1abf", F.log(F.col("priorc1")) + F.col("left_logsum")) | |
.withColumn("lH2abf", F.log(F.col("priorc2")) + F.col("right_logsum")) | |
# h3 | |
.withColumn("sumlogsum", F.col("left_logsum") + F.col("right_logsum")) | |
.withColumn("max", F.greatest("sumlogsum", "logsum_left_right")) | |
.withColumn("logdiff", ( | |
F.col("max") + | |
F.log(F.exp(F.col("sumlogsum") - F.col("max")) - | |
F.exp(F.col("logsum_left_right") - F.col("max"))))) | |
.withColumn("lH3abf", F.log(F.col("priorc1")) + F.log(F.col("priorc2")) + F.col("logdiff")) | |
.drop("sumlogsum", "max", "logdiff") | |
# h4 | |
.withColumn("lH4abf", F.log(F.col("priorc12")) + F.col("logsum_left_right")) | |
# posteriors | |
.withColumn("allABF", F.array( | |
F.col("lH0abf"), | |
F.col("lH1abf"), | |
F.col("lH2abf"), | |
F.col("lH3abf"), | |
F.col("lH4abf")) | |
) | |
.withColumn("max", F.array_max(F.col("allABF"))) | |
.withColumn("exp", F.transform("allABF", lambda x: F.exp(x - F.col("max")))) | |
.withColumn("mydenom", F.col("max") + F.log(F.expr('AGGREGATE(exp, DOUBLE(0), (acc, x) -> acc + x)'))) | |
# TODO; write this more nicely? | |
.withColumn("coloc_h0", F.exp(F.col("lH0abf") - F.col("mydenom"))) | |
.withColumn("coloc_h1", F.exp(F.col("lH1abf") - F.col("mydenom"))) | |
.withColumn("coloc_h2", F.exp(F.col("lH2abf") - F.col("mydenom"))) | |
.withColumn("coloc_h3", F.exp(F.col("lH3abf") - F.col("mydenom"))) | |
.withColumn("coloc_h4", F.exp(F.col("lH4abf") - F.col("mydenom"))) | |
.withColumn("coloc_h4_h3", F.col("coloc_h4") / F.col("coloc_h3")) | |
.withColumn("coloc_log2_h4_h3", F.log2(F.col("coloc_h4_h3"))) | |
# cleanup | |
.drop("allABF", "mydenom", "max", "exp") | |
) | |
## TODO: Add alphas ? | |
## https://github.com/tobyjohnson/gtx/blob/9afa9597a51d0ff44536bc5c8eddd901ab3e867c/R/coloc.R#L91 | |
# For debugging | |
# ( | |
# coloc | |
# .filter( | |
# (F.col("left_studyKey") == "NEALE2_20003_1140909872") & | |
# (F.col("right_studyKey") == "GTEx-sQTL_chr22:17791301:17806239:clu_21824:ENSG00000243156_Ovary") & | |
# (F.col("left_lead_variant_id") == "22:16590692:CAA:C") & | |
# (F.col("right_lead_variant_id") == "22:17806438:G:A")) | |
# .show(vertical = True) | |
# ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment