Skip to content

Instantly share code, notes, and snippets.

@tfranzel
Created November 1, 2020 11:31
Show Gist options
  • Save tfranzel/c11382ea8e3801b701de0f4ef73fe0d5 to your computer and use it in GitHub Desktop.
Save tfranzel/c11382ea8e3801b701de0f4ef73fe0d5 to your computer and use it in GitHub Desktop.
from unittest import mock
import pytest
from rest_framework_mongoengine import routers
from rest_framework_mongoengine.serializers import DocumentSerializer
from rest_framework_mongoengine.viewsets import GenericViewSet
from drf_spectacular.contrib.rest_framework_mongoengine import get_mongoengine_extended_doc_excludes
from drf_spectacular.generators import SchemaGenerator
from tests import assert_schema
try:
from mongoengine import connect, Document, EmbeddedDocument, fields
connect('mongoenginetest', host='mongomock://localhost')
except:
pass
class DumbEmbedded(EmbeddedDocument):
name = fields.StringField()
foo = fields.IntField()
class NestedEmbeddedDoc(EmbeddedDocument):
name = fields.StringField()
embedded2 = fields.EmbeddedDocumentField(DumbEmbedded)
class RegularMongoModel(Document):
"""
A model class for testing regular flat fields.
"""
str_field = fields.StringField()
str_regex_field = fields.StringField(regex="^.*")
url_field = fields.URLField()
email_field = fields.EmailField()
int_field = fields.IntField()
long_field = fields.LongField()
float_field = fields.FloatField()
boolean_field = fields.BooleanField()
nullboolean_field = fields.BooleanField(null=True)
date_field = fields.DateTimeField()
complexdate_field = fields.ComplexDateTimeField()
uuid_field = fields.UUIDField()
id_field = fields.ObjectIdField()
decimal_field = fields.DecimalField()
embedded = fields.EmbeddedDocumentField(DumbEmbedded)
class RegularMongoModelSerializer(DocumentSerializer):
class Meta:
model = RegularMongoModel
fields = '__all__'
class MongoViewSet(GenericViewSet):
queryset = RegularMongoModel.objects.none()
serializer_class = RegularMongoModelSerializer
def list(self, request):
pass
def retrieve(self, request, id):
pass
@mock.patch(
'drf_spectacular.settings.spectacular_settings.GET_LIB_DOC_EXCLUDES',
get_mongoengine_extended_doc_excludes
)
@pytest.mark.contrib('rest_framework_mongoengine')
def test_mongo(no_warnings):
router = routers.SimpleRouter()
router.register('persons', MongoViewSet)
generator = SchemaGenerator(patterns=router.urls)
schema = generator.get_schema(request=None, public=True)
assert_schema(schema, 'tests/contrib/test_rest_framework_mongoengine.yml')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment