Skip to content

Instantly share code, notes, and snippets.

@kangks
Last active May 29, 2020 01:41
Show Gist options
  • Save kangks/6c2aa4b766738b099cf9ed1c00dc8b1e to your computer and use it in GitHub Desktop.
Save kangks/6c2aa4b766738b099cf9ed1c00dc8b1e to your computer and use it in GitHub Desktop.
AWS IoT Persistent client test
from datetime import datetime
import time
import threading
import json
import logging
import paho.mqtt.client as mqtt
import paho.mqtt.publish as mqtt_publish
import boto3
subscriber_keepalive_sec = 40 * 60 # 40 mins
logging.basicConfig(
format='%(asctime)s %(levelname)-8s %(message)s',
level=logging.INFO,
datefmt='%Y-%m-%d %H:%M:%S')
logger = logging.getLogger(__name__)
class AWSIoTConnection:
iot_endpoint_port = 8883
# on_connect_cb = lambda a,b,c,d,e: None
def log(self, message, level=logging.INFO):
self.logger.log(level, message)
# def on_connect(self, mqttc, obj, flags, rc):
# self.on_connect_cb(mqttc, obj, flags, rc)
def get_iot_endpoint(self):
aws_client = boto3.client('iot')
key = "endpointAddress"
response = aws_client.describe_endpoint(
endpointType='iot:Data-ATS'
)
if key in response:
return response[key]
else:
raise ValueError("No endpointAddress found")
def __init__(self, client_id, ca_cert_path, private_cert_path, key_cert_path):
self.tls = {
'ca_certs': ca_cert_path,
'certfile': key_cert_path,
'keyfile': private_cert_path
}
self.iot_endpoint=self.get_iot_endpoint()
self.client_id = client_id
def on_log(self, mqttc, obj, level, string):
self.log(string, level)
class MqttSubscriber(AWSIoTConnection):
# The callback for when a PUBLISH message is received from the server.
def on_message(self, mqttc, userdata, msg):
self.log("received [{}]: {}".format(msg.topic,msg.payload))
received_all_event.set()
def on_connect(self, mqttc, obj, flags, rc):
self.mqtt_client.subscribe(self.topics, qos=1)
def on_publish(self, mqttc, userdata, message_id):
self.log("published with message id {}".format(message_id) )
def __init__(self, client_id, clean_session: False, timeout_sec: 60, ca_cert_path, private_cert_path, key_cert_path, topics):
super().__init__(client_id, ca_cert_path, private_cert_path, key_cert_path)
self.timeout_sec = timeout_sec
self.topics = topics
self.mqtt_client = mqtt.Client( client_id=self.client_id, clean_session=clean_session )
self.mqtt_client.enable_logger(logger)
self.logger = logger
self.mqtt_client.on_connect = self.on_connect
# uncomment for more logging from the Paho MQTT
self.mqtt_client.on_log = self.on_log
self.mqtt_client.on_message = self.on_message
self.mqtt_client.tls_set(ca_certs=self.tls["ca_certs"], certfile=self.tls["certfile"], keyfile=self.tls["keyfile"])
self.connect()
def connect(self):
self.mqtt_client.connect(self.iot_endpoint, self.iot_endpoint_port, self.timeout_sec)
self.mqtt_client.loop_start()
def reconnect(self):
self.mqtt_client.reconnect()
self.mqtt_client.loop_start()
def disconnect(self):
self.mqtt_client.loop_stop()
self.mqtt_client.disconnect()
class SinglePublisher(AWSIoTConnection):
publish_message_count = 0
def __init__(self, client_id, clean_session: True, timeout_sec: 1, ca_cert_path, private_cert_path, key_cert_path):
super().__init__(client_id, ca_cert_path, private_cert_path, key_cert_path)
# self.client_id = client_id
self.timeout_sec = timeout_sec
self.publish_message_count = 0
self.mqtt_client = mqtt.Client( client_id=self.client_id, clean_session=clean_session )
self.mqtt_client.enable_logger(logger)
self.mqtt_client.on_connect = self.publish_data
# uncomment for more logging
self.mqtt_client.on_log = self.on_log
self.mqtt_client.tls_set(ca_certs=self.tls["ca_certs"], certfile=self.tls["certfile"], keyfile=self.tls["keyfile"])
self.connect()
def on_connect(self, mqttc, obj, flags, rc):
self.publish_data(mqttc, obj, flags, rc)
def connect(self):
self.mqtt_client.connect(self.iot_endpoint, self.iot_endpoint_port, self.timeout_sec)
self.mqtt_client.loop_start()
def publish_data(self, mqttc, obj, flags, rc):
self.publish_message_count += 1
payload = {
"state": {
"reported":{
"value": 10 + self.publish_message_count
},
"desired": {
"value": self.publish_message_count
}
}
}
rc = mqttc.publish("$aws/things/shadow_subscriber/shadow/update",json.dumps(payload))
self.mqtt_client.disconnect() # Single publish
self.mqtt_client.loop_stop()
received_all_event = threading.Event()
subscriber = MqttSubscriber(
client_id="subscriber01",
clean_session=False,
timeout_sec=subscriber_keepalive_sec,
ca_cert_path="./certs/ca.pem",
private_cert_path="./certs/private.pem.key",
key_cert_path="./certs/cert.pem.crt",
topics="$aws/things/+/shadow/update/delta"
)
publisher = SinglePublisher(
client_id="publisher01",
clean_session=True,
timeout_sec=1,
ca_cert_path="./certs/ca.pem",
private_cert_path="./certs/private.pem.key",
key_cert_path="./certs/cert.pem.crt")
if not received_all_event.is_set():
logger.info("Waiting for all messages to be received...")
received_all_event.wait()
subscriber.disconnect()
logger.info("subscriber disconnected")
received_all_event.clear()
publisher = SinglePublisher(
client_id="publisher02",
clean_session=True,
timeout_sec=1,
ca_cert_path="./certs/ca.pem",
private_cert_path="./certs/private.pem.key",
key_cert_path="./certs/cert.pem.crt")
subscriber_sleep = subscriber_keepalive_sec - 30 #reconnect before actual times out
logger.info("subscriber will reconnect after {} seconds".format(subscriber_sleep))
while(subscriber_sleep>0):
time.sleep(1)
if(subscriber_sleep % 20 == 0):
logger.info("subscriber will reconnect after {} seconds".format(subscriber_sleep))
subscriber_sleep -= 1
logger.info("\n")
subscriber.reconnect()
if not received_all_event.is_set():
logger.info("Waiting for all messages to be received...")
received_all_event.wait()
subscriber.disconnect()
awscrt==0.5.15
awsiotsdk==1.1.0
boto3==1.13.18
botocore==1.16.18
docutils==0.15.2
jmespath==0.10.0
paho-mqtt==1.5.0
python-dateutil==2.8.1
s3transfer==0.3.3
six==1.15.0
urllib3==1.25.9
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment