Last active
February 20, 2023 15:25
-
-
Save ireneisdoomed/9553e111e458159b826a602183173194 to your computer and use it in GitHub Desktop.
Function to perform clumping in Spark using Graphframes - Implementation by @DSuveges extracted from this commit https://github.com/opentargets/genetics_etl_python/commit/31cab8de2d0aa206211e578aa0fb701dd5e064b2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def resolve_graph(df: DataFrame) -> DataFrame: | |
"""Graph resolver for clumping. | |
It takes a dataframe with a list of variants and their explained variants, and returns a dataframe | |
with a list of variants and their resolved roots | |
Args: | |
df (DataFrame): DataFrame | |
Returns: | |
A dataframe with the resolved roots. | |
""" | |
# Convert to vertices: | |
nodes = df.select( | |
"studyId", | |
"variantId", | |
# Generating node identifier column (has to be unique): | |
f.concat_ws("_", f.col("studyId"), f.col("variantId")).alias("id"), | |
# Generating the original root list. This is the original message which is propagated across nodes: | |
f.when(f.col("variantId") == f.col("explained"), f.col("variantId")).alias( | |
"origin_root" | |
), | |
).distinct() | |
# Convert to edges (more significant points to less significant): | |
edges = ( | |
df.filter(f.col("variantId") != f.col("explained")) | |
.select( | |
f.concat_ws("_", f.col("studyId"), f.col("variantId")).alias("dst"), | |
f.concat_ws("_", f.col("studyId"), f.col("explained")).alias("src"), | |
f.lit("explains").alias("edgeType"), | |
) | |
.distinct() | |
) | |
# Building graph: | |
graph = GraphFrame(nodes, edges) | |
# Extracing nodes with edges - most of the | |
filtered_nodes = ( | |
graph.outDegrees.join(graph.inDegrees, on="id", how="outer") | |
.drop("outDegree", "inDegree") | |
.join(nodes, on="id", how="inner") | |
.repartition("studyId", "variantId") | |
) | |
# Building graph: | |
graph = GraphFrame(filtered_nodes, edges) | |
# Pregel resolver: | |
resolved_nodes = ( | |
graph.pregel.setMaxIter(5) | |
# New column for the resolved roots: | |
.withVertexColumn( | |
"message", | |
f.when(f.col("origin_root").isNotNull(), f.col("origin_root")), | |
f.when(Pregel.msg().isNotNull(), Pregel.msg()), | |
) | |
.withVertexColumn( | |
"resolved_roots", | |
# The value is initialized by the original root value: | |
f.when( | |
f.col("origin_root").isNotNull(), f.array(f.col("origin_root")) | |
).otherwise(f.array()), | |
# When new value arrives to the node, it gets merged with the existing list: | |
f.when( | |
Pregel.msg().isNotNull(), | |
f.array_union(f.split(Pregel.msg(), " "), f.col("resolved_roots")), | |
).otherwise(f.col("resolved_roots")), | |
) | |
# We need to reinforce the message in both direction: | |
.sendMsgToDst(Pregel.src("message")) | |
# Once the message is delivered it is updated with the existing list of roots at the node: | |
.aggMsgs(f.concat_ws(" ", f.collect_set(Pregel.msg()))) | |
.run() | |
.orderBy("studyId", "id") | |
.persist() | |
) | |
# Joining back the dataset: | |
return df.join( | |
# The `resolved_roots` column will be null for nodes, with no connection. | |
resolved_nodes.select("resolved_roots", "studyId", "variantId"), | |
on=["studyId", "variantId"], | |
how="left", | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment