Skip to content

Instantly share code, notes, and snippets.

@DSuveges
Last active July 5, 2023 15:03
Show Gist options
  • Save DSuveges/9dd7a701e73462e6e36df587496b5c94 to your computer and use it in GitHub Desktop.
Save DSuveges/9dd7a701e73462e6e36df587496b5c94 to your computer and use it in GitHub Desktop.
GCS/Window_benchmark 2023.07.04.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {},
"id": "945849c9",
"cell_type": "markdown",
"source": "# Benchmarking peak-finding annotation\n\n1. Get test dataset.\n 1. Select studies w. summary statistics\n 2. Read data + apply p-value filter.\n 3. Save filtered dataset for read later.\n2. Read test dataset, apply peak finder algorithm\n3. Record time complexity.\n\n\n## Generating test set\n\nThe following summary statistics are read:\n\n```\nGCST90002392\nGCST90018969\n```\n\n- The datast is saved to: `gs://ot-team/dsuveges/sumstats_test_significant`\n- Number of GWAS significant associations: 538,829\n\n## Conclusions\n\n\n|Method |Execution time |\n|:-------------------|:--------------------|\n|UDF with vectors | 3min |\n"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2023-07-04T20:48:52.540345Z",
"start_time": "2023-07-04T20:48:39.750039Z"
},
"trusted": false
},
"id": "4322e964",
"cell_type": "code",
"source": "from pyspark.sql import SparkSession, functions as f, types as t, DataFrame, Column\nfrom pyspark.sql.window import Window\nimport numpy as np\nfrom pyspark.ml import functions as mlf\nfrom pyspark.ml.linalg import DenseVector, VectorUDT\nfrom numpy import ndarray\n\nimport sys\nfrom typing import List, Callable\nimport pandas as pd\nfrom copy import deepcopy\n\nspark = SparkSession.builder.getOrCreate()\n\n# Input data:\nsummary_stats = 'gs://open-targets-gwas-summary-stats/studies'\ntest_dataset_output = 'gs://ot-team/dsuveges/sumstats_test_significant'\npv_threshold = 5e-8\ndistance_threshold = 2.5e5\n\n",
"execution_count": 1,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": "Setting default log level to \"WARN\".\nTo adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n23/07/04 20:48:44 INFO SparkEnv: Registering MapOutputTracker\n23/07/04 20:48:44 INFO SparkEnv: Registering BlockManagerMaster\n23/07/04 20:48:44 INFO SparkEnv: Registering BlockManagerMasterHeartbeat\n23/07/04 20:48:44 INFO SparkEnv: Registering OutputCommitCoordinator\n"
}
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2023-07-04T20:49:10.021283Z",
"start_time": "2023-07-04T20:49:10.012390Z"
},
"trusted": false
},
"id": "2b6fc015",
"cell_type": "code",
"source": "def parse_pvalue(pv: Column) -> List[Column]:\n \"\"\"This function takes a p-value string and returns two columns mantissa (float), exponent (integer).\n\n Args:\n pv (Column): P-value as string\n\n Returns:\n Column: p-value mantissa (float)\n Column: p-value exponent (integer)\n\n Examples:\n >>> d = [(\"0.01\",),(\"4.2E-45\",),(\"43.2E5\",),(\"0\",),(\"1\",)]\n >>> spark.createDataFrame(d, ['pval']).select('pval',*parse_pvalue(f.col('pval'))).show()\n +-------+--------------+--------------+\n | pval|pValueExponent|pValueMantissa|\n +-------+--------------+--------------+\n | 0.01| -2| 1.0|\n |4.2E-45| -45| 4.2|\n | 43.2E5| 5| 43.2|\n | 0| -308| 2.225|\n | 1| 0| 1.0|\n +-------+--------------+--------------+\n <BLANKLINE> \n \"\"\"\n # Making sure there's a number in the string:\n pv = f.when(pv == \"0\", f.lit(sys.float_info.min).cast(t.StringType())).otherwise(pv)\n \n # Get exponent:\n exponent = f.when(\n f.upper(pv).contains('E'),\n f.split(f.upper(pv), 'E').getItem(1).cast(t.IntegerType())\n ).otherwise(\n f.log10(pv).cast(t.IntegerType())\n )\n\n # Get mantissa:\n mantissa = f.when(\n f.upper(pv).contains('E'),\n f.split(f.upper(pv), 'E').getItem(0).cast(t.FloatType())\n ).otherwise(\n pv / (10 ** exponent)\n )\n \n # Round value:\n mantissa = f.round(mantissa, 3)\n \n return [\n mantissa.alias('pValueMantissa'),\n exponent.alias('pValueExponent'), \n ]\n\ndef calculate_neglog_pvalue(\n p_value_mantissa: Column, p_value_exponent: Column\n) -> Column:\n \"\"\"Compute the negative log p-value.\n\n Args:\n p_value_mantissa (Column): P-value mantissa\n p_value_exponent (Column): P-value exponent\n\n Returns:\n Column: Negative log p-value\n\n Examples:\n >>> d = [(1, 1), (5, -2), (1, -1000)]\n >>> df = spark.createDataFrame(d).toDF(\"p_value_mantissa\", \"p_value_exponent\")\n >>> df.withColumn(\"neg_log_p\", calculate_neglog_pvalue(f.col(\"p_value_mantissa\"), f.col(\"p_value_exponent\"))).show()\n +----------------+----------------+------------------+\n |p_value_mantissa|p_value_exponent| neg_log_p|\n +----------------+----------------+------------------+\n | 1| 1| -1.0|\n | 5| -2|1.3010299956639813|\n | 1| -1000| 1000.0|\n +----------------+----------------+------------------+\n <BLANKLINE>\n \"\"\"\n return f.round(-1 * (f.log10(p_value_mantissa) + p_value_exponent), 3)\n\n",
"execution_count": 2,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2023-07-04T18:44:21.692855Z",
"start_time": "2023-07-04T18:43:33.999045Z"
},
"trusted": false
},
"id": "7abf6c87",
"cell_type": "code",
"source": "\n\n# Studies to ingest:\nstudy_list = [\n \"GCST90089135\",\n \"GCST90026279\",\n \"GCST90024477\",\n \"GCST90079246\",\n \"GCST90089628\",\n \"GCST90092955\",\n \"GCST90026202\",\n \"GCST90044125\",\n \"GCST90086294\",\n \"GCST90080742\",\n \"GCST90087657\",\n \"GCST90079633\",\n \"GCST90086745\",\n \"GCST90086824\",\n \"GCST90086743\",\n \"GCST90082237\",\n \"GCST90086249\",\n \"GCST90086305\",\n \"GCST90018775\",\n \"GCST90083295\",\n \"GCST90002392\",\n \"GCST90018969\",\n \"GCST90013976\",\n \"GCST90002346\"\n]\n\n# Read data:\nsumstats = (\n # Reading all summary statistics:\n spark.read.parquet(*[f'{summary_stats}/{study}' for study in study_list], recursiveFileLookup=True)\n .filter(\n # Filter for snps with position available:\n f.col('position').isNotNull() & \n # Filter snps for GWAS significance:\n (f.col('pValue').cast(t.DoubleType()) <= pv_threshold)\n )\n # Parser p-values provided as text:\n .select('*', *parse_pvalue(f.col('pValue')))\n # Select relevant columns + calculate neg-log p-value:\n .select(\n 'studyId', \n 'variantId', \n 'chromosome', \n 'position', \n 'rsId', \n 'pValue', \n 'pValueMantissa',\n 'pValueExponent',\n calculate_neglog_pvalue(f.col('pValueMantissa'), f.col('pValueExponent')).alias('negLogPValue')\n )\n .repartition(200)\n .persist()\n)\n\nsumstats.show(1, False, True)\nsumstats.printSchema()\nsumstats.count()",
"execution_count": 2,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": " \r"
},
{
"name": "stdout",
"output_type": "stream",
"text": "-RECORD 0------------------------\n studyId | GCST90002392 \n variantId | 1_65456862_G_C \n chromosome | 1 \n position | 65456862 \n rsId | rs61779779 \n pValue | 1.1E-11 \n pValueMantissa | 1.1 \n pValueExponent | -11 \n negLogPValue | 10.959 \nonly showing top 1 row\n\nroot\n |-- studyId: string (nullable = true)\n |-- variantId: string (nullable = true)\n |-- chromosome: string (nullable = true)\n |-- position: integer (nullable = true)\n |-- rsId: string (nullable = true)\n |-- pValue: string (nullable = true)\n |-- pValueMantissa: double (nullable = true)\n |-- pValueExponent: integer (nullable = true)\n |-- negLogPValue: double (nullable = true)\n\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": " \r"
},
{
"data": {
"text/plain": "538829"
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2023-07-04T11:42:40.945928Z",
"start_time": "2023-07-04T11:42:05.929734Z"
},
"trusted": false
},
"id": "227b64b3",
"cell_type": "code",
"source": "sumstats.select('studyId').distinct().show(1000)",
"execution_count": 5,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": "[Stage 12:=====================================================>(106 + 1) / 107]\r"
},
{
"name": "stdout",
"output_type": "stream",
"text": "+------------+\n| studyId|\n+------------+\n|GCST90002392|\n|GCST90002346|\n|GCST90018775|\n|GCST90018969|\n|GCST90013976|\n|GCST90092955|\n|GCST90087657|\n|GCST90086745|\n|GCST90086249|\n|GCST90024477|\n|GCST90086305|\n|GCST90086294|\n|GCST90026202|\n|GCST90026279|\n|GCST90086743|\n+------------+\n\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "\r \r"
}
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2023-07-04T14:35:17.915418Z",
"start_time": "2023-07-04T14:35:03.352626Z"
},
"trusted": false
},
"id": "5a84fdbe",
"cell_type": "code",
"source": "sumstats.write.mode('overwrite').parquet(test_dataset_output)",
"execution_count": 14,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": " \r"
}
]
},
{
"metadata": {},
"id": "1b09e42e",
"cell_type": "markdown",
"source": "## UDF with vectors\n\n- Positions are collected on a window sorted by negLogPV.\n- Positions are converted to dense vectors. <- we need to pack the variant ids as well.\n- Looped operations are executed on vectors.\n- Unpack values after processing."
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2023-07-05T14:46:53.706681Z",
"end_time": "2023-07-05T14:46:54.255924Z"
},
"trusted": true
},
"id": "d8e65307",
"cell_type": "code",
"source": "def cluster_peaks(df: DataFrame, window_length: int) -> DataFrame:\n \"\"\"Cluster GWAS significant variants, were clusters are speparated by a defined distance.\n \n !! Important to note that the length of the clusters can be arbitrarily big.\n \n Args:\n df (DataFrame): table with studyId, chromosome and position columns\n window_length (int): a minimal basepair distance required between snps to call for a new cluster\n\n Returns:\n DataFrame: with cluster_id column added. \n\n Examples:\n >>> data = [\n ... # Cluster 1:\n ... ('s1', 'chr1', 2),\n ... ('s1', 'chr1', 4),\n ... ('s1', 'chr1', 12),\n ... # Cluster 2 - Same chromosome:\n ... ('s1', 'chr1', 31),\n ... ('s1', 'chr1', 38),\n ... ('s1', 'chr1', 42),\n ... # Cluster 3 - New chromosome:\n ... ('s1', 'chr2', 41),\n ... ('s1', 'chr2', 44),\n ... ('s1', 'chr2', 50),\n ... # Cluster 4 - other study:\n ... ('s2', 'chr2', 55),\n ... ('s2', 'chr2', 62),\n ... ('s2', 'chr2', 70),\n ... ]\n >>> window_length = 10\n >>> (\n ... spark.createDataFrame(data, ['studyId', 'chromosome', 'position'])\n ... .transform(lambda df: cluster_peaks(df, window_length))\n ... .show()\n ... ) \n +-------+----------+--------+----------+\n |studyId|chromosome|position|cluster_id|\n +-------+----------+--------+----------+\n | s1| chr1| 2| s1_chr1_2|\n | s1| chr1| 4| s1_chr1_2|\n | s1| chr1| 12| s1_chr1_2|\n | s1| chr1| 31|s1_chr1_31|\n | s1| chr1| 38|s1_chr1_31|\n | s1| chr1| 42|s1_chr1_31|\n | s1| chr2| 41|s1_chr2_41|\n | s1| chr2| 44|s1_chr2_41|\n | s1| chr2| 50|s1_chr2_41|\n | s2| chr2| 55|s2_chr2_55|\n | s2| chr2| 62|s2_chr2_55|\n | s2| chr2| 70|s2_chr2_55|\n +-------+----------+--------+----------+\n <BLANKLINE>\n \"\"\"\n return (\n df\n # By adding previous position, the cluster boundary can be identified: \n .withColumn(\n 'previous_position', \n f.lag('position').over(Window.partitionBy('studyId', 'chromosome').orderBy('position'))\n )\n # We consider a cluster boudary if subsequent snps are further than the defined window:\n .withColumn(\n 'cluster_id',\n f.when(\n (f.col('previous_position').isNull()) | \n (f.col('position')- f.col('previous_position') > window_length), \n f.concat_ws('_', f.col('studyId'), f.col('chromosome'), f.col('position'))\n )\n )\n # The cluster identifier is propagated across every variant of the cluster:\n .withColumn(\n 'cluster_id',\n f.when(\n f.col('cluster_id').isNull(),\n f.last(\"cluster_id\", ignorenulls = True).over(\n Window.partitionBy('studyId', 'chromosome')\n .orderBy('position')\n .rowsBetween(Window.unboundedPreceding, Window.currentRow))\n )\n .otherwise(\n f.col('cluster_id')\n )\n )\n .drop('previous_position')\n )\n\n\n\n# def find_peak_w_window(window_size: int) -> callable:\n\n@f.udf(VectorUDT())\ndef find_peak(position: ndarray, window_size: int) -> DenseVector:\n \"\"\"\n \n \"\"\"\n # Initializing the lead list with zeroes:\n is_lead = np.zeros(len(position))\n\n # List containing indices of leads:\n lead_indices = []\n\n # Looping through all positions:\n for index in range(len(position)):\n # Looping through leads to find out if they are within a window:\n for lead_index in lead_indices:\n # If any of the leads within the window:\n if abs(position[lead_index] - position[index]) < window_size:\n # Skipping further checks:\n break\n else:\n # None of the leads were within the window:\n lead_indices.append(index)\n is_lead[index] = 1\n\n return DenseVector(is_lead)\n\n# return find_peak\n\n\n\n# stepped_window = (\n# Window\n# .partitionBy('window_id')\n# .orderBy(f.col('negLogPValue').desc())\n# )\n\n# complete_window = (\n# stepped_window\n# .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)\n# )\n\n# sumstats = (\n# spark.read.parquet('gs://ot-team/dsuveges/ingested_sumstats_3000_studies_significant/')\n# # Parser p-values provided as text:\n# .select('*', *parse_pvalue(f.col('pValue')))\n# # Select relevant columns + calculate neg-log p-value:\n# .select(\n# 'studyId', \n# 'variantId', \n# 'chromosome', \n# 'position', \n# 'rsId', \n# 'pValue', \n# 'pValueMantissa',\n# 'pValueExponent',\n# calculate_neglog_pvalue(f.col('pValueMantissa'), f.col('pValueExponent')).alias('negLogPValue')\n# )\n# .repartition(200)\n# .persist()\n# )\n\n# window_length = 2.5e5\n\ndata = [\n # Cluster 1:\n ('s1', 'chr1', 2),\n ('s1', 'chr1', 4),\n ('s1', 'chr1', 12),\n # Cluster 2 - Same chromosome:\n ('s1', 'chr1', 31),\n ('s1', 'chr1', 38),\n ('s1', 'chr1', 42),\n # Cluster 3 - New chromosome:\n ('s1', 'chr2', 41),\n ('s1', 'chr2', 44),\n ('s1', 'chr2', 50),\n # Cluster 4 - other study:\n ('s2', 'chr2', 55),\n ('s2', 'chr2', 62),\n ('s2', 'chr2', 70),\n]\nwindow_length = 10\n(\n spark.createDataFrame(data, ['studyId', 'chromosome', 'position'])\n .transform(lambda df: cluster_peaks(df, window_length))\n .show()\n)",
"execution_count": 42,
"outputs": [
{
"output_type": "stream",
"text": "+-------+----------+--------+----------+\n|studyId|chromosome|position|cluster_id|\n+-------+----------+--------+----------+\n| s1| chr1| 2| s1_chr1_2|\n| s1| chr1| 4| s1_chr1_2|\n| s1| chr1| 12| s1_chr1_2|\n| s1| chr1| 31|s1_chr1_31|\n| s1| chr1| 38|s1_chr1_31|\n| s1| chr1| 42|s1_chr1_31|\n| s1| chr2| 41|s1_chr2_41|\n| s1| chr2| 44|s1_chr2_41|\n| s1| chr2| 50|s1_chr2_41|\n| s2| chr2| 55|s2_chr2_55|\n| s2| chr2| 62|s2_chr2_55|\n| s2| chr2| 70|s2_chr2_55|\n+-------+----------+--------+----------+\n\n",
"name": "stdout"
}
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2023-07-04T20:54:14.721448Z",
"start_time": "2023-07-04T20:50:02.065017Z"
},
"trusted": false
},
"id": "a4c3dc57",
"cell_type": "code",
"source": "window_length = 2.5e5\n\n# Let's measure the time that it takes to do the windowing/aggregation:\ndf = (\n sumstats\n .transform(lambda df: cluster_peaks(df, window_length))\n # Get all positions and variants:\n .withColumn('all_pos', mlf.array_to_vector(f.collect_list(f.col('position')).over(complete_window)))\n .withColumn('all_vars', f.collect_list(f.col('variantId')).over(complete_window))\n # Group by study/chromosome:\n .groupBy('window_id')\n .agg(\n f.first('all_pos').alias('all_pos'),\n f.first('all_vars').alias('all_vars'),\n f.first('studyId').alias('studyId'),\n f.first('chromosome').alias('chromosome')\n )\n # This is a spacer here to add the peak finding logic:\n .withColumn('resolvedPeaks', mlf.vector_to_array(find_peak_w_window(window_length)(f.col('all_pos'))))\n # Once all good combine positions with variants:\n .withColumn('to_explode', f.explode_outer(f.arrays_zip(f.col('resolvedPeaks'), f.col('all_vars'))))\n .select(\n 'studyId', \n 'chromosome', \n f.col('to_explode.all_vars').alias('variantId'),\n f.col('to_explode.resolvedPeaks').alias('isLead'),\n )\n .join(sumstats, on=['studyId', 'chromosome', 'variantId'], how='inner')\n .persist()\n)\n\nprint(df.count())\nprint(df.filter(f.col('isLead') == 1.0).count())\ndf.show()",
"execution_count": 4,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": " \r"
},
{
"name": "stdout",
"output_type": "stream",
"text": "25061908\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "\r[Stage 11:=============================================> (172 + 16) / 200]\r\r \r"
},
{
"name": "stdout",
"output_type": "stream",
"text": "28338\n+------------+----------+----------------+------+---------+-----------+--------------------+--------------+--------------+------------+\n| studyId|chromosome| variantId|isLead| position| rsId| pValue|pValueMantissa|pValueExponent|negLogPValue|\n+------------+----------+----------------+------+---------+-----------+--------------------+--------------+--------------+------------+\n|GCST90012009| 14| 14_55035084_C_T| 0.0| 55035084|rs112642967| 1.78e-106| 1.78| -106| 105.75|\n|GCST90092924| 2| 2_21182108_A_G| 0.0| 21182108| rs538862| 5.5e-29| 5.5| -29| 28.26|\n|GCST90092955| X| X_110603328_A_C| 0.0|110603328| rs12398815| 1.2e-18| 1.2| -18| 17.921|\n|GCST90013976| 15| 15_75558952_C_T| 0.0| 75558952| rs4393535|2.69736675171327e-08| 2.697| -8| 7.569|\n|GCST90018969| 15| 15_56320984_C_T| 0.0| 56320984| rs2725852| 1.745e-10| 1.745| -10| 9.758|\n|GCST90092955| 1| 1_62523301_T_C| 0.0| 62523301| rs1781195| 1.9e-52| 1.9| -52| 51.721|\n|GCST90013976| 14| 14_68775222_A_T| 0.0| 68775222| rs72731547|1.62854588682907e-14| 1.629| -14| 13.788|\n|GCST90092924| 2| 2_43851714_G_A| 0.0| 43851714| rs56132765| 3.4e-13| 3.4| -13| 12.469|\n|GCST90013917| 6| 6_28938508_C_G| 0.0| 28938508| rs3131069|2.22910230011868e-09| 2.229| -9| 8.652|\n|GCST90002335| 12|12_111557048_A_C| 0.0|111557048| rs11065950| 2.43e-35| 2.43| -35| 34.614|\n|GCST90092816| 19| 19_19342751_C_A| 0.0| 19342751| rs2301669| 6.9E-10| 6.9| -10| 9.161|\n| GCST010703| 3| 3_89319479_C_A| 0.0| 89319479| rs7619025|1.551038907700539...| 1.551| -11| 10.809|\n| GCST010703| 8| 8_92188626_C_T| 0.0| 92188626| rs2010637|4.829377385866375...| 4.829| -8| 7.316|\n| GCST010703| 3| 3_17222313_C_T| 0.0| 17222313| rs2733510|5.144947796417256...| 5.145| -10| 9.289|\n|GCST90018969| 15| 15_69562643_T_C| 0.0| 69562643| rs80076036| 6.52e-10| 6.52| -10| 9.186|\n|GCST90092874| X| X_110631621_T_A| 0.0|110631621| rs7062458| 6.4e-19| 6.4| -19| 18.194|\n|GCST90088724| 4| 4_15833163_T_A| 0.0| 15833163| rs72616185|1.86423217825249e-91| 1.864| -91| 90.73|\n| GCST010703| 8| 8_8401202_T_C| 0.0| 8401202| rs2976929|8.932943964047577...| 8.933| -15| 14.049|\n|GCST90092833| 2| 2_21068440_T_C| 0.0| 21068440| rs548145| 2.4E-79| 2.4| -79| 78.62|\n|GCST90012113| 2| 2_27493627_G_A| 0.0| 27493627| rs704795| 8.7E-20| 8.7| -20| 19.06|\n+------------+----------+----------------+------+---------+-----------+--------------------+--------------+--------------+------------+\nonly showing top 20 rows\n\n"
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2023-07-05T13:58:27.668267Z",
"end_time": "2023-07-05T14:31:30.688739Z"
},
"scrolled": false,
"trusted": true
},
"id": "87fda45a",
"cell_type": "code",
"source": "\nwindow_length = 2.5e5\n\nnew_df = (\n # Reading all summary statistics:\n spark.read.parquet(f'{summary_stats}', recursiveFileLookup=True)\n .filter(\n # Filter for snps with position available:\n f.col('position').isNotNull() & \n # Filter snps for GWAS significance:\n (f.col('pValue').cast(t.DoubleType()) <= pv_threshold)\n )\n # Parser p-values provided as text:\n .select('*', *parse_pvalue(f.col('pValue')))\n # Select relevant columns + calculate neg-log p-value:\n .select(\n 'studyId', \n 'variantId', \n 'chromosome', \n 'position', \n 'rsId', \n 'pValue', \n 'pValueMantissa',\n 'pValueExponent',\n calculate_neglog_pvalue(f.col('pValueMantissa'), f.col('pValueExponent')).alias('negLogPValue')\n )\n .repartition(200)\n .transform(lambda df: cluster_peaks(df, window_length))\n # Get all positions and variants:\n .groupBy(\"window_id\")\n# .groupBy('chromosome', 'studyId')\n # collect the position and negLogPValue into a struct\n .agg(\n f.sort_array(f.collect_list(f.struct(\"negLogPValue\", \"position\")), False).alias(\n \"snps\"\n )\n )\n # prepare dense vectors for daniel's magic\n .withColumn(\n \"snps2\",\n mlf.vector_to_array(\n find_peak_w_window(window_length)(\n mlf.array_to_vector(\n f.transform(f.col(\"snps\"), lambda x: x.position)\n )\n )\n ),\n )\n .withColumn(\n \"snps3\",\n f.zip_with(\n f.col(\"snps\"),\n f.col(\"snps2\"),\n lambda x, y: f.struct(\n x.negLogPValue.alias(\"negLogPValue\"),\n x.position.alias(\"position\"),\n y.alias(\"isLead\"),\n ),\n ),\n )\n .withColumn('exploded', f.explode(f.col('snps3')))\n .select(\n f.split(f.col('window_id'), '_').getItem(0).alias('studyId'),\n f.split(f.col('window_id'), '_').getItem(1).alias('chromosome'),\n# f.col('studyId'),\n# f.col('chromosome'),\n f.col('exploded.position').alias('position'),\n f.col('exploded.negLogPValue').alias('negLogPValue'),\n f.col('exploded.isLead').alias('isLead'),\n )\n# .drop(\"snps\", \"snps2\")\n # sort the struct by negLogPValue (descending)\n # .withColumn(\"test\", f.transform(f.col(\"snps\"), lambda x: x.position))\n # print the result\n# .show(100, False)\n # .pripersistchema()\n .persist()\n)\n\nnew_df.filter(f.col('isLead') == 1.0).show()\nprint(new_df.filter(f.col('isLead') == 1.0).count())\nnew_df.write.mode('overwrite').parquet('gs://ot-team/dsuveges/cicaful')",
"execution_count": 35,
"outputs": [
{
"output_type": "stream",
"text": " \r",
"name": "stderr"
},
{
"output_type": "stream",
"text": "+------------+----------+---------+------------+------+\n| studyId|chromosome| position|negLogPValue|isLead|\n+------------+----------+---------+------------+------+\n|GCST90088244| 17| 31063653| 11.154| 1.0|\n|GCST90088244| 17| 31434507| 7.904| 1.0|\n|GCST90026654| 15| 73767737| 7.591| 1.0|\n|GCST90086636| 20| 30305384| 12.494| 1.0|\n|GCST90001553| 1| 67649696| 7.561| 1.0|\n|GCST90025945| 2| 96521868| 13.886| 1.0|\n|GCST90025945| 2| 96032437| 12.796| 1.0|\n|GCST90025945| 2| 95725361| 9.0| 1.0|\n|GCST90025945| 2| 96874843| 8.092| 1.0|\n| GCST010703| 3| 17366438| 12.092| 1.0|\n| GCST010703| 3| 17697055| 9.044| 1.0|\n|GCST90013976| 4| 54643022| 23.79| 1.0|\n|GCST90044763| 14| 64510603| 8.153| 1.0|\n|GCST90002392| 7| 39702208| 7.319| 1.0|\n|GCST90018969| 6|165769413| 7.621| 1.0|\n|GCST90092955| 12| 20320824| 10.585| 1.0|\n|GCST90025948| 8| 8319038| 8.319| 1.0|\n|GCST90026106| 4| 68328725| 7.338| 1.0|\n|GCST90026028| 18| 52847775| 7.646| 1.0|\n|GCST90026029| 7| 20736710| 8.665| 1.0|\n+------------+----------+---------+------------+------+\nonly showing top 20 rows\n\n",
"name": "stdout"
},
{
"output_type": "stream",
"text": " \r",
"name": "stderr"
},
{
"output_type": "stream",
"text": "28359\n",
"name": "stdout"
},
{
"output_type": "stream",
"text": " \r",
"name": "stderr"
}
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2023-07-04T21:00:14.175771Z",
"start_time": "2023-07-04T21:00:12.198861Z"
},
"trusted": false
},
"id": "2b83bec4",
"cell_type": "code",
"source": "data = [\n # Cluster 1:\n ('s1', 'chr1', 'v1', 3, 2.0, False),\n ('s1', 'chr1', 'v2', 4, 3.0, False),\n ('s1', 'chr1', 'v3', 5, 4.0, True),\n ('s1', 'chr1', 'v4', 6, 2.0, False),\n ('s1', 'chr1', 'v5', 7, 3.0, False),\n ('s1', 'chr1', 'v6', 8, 4.0, False),\n ('s1', 'chr1', 'v7', 9, 4.5, False),\n ('s1', 'chr1', 'v8', 10, 6.0, True),\n ('s1', 'chr1', 'v9', 11, 5.0, False),\n ('s1', 'chr1', 'v10', 12, 3.0, False),\n ('s1', 'chr1', 'v11', 14, 2.0, True),\n ('s1', 'chr1', 'v12', 16, 2.5, False),\n ('s1', 'chr1', 'v13', 18, 3.0, True),\n ('s1', 'chr1', 'v14', 20, 1.5, False),\n # Cluster 2:\n ('s1', 'chr1', 'v15', 24, 2.0, False),\n ('s1', 'chr1', 'v16', 25, 4.0, True),\n ('s1', 'chr1', 'v17', 27, 3.0, False),\n]\n\ntest_df = (\n spark.createDataFrame(\n data, \n ['studyId', 'chromosome', 'variantId', 'position', 'negLogPValue', 'isSemiIndex'])\n .persist()\n)\n\nwindow_length = 3\n\n(\n test_df\n .show(1000, truncate=False)\n)",
"execution_count": 7,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": " \r"
},
{
"name": "stdout",
"output_type": "stream",
"text": "+-------+----------+---------+--------+------------+-----------+\n|studyId|chromosome|variantId|position|negLogPValue|isSemiIndex|\n+-------+----------+---------+--------+------------+-----------+\n|s1 |chr1 |v1 |3 |2.0 |false |\n|s1 |chr1 |v2 |4 |3.0 |false |\n|s1 |chr1 |v3 |5 |4.0 |true |\n|s1 |chr1 |v4 |6 |2.0 |false |\n|s1 |chr1 |v5 |7 |3.0 |false |\n|s1 |chr1 |v6 |8 |4.0 |false |\n|s1 |chr1 |v7 |9 |4.5 |false |\n|s1 |chr1 |v8 |10 |6.0 |true |\n|s1 |chr1 |v9 |11 |5.0 |false |\n|s1 |chr1 |v10 |12 |3.0 |false |\n|s1 |chr1 |v11 |14 |2.0 |true |\n|s1 |chr1 |v12 |16 |2.5 |false |\n|s1 |chr1 |v13 |18 |3.0 |true |\n|s1 |chr1 |v14 |20 |1.5 |false |\n|s1 |chr1 |v15 |24 |2.0 |false |\n|s1 |chr1 |v16 |25 |4.0 |true |\n|s1 |chr1 |v17 |27 |3.0 |false |\n+-------+----------+---------+--------+------------+-----------+\n\n"
}
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2023-07-04T21:00:43.698477Z",
"start_time": "2023-07-04T21:00:43.283179Z"
},
"trusted": false
},
"id": "8c78788f",
"cell_type": "code",
"source": "def cluster_peaks(df: DataFrame) -> DataFrame:\n window_length = 3\n return (\n df\n .withColumn('previous_position', f.lag('position').over(Window.partitionBy('studyId', 'chromosome').orderBy('position')))\n .withColumn(\n 'window_id',\n f.when(\n (f.col('previous_position').isNull()) | \n (f.col('position')- f.col('previous_position') > window_length), \n f.concat_ws('_', f.col('studyId'), f.col('chromosome'), f.col('position'))\n )\n )\n .withColumn(\n 'window_id',\n f.when(\n f.col('window_id').isNull(),\n f.last(\"window_id\", ignorenulls = True).over(\n Window.partitionBy('studyId', 'chromosome')\n .orderBy('position')\n .rowsBetween(Window.unboundedPreceding, Window.currentRow))\n ).otherwise(f.col('window_id'))\n )\n .drop('previous_position')\n \n )\n\n\n(\n test_df\n .transform(cluster_peaks)\n .show(1000, truncate=False)\n)",
"execution_count": 8,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "+-------+----------+---------+--------+------------+-----------+----------+\n|studyId|chromosome|variantId|position|negLogPValue|isSemiIndex|window_id |\n+-------+----------+---------+--------+------------+-----------+----------+\n|s1 |chr1 |v1 |3 |2.0 |false |s1_chr1_3 |\n|s1 |chr1 |v2 |4 |3.0 |false |s1_chr1_3 |\n|s1 |chr1 |v3 |5 |4.0 |true |s1_chr1_3 |\n|s1 |chr1 |v4 |6 |2.0 |false |s1_chr1_3 |\n|s1 |chr1 |v5 |7 |3.0 |false |s1_chr1_3 |\n|s1 |chr1 |v6 |8 |4.0 |false |s1_chr1_3 |\n|s1 |chr1 |v7 |9 |4.5 |false |s1_chr1_3 |\n|s1 |chr1 |v8 |10 |6.0 |true |s1_chr1_3 |\n|s1 |chr1 |v9 |11 |5.0 |false |s1_chr1_3 |\n|s1 |chr1 |v10 |12 |3.0 |false |s1_chr1_3 |\n|s1 |chr1 |v11 |14 |2.0 |true |s1_chr1_3 |\n|s1 |chr1 |v12 |16 |2.5 |false |s1_chr1_3 |\n|s1 |chr1 |v13 |18 |3.0 |true |s1_chr1_3 |\n|s1 |chr1 |v14 |20 |1.5 |false |s1_chr1_3 |\n|s1 |chr1 |v15 |24 |2.0 |false |s1_chr1_24|\n|s1 |chr1 |v16 |25 |4.0 |true |s1_chr1_24|\n|s1 |chr1 |v17 |27 |3.0 |false |s1_chr1_24|\n+-------+----------+---------+--------+------------+-----------+----------+\n\n"
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2023-07-05T08:20:22.661728Z",
"end_time": "2023-07-05T08:20:23.726492Z"
},
"trusted": true
},
"id": "e9e313a2",
"cell_type": "code",
"source": "def pass_window_length(f:Callable, window_length: int) -> Callable:\n window_length = window_length\n return f\n \n\n\nstepped_window = (\n Window\n .partitionBy('window_id')\n .orderBy(f.col('negLogPValue').desc())\n)\n\ncomplete_window = (\n stepped_window\n .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)\n)\n\n\n(\n test_df\n .transform(cluster_peaks)\n # Get all positions and variants:\n .groupBy(\"window_id\")\n # collect the position and negLogPValue into a struct\n .agg(\n f.sort_array(f.collect_list(f.struct(\"negLogPValue\", \"position\")), False).alias(\n \"snps\"\n )\n )\n # prepare dense vectors for daniel's magic\n .withColumn(\n \"snps2\",\n mlf.vector_to_array(\n find_peak_w_window(window_length)(\n mlf.array_to_vector(\n f.transform(f.col(\"snps\"), lambda x: x.negLogPvalue)\n )\n )\n ),\n )\n .withColumn(\n \"snps3\",\n f.zip_with(\n f.col(\"snps\"),\n f.col(\"snps2\"),\n lambda x, y: f.struct(\n x.negLogPValue.alias(\"negLogPValue\"),\n x.position.alias(\"position\"),\n y.alias(\"isLead\"),\n ),\n ),\n )\n .drop(\"snps\", \"snps2\")\n # Get all positions and variants:\n# .withColumn('all_pos', mlf.array_to_vector(f.collect_list(f.col('position')).over(complete_window)))\n# .withColumn('all_vars', f.collect_list(f.col('variantId')).over(complete_window))\n# # Group by study/chromosome:\n# .groupBy('window_id')\n# .agg(\n# f.first('all_pos').alias('all_pos'),\n# f.first('all_vars').alias('all_vars'),\n# f.first('studyId').alias('studyId'),\n# f.first('chromosome').alias('chromosome')\n# )\n# # This is a spacer here to add the peak finding logic:\n# .withColumn('resolvedPeaks', mlf.vector_to_array(find_peak_w_window(window_length)(f.col('all_pos'))))\n# # Once all good combine positions with variants:\n# .withColumn('to_explode', f.explode_outer(f.arrays_zip(f.col('resolvedPeaks'), f.col('all_vars'))))\n# .select(\n# 'studyId', \n# 'chromosome', \n# f.col('to_explode.all_vars').alias('variantId'),\n# f.col('to_explode.resolvedPeaks').alias('isLead'),\n# )\n# # Joining back to data:\n# .join(\n# test_df,\n# on=['studyId', 'chromosome', 'variantId'],\n# how='outer'\n# )\n# .orderBy(f.col('position'))\n .show(1000, truncate=False)\n)",
"execution_count": 13,
"outputs": [
{
"output_type": "stream",
"text": "+----------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\n|window_id |snps3 |\n+----------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\n|s1_chr1_3 |[{6.0, 10, 1.0}, {5.0, 11, 0.0}, {4.5, 9, 0.0}, {4.0, 8, 0.0}, {4.0, 5, 0.0}, {3.0, 18, 1.0}, {3.0, 12, 0.0}, {3.0, 7, 0.0}, {3.0, 4, 0.0}, {2.5, 16, 0.0}, {2.0, 14, 0.0}, {2.0, 6, 0.0}, {2.0, 3, 0.0}, {1.5, 20, 0.0}]|\n|s1_chr1_24|[{4.0, 25, 1.0}, {3.0, 27, 0.0}, {2.0, 24, 0.0}] |\n+----------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\n\n",
"name": "stdout"
},
{
"output_type": "stream",
"text": "\r[Stage 56:> (0 + 1) / 1]\r\r \r",
"name": "stderr"
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2023-07-05T13:11:49.286529Z",
"end_time": "2023-07-05T13:11:49.685419Z"
},
"trusted": true
},
"id": "d9ddb84d",
"cell_type": "code",
"source": "(\n sumstats\n .filter(f.col('position').isNull())\n .count()\n)",
"execution_count": 29,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 29,
"data": {
"text/plain": "19049"
},
"metadata": {}
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2023-07-05T14:38:05.533387Z",
"end_time": "2023-07-05T14:38:05.970693Z"
},
"trusted": true
},
"cell_type": "code",
"source": "data = [\n # Cluster 1:\n ('s1', 'chr1', 2),\n ('s1', 'chr1', 4),\n ('s1', 'chr1', 12),\n # Cluster 2 - Same chromosome:\n ('s1', 'chr1', 31),\n ('s1', 'chr1', 38),\n ('s1', 'chr1', 42),\n # Cluster 3 - New chromosome:\n ('s1', 'chr2', 41),\n ('s1', 'chr2', 44),\n ('s1', 'chr2', 50),\n # Cluster 4 - other study:\n ('s2', 'chr2', 55),\n ('s2', 'chr2', 62),\n ('s2', 'chr2', 70),\n]\nwindow_length = 10\n(\n spark.createDataFrame(data, ['studyId', 'chromosome', 'position'])\n .transform(lambda df: cluster_peaks(df, window_length))\n .show()\n)\n\n",
"execution_count": 41,
"outputs": [
{
"output_type": "stream",
"text": "+-------+----------+--------+----------+\n|studyId|chromosome|position| window_id|\n+-------+----------+--------+----------+\n| s1| chr1| 2| s1_chr1_2|\n| s1| chr1| 4| s1_chr1_2|\n| s1| chr1| 12| s1_chr1_2|\n| s1| chr1| 31|s1_chr1_31|\n| s1| chr1| 38|s1_chr1_31|\n| s1| chr1| 42|s1_chr1_31|\n| s1| chr2| 41|s1_chr2_41|\n| s1| chr2| 44|s1_chr2_41|\n| s1| chr2| 50|s1_chr2_41|\n| s2| chr2| 55|s2_chr2_55|\n| s2| chr2| 62|s2_chr2_55|\n| s2| chr2| 70|s2_chr2_55|\n+-------+----------+--------+----------+\n\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"_draft": {
"nbviewer_url": "https://gist.github.com/DSuveges/9dd7a701e73462e6e36df587496b5c94"
},
"gist": {
"id": "9dd7a701e73462e6e36df587496b5c94",
"data": {
"description": "GCS/Window_benchmark 2023.07.04.ipynb",
"public": true
}
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3",
"language": "python"
},
"language_info": {
"name": "python",
"version": "3.10.8",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment