Last active
March 28, 2023 07:11
-
-
Save ireneisdoomed/29942887f259ab20634ab09aa09ea237 to your computer and use it in GitHub Desktop.
Translate feature matrix to a long format dataset of type L2GFeature
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Iterable, Optional | |
from pyspark.sql import DataFrame, SparkSession | |
import pyspark.sql.functions as f | |
spark = SparkSession.builder.getOrCreate() | |
fm = spark.read.parquet("gs://genetics-portal-dev-staging/l2g/221107/features/output/features.raw.221107.parquet") | |
fm.printSchema() | |
""" | |
root | |
|-- gene_id: string (nullable = true) | |
|-- study_id: string (nullable = true) | |
|-- chrom: string (nullable = true) | |
|-- pos: long (nullable = true) | |
|-- ref: string (nullable = true) | |
|-- alt: string (nullable = true) | |
|-- dhs_prmtr_max: double (nullable = true) | |
... | |
|-- dist_tss_min_nbh: double (nullable = true) | |
|-- dist_tss_ave: double (nullable = true) | |
|-- dist_tss_ave_nbh: double (nullable = true) | |
|-- gene_count_lte_50k: long (nullable = true) | |
|-- gene_count_lte_100k: long (nullable = true) | |
|-- gene_count_lte_250k: long (nullable = true) | |
""" | |
# Filter feature matrix to the features I've generated so far | |
fm_filtered = ( | |
fm | |
.withColumn("variantId", f.concat_ws("_", "chrom", "pos", "ref", "alt")) | |
.withColumn("studyLocusId", f.xxhash64(*["study_id", "variantId"])) | |
.selectExpr( | |
"gene_id as geneId", | |
"studyLocusId", | |
"dist_tss_min", | |
"dist_tss_ave", | |
"eqtl_pics_clpp_max as eqtl_max_coloc_clpp_local", | |
"eqtl_pics_clpp_max_nhb as eqtl_max_coloc_clpp_nbh", | |
"sqtl_pics_clpp_max as sqtl_max_coloc_clpp_local", | |
"sqtl_pics_clpp_max_nhb as sqtl_max_coloc_clpp_nbh", | |
"pqtl_pics_clpp_max as pqtl_max_coloc_clpp_local", | |
"pqtl_pics_clpp_max_nhb as pqtl_max_coloc_clpp_nbh", | |
"eqtl_coloc_llr_max as eqtl_max_coloc_llr_local", | |
"eqtl_coloc_llr_max_nbh as eqtl_max_coloc_llr_nbh", | |
"sqtl_coloc_llr_max as sqtl_max_coloc_llr_local", | |
"sqtl_coloc_llr_max_nbh as sqtl_max_coloc_llr_nbh", | |
"pqtl_coloc_llr_max as pqtl_max_coloc_llr_local", | |
"pqtl_coloc_llr_max_nbh as pqtl_max_coloc_llr_nbh", | |
) | |
) | |
fm_filtered.printSchema() | |
""" | |
root | |
|-- geneId: string (nullable = true) | |
|-- studyLocusId: long (nullable = false) | |
|-- dist_tss_min: double (nullable = true) | |
|-- dist_tss_ave: double (nullable = true) | |
|-- eqtl_max_coloc_clpp_local: double (nullable = true) | |
|-- eqtl_max_coloc_clpp_nbh: double (nullable = true) | |
|-- sqtl_max_coloc_clpp_local: double (nullable = true) | |
|-- sqtl_max_coloc_clpp_nbh: double (nullable = true) | |
|-- pqtl_max_coloc_clpp_local: double (nullable = true) | |
|-- pqtl_max_coloc_clpp_nbh: double (nullable = true) | |
|-- eqtl_max_coloc_llr_local: double (nullable = true) | |
|-- eqtl_max_coloc_llr_nbh: double (nullable = true) | |
|-- sqtl_max_coloc_llr_local: double (nullable = true) | |
|-- sqtl_max_coloc_llr_nbh: double (nullable = true) | |
|-- pqtl_max_coloc_llr_local: double (nullable = true) | |
|-- pqtl_max_coloc_llr_nbh: double (nullable = true) | |
""" | |
# Convert to L2GFeature format | |
def _convert_from_wide_to_long( | |
df: DataFrame, | |
id_vars: Iterable[str], | |
var_name: str, | |
value_name: str, | |
value_vars: Optional[Iterable[str]] = None, | |
) -> DataFrame: | |
"""Converts a dataframe from wide to long format using Pandas melt built-in function. | |
The Pandas df schema needs to be parsed to account for the cases where the df is empty and Spark cannot infer the schema. | |
Args: | |
df (DataFrame): Dataframe to melt | |
id_vars (Iterable[str]): List of fixed columns to keep | |
var_name (str): Name of the column containing the variable names | |
value_name (str): Name of the column containing the values | |
value_vars (Optional[Iterable[str]]): List of columns to melt. Defaults to None. | |
Returns: | |
DataFrame: Melted dataframe | |
Examples: | |
>>> df = spark.createDataFrame([("a", 1, 2)], ["id", "feature_1", "feature_2"]) | |
>>> _convert_from_wide_to_long(df, ["id"], "feature", "value").show() | |
+---+---------+-----+ | |
| id| feature|value| | |
+---+---------+-----+ | |
| a|feature_1| 1| | |
| a|feature_2| 2| | |
+---+---------+-----+ | |
<BLANKLINE> | |
""" | |
if not value_vars: | |
value_vars = [c for c in df.columns if c not in id_vars] | |
_vars_and_vals = f.array( | |
*( | |
f.struct(f.lit(c).alias(var_name), f.col(c).alias(value_name)) | |
for c in value_vars | |
) | |
) | |
# Add to the DataFrame and explode to convert into rows | |
_tmp = df.withColumn("_vars_and_vals", f.explode(_vars_and_vals)) | |
cols = list(id_vars) + [ | |
f.col("_vars_and_vals")[x].alias(x) for x in [var_name, value_name] | |
] | |
return _tmp.select(*cols) | |
long_fm = _convert_from_wide_to_long( | |
fm_filtered, | |
id_vars=("studyLocusId", "geneId"), | |
var_name="feature", | |
value_name="value", | |
) | |
long_fm.count() | |
# 57_177_260 | |
long_fm.printSchema() | |
""" | |
root | |
|-- studyLocusId: long (nullable = false) | |
|-- geneId: string (nullable = true) | |
|-- feature: string (nullable = false) | |
|-- value: double (nullable = true) | |
""" | |
long_fm.repartition(400).write.parquet("gs://genetics_etl_python_playground/input/l2g_input/feature_matrix_221107_transformed") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment