# Benchmarking peak-finding annotation

1. Get test dataset.
 1. Select studies w. summary statistics
 2. Read data + apply p-value filter.
 3. Save filtered dataset for read later.
2. Read test dataset, apply peak finder algorithm
3. Record time complexity.


## Generating test set

The following summary statistics are read:

```
GCST90002392
GCST90018969
```

- The datast is saved to: `gs://ot-team/dsuveges/sumstats_test_significant`
- Number of GWAS significant associations: 538,829

## Conclusions


|Method |Execution time |
|:-------------------|:--------------------|
|UDF with vectors | 3min |
"name": "stderr",
"output_type": "stream",
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/07/04 20:48:44 INFO SparkEnv: Registering MapOutputTracker
23/07/04 20:48:44 INFO SparkEnv: Registering BlockManagerMaster
23/07/04 20:48:44 INFO SparkEnv: Registering BlockManagerMasterHeartbeat
23/07/04 20:48:44 INFO SparkEnv: Registering OutputCommitCoordinator
"source": "\n\n# Studies to ingest:\nstudy_list = [\n \"GCST90089135\",\n \"GCST90026279\",\n \"GCST90024477\",\n \"GCST90079246\",\n \"GCST90089628\",\n \"GCST90092955\",\n \"GCST90026202\",\n \"GCST90044125\",\n \"GCST90086294\",\n \"GCST90080742\",\n \"GCST90087657\",\n \"GCST90079633\",\n \"GCST90086745\",\n \"GCST90086824\",\n \"GCST90086743\",\n \"GCST90082237\",\n \"GCST90086249\",\n \"GCST90086305\",\n \"GCST90018775\",\n \"GCST90083295\",\n \"GCST90002392\",\n \"GCST90018969\",\n \"GCST90013976\",\n \"GCST90002346\"\n]\n\n# Read data:\nsumstats = (\n # Reading all summary statistics:\n*[f'{summary_stats}/{study}' for study in study_list], recursiveFileLookup=True)\n .filter(\n # Filter for snps with position available:\n f.col('position').isNotNull() & \n # Filter snps for GWAS significance:\n (f.col('pValue').cast(t.DoubleType()) <= pv_threshold)\n )\n # Parser p-values provided as text:\n .select('*', *parse_pvalue(f.col('pValue')))\n # Select relevant columns + calculate neg-log p-value:\n .select(\n 'studyId', \n 'variantId', \n 'chromosome', \n 'position', \n 'rsId', \n 'pValue', \n 'pValueMantissa',\n 'pValueExponent',\n calculate_neglog_pvalue(f.col('pValueMantissa'), f.col('pValueExponent')).alias('negLogPValue')\n )\n .repartition(200)\n .persist()\n)\n\, False, True)\nsumstats.printSchema()\nsumstats.count()",
"name": "stdout",
"output_type": "stream",
-RECORD 0------------------------
 studyId | GCST90002392 
 variantId | 1_65456862_G_C 
 chromosome | 1 
 position | 65456862 
 rsId | rs61779779 
 pValue | 1.1E-11 
 pValueMantissa | 1.1 
 pValueExponent | -11 
 negLogPValue | 10.959 
only showing top 1 row

root
 |-- studyId: string (nullable = true)
 |-- variantId: string (nullable = true)
 |-- chromosome: string (nullable = true)
 |-- position: integer (nullable = true)
 |-- rsId: string (nullable = true)
 |-- pValue: string (nullable = true)
 |-- pValueMantissa: double (nullable = true)
 |-- pValueExponent: integer (nullable = true)
 |-- negLogPValue: double (nullable = true)
"name": "stdout",
"output_type": "stream",
+------------+
| studyId|
+------------+
|GCST90002392|
|GCST90002346|
|GCST90018775|
|GCST90018969|
|GCST90013976|
|GCST90092955|
|GCST90087657|
|GCST90086745|
|GCST90086249|
|GCST90024477|
|GCST90086305|
|GCST90086294|
|GCST90026202|
|GCST90026279|
|GCST90086743|
+------------+
## UDF with vectors

- Positions are collected on a window sorted by negLogPV.
- Positions are converted to dense vectors. <- we need to pack the variant ids as well.
- Looped operations are executed on vectors.
- Unpack values after processing.
"id": "d8e65307",
"cell_type": "code",
"source": "def cluster_peaks(df: DataFrame, window_length: int) -> DataFrame:\n \"\"\"Cluster GWAS significant variants, were clusters are speparated by a defined distance.\n \n !! Important to note that the length of the clusters can be arbitrarily big.\n \n Args:\n df (DataFrame): table with studyId, chromosome and position columns\n window_length (int): a minimal basepair distance required between snps to call for a new cluster\n\n Returns:\n DataFrame: with cluster_id column added. \n\n Examples:\n >>> data = [\n ... # Cluster 1:\n ... ('s1', 'chr1', 2),\n ... ('s1', 'chr1', 4),\n ... ('s1', 'chr1', 12),\n ... # Cluster 2 - Same chromosome:\n ... ('s1', 'chr1', 31),\n ... ('s1', 'chr1', 38),\n ... ('s1', 'chr1', 42),\n ... # Cluster 3 - New chromosome:\n ... ('s1', 'chr2', 41),\n ... ('s1', 'chr2', 44),\n ... ('s1', 'chr2', 50),\n ... # Cluster 4 - other study:\n ... ('s2', 'chr2', 55),\n ... ('s2', 'chr2', 62),\n ... ('s2', 'chr2', 70),\n ... ]\n >>> window_length = 10\n >>> (\n ... spark.createDataFrame(data, ['studyId', 'chromosome', 'position'])\n ... .transform(lambda df: cluster_peaks(df, window_length))\n ... .show()\n ... ) \n +-------+----------+--------+----------+\n |studyId|chromosome|position|cluster_id|\n +-------+----------+--------+----------+\n | s1| chr1| 2| s1_chr1_2|\n | s1| chr1| 4| s1_chr1_2|\n | s1| chr1| 12| s1_chr1_2|\n | s1| chr1| 31|s1_chr1_31|\n | s1| chr1| 38|s1_chr1_31|\n | s1| chr1| 42|s1_chr1_31|\n | s1| chr2| 41|s1_chr2_41|\n | s1| chr2| 44|s1_chr2_41|\n | s1| chr2| 50|s1_chr2_41|\n | s2| chr2| 55|s2_chr2_55|\n | s2| chr2| 62|s2_chr2_55|\n | s2| chr2| 70|s2_chr2_55|\n +-------+----------+--------+----------+\n <BLANKLINE>\n \"\"\"\n return (\n df\n # By adding previous position, the cluster boundary can be identified: \n .withColumn(\n 'previous_position', \n f.lag('position').over(Window.partitionBy('studyId', 'chromosome').orderBy('position'))\n )\n # We consider a cluster boudary if subsequent snps are further than the defined window:\n .withColumn(\n 'cluster_id',\n f.when(\n (f.col('previous_position').isNull()) | \n (f.col('position')- f.col('previous_position') > window_length), \n f.concat_ws('_', f.col('studyId'), f.col('chromosome'), f.col('position'))\n )\n )\n # The cluster identifier is propagated across every variant of the cluster:\n .withColumn(\n 'cluster_id',\n f.when(\n f.col('cluster_id').isNull(),\n f.last(\"cluster_id\", ignorenulls = True).over(\n Window.partitionBy('studyId', 'chromosome')\n .orderBy('position')\n .rowsBetween(Window.unboundedPreceding, Window.currentRow))\n )\n .otherwise(\n f.col('cluster_id')\n )\n )\n .drop('previous_position')\n )\n\n\n\n# def find_peak_w_window(window_size: int) -> callable:\n\n@f.udf(VectorUDT())\ndef find_peak(position: ndarray, window_size: int) -> DenseVector:\n \"\"\"\n \n \"\"\"\n # Initializing the lead list with zeroes:\n is_lead = np.zeros(len(position))\n\n # List containing indices of leads:\n lead_indices = []\n\n # Looping through all positions:\n for index in range(len(position)):\n # Looping through leads to find out if they are within a window:\n for lead_index in lead_indices:\n # If any of the leads within the window:\n if abs(position[lead_index] - position[index]) < window_size:\n # Skipping further checks:\n break\n else:\n # None of the leads were within the window:\n lead_indices.append(index)\n is_lead[index] = 1\n\n return DenseVector(is_lead)\n\n# return find_peak\n\n\n\n# stepped_window = (\n# Window\n# .partitionBy('window_id')\n# .orderBy(f.col('negLogPValue').desc())\n# )\n\n# complete_window = (\n# stepped_window\n# .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)\n# )\n\n# sumstats = (\n#'gs://ot-team/dsuveges/ingested_sumstats_3000_studies_significant/')\n# # Parser p-values provided as text:\n# .select('*', *parse_pvalue(f.col('pValue')))\n# # Select relevant columns + calculate neg-log p-value:\n# .select(\n# 'studyId', \n# 'variantId', \n# 'chromosome', \n# 'position', \n# 'rsId', \n# 'pValue', \n# 'pValueMantissa',\n# 'pValueExponent',\n# calculate_neglog_pvalue(f.col('pValueMantissa'), f.col('pValueExponent')).alias('negLogPValue')\n# )\n# .repartition(200)\n# .persist()\n# )\n\n# window_length = 2.5e5\n\ndata = [\n # Cluster 1:\n ('s1', 'chr1', 2),\n ('s1', 'chr1', 4),\n ('s1', 'chr1', 12),\n # Cluster 2 - Same chromosome:\n ('s1', 'chr1', 31),\n ('s1', 'chr1', 38),\n ('s1', 'chr1', 42),\n # Cluster 3 - New chromosome:\n ('s1', 'chr2', 41),\n ('s1', 'chr2', 44),\n ('s1', 'chr2', 50),\n # Cluster 4 - other study:\n ('s2', 'chr2', 55),\n ('s2', 'chr2', 62),\n ('s2', 'chr2', 70),\n]\nwindow_length = 10\n(\n spark.createDataFrame(data, ['studyId', 'chromosome', 'position'])\n .transform(lambda df: cluster_peaks(df, window_length))\n .show()\n)",
+-------+----------+--------+----------+
|studyId|chromosome|position|cluster_id|
+-------+----------+--------+----------+
| s1| chr1| 2| s1_chr1_2|
| s1| chr1| 4| s1_chr1_2|
| s1| chr1| 12| s1_chr1_2|
| s1| chr1| 31|s1_chr1_31|
| s1| chr1| 38|s1_chr1_31|
| s1| chr1| 42|s1_chr1_31|
| s1| chr2| 41|s1_chr2_41|
| s1| chr2| 44|s1_chr2_41|
| s1| chr2| 50|s1_chr2_41|
| s2| chr2| 55|s2_chr2_55|
| s2| chr2| 62|s2_chr2_55|
| s2| chr2| 70|s2_chr2_55|
+-------+----------+--------+----------+
"id": "a4c3dc57",
"cell_type": "code",
"source": "window_length = 2.5e5\n\n# Let's measure the time that it takes to do the windowing/aggregation:\ndf = (\n sumstats\n .transform(lambda df: cluster_peaks(df, window_length))\n # Get all positions and variants:\n .withColumn('all_pos', mlf.array_to_vector(f.collect_list(f.col('position')).over(complete_window)))\n .withColumn('all_vars', f.collect_list(f.col('variantId')).over(complete_window))\n # Group by study/chromosome:\n .groupBy('window_id')\n .agg(\n f.first('all_pos').alias('all_pos'),\n f.first('all_vars').alias('all_vars'),\n f.first('studyId').alias('studyId'),\n f.first('chromosome').alias('chromosome')\n )\n # This is a spacer here to add the peak finding logic:\n .withColumn('resolvedPeaks', mlf.vector_to_array(find_peak_w_window(window_length)(f.col('all_pos'))))\n # Once all good combine positions with variants:\n .withColumn('to_explode', f.explode_outer(f.arrays_zip(f.col('resolvedPeaks'), f.col('all_vars'))))\n .select(\n 'studyId', \n 'chromosome', \n f.col('to_explode.all_vars').alias('variantId'),\n f.col('to_explode.resolvedPeaks').alias('isLead'),\n )\n .join(sumstats, on=['studyId', 'chromosome', 'variantId'], how='inner')\n .persist()\n)\n\nprint(df.count())\nprint(df.filter(f.col('isLead') == 1.0).count())\",
"name": "stdout",
"output_type": "stream",
"text": "25061908\n"
"name": "stdout",
"output_type": "stream",
+------------+----------+----------------+------+---------+-----------+--------------------+--------------+--------------+------------+
| studyId|chromosome| variantId|isLead| position| rsId| pValue|pValueMantissa|pValueExponent|negLogPValue|
+------------+----------+----------------+------+---------+-----------+--------------------+--------------+--------------+------------+
|GCST90012009| 14| 14_55035084_C_T| 0.0| 55035084|rs112642967| 1.78e-106| 1.78| -106| 105.75|
|GCST90092924| 2| 2_21182108_A_G| 0.0| 21182108| rs538862| 5.5e-29| 5.5| -29| 28.26|
|GCST90092955| X| X_110603328_A_C| 0.0|110603328| rs12398815| 1.2e-18| 1.2| -18| 17.921|
|GCST90013976| 15| 15_75558952_C_T| 0.0| 75558952| rs4393535|2.69736675171327e-08| 2.697| -8| 7.569|
|GCST90018969| 15| 15_56320984_C_T| 0.0| 56320984| rs2725852| 1.745e-10| 1.745| -10| 9.758|
|GCST90092955| 1| 1_62523301_T_C| 0.0| 62523301| rs1781195| 1.9e-52| 1.9| -52| 51.721|
|GCST90013976| 14| 14_68775222_A_T| 0.0| 68775222| rs72731547|1.62854588682907e-14| 1.629| -14| 13.788|
|GCST90092924| 2| 2_43851714_G_A| 0.0| 43851714| rs56132765| 3.4e-13| 3.4| -13| 12.469|
|GCST90013917| 6| 6_28938508_C_G| 0.0| 28938508| rs3131069|2.22910230011868e-09| 2.229| -9| 8.652|
|GCST90002335| 12|12_111557048_A_C| 0.0|111557048| rs11065950| 2.43e-35| 2.43| -35| 34.614|
|GCST90092816| 19| 19_19342751_C_A| 0.0| 19342751| rs2301669| 6.9E-10| 6.9| -10| 9.161|
| GCST010703| 3| 3_89319479_C_A| 0.0| 89319479| rs7619025|1.551038907700539...| 1.551| -11| 10.809|
| GCST010703| 8| 8_92188626_C_T| 0.0| 92188626| rs2010637|4.829377385866375...| 4.829| -8| 7.316|
| GCST010703| 3| 3_17222313_C_T| 0.0| 17222313| rs2733510|5.144947796417256...| 5.145| -10| 9.289|
|GCST90018969| 15| 15_69562643_T_C| 0.0| 69562643| rs80076036| 6.52e-10| 6.52| -10| 9.186|
|GCST90092874| X| X_110631621_T_A| 0.0|110631621| rs7062458| 6.4e-19| 6.4| -19| 18.194|
|GCST90088724| 4| 4_15833163_T_A| 0.0| 15833163| rs72616185|1.86423217825249e-91| 1.864| -91| 90.73|
| GCST010703| 8| 8_8401202_T_C| 0.0| 8401202| rs2976929|8.932943964047577...| 8.933| -15| 14.049|
|GCST90092833| 2| 2_21068440_T_C| 0.0| 21068440| rs548145| 2.4E-79| 2.4| -79| 78.62|
|GCST90012113| 2| 2_27493627_G_A| 0.0| 27493627| rs704795| 8.7E-20| 8.7| -20| 19.06|
+------------+----------+----------------+------+---------+-----------+--------------------+--------------+--------------+------------+
only showing top 20 rows
"id": "87fda45a",
"cell_type": "code",
"source": "\nwindow_length = 2.5e5\n\nnew_df = (\n # Reading all summary statistics:\n'{summary_stats}', recursiveFileLookup=True)\n .filter(\n # Filter for snps with position available:\n f.col('position').isNotNull() & \n # Filter snps for GWAS significance:\n (f.col('pValue').cast(t.DoubleType()) <= pv_threshold)\n )\n # Parser p-values provided as text:\n .select('*', *parse_pvalue(f.col('pValue')))\n # Select relevant columns + calculate neg-log p-value:\n .select(\n 'studyId', \n 'variantId', \n 'chromosome', \n 'position', \n 'rsId', \n 'pValue', \n 'pValueMantissa',\n 'pValueExponent',\n calculate_neglog_pvalue(f.col('pValueMantissa'), f.col('pValueExponent')).alias('negLogPValue')\n )\n .repartition(200)\n .transform(lambda df: cluster_peaks(df, window_length))\n # Get all positions and variants:\n .groupBy(\"window_id\")\n# .groupBy('chromosome', 'studyId')\n # collect the position and negLogPValue into a struct\n .agg(\n f.sort_array(f.collect_list(f.struct(\"negLogPValue\", \"position\")), False).alias(\n \"snps\"\n )\n )\n # prepare dense vectors for daniel's magic\n .withColumn(\n \"snps2\",\n mlf.vector_to_array(\n find_peak_w_window(window_length)(\n mlf.array_to_vector(\n f.transform(f.col(\"snps\"), lambda x: x.position)\n )\n )\n ),\n )\n .withColumn(\n \"snps3\",\n f.zip_with(\n f.col(\"snps\"),\n f.col(\"snps2\"),\n lambda x, y: f.struct(\n x.negLogPValue.alias(\"negLogPValue\"),\n x.position.alias(\"position\"),\n y.alias(\"isLead\"),\n ),\n ),\n )\n .withColumn('exploded', f.explode(f.col('snps3')))\n .select(\n f.split(f.col('window_id'), '_').getItem(0).alias('studyId'),\n f.split(f.col('window_id'), '_').getItem(1).alias('chromosome'),\n# f.col('studyId'),\n# f.col('chromosome'),\n f.col('exploded.position').alias('position'),\n f.col('exploded.negLogPValue').alias('negLogPValue'),\n f.col('exploded.isLead').alias('isLead'),\n )\n# .drop(\"snps\", \"snps2\")\n # sort the struct by negLogPValue (descending)\n # .withColumn(\"test\", f.transform(f.col(\"snps\"), lambda x: x.position))\n # print the result\n# .show(100, False)\n # .pripersistchema()\n .persist()\n)\n\nnew_df.filter(f.col('isLead') == 1.0).show()\nprint(new_df.filter(f.col('isLead') == 1.0).count())\nnew_df.write.mode('overwrite').parquet('gs://ot-team/dsuveges/cicaful')",
"output_type": "stream",
+------------+----------+---------+------------+------+
| studyId|chromosome| position|negLogPValue|isLead|
+------------+----------+---------+------------+------+
|GCST90088244| 17| 31063653| 11.154| 1.0|
|GCST90088244| 17| 31434507| 7.904| 1.0|
|GCST90026654| 15| 73767737| 7.591| 1.0|
|GCST90086636| 20| 30305384| 12.494| 1.0|
|GCST90001553| 1| 67649696| 7.561| 1.0|
|GCST90025945| 2| 96521868| 13.886| 1.0|
|GCST90025945| 2| 96032437| 12.796| 1.0|
|GCST90025945| 2| 95725361| 9.0| 1.0|
|GCST90025945| 2| 96874843| 8.092| 1.0|
| GCST010703| 3| 17366438| 12.092| 1.0|
| GCST010703| 3| 17697055| 9.044| 1.0|
|GCST90013976| 4| 54643022| 23.79| 1.0|
|GCST90044763| 14| 64510603| 8.153| 1.0|
|GCST90002392| 7| 39702208| 7.319| 1.0|
|GCST90018969| 6|165769413| 7.621| 1.0|
|GCST90092955| 12| 20320824| 10.585| 1.0|
|GCST90025948| 8| 8319038| 8.319| 1.0|
|GCST90026106| 4| 68328725| 7.338| 1.0|
|GCST90026028| 18| 52847775| 7.646| 1.0|
|GCST90026029| 7| 20736710| 8.665| 1.0|
+------------+----------+---------+------------+------+
only showing top 20 rows

28359
"name": "stdout"
"output_type": "stream",
"text": " \r",
"name": "stderr"
"output_type": "stream",
"text": "28359\n",
"name": "stdout"
"output_type": "stream",
"text": " \r",
"name": "stderr"
"id": "2b83bec4",
"cell_type": "code",
"source": "data = [\n # Cluster 1:\n ('s1', 'chr1', 'v1', 3, 2.0, False),\n ('s1', 'chr1', 'v2', 4, 3.0, False),\n ('s1', 'chr1', 'v3', 5, 4.0, True),\n ('s1', 'chr1', 'v4', 6, 2.0, False),\n ('s1', 'chr1', 'v5', 7, 3.0, False),\n ('s1', 'chr1', 'v6', 8, 4.0, False),\n ('s1', 'chr1', 'v7', 9, 4.5, False),\n ('s1', 'chr1', 'v8', 10, 6.0, True),\n ('s1', 'chr1', 'v9', 11, 5.0, False),\n ('s1', 'chr1', 'v10', 12, 3.0, False),\n ('s1', 'chr1', 'v11', 14, 2.0, True),\n ('s1', 'chr1', 'v12', 16, 2.5, False),\n ('s1', 'chr1', 'v13', 18, 3.0, True),\n ('s1', 'chr1', 'v14', 20, 1.5, False),\n # Cluster 2:\n ('s1', 'chr1', 'v15', 24, 2.0, False),\n ('s1', 'chr1', 'v16', 25, 4.0, True),\n ('s1', 'chr1', 'v17', 27, 3.0, False),\n]\n\ntest_df = (\n spark.createDataFrame(\n data, \n ['studyId', 'chromosome', 'variantId', 'position', 'negLogPValue', 'isSemiIndex'])\n .persist()\n)\n\nwindow_length = 3\n\n(\n test_df\n .show(1000, truncate=False)\n)",
"execution_count": 7,
"outputs": [
"name": "stderr",
"output_type": "stream",
"text": " \r"
"name": "stdout",
"output_type": "stream",
+-------+----------+---------+--------+------------+-----------+
|studyId|chromosome|variantId|position|negLogPValue|isSemiIndex|
+-------+----------+---------+--------+------------+-----------+
|s1 |chr1 |v1 |3 |2.0 |false |
|s1 |chr1 |v2 |4 |3.0 |false |
|s1 |chr1 |v3 |5 |4.0 |true |
|s1 |chr1 |v4 |6 |2.0 |false |
|s1 |chr1 |v5 |7 |3.0 |false |
|s1 |chr1 |v6 |8 |4.0 |false |
|s1 |chr1 |v7 |9 |4.5 |false |
|s1 |chr1 |v8 |10 |6.0 |true |
|s1 |chr1 |v9 |11 |5.0 |false |
|s1 |chr1 |v10 |12 |3.0 |false |
|s1 |chr1 |v11 |14 |2.0 |true |
|s1 |chr1 |v12 |16 |2.5 |false |
|s1 |chr1 |v13 |18 |3.0 |true |
|s1 |chr1 |v14 |20 |1.5 |false |
|s1 |chr1 |v15 |24 |2.0 |false |
|s1 |chr1 |v16 |25 |4.0 |true |
|s1 |chr
"id": "8c78788f",
"cell_type": "code",
"source": "def cluster_peaks(df: DataFrame) -> DataFrame:\n window_length = 3\n return (\n df\n .withColumn('previous_position', f.lag('position').over(Window.partitionBy('studyId', 'chromosome').orderBy('position')))\n .withColumn(\n 'window_id',\n f.when(\n (f.col('previous_position').isNull()) | \n (f.col('position')- f.col('previous_position') > window_length), \n f.concat_ws('_', f.col('studyId'), f.col('chromosome'), f.col('position'))\n )\n )\n .withColumn(\n 'window_id',\n f.when(\n f.col('window_id').isNull(),\n f.last(\"window_id\", ignorenulls = True).over(\n Window.partitionBy('studyId', 'chromosome')\n .orderBy('position')\n .rowsBetween(Window.unboundedPreceding, Window.currentRow))\n ).otherwise(f.col('window_id'))\n )\n .drop('previous_position')\n \n )\n\n\n(\n test_df\n .transform(cluster_peaks)\n .show(1000, truncate=False)\n)",
"execution_count": 8,
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": "+-------+----------+---------+--------+------------+-----------+----------+\n|studyId|chromosome|variantId|position|negLogPValue|isSemiIndex|window_id |\n+-------+----------+---------+--------+------------+-----------+----------+\n|s1 |chr1 |v1 |3 |2.0 |false |s1_chr1_3 |\n|s1 |chr1 |v2 |4 |3.0 |false |s1_chr1_3 |\n|s1 |chr1 |v3 |5 |4.0 |true |s1_chr1_3 |\n|s1 |chr1 |v4 |6 |2.0 |false |s1_chr1_3 |\n|s1 |chr1 |v5 |7 |3.0 |false |s1_chr1_3 |\n|s1 |chr1 |v6 |8 |4.0 |false |s1_chr1_3 |\n|s1 |chr1 |v7 |9 |4.5 |false |s1_chr1_3 |\n|s1 |chr1 |v8 |10 |6.0 |true |s1_chr1_3 |\n|s1 |chr1 |v9 |11 |5.0 |false |s1_chr1_3 |\n|s1 |chr1 |v10 |12 |3.0 |false |s1_chr1_3 |\n|s1 |chr1 |v11 |14 |2.0 |true |s1_chr1_3 |\n|s1 |chr1 |v12 |16 |2.5 |false |s1_chr1_3 |\n|s1 |chr1 |v13 |18 |3.0 |true |s1_chr1_3 |\n|s1 |chr1 |v14 |20 |1.5 |false |s1_chr1_3 |\n|s1 |chr1 |v15 |24 |2.0 |false |s1_chr1_24|\n|s1 |chr1 |v16 |25 |4.0 |true |s1_chr1_24|\n|s1 |chr1 |v17 |27 |3.0 |false |s1_chr1_24|\n+-------+----------+---------+--------+------------+-----------+----------+\n\n"
"id": "e9e313a2",
"cell_type": "code",
"source": "def pass_window_length(f:Callable, window_length: int) -> Callable:\n window_length = window_length\n return f\n \n\n\nstepped_window = (\n Window\n .partitionBy('window_id')\n .orderBy(f.col('negLogPValue').desc())\n)\n\ncomplete_window = (\n stepped_window\n .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)\n)\n\n\n(\n test_df\n .transform(cluster_peaks)\n # Get all positions and variants:\n .groupBy(\"window_id\")\n # collect the position and negLogPValue into a struct\n .agg(\n f.sort_array(f.collect_list(f.struct(\"negLogPValue\", \"position\")), False).alias(\n \"snps\"\n )\n )\n # prepare dense vectors for daniel's magic\n .withColumn(\n \"snps2\",\n mlf.vector_to_array(\n find_peak_w_window(window_length)(\n mlf.array_to_vector(\n f.transform(f.col(\"snps\"), lambda x: x.negLogPvalue)\n )\n )\n ),\n )\n .withColumn(\n \"snps3\",\n f.zip_with(\n f.col(\"snps\"),\n f.col(\"snps2\"),\n lambda x, y: f.struct(\n x.negLogPValue.alias(\"negLogPValue\"),\n x.position.alias(\"position\"),\n y.alias(\"isLead\"),\n ),\n ),\n )\n .drop(\"snps\", \"snps2\")\n # Get all positions and variants:\n# .withColumn('all_pos', mlf.array_to_vector(f.collect_list(f.col('position')).over(complete_window)))\n# .withColumn('all_vars', f.collect_list(f.col('variantId')).over(complete_window))\n# # Group by study/chromosome:\n# .groupBy('window_id')\n# .agg(\n# f.first('all_pos').alias('all_pos'),\n# f.first('all_vars').alias('all_vars'),\n# f.first('studyId').alias('studyId'),\n# f.first('chromosome').alias('chromosome')\n# )\n# # This is a spacer here to add the peak finding logic:\n# .withColumn('resolvedPeaks', mlf.vector_to_array(find_peak_w_window(window_length)(f.col('all_pos'))))\n# # Once all good combine positions with variants:\n# .withColumn('to_explode', f.explode_outer(f.arrays_zip(f.col('resolvedPeaks'), f.col('all_vars'))))\n# .select(\n# 'studyId', \n# 'chromosome', \n# f.col('to_explode.all_vars').alias('variantId'),\n# f.col('to_explode.resolvedPeaks').alias('isLead'),\n# )\n# # Joining back to data:\n# .join(\n# test_df,\n# on=['studyId', 'chromosome', 'variantId'],\n# how='outer'\n# )\n# .orderBy(f.col('position'))\n .show(1000, truncate=False)\n)",
"execution_count": 13,
"outputs": [
"output_type": "stream",
"text": "+----------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\n|window_id |snps3 |\n+----------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\n|s1_chr1_3 |[{6.0, 10, 1.0}, {5.0, 11, 0.0}, {4.5, 9, 0.0}, {4.0, 8, 0.0}, {4.0, 5, 0.0}, {3.0, 18, 1.0}, {3.0, 12, 0.0}, {3.0, 7, 0.0}, {3.0, 4, 0.0}, {2.5, 16, 0.0}, {2.0, 14, 0.0}, {2.0, 6, 0.0}, {2.0, 3, 0.0}, {1.5, 20, 0.0}]|\n|s1_chr1_24|[{4.0, 25, 1.0}, {3.0, 27, 0.0}, {2.0, 24, 0.0}] |\n+----------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\n\n",
"name": "stdout"
"output_type": "stream",
"text": "\r[Stage 56:> (0 + 1) / 1]\r\r \r",
"name": "stderr"
"id": "d9ddb84d",
"cell_type": "code",
"source": "(\n sumstats\n .filter(f.col('position').isNull())\n .count()\n)",
"execution_count": 29,
"outputs": [
"output_type": "execute_result",
"execution_count": 29,
"data": {
"text/plain": "19049"
"metadata": {}
"cell_type": "code",
"source": "data = [\n # Cluster 1:\n ('s1', 'chr1', 2),\n ('s1', 'chr1', 4),\n ('s1', 'chr1', 12),\n # Cluster 2 - Same chromosome:\n ('s1', 'chr1', 31),\n ('s1', 'chr1', 38),\n ('s1', 'chr1', 42),\n # Cluster 3 - New chromosome:\n ('s1', 'chr2', 41),\n ('s1', 'chr2', 44),\n ('s1', 'chr2', 50),\n # Cluster 4 - other study:\n ('s2', 'chr2', 55),\n ('s2', 'chr2', 62),\n ('s2', 'chr2', 70),\n]\nwindow_length = 10\n(\n spark.createDataFrame(data, ['studyId', 'chromosome', 'position'])\n .transform(lambda df: cluster_peaks(df, window_length))\n .show()\n)\n\n",
"execution_count": 41,
"outputs": [
"output_type": "stream",
"text": "+-------+----------+--------+----------+\n|studyId|chromosome|position| window_id|\n+-------+----------+--------+----------+\n| s1| chr1| 2| s1_chr1_2|\n| s1| chr1| 4| s1_chr1_2|\n| s1| chr1| 12| s1_chr1_2|\n| s1| chr1| 31|s1_chr1_31|\n| s1| chr1| 38|s1_chr1_31|\n| s1| chr1| 42|s1_chr1_31|\n| s1| chr2| 41|s1_chr2_41|\n| s1| chr2| 44|s1_chr2_41|\n| s1| chr2| 50|s1_chr2_41|\n| s2| chr2| 55|s2_chr2_55|\n| s2| chr2| 62|s2_chr2_55|\n| s2| chr2| 70|s2_chr2_55|\n+-------+----------+--------+----------+\n\n",
"name": "stdout"
