Skip to content

Instantly share code, notes, and snippets.

@d0choa
Last active March 15, 2021 11:00
Show Gist options
  • Save d0choa/5693d6fa434675e0d78a740d6375812a to your computer and use it in GitHub Desktop.
Save d0choa/5693d6fa434675e0d78a740d6375812a to your computer and use it in GitHub Desktop.
Some prototype of metrics
import argparse
import array
import struct
import pyspark.sql.functions as F
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.sql import DataFrame, StructType, ArrayType, StringType
from typing import Iterable
from functools import reduce
# Required to flatten the schema
def flatten(schema, prefix=None):
"""Flatten schema"""
fields = []
for field in schema.fields:
name = prefix + '.' + field.name if prefix else field.name
dtype = field.dataType
if isinstance(dtype, ArrayType):
dtype = dtype.elementType
if isinstance(dtype, StructType):
fields += flatten(dtype, prefix=name)
else:
fields.append(name)
return fields
def melt(
df: DataFrame,
id_vars: Iterable[str], value_vars: Iterable[str],
var_name: str = "variable", value_name: str = "value") -> DataFrame:
"""Convert :class:`DataFrame` from wide to long format."""
# Create array<struct<variable: str, value: ...>>
_vars_and_vals = array(*(
struct(F.lit(c).alias(var_name), F.col(c).alias(value_name))
for c in value_vars))
# Add to the DataFrame and explode
_tmp = df.withColumn("_vars_and_vals", F.explode(_vars_and_vals))
cols = id_vars + [
F.col("_vars_and_vals")[x].alias(x) for x in [var_name, value_name]]
return _tmp.select(*cols)
def documentTotalCount(
df: DataFrame,
var_name: str) -> DataFrame:
'''Count total documents'''
out = df.groupBy().count().alias("count")
out = out.withColumn("sourceId", F.lit("all"))
out = out.withColumn("variable", F.lit(var_name))
out = out.withColumn("field", F.lit(None).cast(StringType()))
return out
def documentCountBy(
df: DataFrame,
column: str,
var_name: str) -> DataFrame:
'''Count documents by grouping column'''
out = df.groupBy(column).count().alias("count")
out = out.withColumn("variable", F.lit(var_name))
out = out.withColumn("field", F.lit(None).cast(StringType()))
return out
def evidenceNotNullFieldsCount(
df: DataFrame,
var_name: str) -> DataFrame:
'''Counts number of evidences with not null values in variable.'''
# flatten dataframe schema
flatDf = df.select([F.col(c).alias(c) for c in flatten(df.schema)])
# counting not-null evidence per field
exprs = [sum(F.when(F.col(f.name).getItem(0).isNotNull(), F.lit(1))
.otherwise(F.lit(0))).alias(f.name)
if isinstance(f.dataType, ArrayType)
else
sum(F.when(F.col(f.name).isNotNull(), F.lit(1))
.otherwise(F.lit(0))).alias(f.name)
for f in list(filter(lambda x: x.name != "sourceId",
flatDf.schema))]
out = df.groupBy(F.col("sourceId")).agg(*exprs)
# Clean column names
out_cleaned = out.toDF(*(c.replace('.', '_') for c in out.columns))
# wide to long format
cols = [c.name for c in filter(lambda x: x.name != "sourceId",
out_cleaned.schema.fields)]
melted = melt(out_cleaned,
id_vars=["sourceId"],
var_name="field",
value_vars=cols,
value_name="count")
melted = melted.withColumn("variable", F.lit(var_name))
return melted
def evidenceDistinctFieldsCount(
df: DataFrame,
var_name: str) -> DataFrame:
'''Counts unique values in variable (e.g. targetId) and datasource.'''
# flatten dataframe schema
flatDf = df.select([F.col(c).alias(c) for c in flatten(df.schema)])
# Unique counts per column field
exprs = [F.countDistinct(F.col(f.name)).alias(f.name)
for f in list(filter(lambda x: x.name != "sourceId",
flatDf.schema))]
out = df.groupBy(F.col("sourceId")).agg(*exprs)
# Clean column names
out_cleaned = out.toDF(*(c.replace('.', '_') for c in out.columns))
# Clean column names
cols = [c.name for c in filter(lambda x: x.name != "sourceId",
out_cleaned.schema.fields)]
melted = melt(out_cleaned,
id_vars=["sourceId"],
var_name="field",
value_vars=cols,
value_name="count")
melted = melted.withColumn("variable", F.lit(var_name))
return melted
def parse_args():
""" Load command line args """
parser = argparse.ArgumentParser()
parser.add_argument('--runId',
help=('Pipeline run identifier'),
type=str,
required=True)
parser.add_argument('--disease',
metavar="<path>",
help=('Disease path'),
type=str,
required=False)
parser.add_argument('--evidence',
metavar="<path>",
help=('Evidence path'),
type=str,
required=False)
parser.add_argument('--failedEvidence',
metavar="<path>",
help=('Evidence failing path'),
type=str,
required=False)
parser.add_argument('--directAssociations',
metavar="<path>",
help=('Direct associations'),
type=str,
required=False)
parser.add_argument('--indirectAssociations',
metavar="<path>",
help=('Indirect associations'),
type=str,
required=False)
parser.add_argument('--out',
metavar="<path>",
help=("Output path"),
type=str,
required=True)
parser.add_argument('--local',
help="run local[*]",
action='store_true',
required=False,
default=True)
args = parser.parse_args()
return args
def main(args):
sparkConf = SparkConf()
if args.local:
spark = (
SparkSession.builder
.config(conf=sparkConf)
.master('local[*]')
.getOrCreate()
)
else:
spark = (
SparkSession.builder
.config(conf=sparkConf)
.getOrCreate()
)
# Load data
evd = spark.read.parquet(args.evidence)
evdBad = spark.read.parquet(args.failedEvidence)
assDirect = spark.read.parquet(args.directAssociations)
assIndirect = spark.read.parquet(args.indirectAssociations)
disease = spark.read.parquet(args.disease)
# "/Users/ochoa/Datasets/iter13/parquet/associations/indirect/overall/"
columnsToReport = ["sourceId", "targetId", "diseaseId", "drugId",
"variantId", "literature"]
datasets = [
# VALID EVIDENCE
# Total evidence count
documentTotalCount(evd, "evidenceTotalCount"),
# Evidence count by datasource
documentCountBy(evd, "sourceId", "evidenceCountByDatasource"),
# Number of evidences that have a not null value in the given field
evidenceNotNullFieldsCount(evd,
"evidenceFieldNotNullCountByDatasource"),
# Number of distinct values in selected fields
# distinctCount takes some time on all columns: subsetting them
evidenceDistinctFieldsCount(evd.select(columnsToReport),
"evidenceDistinctFieldsCountByDatasource"),
# INVALID EVIDENCE
# Total invalids
documentTotalCount(evdBad, "evidenceInvalidTotalCount"),
# Evidence count (duplicates)
documentTotalCount(evdBad.filter(F.col("markedDuplicate")),
"evidenceDuplicateTotalCount"),
# Evidence count (targets not resolved)
documentTotalCount(evdBad.filter(~F.col("resolvedTarget")),
"evidenceUnresolvedTargetTotalCount"),
# Evidence count (diseases not resolved)
documentTotalCount(evdBad.filter(~F.col("resolvedDisease")),
"evidenceUnresolvedDiseaseTotalCount"),
# Evidence count by datasource (invalids)
documentCountBy(evdBad, "sourceId",
"evidenceInvalidCountByDatasource"),
# Evidence count by datasource (duplicates)
documentCountBy(evdBad.filter(F.col("markedDuplicate")), "sourceId",
"evidenceDuplicateCountByDatasource"),
# Evidence count by datasource (targets not resolved)
documentCountBy(evdBad.filter(~F.col("resolvedTarget")),
"sourceId",
"evidenceUnresolvedTargetCountByDatasource"),
# Evidence count by datasource (diseases not resolved)
documentCountBy(evdBad.filter(~F.col("resolvedDisease")),
"sourceId",
"evidenceUnresolvedDiseaseCountByDatasource"),
# Distinct values in selected fields (invalid evidence)
evidenceDistinctFieldsCount(
evdBad
.select(columnsToReport),
"evidenceInvalidDistinctFieldsCountByDatasource"),
# Evidence count by datasource (duplicates)
evidenceDistinctFieldsCount(
evdBad
.filter(F.col("markedDuplicate"))
.select(columnsToReport),
"evidenceDuplicateDistinctFieldsCountByDatasource"),
# Evidence count by datasource (targets not resolved)
evidenceDistinctFieldsCount(
evdBad
.filter(~F.col("resolvedTarget"))
.select(columnsToReport),
"evidenceUnresolvedTargetDistinctFieldsCountByDatasource"),
# Evidence count by datasource (diseases not resolved)
evidenceDistinctFieldsCount(
evdBad
.filter(~F.col("resolvedDisease"))
.select(columnsToReport),
"evidenceUnresolvedDiseaseDistinctFieldsCountByDatasource"),
# DIRECT ASSOCIATIONS
# Total association count
documentTotalCount(assDirect, "associationsDirectTotalCount"),
# Associations by datasource
documentCountBy(
assDirect
.select(
"targetId",
"diseaseId",
F.explode(
F.col("overallDatasourceHarmonicScoreDSs.datasourceId"))
.alias("sourceId")),
"sourceId",
"associationsDirectByDatasource"),
# INDIRECT ASSOCIATIONS
# Total association count
documentTotalCount(assIndirect,
"associationsIndirectTotalCount"),
# Associations by datasource
documentCountBy(
assIndirect
.select("targetId",
"diseaseId",
F.explode(
F.col(
"overallDatasourceHarmonicScoreDSs.datasourceId"))
.alias("sourceId")),
"sourceId",
"associationsIndirectByDatasource"),
# TODO: DISEASE
documentTotalCount(disease, "diseaseTotalCount")
# TODO: DRUG
]
metrics = reduce(DataFrame.unionByName, datasets)
metrics.withColumn("runId", F.lit(args.runId))
# Write output
# metrics.coalesce(1).write.options(header='true').csv(args.out)
metrics.write.json(args.out)
# clean up
spark.stop()
return 0
if __name__ == '__main__':
args = parse_args()
exit(main(args))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment