Skip to content

Instantly share code, notes, and snippets.

@ireneisdoomed
Last active December 20, 2022 07:53
Show Gist options
  • Save ireneisdoomed/e70420355c1ca222ba0099c4697c5a2a to your computer and use it in GitHub Desktop.
Save ireneisdoomed/e70420355c1ca222ba0099c4697c5a2a to your computer and use it in GitHub Desktop.
"""Gist to test inheritance."""
from __future__ import annotations
from typing import TYPE_CHECKING, Type, Optional
from typing import List
import importlib.resources as pkg_resources
import json
from dataclasses import dataclass
from pathlib import Path
from pyspark.sql import SparkSession
import pyspark.sql.functions as f
from etl.json import schemas
if TYPE_CHECKING:
from pyspark.sql.types import StructType
from pyspark.sql import DataFrame
etl = SparkSession.builder.master("local[*]").appName("spark").getOrCreate()
def parse_spark_schema(schema_json: str):
"""Convert a JSON file schema to StructType."""
return json.loads(pkg_resources.read_text(schemas, schema_json, encoding="utf-8"))
@dataclass
class Dataset:
"""Main class that defines a Dataset."""
schema: StructType
df: DataFrame
@classmethod
def from_parquet(cls: Type[Dataset], etl: SparkSession, input_path: str) -> Dataset:
"""Loads a Parquet file/directory to a validated Spark Dataframe."""
cls.df = etl.read.parquet(input_path, schema=cls.schema)
if not cls.schema:
# If no schema was provided, generate the schema from the data
cls.schema = cls.generate_schema(cls, cls.df)
# Instantiate a Dataset object to validate the schema of the data
cls(schema=cls.schema, df=cls.df).validate_schema()
return cls(schema=cls.schema, df=cls.df)
def to_parquet(
self: Dataset, output_path: str, partition: List[str], write_mode: str
) -> None:
"""Exports the content of the DataFrame to Parquet."""
(self.df.write.partitionBy(partition).mode(write_mode).parquet(output_path))
def validate_schema(self: Dataset) -> None:
"""Validate DataFrame schema based on JSON.
Args:
df (DataFrame): DataFrame to validate
expected_schema (StructType): Expected schema
Raises:
ValueError: DataFrame schema is not valid
"""
expected_schema = self.schema # type: ignore[attr-defined]
observed_schema = self.df.schema # type: ignore[attr-defined]
# Observed fields not in schema
missing_struct_fields = [x for x in observed_schema if x not in expected_schema]
error_message = f"The {missing_struct_fields} StructFields are not included in DataFrame schema: {expected_schema}"
if missing_struct_fields:
raise ValueError(error_message)
# Required fields not in dataset
required_fields = [x for x in expected_schema if not x.nullable]
missing_required_fields = [
x for x in required_fields if x not in observed_schema
]
error_message = f"The {missing_required_fields} StructFields are required but missing from the DataFrame schema: {expected_schema}"
if missing_required_fields:
raise ValueError(error_message)
def generate_schema(self: Dataset, schema_json: str = None) -> None:
"""Generate the schema for the DataFrame based on the data."""
new_schema = self.df.schema.jsonValue()
print(new_schema)
if schema_json:
with open(Path(f"src/etl/json/schemas/{schema_json}"), "w") as f:
json.dump(new_schema, f, indent=4)
@dataclass
class VariantAnnotation(Dataset):
"""Class that defines a Variant Annotation Dataset."""
schema = parse_spark_schema("variant_annotation.json")
@classmethod
def from_gnomad(
cls: Type[VariantAnnotation],
etl: SparkSession,
gnomad_file: str,
chain_file: str,
populations: List[str],
) -> VariantAnnotation:
"""Generate VA from source."""
cls.df = etl.read.parquet(gnomad_file).withColumn("new_column", f.lit("value"))
# ...
cls(schema=cls.schema, df=cls.df).validate_schema()
return cls(schema=cls.schema, df=cls.df)
# Read the data from the file using the read_index class method
input_path = "va.parquet"
old_df = VariantAnnotation.from_parquet(etl, input_path)
new_df = VariantAnnotation.from_gnomad(etl, input_path, "", "")
old_df.df.printSchema() # to operate with the df
old_df.generate_schema() # to print the schema of the df in json
old_df.generate_schema(
"name.json"
) # to save the schema of the df in the schemas folder
old_df.validate_schema() # to validate the schema of the df
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment