Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save dineshdharme/c728d96ccb5d219f23ed69f14199f0de to your computer and use it in GitHub Desktop.
Save dineshdharme/c728d96ccb5d219f23ed69f14199f0de to your computer and use it in GitHub Desktop.
BatchWriteParquetFromStructuredStreaming.py
This is not a perfect solution. But since streaming solution would be more suitable so providing it as an option.
Adapted from socket example below
https://github.com/abulbasar/pyspark-examples/blob/master/structured-streaming-socket.py
https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html (search for 'socket' in this webpage)
To figure out if processing is finished., just check for this line in the logs.
`WARN TextSocketMicroBatchStream: Stream closed by localhost:9979`
Just one caveat, the number of rows may not be exactly `num_rows_per_batch` , you can set a trigger timer to gauge how much time does it take for the iterator to generate 10000 rows.
https://spark.apache.org/docs/3.1.1/api/python/reference/api/pyspark.sql.streaming.DataStreamWriter.trigger.html
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, from_json
from pyspark.sql.types import StructType, StructField, StringType
import json
import threading
import socket
spark = SparkSession.builder \
.appName("Example") \
.getOrCreate()
schema = StructType([
StructField("column1", StringType()),
StructField("column2", StringType()),
])
def data_iterator():
for i in range(100):
yield {"column1": f"value1_{i}", "column2": f"value2_{i}"}
host_given, port_given = "localhost", 9979
def socket_server():
host = host_given
port = port_given
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((host, port))
s.listen(1)
conn, addr = s.accept()
with conn:
for row in data_iterator():
data = json.dumps(row) + "\n"
conn.sendall(data.encode())
server_thread = threading.Thread(target=socket_server)
server_thread.start()
df = spark.readStream \
.format("socket") \
.option("host", host_given) \
.option("port", port_given) \
.load() \
.select(from_json(col("value"), schema).alias("data")) \
.select("data.*")
output_hello = "/path/to/data_output/parquet_so/"
checkpoint_hello = "/path/to/data_output/parquet_checkpoint/"
num_rows_per_batch = 20
query = df.writeStream \
.format("csv") \
.option("path", output_hello) \
.option("checkpointLocation", checkpoint_hello) \
.option("maxRowsPerFile", num_rows_per_batch) \
.start()
query.awaitTermination()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment