Skip to content

Instantly share code, notes, and snippets.

@ahopkins
Last active February 5, 2022 14:46
Show Gist options
  • Star 12 You must be signed in to star a gist
  • Fork 4 You must be signed in to fork a gist
  • Save ahopkins/9816b39aedb2d409ef8d1b85f62e8bec to your computer and use it in GitHub Desktop.
Save ahopkins/9816b39aedb2d409ef8d1b85f62e8bec to your computer and use it in GitHub Desktop.
Sanic based websocket pubsub feed

Sanic Websockets Feeds v2

This is an outdated version

See latest

import json
import random
import string
from functools import partial
from sanic import Sanic
import aioredis
import asyncio
import websockets
from dataclasses import dataclass, field
from typing import Optional, Set
app = Sanic(__name__)
PUBSUB_HOST = "localhost"
PUBSUB_PORT = "6379"
TIMEOUT = 10
INTERVAL = 20
def generate_code(length=12, include_punctuation=False):
characters = string.ascii_letters + string.digits
if include_punctuation:
characters += string.punctuation
return "".join(random.choice(characters) for x in range(length))
@dataclass
class Client:
interface: websockets.server.WebSocketServerProtocol = field(repr=False)
sid: str = field(default_factory=partial(generate_code, 36))
def __hash__(self):
return hash(str(self))
async def keep_alive(self) -> None:
while True:
try:
try:
pong_waiter = await self.interface.ping()
await asyncio.wait_for(pong_waiter, timeout=TIMEOUT)
except asyncio.TimeoutError:
print("NO PONG!!")
await self.feed.unregister(self)
else:
print(f"ping: {self.sid} on <{self.feed.name}>")
await asyncio.sleep(INTERVAL)
except websockets.exceptions.ConnectionClosed:
print(f"broken connection: {self.sid} on <{self.feed.name}>")
await self.feed.unregister(self)
break
async def shutdown(self) -> None:
self.interface.close()
async def receiver(self) -> None:
try:
self.feed.app.add_task(self.keep_alive())
async for message in self.interface:
await self.feed.publish(message)
except websockets.exceptions.ConnectionClosed:
print("connection closed")
finally:
await self.feed.unregister(self)
class FeedCache(dict):
def __repr__(self):
return str({k: len(v.clients) for k, v in self.items()})
class Feed:
name: str
app: Sanic
clients: Set[Client]
cache = FeedCache()
lock: asyncio.Lock
def __init__(self, name):
self.name = name
self.clients = set()
self.lock = asyncio.Lock()
@classmethod
async def get(cls, name: str):
is_existing = False
if name in cls.cache:
feed = cls.cache[name]
await feed.acquire_lock()
is_existing = True
else:
feed = cls(name=name)
await feed.acquire_lock()
feed.pool = await aioredis.create_redis_pool(
address=f"redis://{PUBSUB_HOST}:{PUBSUB_PORT}",
minsize=2,
maxsize=500,
encoding="utf-8",
)
cls.cache[name] = feed
channels = await feed.pool.subscribe(name)
if channels:
feed.pubsub = channels[0]
if not is_existing:
loop = asyncio.get_event_loop()
loop.create_task(feed.receiver())
return feed, is_existing
async def acquire_lock(self) -> None:
if not self.lock.locked():
print("Lock acquired")
await self.lock.acquire()
else:
print("Lock already acquired")
async def receiver(self) -> None:
while True:
try:
await self.pubsub.wait_message()
raw = await self.pubsub.get(encoding="utf-8")
print(f">>> PUBSUB rcvd <{self.name}>: length=={len(raw)}")
except aioredis.errors.ChannelClosedError:
print(f">>> PUBSUB closed <{self.name}>")
break
else:
if raw:
for client in self.clients:
try:
print(f"\tSending to {client.sid}")
await client.interface.send(raw)
except websockets.exceptions.ConnectionClosed:
print(f"ConnectionClosed. Client {client.sid}")
async def register(
self, websocket: websockets.server.WebSocketServerProtocol
) -> Optional[Client]:
client = Client(interface=websocket)
print(f">>> register {client} on {self.name}")
client.feed = self
self.clients.add(client)
message = f"New client has joined."
await self.publish(message)
print(f"\nAll clients on {self.name}\n{self.clients}\n\n")
return client
async def unregister(self, client: Client) -> None:
print(f">>> unregister {client} on <{self.name}>")
if client in self.clients:
await client.shutdown()
self.clients.remove(client)
print(
f"\nAll remaining clients on <{self.name}>\n{self.clients}\n\n"
)
if len(self.clients) == 0:
self.lock.release()
await self.destroy()
await self.pool.unsubscribe(self.name)
else:
message = f"Client has left."
await self.publish(message)
async def destroy(self) -> None:
if not self.lock.locked():
print(f">>> DESTROYING <{self.name}>")
del self.cache[self.name]
self.pool.close()
print(f">>> DESTROYED <{self.name}>")
else:
print(f">>> <{self.name}> is locked. ABORT DESTROY.")
async def publish(self, message: str) -> None:
await self.pool.execute("publish", self.name, message)
@app.websocket("/<feed_name:[A-z][A-z0-9]+>")
async def feed(request, ws, feed_name):
feed, is_existing = await Feed.get(feed_name)
if not is_existing:
feed.app = app
client = await feed.register(ws)
await client.receiver()
if __name__ == "__main__":
app.run(debug=True, port=7777)
@ahopkins
Copy link
Author

@sjquant
Copy link

sjquant commented Jul 25, 2019

Can I ask why you used asyncio.lock. isn't it blocking other connections?

@ahopkins
Copy link
Author

The purpose of the lock is to avoid race conditions and guarantee that the feed is not destroyed prematurely. For example when an existing feed was ready to be destroyed (after checking that it is okay to do so), then another connection is made.

@sjquant
Copy link

sjquant commented Aug 13, 2019

@ahopkins Thanks for the comment! : ) This example is really helpful.

@ahopkins
Copy link
Author

I hope it helps

@ohld
Copy link

ohld commented Sep 17, 2019

I have a question on https://gist.github.com/ahopkins/9816b39aedb2d409ef8d1b85f62e8bec#file-sanic-websockets-py-L61
Why do you send message (which was received by WebSocket) back to the pubsub queue? Isn't it going to send back to the websocket on line https://gist.github.com/ahopkins/9816b39aedb2d409ef8d1b85f62e8bec#file-sanic-websockets-py-L137 which will create an infinite loop of

  1. receive message from pubsub
  2. send message to ws
  3. receive message from ws
  4. send message to the same pubsub
  5. ....

@ahopkins
Copy link
Author

@ohld So sorry for the absurd delay. I didn't see this comment until recently.

Line 61 is part of the Client receiver. It listens to messages coming from the client and then does something. In this case, it pushes the message to the Feed to publish it. That pushes to the pubsub.

Yes, that will then be received by the Feed receiver, whose job it is to send to the ws for each connected client.

The only way you would end in your proposed infinite loop was if the Client was set to repeat messages back to the ws when it received. If it were a JS client, something like this:

socket.onmessage = e => socket.send(e.data)

That would be a problem.

I do not claim this to be a production worthy implementation. Indeed, I typically would add some sort of mechanism to allow for bypassing the broadcast of a message back to a client that originally sent it. But, sometimes you may want that.

This should work as written without any loops though.

I hope this answers the question (even if late).

@Scharfsinnig
Copy link

Scharfsinnig commented Oct 14, 2021

@dataclass
class Client:
    interface: websockets.server.WebSocketServerProtocol = field(repr=False)
    sid: str = field(default_factory=partial(generate_code, 36))

    def __hash__(self):
        return hash(str(self))

    async def keep_alive(self) -> None:
        while True:
            try:
                try:
                    pong_waiter = await self.interface.ping()
                    await asyncio.wait_for(pong_waiter, timeout=TIMEOUT)
                except asyncio.TimeoutError:
                    print("NO PONG!!")
                    await self.feed.unregister(self)
                else:
                    print(f"ping: {self.sid} on <{self.feed.name}>")
                await asyncio.sleep(INTERVAL)
            except websockets.exceptions.ConnectionClosed:
                print(f"broken connection: {self.sid} on <{self.feed.name}>")
                await self.feed.unregister(self)
                break

    async def shutdown(self) -> None:
        self.interface.close()

    async def receiver(self) -> None:
        try:
            self.feed.app.add_task(self.keep_alive())
            async for message in self.interface:
                await self.feed.publish(message)
        except websockets.exceptions.ConnectionClosed:
            print("connection closed")
        finally:
            await self.feed.unregister(self)
    
    async def wait_close_connection(self) -> None:
        logger.info(">>>>> wait connection_lost......")
        lost_connection = await self.interface.wait_for_connection_lost()
        if lost_connection is True:
            logger.info(f">>>>> lost_connection flag is {lost_connection}, unregister......")
            await self.feed.unregister(self)

add lost connection method to remove lost connection.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment