Last active
October 25, 2019 21:59
-
-
Save kingbuzzman/d7859d9734b590e52fad787d19c34b52 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding:utf-8 -*- | |
# Stolen from: https://mlvin.xyz/django-single-file-project.html | |
import inspect | |
import os | |
import sys | |
from types import ModuleType | |
import django | |
from django.conf import settings | |
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
# The current name of the file, which will be the name of our app | |
APP_LABEL, _ = os.path.splitext(os.path.basename(os.path.abspath(__file__))) | |
# Migrations folder need to be created, and django needs to be told where it is | |
APP_MIGRATION_MODULE = '%s_migrations' % APP_LABEL | |
APP_MIGRATION_PATH = os.path.join(BASE_DIR, APP_MIGRATION_MODULE) | |
# Create the folder and a __init__.py if they don't exist | |
if not os.path.exists(APP_MIGRATION_PATH): | |
os.makedirs(APP_MIGRATION_PATH) | |
open(os.path.join(APP_MIGRATION_PATH, '__init__.py'), 'w').close() | |
# Hack to trick Django into thinking this file is actually a package | |
sys.modules[APP_LABEL] = sys.modules[__name__] | |
sys.modules[APP_LABEL].__path__ = [os.path.abspath(__file__)] | |
settings.configure( | |
DEBUG=True, | |
ROOT_URLCONF='%s.urls' % APP_LABEL, | |
MIDDLEWARE=( | |
'django.middleware.common.CommonMiddleware', | |
'django.contrib.sessions.middleware.SessionMiddleware', | |
'django.contrib.auth.middleware.AuthenticationMiddleware', | |
'django.contrib.messages.middleware.MessageMiddleware', | |
), | |
INSTALLED_APPS=[ | |
'django.contrib.admin', | |
'django.contrib.auth', | |
'django.contrib.contenttypes', | |
'django.contrib.sessions', | |
'django.contrib.sites', | |
'django.contrib.staticfiles', | |
'rest_framework', | |
APP_LABEL, | |
], | |
MIGRATION_MODULES={APP_LABEL: APP_MIGRATION_MODULE}, | |
SITE_ID=1, | |
DATABASES={ | |
'default': { | |
'ENGINE': 'django.db.backends.sqlite3', | |
'NAME': os.path.join(BASE_DIR, 'db.sqlite3'), | |
} | |
}, | |
LOGGING={ | |
'version': 1, | |
'disable_existing_loggers': False, | |
'formatters': { | |
'simple': { | |
'format': "%(levelname)s %(message)s", | |
}, | |
}, | |
'handlers': { | |
'console': { | |
'level': 'DEBUG', | |
'class': 'logging.StreamHandler', | |
'formatter': 'simple', | |
} | |
}, | |
'loggers': { | |
'django.db.backends': {'handlers': ['console'], 'level': 'DEBUG', 'propagate': False}, | |
'django.db.backends.schema': {'level': 'ERROR'}, # Causes sql logs to duplicate -- really annoying | |
} | |
}, | |
STATIC_URL='/static/', | |
REST_FRAMEWORK={ | |
'DEFAULT_PERMISSION_CLASSES': (), | |
'DEFAULT_AUTHENTICATION_CLASSES': (), | |
} | |
) | |
django.setup() | |
from django.apps import apps # noqa: E402 isort:skip | |
# Setup the AppConfig so we don't have to add the app_label to all our models | |
def get_containing_app_config(module): | |
if module == '__main__': | |
return apps.get_app_config(APP_LABEL) | |
return apps._get_containing_app_config(module) | |
apps._get_containing_app_config = apps.get_containing_app_config | |
apps.get_containing_app_config = get_containing_app_config | |
# Your code below this line | |
# ############################################################################## | |
from django.db import models, connection # noqa: E402 isort:skip | |
from django.test import TestCase # noqa: E402 isort:skip | |
from django.test.utils import CaptureQueriesContext # noqa: E402 isort:skip | |
from django.urls import path # noqa: E402 isort:skip | |
from rest_framework import serializers, viewsets # noqa: E402 isort:skip | |
class TestData(models.Model): | |
field1 = models.CharField(max_length=10) | |
field2 = models.CharField(max_length=10) | |
field3 = models.CharField(max_length=10) | |
field4 = models.CharField(max_length=10) | |
field5 = models.CharField(max_length=10) | |
field6 = models.CharField(max_length=10) | |
field7 = models.CharField(max_length=10) | |
field8 = models.CharField(max_length=10) | |
class DynamicFieldListSerializer(serializers.ListSerializer): | |
def to_representation(self, data): | |
""" | |
List of object instances -> List of dicts of primitive datatypes. | |
""" | |
# ONLY HERE BECAUSE OF THE NATURE OF EMBEDDING DJANGO INTO A SINGLE FILE. `models` GETS REROUTED TO BE THIS | |
# FILES `models` NOT THE ONE ORIGNIALLY IMPORTED. | |
from django.db import models # noqa: E402 isort:skip | |
# Dealing with nested relationships, data can be a Manager, | |
# so, first get a queryset from the Manager if needed | |
iterable = data.all() if isinstance(data, models.Manager) else data | |
fields = list(self.child.get_fields().keys()) | |
iterable = iterable.only(*fields) | |
return [ | |
self.child.to_representation(item) for item in iterable | |
] | |
class DynamicSerializerFieldsMixin: | |
def get_fields(self): | |
fields = super().get_fields() | |
raw_fields = set(self.context['request'].GET.get('fields', '').split(',')) | |
# If querysparams ?fields= doesn't evaluate to anything, default to original | |
validated_fields = set(raw_fields) & set(fields.keys()) or set(fields.keys()) | |
return {key: value for key, value in fields.items() if key in validated_fields} | |
@classmethod | |
def many_init(cls, *args, **kwargs): | |
meta = getattr(cls, 'Meta', None) | |
if not hasattr(meta, 'list_serializer_class'): | |
meta.list_serializer_class = DynamicFieldListSerializer | |
return super().many_init(*args, **kwargs) | |
class TestSerializer(DynamicSerializerFieldsMixin, serializers.ModelSerializer): | |
class Meta: | |
model = TestData | |
fields = '__all__' | |
class TestSimpleViewSet(viewsets.ReadOnlyModelViewSet): | |
queryset = TestData.objects.all() | |
serializer_class = TestSerializer | |
urlpatterns = [ | |
path('test_simple/', TestSimpleViewSet.as_view({'get': 'list'})), | |
] | |
class APITestCase(TestCase): | |
def setUp(self): | |
# Create 20 records or test data | |
for _ in range(20): | |
TestData.objects.create(**{'field' + str(_): 'value' + str(_) for _ in range(1, 9)}) | |
@staticmethod | |
def get_sql_fields_selected(sql): | |
table_name = TestData._meta.db_table | |
# Get all the fields, everything to the left of the "from" | |
sql = sql.lower().split('from')[0] | |
# Remove the beging "SELECT" | |
sql = sql.replace('select', '') | |
# Remove the table name | |
sql = sql.replace('"%s".' % table_name, '') | |
# last bit of clean up -- sorting for predictability | |
return sorted(sql.replace('"', '').replace(' ', '').split(',')) | |
def test_normal_fields(self): | |
response = self.client.get('/test_simple/', content_type='application/json') | |
self.assertEqual(response.status_code, 200) | |
data = response.json() | |
self.assertEqual(20, len(data)) | |
self.assertEqual(['field1', 'field2', 'field3', 'field4', 'field5', 'field6', 'field7', 'field8', 'id'], | |
sorted(data[0].keys())) | |
def test_filtered_fields(self): | |
with CaptureQueriesContext(connection) as context: | |
response = self.client.get('/test_simple/?fields=field1,field2', content_type='application/json') | |
self.assertEqual(response.status_code, 200) | |
self.assertEqual(['field1', 'field2'], sorted(response.json()[0].keys())) | |
queries = context.captured_queries | |
self.assertEqual(1, len(queries)) | |
selected_fields = self.get_sql_fields_selected(queries[0]['sql']) | |
self.assertEqual(['field1', 'field2', 'id'], selected_fields) | |
def test_filtered_incorrect_fields(self): | |
with CaptureQueriesContext(connection) as context: | |
response = self.client.get('/test_simple/?fields=field1,field2,nope1', content_type='application/json') | |
self.assertEqual(response.status_code, 200) | |
# Keeps working with fields that do exist | |
self.assertEqual(['field1', 'field2'], sorted(response.json()[0].keys())) | |
queries = context.captured_queries | |
self.assertEqual(1, len(queries)) | |
selected_fields = self.get_sql_fields_selected(queries[0]['sql']) | |
self.assertEqual(['field1', 'field2', 'id'], selected_fields) | |
def test_filtered_all_incorrect_fields(self): | |
with CaptureQueriesContext(connection) as context: | |
response = self.client.get('/test_simple/?fields=nope1,nope2,nope3', content_type='application/json') | |
self.assertEqual(response.status_code, 200) | |
# Keeps working with fields that do exist | |
self.assertEqual(['field1', 'field2', 'field3', 'field4', 'field5', 'field6', 'field7', 'field8', 'id'], | |
sorted(response.json()[0].keys())) | |
queries = context.captured_queries | |
self.assertEqual(1, len(queries)) | |
selected_fields = self.get_sql_fields_selected(queries[0]['sql']) | |
self.assertEqual(['field1', 'field2', 'field3', 'field4', 'field5', 'field6', 'field7', 'field8', 'id'], | |
selected_fields) | |
# Your code above this line | |
# ############################################################################## | |
# Used so you can do 'from <name of file>.models import *' | |
models_module = ModuleType('%s.models' % (APP_LABEL)) | |
tests_module = ModuleType('%s.tests' % (APP_LABEL)) | |
urls_module = ModuleType('%s.urls' % (APP_LABEL)) | |
urls_module.urlpatterns = urlpatterns | |
for variable_name, value in list(locals().items()): | |
# We are only interested in models | |
if inspect.isclass(value) and issubclass(value, models.Model): | |
setattr(models_module, variable_name, value) | |
# We are only interested in tests | |
if inspect.isclass(value) and issubclass(value, TestCase): | |
setattr(tests_module, variable_name, value) | |
# Setup the fake modules | |
sys.modules[models_module.__name__] = models_module | |
sys.modules[tests_module.__name__] = tests_module | |
sys.modules[urls_module.__name__] = urls_module | |
sys.modules[APP_LABEL].models = models_module | |
sys.modules[APP_LABEL].tests = tests_module | |
sys.modules[APP_LABEL].urls = urls_module | |
if __name__ == "__main__": | |
# Hack to fix tests | |
argv = [arg for arg in sys.argv if not arg.startswith('-')] | |
if len(argv) == 2 and argv[1] == 'test': | |
sys.argv.append(APP_LABEL) | |
from django.core.management import execute_from_command_line | |
execute_from_command_line(sys.argv) | |
else: | |
from django.core.wsgi import get_wsgi_application | |
get_wsgi_application() |
@LeMetore ps. the disqus
integration on your site is down.
Yeah, I'm actually reading the whole thing. It's so nice 👌 And yeah, Disqus is down because I planned to move from a centralized and proprietary comments system and never actually replaced it :(
@LeMeteore Your site has been down since Aug 22 :/ Im sure you're aware, just saying...
Hi Javier, really hope you're doing fine! Thanks a lot for reaching. Yeah, I lost my domain name for a reason that is so long and so stupid, you don't want to hear it. After spending too much time (w/o success) trying to retrieve it, I'm now slowly redeploying behind nskm.xyz. Once again, thanks a lot for reaching o/
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@LeMeteore haha, i've made more improvements btw: https://gist.github.com/kingbuzzman/ac2ada9c27196fc90c1b75f2d01a6271 mostly to geared towards tests and testing