Skip to content

Instantly share code, notes, and snippets.

@d0choa
Last active March 30, 2022 18:36
Show Gist options
  • Save d0choa/f7476b18f603fe9d4222af7b83c94b0b to your computer and use it in GitHub Desktop.
Save d0choa/f7476b18f603fe9d4222af7b83c94b0b to your computer and use it in GitHub Desktop.
Experimenting with coloc in pyspark
import pyspark.sql.functions as F
from pyspark import SparkConf
from pyspark.sql import SparkSession
from functools import reduce
sparkConf = SparkConf()
sparkConf = sparkConf.set('spark.hadoop.fs.gs.requester.pays.mode', 'AUTO')
sparkConf = sparkConf.set('spark.hadoop.fs.gs.requester.pays.project.id',
'open-targets-eu-dev')
# establish spark connection
spark = (
SparkSession.builder
.config(conf=sparkConf)
.master('local[*]')
.getOrCreate()
)
# credSetPath = "gs://genetics-portal-dev-staging/finemapping/220228_merged/credset/"
# credSet = (
# spark.read.json(credSetPath)
# .distinct()
# .withColumn("studyKey", F.concat_ws('_', *['study_id', 'phenotype_id', 'bio_feature']))
# )
credSet22Path = "gs://genetics-portal-dev-analysis/dsuveges/test_credible_set_chr22.parquet"
credSet = (
spark.read.parquet(credSet22Path)
.distinct()
.withColumn("studyKey", F.concat_ws('_', *['type', 'study_id', 'phenotype_id', 'bio_feature']))
)
# Priors
# priorc1 Prior on variant being causal for trait 1
# priorc2 Prior on variant being causal for trait 2
# priorc12 Prior on variant being causal for traits 1 and 2
priors = spark.createDataFrame([(1e-4, 1e-4, 1e-5)], ("priorc1", "priorc2", "priorc12"))
## TODO: calculate logABF from data, because Finngen studies (Sumstats) don't have logABF
## https://github.com/tobyjohnson/gtx/blob/9afa9597a51d0ff44536bc5c8eddd901ab3e867c/R/abf.R#L53
columnsToJoin = [
"studyKey","tag_variant_id", "lead_variant_id", "type", "logABF"
]
rename_columns = [
"studyKey", "lead_variant_id", "type", "logABF"
]
## TODO: Resolve so that biofeatures are always right
# Overlapping signals (exploded at the tag variant level)
leftDf = reduce(lambda DF, col:
DF.withColumnRenamed(col, 'left_' + col), rename_columns, credSet.select(columnsToJoin).distinct())
rightDf = reduce(lambda DF, col:
DF.withColumnRenamed(col, 'right_' + col), rename_columns, credSet.select(columnsToJoin).distinct())
overlappingPeaks = (
leftDf
# molecular traits always on the right-side
.filter(F.col("left_type") == "gwas")
# Get all study/peak pairs where at least one tagging variant overlap:
.join(rightDf, on='tag_variant_id', how='inner')
.filter(
# Remove rows with identical study:
(F.col('left_studyKey') != F.col('right_studyKey'))
)
# Keep only the upper triangle where both study is gwas
.filter(
(F.col('right_type') != 'gwas') |
(F.col('left_studyKey') > F.col('right_studyKey'))
)
# remove overlapping tag variant isnfo
.drop("left_logABF", "right_logABF", "tag_variant_id")
# distinct to get study-pair info
.distinct()
.persist()
)
overlappingLeft = (
overlappingPeaks
.join(
leftDf
.select("left_studyKey", "left_lead_variant_id", "tag_variant_id", "left_logABF"),
on=["left_studyKey", "left_lead_variant_id"],
how='inner'
)
)
overlappingRight = (
overlappingPeaks
.join(
rightDf
.select("right_studyKey", "right_lead_variant_id", "tag_variant_id", "right_logABF"),
on=['right_studyKey', 'right_lead_variant_id'],
how='inner'
)
)
overlappingSignals = (
overlappingLeft.alias("a")
.join(
overlappingRight.alias("b"),
on = [
"tag_variant_id",
"left_lead_variant_id",
"right_lead_variant_id",
"left_studyKey",
"right_studyKey",
"right_type",
"left_type"
],
how='outer'
)
)
signalPairsCols = ["studyKey", "lead_variant_id", "type"]
# Colocalisation analysis
coloc = (
overlappingSignals
# Before summarizing logABF columns nulls need to be filled with 0:
.fillna(0, subset=['left_logABF', 'right_logABF'])
# Grouping data by peak and collect list of the sums:
.withColumn('sum_logABF', F.col('left_logABF') + F.col('right_logABF'))
# TODO: group by key column and keep rest of columns
.groupBy(*["left_" + col for col in signalPairsCols] + ["right_" + col for col in signalPairsCols])
.agg(
F.count('*').alias('coloc_n_vars'),
F.collect_list(F.col('left_logABF')).alias('left_logABF_array'),
F.collect_list(F.col('right_logABF')).alias('right_logABF_array'),
F.collect_list(F.col('sum_logABF')).alias('sum_logABF_array')
)
# Calculating the logsum of the sum_logABF array:
.withColumn("max", F.array_max(F.col('left_logABF_array')))
.withColumn("exp", F.transform(F.col('left_logABF_array'), lambda x: F.exp(x - F.col("max"))))
.withColumn("left_logsum", F.col("max") + F.log(F.expr('AGGREGATE(exp, DOUBLE(0), (acc, x) -> acc + x)')))
# Drop unneeded column:
.drop('left_logABF_array', 'max', 'exp')
# Calculating the logsum of the sum_logABF array:
.withColumn("max", F.array_max(F.col('right_logABF_array')))
.withColumn("exp", F.transform(F.col('right_logABF_array'), lambda x: F.exp(x - F.col("max"))))
.withColumn("right_logsum", F.col("max") + F.log(F.expr('AGGREGATE(exp, DOUBLE(0), (acc, x) -> acc + x)')))
# Drop unneeded column:
.drop('right_logABF_array', 'max', 'exp')
# Calculating the logsum of the sum_logABF array:
.withColumn("max", F.array_max(F.col('sum_logABF_array')))
.withColumn("exp", F.transform(F.col('sum_logABF_array'), lambda x: F.exp(x - F.col("max"))))
.withColumn("logsum_left_right", F.col("max") + F.log(F.expr('AGGREGATE(exp, DOUBLE(0), (acc, x) -> acc + x)')))
# Drop unneeded column:
.drop('sum_logABF_array', 'max', 'exp')
# Add priors
.crossJoin(priors)
# h0-h2
.withColumn("lH0abf", F.lit(0))
.withColumn("lH1abf", F.log(F.col("priorc1")) + F.col("left_logsum"))
.withColumn("lH2abf", F.log(F.col("priorc2")) + F.col("right_logsum"))
# h3
.withColumn("sumlogsum", F.col("left_logsum") + F.col("right_logsum"))
.withColumn("max", F.greatest("sumlogsum", "logsum_left_right"))
.withColumn("logdiff", (
F.col("max") +
F.log(F.exp(F.col("sumlogsum") - F.col("max")) -
F.exp(F.col("logsum_left_right") - F.col("max")))))
.withColumn("lH3abf", F.log(F.col("priorc1")) + F.log(F.col("priorc2")) + F.col("logdiff"))
.drop("sumlogsum", "max", "logdiff")
# h4
.withColumn("lH4abf", F.log(F.col("priorc12")) + F.col("logsum_left_right"))
# posteriors
.withColumn("allABF", F.array(
F.col("lH0abf"),
F.col("lH1abf"),
F.col("lH2abf"),
F.col("lH3abf"),
F.col("lH4abf"))
)
.withColumn("max", F.array_max(F.col("allABF")))
.withColumn("exp", F.transform("allABF", lambda x: F.exp(x - F.col("max"))))
.withColumn("mydenom", F.col("max") + F.log(F.expr('AGGREGATE(exp, DOUBLE(0), (acc, x) -> acc + x)')))
# TODO; write this more nicely?
.withColumn("coloc_h0", F.exp(F.col("lH0abf") - F.col("mydenom")))
.withColumn("coloc_h1", F.exp(F.col("lH1abf") - F.col("mydenom")))
.withColumn("coloc_h2", F.exp(F.col("lH2abf") - F.col("mydenom")))
.withColumn("coloc_h3", F.exp(F.col("lH3abf") - F.col("mydenom")))
.withColumn("coloc_h4", F.exp(F.col("lH4abf") - F.col("mydenom")))
.withColumn("coloc_h4_h3", F.col("coloc_h4") / F.col("coloc_h3"))
.withColumn("coloc_log2_h4_h3", F.log2(F.col("coloc_h4_h3")))
# cleanup
.drop("allABF", "mydenom", "max", "exp")
)
## TODO: Add alphas ?
## https://github.com/tobyjohnson/gtx/blob/9afa9597a51d0ff44536bc5c8eddd901ab3e867c/R/coloc.R#L91
# For debugging
# (
# coloc
# .filter(
# (F.col("left_studyKey") == "NEALE2_20003_1140909872") &
# (F.col("right_studyKey") == "GTEx-sQTL_chr22:17791301:17806239:clu_21824:ENSG00000243156_Ovary") &
# (F.col("left_lead_variant_id") == "22:16590692:CAA:C") &
# (F.col("right_lead_variant_id") == "22:17806438:G:A"))
# .show(vertical = True)
# )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment