Skip to content

Instantly share code, notes, and snippets.

@bhaavanmerchant
Last active February 13, 2020 23:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bhaavanmerchant/0e43092177021ffc7edbd1116eef383d to your computer and use it in GitHub Desktop.
Save bhaavanmerchant/0e43092177021ffc7edbd1116eef383d to your computer and use it in GitHub Desktop.
from pyspark.sql import SparkSession
import pyspark.sql.types as T
import pyspark.sql.functions as func
from pyspark.sql.window import Window
def load_trades(spark):
data = [
(10, 1546300800000, 37.50, 100.000),
(10, 1546300801000, 37.51, 100.000),
(20, 1546300804000, 12.67, 300.000),
(10, 1546300807000, 37.50, 200.000),
]
schema = T.StructType(
[
T.StructField("id", T.LongType()),
T.StructField("timestamp", T.LongType()),
T.StructField("price", T.DoubleType()),
T.StructField("quantity", T.DoubleType()),
]
)
return spark.createDataFrame(data, schema)
def load_prices(spark):
data = [
(10, 1546300799000, 37.50, 37.51),
(10, 1546300802000, 37.51, 37.52),
(10, 1546300806000, 37.50, 37.51),
]
schema = T.StructType(
[
T.StructField("id", T.LongType()),
T.StructField("timestamp", T.LongType()),
T.StructField("bid", T.DoubleType()),
T.StructField("ask", T.DoubleType()),
]
)
return spark.createDataFrame(data, schema)
def fill(trades, prices):
"""
Combine the sets of events and fill forward the value columns so that each
row has the most recent non-null value for the corresponding id. For
example, given the above input tables the expected output is:
+---+-------------+-----+-----+-----+--------+
| id| timestamp| bid| ask|price|quantity|
+---+-------------+-----+-----+-----+--------+
| 10|1546300799000| 37.5|37.51| null| null|
| 10|1546300800000| 37.5|37.51| 37.5| 100.0|
| 10|1546300801000| 37.5|37.51|37.51| 100.0|
| 10|1546300802000|37.51|37.52|37.51| 100.0|
| 20|1546300804000| null| null|12.67| 300.0|
| 10|1546300806000| 37.5|37.51|37.51| 100.0|
| 10|1546300807000| 37.5|37.51| 37.5| 200.0|
+---+-------------+-----+-----+-----+--------+
:param trades: DataFrame of trade events
:param prices: DataFrame of price events
:return: A DataFrame of the combined events and filled.
"""
id_window = Window.partitionBy('id').orderBy('timestamp').rowsBetween(Window.unboundedPreceding, 0)
df = trades \
.withColumn("bid", func.lit(None)) \
.withColumn("ask", func.lit(None)) \
.unionByName(
prices
.withColumn("price", func.lit(None))
.withColumn("quantity", func.lit(None))
) \
.withColumn("bid", func.last("bid", True).over(id_window)) \
.withColumn("ask", func.last("ask", True).over(id_window)) \
.withColumn("price", func.last("price", True).over(id_window)) \
.withColumn("quantity", func.last("quantity", True).over(id_window)) \
.orderBy("timestamp") \
.select("id", "timestamp", "bid", "ask", "price", "quantity")
return df
def pivot(trades, prices):
"""
Pivot and fill the columns on the event id so that each row contains a
column for each id + column combination where the value is the most recent
non-null value for that id. For example, given the above input tables the
expected output is:
+---+-------------+-----+-----+-----+--------+------+------+--------+-----------+------+------+--------+-----------+
| id| timestamp| bid| ask|price|quantity|10_bid|10_ask|10_price|10_quantity|20_bid|20_ask|20_price|20_quantity|
+---+-------------+-----+-----+-----+--------+------+------+--------+-----------+------+------+--------+-----------+
| 10|1546300799000| 37.5|37.51| null| null| 37.5| 37.51| null| null| null| null| null| null|
| 10|1546300800000| null| null| 37.5| 100.0| 37.5| 37.51| 37.5| 100.0| null| null| null| null|
| 10|1546300801000| null| null|37.51| 100.0| 37.5| 37.51| 37.51| 100.0| null| null| null| null|
| 10|1546300802000|37.51|37.52| null| null| 37.51| 37.52| 37.51| 100.0| null| null| null| null|
| 20|1546300804000| null| null|12.67| 300.0| 37.51| 37.52| 37.51| 100.0| null| null| 12.67| 300.0|
| 10|1546300806000| 37.5|37.51| null| null| 37.5| 37.51| 37.51| 100.0| null| null| 12.67| 300.0|
| 10|1546300807000| null| null| 37.5| 200.0| 37.5| 37.51| 37.5| 200.0| null| null| 12.67| 300.0|
+---+-------------+-----+-----+-----+--------+------+------+--------+-----------+------+------+--------+-----------+
:param trades: DataFrame of trade events
:param prices: DataFrame of price events
:return: A DataFrame of the combined events and pivoted columns.
"""
df = trades.groupBy('id', 'timestamp', 'bid', 'ask', 'price', 'quantity').pivot('id').agg(
func.last('bid').alias('bid'),
func.last('ask').alias('ask'),
func.last('price').alias('price'),
func.last('quantity').alias('quantity')) \
.orderBy("timestamp")
return df
if __name__ == "__main__":
spark = SparkSession.builder.master("local[*]").getOrCreate()
trades = load_trades(spark)
trades.show()
prices = load_prices(spark)
prices.show()
fill(trades, prices).show()
pivot(trades, prices).show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment