Skip to content

Instantly share code, notes, and snippets.

@zbyte64
Created Nov 1, 2020
Embed
What would you like to do?
Example: Django-Graphene subscriptions with graphene_subscriptions and django_lifecycle
from enum import Enum, auto
class NotificationEvents(Enum):
NEW_MESSAGE = auto()
UPDATE_MESSAGE = auto()
from django.db import models
from org.models import User
from django_lifecycle import LifecycleModelMixin, hook, AFTER_CREATE, AFTER_UPDATE
from graphene_subscriptions.events import SubscriptionEvent
from .events import NotificationEvents
class Message(models.Model):
owner = models.ForeignKey(
User, related_name="sent_messages", on_delete=models.CASCADE
)
title = models.CharField(max_length=100)
text = models.TextField()
to = models.ForeignKey(
User,
null=True,
blank=True,
on_delete=models.CASCADE,
related_name="to_received_messages",
)
participants = models.ManyToManyField(
User, related_name="received_messages"
)
@hook(AFTER_CREATE)
def notify_new_message(self):
event = SubscriptionEvent(
operation=NotificationEvents.NEW_MESSAGE, instance=self
)
event.send()
@hook(AFTER_UPDATE)
def notify_update_message(self):
event = SubscriptionEvent(
operation=NotificationEvents.UPDATE_MESSAGE, instance=self
)
event.send()
from django.test import TransactionTestCase
from snapshottest.unittest import TestCase
from graphene_django.utils.testing import graphql_query
from graphql_relay import to_global_id
import json
from asgiref.sync import sync_to_async, async_to_sync
from channels.testing import WebsocketCommunicator
from graphene_subscriptions.consumers import GraphqlSubscriptionConsumer
from org.models import User
from .models import UserMessage
from .forms import SendMessageForm
class UnreadMessageCountTestCase(TestCase, TransactionTestCase):
fixtures = ["test_data"]
maxDiff = 2000
def setUp(self):
assert self.client.login(email="adminUser@graphql.com", password="password")
@classmethod
def teadDownClass(cls):
# fix for async tests not cleaning up connections
# probably fixed in Django 3.1
# https://stackoverflow.com/questions/8242837/django-multiprocessing-and-database-connections
import django
for (
name,
info,
) in django.db.connections.databases.items(): # Close the DB connections
django.db.connection.close()
@async_to_sync
async def test_unread_message_count(self):
async def query(query, communicator, variables=None):
await communicator.send_json_to(
{
"id": 1,
"type": "start",
"payload": {"query": query, "variables": variables},
}
)
communicator = WebsocketCommunicator(GraphqlSubscriptionConsumer, "/graphql")
receiver = await sync_to_async(User.objects.get)(email="adminUser@graphql.com")
communicator.scope["user"] = receiver
connected, subprotocol = await communicator.connect()
assert connected
subscription = """
subscription {
unreadMessageCount
}
"""
await query(subscription, communicator)
response = await communicator.receive_json_from()
assert not response["payload"]["errors"], str(response["payload"]["errors"])
self.assertMatchSnapshot(response)
@async_to_sync
async def test_update_unread_message_count_on_new_message(self):
async def query(query, communicator, variables=None):
await communicator.send_json_to(
{
"id": 1,
"type": "start",
"payload": {"query": query, "variables": variables},
}
)
communicator = WebsocketCommunicator(GraphqlSubscriptionConsumer, "/graphql")
receiver = await sync_to_async(User.objects.get)(email="adminUser@graphql.com")
communicator.scope["user"] = receiver
connected, subprotocol = await communicator.connect()
assert connected
subscription = """
subscription {
unreadMessageCount
}
"""
await query(subscription, communicator)
response = await communicator.receive_json_from()
assert not response["payload"]["errors"], str(response["payload"]["errors"])
# send message
send_form = SendMessageForm(
data={
"to": receiver.id,
"title": "hello world",
"text": "text",
"description": "description",
}
)
owner = await sync_to_async(User.objects.get)(email="Jason@graphql.com")
message = await sync_to_async(send_form.save)(owner=owner)
response_two = await communicator.receive_json_from()
assert not response["payload"]["errors"], str(response["payload"]["errors"])
assert (
response_two["payload"]["data"]["unreadMessageCount"]
== response["payload"]["data"]["unreadMessageCount"] + 1
), str((response, response_two))
self.assertMatchSnapshot(response, "initial")
self.assertMatchSnapshot(response_two, "result")
import graphene
import rx
from .events import NotificationEvents
class ActiveMessageCounter:
def __init__(self, user):
self.user = user
self.count = self.get_count()
def get_count(self):
return self.user.received_messages.count()
def __call__(self, *args):
self.count = self.get_count()
return self.count
class UnreadMessageCountSubscription(graphene.ObjectType):
unread_message_count = graphene.Int(test=graphene.Boolean())
def resolve_unread_message_count(root, info, test=False):
user = info.context.user
active_counter = ActiveMessageCounter(user)
active_increments = root.filter(
lambda event: event.operation
in (
NotificationEvents.NEW_MESSAGE,
NotificationEvents.UPDATE_MESSAGE,
)
and event.instance.participants.filter(id=user.id).exists(
).map(active_counter)
if test:
return (
rx.Observable.merge(
rx.Observable.of(active_counter.count),
rx.Observable.interval(3000).map(active_counter),
active_increments,
)
.debounce(0.1)
.distinct_until_changed()
)
else:
return (
rx.Observable.merge(
rx.Observable.of(active_counter.count), active_increments
)
.debounce(0.1)
.distinct_until_changed()
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment