Created
July 4, 2023 20:30
-
-
Save d0choa/82ea3157a877e2d54fd9bce064c15e4b to your computer and use it in GitHub Desktop.
distance clumping based on nested structures and densevectors
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Prototype of distance based clumping.""" | |
from typing import TYPE_CHECKING | |
import numpy as np | |
import pyspark.ml.functions as fml | |
import pyspark.sql.functions as f | |
from pyspark.ml.linalg import DenseVector, Vectors, VectorUDT | |
from pyspark.sql import SparkSession | |
if TYPE_CHECKING: | |
from numpy import ndarray | |
spark = SparkSession.builder.getOrCreate() | |
data = [ | |
("s1", "chr1", 3, 2.0, False), | |
("s1", "chr1", 4, 3.0, False), | |
("s1", "chr1", 5, 4.0, True), | |
("s1", "chr1", 6, 2.0, False), | |
("s1", "chr1", 7, 3.0, False), | |
("s1", "chr1", 8, 4.0, False), | |
("s1", "chr1", 9, 4.5, False), | |
("s1", "chr1", 10, 6.0, True), | |
("s1", "chr1", 11, 5.0, False), | |
("s1", "chr1", 12, 3.0, False), | |
("s1", "chr1", 14, 2.0, True), | |
("s1", "chr1", 16, 2.5, False), | |
("s1", "chr1", 18, 3.0, True), | |
("s1", "chr1", 20, 1.5, False), | |
] | |
df = spark.createDataFrame( | |
data, ["studyId", "chromosome", "position", "negLogPValue", "isSemiIndex"] | |
).persist() | |
def _daniel_magic(vect: ndarray) -> DenseVector: | |
"""Daniel's magic. | |
Args: | |
vect (ndarray): Vector | |
Returns: | |
DenseVector | |
""" | |
# made up logic | |
return Vectors.dense(np.array(vect) > 3) | |
daniel_magic = f.udf(_daniel_magic, VectorUDT()) | |
( | |
df.groupBy("studyId", "chromosome") | |
# collect the position and negLogPValue into a struct | |
.agg( | |
f.sort_array(f.collect_list(f.struct("negLogPValue", "position")), False).alias( | |
"snps" | |
) | |
) | |
# prepare dense vectors for daniel's magic | |
.withColumn( | |
"snps2", | |
fml.vector_to_array( | |
daniel_magic( | |
fml.array_to_vector( | |
f.transform(f.col("snps"), lambda x: x.negLogPvalue) | |
) | |
) | |
), | |
) | |
.withColumn( | |
"snps3", | |
f.zip_with( | |
f.col("snps"), | |
f.col("snps2"), | |
lambda x, y: f.struct( | |
x.negLogPValue.alias("negLogPValue"), | |
x.position.alias("position"), | |
y.alias("isSemiIndex"), | |
), | |
), | |
) | |
.drop("snps", "snps2") | |
# sort the struct by negLogPValue (descending) | |
# .withColumn("test", f.transform(f.col("snps"), lambda x: x.position)) | |
# print the result | |
.show(100, False) | |
# .printSchema() | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment