Skip to content

Instantly share code, notes, and snippets.

@anuragkapur
Created November 26, 2015 12:21
Show Gist options
  • Save anuragkapur/90c65f35fa47c2d9195c to your computer and use it in GitHub Desktop.
Save anuragkapur/90c65f35fa47c2d9195c to your computer and use it in GitHub Desktop.
import sys
import json
from datetime import datetime
from dateutil.parser import parse
import paho.mqtt.client as mqtt
from pyspark import SparkContext
from pyspark.streaming import StreamingContext
from pyspark.streaming.mqtt import MQTTUtils
# Some useful stuff
brokerHost = "mqtt.freo.me"
brokerPort = 1883
brokerUrl = "tcp://"+brokerHost+":"+str(brokerPort)
listenTopic = "/tfl/"
cpDir = "/home/oxclo/cp"
def update(ds,state):
# The ds is a dstream of new data (from MQTT)
# Each entry is a dictionary with keys: trainNumber, stationId, expArrival (plus others)
# The state is the previously calculated state
# In this case a dictionary of train -> (station, expected time, delayed)
# If the station has changed we set delayed false and update expected time
# if the station is the same, we check the current expected time against the previous
# and mark delayed if it has extended
if state==None:
state = dict()
else:
for current in ds:
trainNumber = current['trainNumber']
stationId = current['stationId']
exp = parse(current['expArrival'])
if trainNumber in state.keys():
old = state[trainNumber]
print old
if old['stationId'] != stationId:
state[trainNumber] = dict(stationId = stationId, expArrival = exp, delayed = False, delay = 0)
else:
delay = exp-old['expArrival']
delay = delay.seconds
if (delay > 60): #anything less that a minute is not "delayed"
state[trainNumber] = dict(stationId = stationId, expArrival = exp, delayed = True, delay = delay)
else:
state[trainNumber] = dict(stationId = stationId, expArrival = exp, delayed = False, delay = 0)
return state
sc = SparkContext(appName="TFLStreaming")
ssc = StreamingContext(sc, 5) # batch interval 5 sec
ssc.checkpoint(cpDir)
lines = MQTTUtils.createStream(ssc, brokerUrl, listenTopic)
windowed = lines.window(600,5) # look at the last 10 minutes worth with a sliding window of 5 seconds
dicts = lines.map(lambda js: json.loads(js)) # convert from json into a Python dict
mapped = dicts.map(lambda d: (d['trainNumber'],d)) # make the train number the key
ds = mapped.updateStateByKey(update) # compare against previous data
info = ds.filter(lambda (r, d): bool(d)) # ignore if there is no previous data
# the state from the update is a dict (train -> info)
# this is then mapped with a key so we have (train, (train->info))
# so let's get rid of the redundancy
unpack = info.map(lambda (r, d): (r, d[r]))
mapOnTime = unpack.map(lambda (r,d): (d['stationId'], 1) if d['delayed'] else (d['stationId'], 0))
counts = mapOnTime.reduceByKey(lambda a,b: a+b)
# and print the result to the console
counts.pprint()
#start the processing
ssc.start()
# keep running forever (until Ctrl-C)
ssc.awaitTermination()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment