Skip to content

Instantly share code, notes, and snippets.

@tilakpatidar
Last active May 10, 2022 13:34
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save tilakpatidar/b2fe2009604d1dd0cf31e5db6d595373 to your computer and use it in GitHub Desktop.
Save tilakpatidar/b2fe2009604d1dd0cf31e5db6d595373 to your computer and use it in GitHub Desktop.
Gist to perform count() on jdbc sources without re-reading the df

Postgres snippet

create database test_db;

create table t_random as select s, md5(random()::text) from generate_Series(1,5000) s;

Pyspark snippet

In [1]: df=spark.read.jdbc(url="jdbc:postgresql://localhost:5432/test_db", table="t_random", properties={"driver": "org.postgresql.Driver"}).repartition(10)

In [2]: row_count = spark.sparkContext.accumulator(0)

In [3]: def onEachPart(part):
   ...:     count = 0
   ...:     for row in part:
   ...:         count += 1
   ...:         yield row
   ...:     print("Add " + str(count))
   ...:     row_count.add(count)
   ...:

In [4]: df = df.rdd.mapPartitions(onEachPart).toDF()

In [5]: df.write.parquet("/tmp/t_random12345")

def get_df_stats(df: DataFrame, spark: SparkSession) -> (DataFrame, Dict[str, Accumulator]):
    row_count: Accumulator = spark.sparkContext.accumulator(0)
    size: Accumulator = spark.sparkContext.accumulator(0)

    def onEachPart(part):
        count = 0
        size_in_bytes = 0
        for row in part:
            size_in_bytes += reduce(lambda a, b: a + b, map(lambda x: sys.getsizeof(x), list(row)))
            count += 1
            yield row
        row_count.add(count)
        size.add(size_in_bytes)

    counters = {
        "row_count": row_count,
        "size": size
    }
    return df.rdd.mapPartitions(onEachPart).toDF(), counters




Add 500
Add 500
Add 500
Add 500
Add 500
Add 500
Add 500
Add 500
Add 500
Add 500


In [6]: row_count.value
Out[6]: 5000
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment