Skip to content

Instantly share code, notes, and snippets.

@suhrmann
Created November 27, 2019 17:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save suhrmann/85dbc8ba13a45cb7544bc342a8bb2f0e to your computer and use it in GitHub Desktop.
Save suhrmann/85dbc8ba13a45cb7544bc342a8bb2f0e to your computer and use it in GitHub Desktop.
Fixed @sryza/spark-ts-examples/python/Stocks.py
from datetime import datetime
from pyspark import SparkConf, SparkContext, SQLContext
from pyspark.sql import Row
from pyspark.sql.types import StructType, StructField, TimestampType, DoubleType, StringType
from sparkts.datetimeindex import uniform, BusinessDayFrequency
from sparkts.timeseriesrdd import time_series_rdd_from_observations
def lineToRow(line):
(year, month, day, symbol, volume, price) = line.split("\t")
# Python 2.x compatible timestamp generation
dt = datetime(int(year), int(month), int(day))
return (dt, symbol, float(price))
def loadObservations(sparkContext, sqlContext, path):
textFile = sparkContext.textFile(path)
rowRdd = textFile.map(lineToRow)
schema = StructType([
StructField('timestamp', TimestampType(), nullable=True),
StructField('symbol', StringType(), nullable=True),
StructField('price', DoubleType(), nullable=True),
])
return sqlContext.createDataFrame(rowRdd, schema)
conf = SparkConf()
sc = SparkContext(conf = conf)
sqlContext = SQLContext(sc)
print("\nLoading data...")
tickerObs = loadObservations(sc, sqlContext, "../data/ticker.tsv")
# Create an daily DateTimeIndex over August and September 2015
freq = BusinessDayFrequency(1, 1, sc)
dtIndex = uniform(start='2015-08-03T00:00-07:00', end='2015-09-22T00:00-07:00', freq=freq, sc=sc)
# Align the ticker data on the DateTimeIndex to create a TimeSeriesRDD
tickerTsrdd = time_series_rdd_from_observations(dtIndex, tickerObs, "timestamp", "symbol", "price")
print("Created time series RDD!")
# Cache it in memory
tickerTsrdd.cache()
# Count the number of series (number of symbols)
print(tickerTsrdd.count())
# Impute missing values using linear interpolation
filled = tickerTsrdd.fill("linear")
# Compute return rates
returnRates = filled.return_rates()
# Durbin-Watson test for serial correlation, ported from TimeSeriesStatisticalTests.scala
def dwtest(residuals):
residsSum = residuals[0] * residuals[0]
diffsSum = 0.0
i = 1
while i < len(residuals):
residsSum += residuals[i] * residuals[i]
diff = residuals[i] - residuals[i - 1]
diffsSum += diff * diff
i += 1
return diffsSum / residsSum
# Compute Durbin-Watson stats for each series
# Swap ticker symbol and stats so min and max compare the statistic value, not the
# ticker names.
dwStats = returnRates.map_series(lambda row: (row[0], [dwtest(row[1])])).map(lambda x: (x[1], x[0]))
print(dwStats.min())
print(dwStats.max())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment