Skip to content

Instantly share code, notes, and snippets.

@d0choa
Created July 4, 2023 20:30
Show Gist options
  • Save d0choa/82ea3157a877e2d54fd9bce064c15e4b to your computer and use it in GitHub Desktop.
Save d0choa/82ea3157a877e2d54fd9bce064c15e4b to your computer and use it in GitHub Desktop.
distance clumping based on nested structures and densevectors
"""Prototype of distance based clumping."""
from typing import TYPE_CHECKING
import numpy as np
import pyspark.ml.functions as fml
import pyspark.sql.functions as f
from pyspark.ml.linalg import DenseVector, Vectors, VectorUDT
from pyspark.sql import SparkSession
if TYPE_CHECKING:
from numpy import ndarray
spark = SparkSession.builder.getOrCreate()
data = [
("s1", "chr1", 3, 2.0, False),
("s1", "chr1", 4, 3.0, False),
("s1", "chr1", 5, 4.0, True),
("s1", "chr1", 6, 2.0, False),
("s1", "chr1", 7, 3.0, False),
("s1", "chr1", 8, 4.0, False),
("s1", "chr1", 9, 4.5, False),
("s1", "chr1", 10, 6.0, True),
("s1", "chr1", 11, 5.0, False),
("s1", "chr1", 12, 3.0, False),
("s1", "chr1", 14, 2.0, True),
("s1", "chr1", 16, 2.5, False),
("s1", "chr1", 18, 3.0, True),
("s1", "chr1", 20, 1.5, False),
]
df = spark.createDataFrame(
data, ["studyId", "chromosome", "position", "negLogPValue", "isSemiIndex"]
).persist()
def _daniel_magic(vect: ndarray) -> DenseVector:
"""Daniel's magic.
Args:
vect (ndarray): Vector
Returns:
DenseVector
"""
# made up logic
return Vectors.dense(np.array(vect) > 3)
daniel_magic = f.udf(_daniel_magic, VectorUDT())
(
df.groupBy("studyId", "chromosome")
# collect the position and negLogPValue into a struct
.agg(
f.sort_array(f.collect_list(f.struct("negLogPValue", "position")), False).alias(
"snps"
)
)
# prepare dense vectors for daniel's magic
.withColumn(
"snps2",
fml.vector_to_array(
daniel_magic(
fml.array_to_vector(
f.transform(f.col("snps"), lambda x: x.negLogPvalue)
)
)
),
)
.withColumn(
"snps3",
f.zip_with(
f.col("snps"),
f.col("snps2"),
lambda x, y: f.struct(
x.negLogPValue.alias("negLogPValue"),
x.position.alias("position"),
y.alias("isSemiIndex"),
),
),
)
.drop("snps", "snps2")
# sort the struct by negLogPValue (descending)
# .withColumn("test", f.transform(f.col("snps"), lambda x: x.position))
# print the result
.show(100, False)
# .printSchema()
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment