Skip to content

Instantly share code, notes, and snippets.

@mistotebe
Created March 21, 2019 18:28
Show Gist options
  • Save mistotebe/fba1a509738efb0110ddcc711a38189b to your computer and use it in GitHub Desktop.
Save mistotebe/fba1a509738efb0110ddcc711a38189b to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
# Adapted from code in
# https://github.com/eclipse/paho.mqtt.python/tree/master/examples
# Can't remember which.
# LICENSE:
# This project is dual licensed under the Eclipse Public License 1.0 and the
# Eclipse Distribution License 1.0 as described in the epl-v10 and edl-v10 files.
"Test client from the paho docs"
import asyncio
import signal
import paho.mqtt.client as mqtt
class MQTTError(RuntimeError):
pass
class ClosableQueue(asyncio.Queue):
def __init__(self, topic):
super().__init__()
self.topic = topic
self.closed = asyncio.Event()
def on_message(self, client, userdata, message):
self.put_nowait(message)
def close(self):
self.closed.set()
async def __aiter__(self):
end = asyncio.create_task(self.closed.wait())
next_item = asyncio.create_task(self.get())
while not end.done():
done, _ = await asyncio.wait({end, next_item}, return_when=asyncio.FIRST_COMPLETED)
if next_item in done:
yield next_item.result()
next_item = asyncio.create_task(self.get())
if next_item.done():
yield next_item.result()
else:
next_item.cancel()
while not self.empty():
yield self.get_nowait()
class Client(mqtt.Client):
def __init__(self, *args, username=None, password=None, loop=None, **kwargs):
super().__init__(*args, **kwargs)
self.__writing = False
self.loop = loop
if username:
self.username_pw_set(username, password)
self.__keepalive = 5
self.__keepalive_task = None
self.__queues = {}
self.__in_flight = {}
self.__pending = dict(
connect=None,
subscribe={},
unsubscribe={},
publish={},
)
def __reset(self):
if self.__keepalive_task:
self.__keepalive_task.cancel()
self.__keepalive_task = None
if self.__writing:
asyncio.get_running_loop().remove_writer(self.socket())
self.__writing = False
async def connect(self, *args, **kwargs):
"""
Initiate a connection if one not already in progress and wait until it
finishes.
"""
# Are we connecting already?
future = self.__pending.get('connect')
if future:
print("we already have a connection attempt")
return await future
self.__reset()
loop = self.loop or asyncio.get_running_loop()
if not self.__keepalive_task:
self.__keepalive_task = loop.create_task(self.__timer())
print("preparing to connect")
super().connect(*args, **kwargs)
future = self.__pending['connect'] = loop.create_future()
if self.want_write():
loop.add_writer(self.socket(), self.__write)
self.__writing = True
result = await future
self.__pending.pop('connect', None)
return result
def on_connect(self, client, userdata, flags, rc):
"CONNACK callback"
print("Connected with result code "+str(rc))
self.__pending['connect'].set_result((flags, rc))
async def subscribe(self, topic, *args, **kwargs):
"""
Subscribe to a topic (or a bunch of them) and return an asynchronous
iterator returning the messages for those topics.
"""
if self.__queues.get(topic):
return self.__queues[topic]
future = self.__pending['subscribe'].get(topic)
if not future:
loop = self.loop or asyncio.get_running_loop()
future = self.__pending['subscribe'][topic] = loop.create_future()
result, mid = super().subscribe(topic, *args, **kwargs)
if result != mqtt.MQTT_ERR_SUCCESS:
raise MQTTError(result)
self.__in_flight[mid] = future
qos = await future
self.__pending['subscribe'].pop(topic, None)
queue = self.__queues.get(topic)
if not queue:
queue = ClosableQueue(topic)
self.message_callback_add(topic, queue.on_message)
self.__queues[topic] = queue
queue.qos = qos
return queue
def on_subscribe(self, client, userdata, mid, granted_qos):
"SUBACK callback"
future = self.__in_flight.pop(mid)
future.set_result(granted_qos)
async def unsubscribe(self, topic):
"""
Unsubscribe from a topic, canceling the iterators that fully match it (best effort).
"""
future = self.__pending['unsubscribe'].get(topic)
if not future:
loop = self.loop or asyncio.get_running_loop()
future = self.__pending['unsubscribe'][topic] = loop.create_future()
result, mid = super().unsubscribe(topic)
if result != mqtt.MQTT_ERR_SUCCESS:
raise MQTTError(result)
self.__in_flight[mid] = (future, topic)
await future
self.__pending['unsubscribe'].pop(topic, None)
def on_unsubscribe(self, client, userdata, mid):
"UNSUBACK callback"
future, subscription = self.__in_flight.pop(mid)
future.set_result(None)
for topic, queue in self.__queues.items():
if subscription == topic or mqtt.topic_matches_sub(subscription, topic):
self.message_callback_remove(topic)
queue.close()
async def publish(self, *args, **kwargs):
"Publish handler"
loop = self.loop or asyncio.get_running_loop()
future = loop.create_future()
print("publishing")
result, mid = super().publish(*args, **kwargs)
print("result", result, "mid", mid)
if result != mqtt.MQTT_ERR_SUCCESS:
raise MQTTError(result)
self.__in_flight[mid] = future
return await future
def on_publish(self, client, userdata, mid):
"PUBACK handler"
print("on_publish mid", mid)
future = self.__in_flight.pop(mid)
future.set_result(None)
print("on_publish finished")
def on_disconnect(self, client, userdata, rc):
"Disconnection handler"
print("Disconnected with code", str(rc))
self.__keepalive_task.cancel()
# The callback for when a PUBLISH message is received from the server.
def on_message(self, client, userdata, msg):
"An MQTT message received"
print(msg.topic+" "+str(msg.payload))
def __read(self):
"React to read events"
#print("on_read", self.socket().fileno())
self.loop_read()
#print("on_read finished")
def __write(self):
"React to write events if writing"
#print("on_write", self.socket().fileno())
self.loop_write()
#print("on_write finished")
def on_socket_open(self, client, userdata, sock):
loop = self.loop or asyncio.get_running_loop()
loop.add_reader(sock, self.__read)
def on_socket_close(self, client, userdata, sock):
loop = self.loop or asyncio.get_running_loop()
loop.remove_reader(sock)
def on_socket_register_write(self, client, userdata, sock):
loop = self.loop or asyncio.get_running_loop()
loop.add_writer(sock, self.__write)
def on_socket_unregister_write(self, client, userdata, sock):
loop = self.loop or asyncio.get_running_loop()
loop.remove_writer(sock)
async def __timer(self):
"Poor man's timeout handling"
while True:
self.loop_misc()
#print("Waiting for", self.__keepalive, "seconds")
await asyncio.sleep(self.__keepalive)
def on_signal():
"React to SIGINT"
print("Interrupt!")
loop = asyncio.get_running_loop()
loop.remove_signal_handler(signal.SIGINT)
for task in asyncio.all_tasks():
print("canceling task", task)
task.cancel()
loop.stop()
async def main():
"Main function"
import sys
host = "iot.eclipse.org"
if len(sys.argv) > 1:
host = sys.argv[1]
if len(sys.argv) > 2:
username = sys.argv[2]
password = None
if len(sys.argv) > 3:
password = sys.argv[3]
if not password:
import getpass
password = getpass.getpass()
loop = asyncio.get_running_loop()
loop.add_signal_handler(signal.SIGINT, on_signal)
tasks = []
client = Client()
if username:
client.username_pw_set(username, password)
async def timer(client):
while True:
print("Timer expired")
await asyncio.gather(asyncio.sleep(5), client.publish("status", 'timeout', qos=2))
await client.connect(host, 1883)
print("Client connected")
task = asyncio.create_task(timer(client))
tasks.append(task)
async def reader(client, queue):
print("Waiting for first message from", queue.topic)
async for msg in queue:
counter = int(msg.payload.decode())
print(msg.topic+" "+str(msg.payload))
await client.publish('status', str(counter), qos=1)
print("Finished for", queue.topic)
queue = await client.subscribe("button")
print("Have queue")
#task = asyncio.create_task(reader(await client.subscribe("$SYS/#")))
task = asyncio.create_task(reader(client, queue))
tasks.append(task)
await asyncio.gather(*tasks)
if __name__ == '__main__':
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment