Created
March 21, 2019 18:28
-
-
Save mistotebe/fba1a509738efb0110ddcc711a38189b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# Adapted from code in | |
# https://github.com/eclipse/paho.mqtt.python/tree/master/examples | |
# Can't remember which. | |
# LICENSE: | |
# This project is dual licensed under the Eclipse Public License 1.0 and the | |
# Eclipse Distribution License 1.0 as described in the epl-v10 and edl-v10 files. | |
"Test client from the paho docs" | |
import asyncio | |
import signal | |
import paho.mqtt.client as mqtt | |
class MQTTError(RuntimeError): | |
pass | |
class ClosableQueue(asyncio.Queue): | |
def __init__(self, topic): | |
super().__init__() | |
self.topic = topic | |
self.closed = asyncio.Event() | |
def on_message(self, client, userdata, message): | |
self.put_nowait(message) | |
def close(self): | |
self.closed.set() | |
async def __aiter__(self): | |
end = asyncio.create_task(self.closed.wait()) | |
next_item = asyncio.create_task(self.get()) | |
while not end.done(): | |
done, _ = await asyncio.wait({end, next_item}, return_when=asyncio.FIRST_COMPLETED) | |
if next_item in done: | |
yield next_item.result() | |
next_item = asyncio.create_task(self.get()) | |
if next_item.done(): | |
yield next_item.result() | |
else: | |
next_item.cancel() | |
while not self.empty(): | |
yield self.get_nowait() | |
class Client(mqtt.Client): | |
def __init__(self, *args, username=None, password=None, loop=None, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.__writing = False | |
self.loop = loop | |
if username: | |
self.username_pw_set(username, password) | |
self.__keepalive = 5 | |
self.__keepalive_task = None | |
self.__queues = {} | |
self.__in_flight = {} | |
self.__pending = dict( | |
connect=None, | |
subscribe={}, | |
unsubscribe={}, | |
publish={}, | |
) | |
def __reset(self): | |
if self.__keepalive_task: | |
self.__keepalive_task.cancel() | |
self.__keepalive_task = None | |
if self.__writing: | |
asyncio.get_running_loop().remove_writer(self.socket()) | |
self.__writing = False | |
async def connect(self, *args, **kwargs): | |
""" | |
Initiate a connection if one not already in progress and wait until it | |
finishes. | |
""" | |
# Are we connecting already? | |
future = self.__pending.get('connect') | |
if future: | |
print("we already have a connection attempt") | |
return await future | |
self.__reset() | |
loop = self.loop or asyncio.get_running_loop() | |
if not self.__keepalive_task: | |
self.__keepalive_task = loop.create_task(self.__timer()) | |
print("preparing to connect") | |
super().connect(*args, **kwargs) | |
future = self.__pending['connect'] = loop.create_future() | |
if self.want_write(): | |
loop.add_writer(self.socket(), self.__write) | |
self.__writing = True | |
result = await future | |
self.__pending.pop('connect', None) | |
return result | |
def on_connect(self, client, userdata, flags, rc): | |
"CONNACK callback" | |
print("Connected with result code "+str(rc)) | |
self.__pending['connect'].set_result((flags, rc)) | |
async def subscribe(self, topic, *args, **kwargs): | |
""" | |
Subscribe to a topic (or a bunch of them) and return an asynchronous | |
iterator returning the messages for those topics. | |
""" | |
if self.__queues.get(topic): | |
return self.__queues[topic] | |
future = self.__pending['subscribe'].get(topic) | |
if not future: | |
loop = self.loop or asyncio.get_running_loop() | |
future = self.__pending['subscribe'][topic] = loop.create_future() | |
result, mid = super().subscribe(topic, *args, **kwargs) | |
if result != mqtt.MQTT_ERR_SUCCESS: | |
raise MQTTError(result) | |
self.__in_flight[mid] = future | |
qos = await future | |
self.__pending['subscribe'].pop(topic, None) | |
queue = self.__queues.get(topic) | |
if not queue: | |
queue = ClosableQueue(topic) | |
self.message_callback_add(topic, queue.on_message) | |
self.__queues[topic] = queue | |
queue.qos = qos | |
return queue | |
def on_subscribe(self, client, userdata, mid, granted_qos): | |
"SUBACK callback" | |
future = self.__in_flight.pop(mid) | |
future.set_result(granted_qos) | |
async def unsubscribe(self, topic): | |
""" | |
Unsubscribe from a topic, canceling the iterators that fully match it (best effort). | |
""" | |
future = self.__pending['unsubscribe'].get(topic) | |
if not future: | |
loop = self.loop or asyncio.get_running_loop() | |
future = self.__pending['unsubscribe'][topic] = loop.create_future() | |
result, mid = super().unsubscribe(topic) | |
if result != mqtt.MQTT_ERR_SUCCESS: | |
raise MQTTError(result) | |
self.__in_flight[mid] = (future, topic) | |
await future | |
self.__pending['unsubscribe'].pop(topic, None) | |
def on_unsubscribe(self, client, userdata, mid): | |
"UNSUBACK callback" | |
future, subscription = self.__in_flight.pop(mid) | |
future.set_result(None) | |
for topic, queue in self.__queues.items(): | |
if subscription == topic or mqtt.topic_matches_sub(subscription, topic): | |
self.message_callback_remove(topic) | |
queue.close() | |
async def publish(self, *args, **kwargs): | |
"Publish handler" | |
loop = self.loop or asyncio.get_running_loop() | |
future = loop.create_future() | |
print("publishing") | |
result, mid = super().publish(*args, **kwargs) | |
print("result", result, "mid", mid) | |
if result != mqtt.MQTT_ERR_SUCCESS: | |
raise MQTTError(result) | |
self.__in_flight[mid] = future | |
return await future | |
def on_publish(self, client, userdata, mid): | |
"PUBACK handler" | |
print("on_publish mid", mid) | |
future = self.__in_flight.pop(mid) | |
future.set_result(None) | |
print("on_publish finished") | |
def on_disconnect(self, client, userdata, rc): | |
"Disconnection handler" | |
print("Disconnected with code", str(rc)) | |
self.__keepalive_task.cancel() | |
# The callback for when a PUBLISH message is received from the server. | |
def on_message(self, client, userdata, msg): | |
"An MQTT message received" | |
print(msg.topic+" "+str(msg.payload)) | |
def __read(self): | |
"React to read events" | |
#print("on_read", self.socket().fileno()) | |
self.loop_read() | |
#print("on_read finished") | |
def __write(self): | |
"React to write events if writing" | |
#print("on_write", self.socket().fileno()) | |
self.loop_write() | |
#print("on_write finished") | |
def on_socket_open(self, client, userdata, sock): | |
loop = self.loop or asyncio.get_running_loop() | |
loop.add_reader(sock, self.__read) | |
def on_socket_close(self, client, userdata, sock): | |
loop = self.loop or asyncio.get_running_loop() | |
loop.remove_reader(sock) | |
def on_socket_register_write(self, client, userdata, sock): | |
loop = self.loop or asyncio.get_running_loop() | |
loop.add_writer(sock, self.__write) | |
def on_socket_unregister_write(self, client, userdata, sock): | |
loop = self.loop or asyncio.get_running_loop() | |
loop.remove_writer(sock) | |
async def __timer(self): | |
"Poor man's timeout handling" | |
while True: | |
self.loop_misc() | |
#print("Waiting for", self.__keepalive, "seconds") | |
await asyncio.sleep(self.__keepalive) | |
def on_signal(): | |
"React to SIGINT" | |
print("Interrupt!") | |
loop = asyncio.get_running_loop() | |
loop.remove_signal_handler(signal.SIGINT) | |
for task in asyncio.all_tasks(): | |
print("canceling task", task) | |
task.cancel() | |
loop.stop() | |
async def main(): | |
"Main function" | |
import sys | |
host = "iot.eclipse.org" | |
if len(sys.argv) > 1: | |
host = sys.argv[1] | |
if len(sys.argv) > 2: | |
username = sys.argv[2] | |
password = None | |
if len(sys.argv) > 3: | |
password = sys.argv[3] | |
if not password: | |
import getpass | |
password = getpass.getpass() | |
loop = asyncio.get_running_loop() | |
loop.add_signal_handler(signal.SIGINT, on_signal) | |
tasks = [] | |
client = Client() | |
if username: | |
client.username_pw_set(username, password) | |
async def timer(client): | |
while True: | |
print("Timer expired") | |
await asyncio.gather(asyncio.sleep(5), client.publish("status", 'timeout', qos=2)) | |
await client.connect(host, 1883) | |
print("Client connected") | |
task = asyncio.create_task(timer(client)) | |
tasks.append(task) | |
async def reader(client, queue): | |
print("Waiting for first message from", queue.topic) | |
async for msg in queue: | |
counter = int(msg.payload.decode()) | |
print(msg.topic+" "+str(msg.payload)) | |
await client.publish('status', str(counter), qos=1) | |
print("Finished for", queue.topic) | |
queue = await client.subscribe("button") | |
print("Have queue") | |
#task = asyncio.create_task(reader(await client.subscribe("$SYS/#"))) | |
task = asyncio.create_task(reader(client, queue)) | |
tasks.append(task) | |
await asyncio.gather(*tasks) | |
if __name__ == '__main__': | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment