Skip to content

Instantly share code, notes, and snippets.

@jhillacre
Created May 31, 2021 16:42
Show Gist options
  • Save jhillacre/ae87a5f038d619728bb0326b879e3f51 to your computer and use it in GitHub Desktop.
Save jhillacre/ae87a5f038d619728bb0326b879e3f51 to your computer and use it in GitHub Desktop.
from channels.db import database_sync_to_async
from djangochannelsrestframework.decorators import action as dcrf_action
from djangochannelsrestframework.generics import GenericAsyncAPIConsumer
from djangochannelsrestframework.mixins import CreateModelMixin
from djangochannelsrestframework.mixins import DeleteModelMixin
from djangochannelsrestframework.mixins import ListModelMixin
from djangochannelsrestframework.mixins import PaginatedModelListMixin
from djangochannelsrestframework.mixins import PatchModelMixin
from djangochannelsrestframework.mixins import UpdateModelMixin
from djangochannelsrestframework.observer import ModelObserver
from djangochannelsrestframework.observer.generics import ObserverModelInstanceMixin
from djangochannelsrestframework.observer.generics import _GenericModelObserver
from rest_framework import status
from rest_framework.exceptions import NotFound
# our client specific stuff.
from tagos.utils.consumers import DCRFDjangoFilterBackend
from tagos.utils.consumers import ModelPermissionMixin
from tagos.utils.consumers import NiceConsumerMixin
from tagos.utils.consumers import OurPaginator
class ModelConsumer(
PaginatedModelListMixin,
ListModelMixin,
PatchModelMixin,
UpdateModelMixin,
CreateModelMixin,
DeleteModelMixin,
ObserverModelInstanceMixin,
NiceConsumerMixin,
ModelPermissionMixin,
GenericAsyncAPIConsumer,
):
perm_names = {
"list": "view",
"retrieve": "view",
"patch": "change",
"update": "change",
"create": "add",
"delete": "delete",
"subscribe_instance": "view",
"unsubscribe_instance": "view",
"subscribe_activity": "view",
"unsubscribe_activity": "view",
}
filter_backends = (DCRFDjangoFilterBackend,)
pagination_class = OurPaginator
def get_model(self):
return self.queryset.model
@dcrf_action()
async def unsubscribe_instance(self, request_id=None, **kwargs):
if request_id is None:
raise ValueError("request_id must have a value set")
instance = await database_sync_to_async(self.get_object)(**kwargs)
await self.handle_instance_change.unsubscribe(instance=instance)
try:
self._unsubscribe(request_id)
except KeyError:
raise NotFound(detail="Subscription not found.")
return None, status.HTTP_204_NO_CONTENT
@_GenericModelObserver
async def model_activity(self, message, observer=None, action=None, **kwargs):
await self.handle_observed_action(
action=action,
**message,
)
@dcrf_action()
async def subscribe_activity(self, request_id=None, **kwargs):
if request_id is None:
raise ValueError("request_id must have a value set")
self.model_activity: ModelObserver
groups = set(await self.model_activity.subscribe())
self._subscribe(request_id, groups)
return None, status.HTTP_201_CREATED
@dcrf_action()
async def unsubscribe_activity(self, request_id=None, **kwargs):
if request_id is None:
raise ValueError("request_id must have a value set")
self.model_activity: ModelObserver
await self.model_activity.unsubscribe()
try:
self._unsubscribe(request_id)
except KeyError:
raise NotFound(detail="Subscription not found.")
return None, status.HTTP_204_NO_CONTENT
def _unsubscribe(self, request_id: str):
# Patch DCRF's `_unsubscribe` to fix unnecessary key errors (which prevent proper cleanup of subscriptions).
request_id_found = False
to_remove = []
for group, request_ids in self.subscribed_requests.items():
if request_id in request_ids:
request_id_found = True
request_ids.remove(request_id)
if not request_ids:
to_remove.append(group)
if not request_id_found:
raise KeyError(request_id)
for group in to_remove:
self.subscribed_requests.pop(group)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment