Skip to content

Instantly share code, notes, and snippets.

@d0choa
Last active August 5, 2021 11:24
Show Gist options
  • Save d0choa/16ebd8aaab6eea97feb189c1c135df3f to your computer and use it in GitHub Desktop.
Save d0choa/16ebd8aaab6eea97feb189c1c135df3f to your computer and use it in GitHub Desktop.
Find W2V similarities between terms in different EFO branches (e.g. disease. vs phenotype)
from pyspark.sql import SparkSession
from pyspark.context import SparkContext
from pyspark.sql import functions as F
from pyspark.ml.feature import Word2VecModel
from pyspark.sql.types import DoubleType
from pyspark.ml.feature import Normalizer
# establish spark connection
spark = (
SparkSession.builder
.config('spark.driver.maxResultSize', '0')
.config('spark.debug.maxToStringFields', '2000')
.config('spark.sql.execution.arrow.maxRecordsPerBatch', '500000')
.master('local[*]')
.getOrCreate()
)
sc = SparkContext.getOrCreate()
# load W2V model
modelPath = "gs://open-targets-data-releases/21.04/output/literature/W2VModel/"
model = Word2VecModel.load(modelPath)
# load diseases
diseasePath = "gs://open-targets-data-releases/21.04/output/etl/parquet/diseases"
disease = spark.read.parquet(diseasePath)
matchesPath = "gs://open-targets-data-releases/21.04/output/literature/matches"
matches = spark.read.parquet(matchesPath)
# non-obvious relationships (from Eric)
noPath = "gs://ot-team/dochoa/eric_nonobvious_trait_relationships.csv"
no = spark.read.csv(noPath, header=True)
# Examples of non-obvious relationships from Eric Fauman
test = (
no
.withColumn("trait_id", F.split(F.col("trait_id"), "\\|"))
.withColumn("valid_evidence", F.split(F.col("valid_evidence"), "\\|"))
.withColumn("trait_id", F.explode("trait_id"))
.withColumn("valid_evidence", F.explode("valid_evidence"))
.join(disease.selectExpr("id as trait_id", "name as trait_name"),
on="trait_id",
how="inner")
.join(disease.selectExpr("id as valid_evidence",
"name as evidence_name"),
on="valid_evidence",
how="inner")
# .show(truncate = False)
)
testOut = (
test
.select(F.col("valid_evidence").alias("i"),
F.col("trait_id").alias("j"))
.union(
test
.select(F.col("trait_id").alias("i"),
F.col("valid_evidence").alias("j"))
)
.distinct()
)
# Calculate term-term similarities using W2V
model = Word2VecModel.load(modelPath)
data = (
model
.getVectors()
.join(disease.select(F.col("id").alias("word")),
how="inner", on="word")
.withColumn("word", F.split(F.col("word"), ","))
)
normalizer = Normalizer(inputCol="vector", outputCol="norm")
normData = normalizer.transform(data)
dot_udf = F.udf(lambda x, y: float(x.dot(y)), DoubleType())
out = (
normData.alias("i")
.join(normData.alias("j"), F.col("i.word") < F.col("j.word"))
.select(
F.col("i.word").alias("i"),
F.col("j.word").alias("j"),
dot_udf("i.norm", "j.norm").alias("similarity"))
.withColumn("i", F.explode("i"))
.withColumn("j", F.explode("j"))
.sort("i", "j")
.persist()
)
# Classifying disease, measurement, phenotype, process
diseaseMetadata = (
disease
.withColumn("category",
F.when(F.array_contains("therapeuticAreas", "EFO_0000651"),
F.lit("phenotype")))
.withColumn("category",
F.when(F.array_contains("therapeuticAreas", "EFO_0001444"),
F.lit("measurement"))
.otherwise(F.col("category")))
.withColumn("category",
F.when(F.array_contains("therapeuticAreas", "GO_0008150"),
F.lit("process"))
.otherwise(F.col("category")))
.withColumn("category",
F.when(F.col("category").isNull(),
F.lit("disease"))
.otherwise(F.col("category")))
.select("id", "name", "category")
# .withColumn("isDisease", ~F.col("isDisease"))
)
# Information on metadata for matches in literature
matchesMetadata = (
matches
.filter(F.col("type") == "DS")
.groupBy("keywordId")
.agg(F.countDistinct("pmid").alias("countPMID"))
.select(F.col("keywordId").alias("id"),
"countPMID")
.persist()
)
# Relevant term-term similarities (between EFO branches)
toprint = (
out
.join(diseaseMetadata
.select(F.col("id").alias("i"),
F.col("name").alias("iName"),
F.col("category").alias("iCategory")),
how="left", on="i")
.join(diseaseMetadata
.select(F.col("id").alias("j"),
F.col("name").alias("jName"),
F.col("category").alias("jCategory")),
how="left", on="j")
.join(matchesMetadata
.select(F.col("id").alias("i"),
F.col("countPMID").alias("iCountPMID")),
how="left", on="i")
.join(matchesMetadata
.select(F.col("id").alias("j"),
F.col("countPMID").alias("jCountPMID")),
how="left", on="j")
.filter(F.col("iCategory") != F.col("jCategory"))
.sort(F.col("similarity").desc())
)
# gold-standard list of non-obvious relationships and their similarities
(
toprint.join(
testOut,
how="inner",
on=["i", "j"])
.sort(F.col("similarity").desc())
.show(40, truncate=False)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment