Skip to content

Instantly share code, notes, and snippets.

@ireneisdoomed
Last active February 20, 2023 15:25
Show Gist options
  • Save ireneisdoomed/9553e111e458159b826a602183173194 to your computer and use it in GitHub Desktop.
Save ireneisdoomed/9553e111e458159b826a602183173194 to your computer and use it in GitHub Desktop.
Function to perform clumping in Spark using Graphframes - Implementation by @DSuveges extracted from this commit https://github.com/opentargets/genetics_etl_python/commit/31cab8de2d0aa206211e578aa0fb701dd5e064b2
def resolve_graph(df: DataFrame) -> DataFrame:
"""Graph resolver for clumping.
It takes a dataframe with a list of variants and their explained variants, and returns a dataframe
with a list of variants and their resolved roots
Args:
df (DataFrame): DataFrame
Returns:
A dataframe with the resolved roots.
"""
# Convert to vertices:
nodes = df.select(
"studyId",
"variantId",
# Generating node identifier column (has to be unique):
f.concat_ws("_", f.col("studyId"), f.col("variantId")).alias("id"),
# Generating the original root list. This is the original message which is propagated across nodes:
f.when(f.col("variantId") == f.col("explained"), f.col("variantId")).alias(
"origin_root"
),
).distinct()
# Convert to edges (more significant points to less significant):
edges = (
df.filter(f.col("variantId") != f.col("explained"))
.select(
f.concat_ws("_", f.col("studyId"), f.col("variantId")).alias("dst"),
f.concat_ws("_", f.col("studyId"), f.col("explained")).alias("src"),
f.lit("explains").alias("edgeType"),
)
.distinct()
)
# Building graph:
graph = GraphFrame(nodes, edges)
# Extracing nodes with edges - most of the
filtered_nodes = (
graph.outDegrees.join(graph.inDegrees, on="id", how="outer")
.drop("outDegree", "inDegree")
.join(nodes, on="id", how="inner")
.repartition("studyId", "variantId")
)
# Building graph:
graph = GraphFrame(filtered_nodes, edges)
# Pregel resolver:
resolved_nodes = (
graph.pregel.setMaxIter(5)
# New column for the resolved roots:
.withVertexColumn(
"message",
f.when(f.col("origin_root").isNotNull(), f.col("origin_root")),
f.when(Pregel.msg().isNotNull(), Pregel.msg()),
)
.withVertexColumn(
"resolved_roots",
# The value is initialized by the original root value:
f.when(
f.col("origin_root").isNotNull(), f.array(f.col("origin_root"))
).otherwise(f.array()),
# When new value arrives to the node, it gets merged with the existing list:
f.when(
Pregel.msg().isNotNull(),
f.array_union(f.split(Pregel.msg(), " "), f.col("resolved_roots")),
).otherwise(f.col("resolved_roots")),
)
# We need to reinforce the message in both direction:
.sendMsgToDst(Pregel.src("message"))
# Once the message is delivered it is updated with the existing list of roots at the node:
.aggMsgs(f.concat_ws(" ", f.collect_set(Pregel.msg())))
.run()
.orderBy("studyId", "id")
.persist()
)
# Joining back the dataset:
return df.join(
# The `resolved_roots` column will be null for nodes, with no connection.
resolved_nodes.select("resolved_roots", "studyId", "variantId"),
on=["studyId", "variantId"],
how="left",
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment