Last active
December 20, 2022 07:53
-
-
Save ireneisdoomed/e70420355c1ca222ba0099c4697c5a2a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Gist to test inheritance.""" | |
from __future__ import annotations | |
from typing import TYPE_CHECKING, Type, Optional | |
from typing import List | |
import importlib.resources as pkg_resources | |
import json | |
from dataclasses import dataclass | |
from pathlib import Path | |
from pyspark.sql import SparkSession | |
import pyspark.sql.functions as f | |
from etl.json import schemas | |
if TYPE_CHECKING: | |
from pyspark.sql.types import StructType | |
from pyspark.sql import DataFrame | |
etl = SparkSession.builder.master("local[*]").appName("spark").getOrCreate() | |
def parse_spark_schema(schema_json: str): | |
"""Convert a JSON file schema to StructType.""" | |
return json.loads(pkg_resources.read_text(schemas, schema_json, encoding="utf-8")) | |
@dataclass | |
class Dataset: | |
"""Main class that defines a Dataset.""" | |
schema: StructType | |
df: DataFrame | |
@classmethod | |
def from_parquet(cls: Type[Dataset], etl: SparkSession, input_path: str) -> Dataset: | |
"""Loads a Parquet file/directory to a validated Spark Dataframe.""" | |
cls.df = etl.read.parquet(input_path, schema=cls.schema) | |
if not cls.schema: | |
# If no schema was provided, generate the schema from the data | |
cls.schema = cls.generate_schema(cls, cls.df) | |
# Instantiate a Dataset object to validate the schema of the data | |
cls(schema=cls.schema, df=cls.df).validate_schema() | |
return cls(schema=cls.schema, df=cls.df) | |
def to_parquet( | |
self: Dataset, output_path: str, partition: List[str], write_mode: str | |
) -> None: | |
"""Exports the content of the DataFrame to Parquet.""" | |
(self.df.write.partitionBy(partition).mode(write_mode).parquet(output_path)) | |
def validate_schema(self: Dataset) -> None: | |
"""Validate DataFrame schema based on JSON. | |
Args: | |
df (DataFrame): DataFrame to validate | |
expected_schema (StructType): Expected schema | |
Raises: | |
ValueError: DataFrame schema is not valid | |
""" | |
expected_schema = self.schema # type: ignore[attr-defined] | |
observed_schema = self.df.schema # type: ignore[attr-defined] | |
# Observed fields not in schema | |
missing_struct_fields = [x for x in observed_schema if x not in expected_schema] | |
error_message = f"The {missing_struct_fields} StructFields are not included in DataFrame schema: {expected_schema}" | |
if missing_struct_fields: | |
raise ValueError(error_message) | |
# Required fields not in dataset | |
required_fields = [x for x in expected_schema if not x.nullable] | |
missing_required_fields = [ | |
x for x in required_fields if x not in observed_schema | |
] | |
error_message = f"The {missing_required_fields} StructFields are required but missing from the DataFrame schema: {expected_schema}" | |
if missing_required_fields: | |
raise ValueError(error_message) | |
def generate_schema(self: Dataset, schema_json: str = None) -> None: | |
"""Generate the schema for the DataFrame based on the data.""" | |
new_schema = self.df.schema.jsonValue() | |
print(new_schema) | |
if schema_json: | |
with open(Path(f"src/etl/json/schemas/{schema_json}"), "w") as f: | |
json.dump(new_schema, f, indent=4) | |
@dataclass | |
class VariantAnnotation(Dataset): | |
"""Class that defines a Variant Annotation Dataset.""" | |
schema = parse_spark_schema("variant_annotation.json") | |
@classmethod | |
def from_gnomad( | |
cls: Type[VariantAnnotation], | |
etl: SparkSession, | |
gnomad_file: str, | |
chain_file: str, | |
populations: List[str], | |
) -> VariantAnnotation: | |
"""Generate VA from source.""" | |
cls.df = etl.read.parquet(gnomad_file).withColumn("new_column", f.lit("value")) | |
# ... | |
cls(schema=cls.schema, df=cls.df).validate_schema() | |
return cls(schema=cls.schema, df=cls.df) | |
# Read the data from the file using the read_index class method | |
input_path = "va.parquet" | |
old_df = VariantAnnotation.from_parquet(etl, input_path) | |
new_df = VariantAnnotation.from_gnomad(etl, input_path, "", "") | |
old_df.df.printSchema() # to operate with the df | |
old_df.generate_schema() # to print the schema of the df in json | |
old_df.generate_schema( | |
"name.json" | |
) # to save the schema of the df in the schemas folder | |
old_df.validate_schema() # to validate the schema of the df |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment