-
-
Save kingbuzzman/d7859d9734b590e52fad787d19c34b52 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python | |
# -*- coding:utf-8 -*- | |
# Stolen from: https://mlvin.xyz/django-single-file-project.html | |
import inspect | |
import os | |
import sys | |
from types import ModuleType | |
import django | |
from django.conf import settings | |
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
# The current name of the file, which will be the name of our app | |
APP_LABEL, _ = os.path.splitext(os.path.basename(os.path.abspath(__file__))) | |
# Migrations folder need to be created, and django needs to be told where it is | |
APP_MIGRATION_MODULE = '%s_migrations' % APP_LABEL | |
APP_MIGRATION_PATH = os.path.join(BASE_DIR, APP_MIGRATION_MODULE) | |
# Create the folder and a __init__.py if they don't exist | |
if not os.path.exists(APP_MIGRATION_PATH): | |
os.makedirs(APP_MIGRATION_PATH) | |
open(os.path.join(APP_MIGRATION_PATH, '__init__.py'), 'w').close() | |
# Hack to trick Django into thinking this file is actually a package | |
sys.modules[APP_LABEL] = sys.modules[__name__] | |
sys.modules[APP_LABEL].__path__ = [os.path.abspath(__file__)] | |
settings.configure( | |
DEBUG=True, | |
ROOT_URLCONF='%s.urls' % APP_LABEL, | |
MIDDLEWARE=( | |
'django.middleware.common.CommonMiddleware', | |
'django.contrib.sessions.middleware.SessionMiddleware', | |
'django.contrib.auth.middleware.AuthenticationMiddleware', | |
'django.contrib.messages.middleware.MessageMiddleware', | |
), | |
INSTALLED_APPS=[ | |
'django.contrib.admin', | |
'django.contrib.auth', | |
'django.contrib.contenttypes', | |
'django.contrib.sessions', | |
'django.contrib.sites', | |
'django.contrib.staticfiles', | |
'rest_framework', | |
APP_LABEL, | |
], | |
MIGRATION_MODULES={APP_LABEL: APP_MIGRATION_MODULE}, | |
SITE_ID=1, | |
DATABASES={ | |
'default': { | |
'ENGINE': 'django.db.backends.sqlite3', | |
'NAME': os.path.join(BASE_DIR, 'db.sqlite3'), | |
} | |
}, | |
LOGGING={ | |
'version': 1, | |
'disable_existing_loggers': False, | |
'formatters': { | |
'simple': { | |
'format': "%(levelname)s %(message)s", | |
}, | |
}, | |
'handlers': { | |
'console': { | |
'level': 'DEBUG', | |
'class': 'logging.StreamHandler', | |
'formatter': 'simple', | |
} | |
}, | |
'loggers': { | |
'django.db.backends': {'handlers': ['console'], 'level': 'DEBUG', 'propagate': False}, | |
'django.db.backends.schema': {'level': 'ERROR'}, # Causes sql logs to duplicate -- really annoying | |
} | |
}, | |
STATIC_URL='/static/', | |
REST_FRAMEWORK={ | |
'DEFAULT_PERMISSION_CLASSES': (), | |
'DEFAULT_AUTHENTICATION_CLASSES': (), | |
} | |
) | |
django.setup() | |
from django.apps import apps # noqa: E402 isort:skip | |
# Setup the AppConfig so we don't have to add the app_label to all our models | |
def get_containing_app_config(module): | |
if module == '__main__': | |
return apps.get_app_config(APP_LABEL) | |
return apps._get_containing_app_config(module) | |
apps._get_containing_app_config = apps.get_containing_app_config | |
apps.get_containing_app_config = get_containing_app_config | |
# Your code below this line | |
# ############################################################################## | |
from django.db import models, connection # noqa: E402 isort:skip | |
from django.test import TestCase # noqa: E402 isort:skip | |
from django.test.utils import CaptureQueriesContext # noqa: E402 isort:skip | |
from django.urls import path # noqa: E402 isort:skip | |
from rest_framework import serializers, viewsets # noqa: E402 isort:skip | |
class TestData(models.Model): | |
field1 = models.CharField(max_length=10) | |
field2 = models.CharField(max_length=10) | |
field3 = models.CharField(max_length=10) | |
field4 = models.CharField(max_length=10) | |
field5 = models.CharField(max_length=10) | |
field6 = models.CharField(max_length=10) | |
field7 = models.CharField(max_length=10) | |
field8 = models.CharField(max_length=10) | |
class DynamicFieldListSerializer(serializers.ListSerializer): | |
def to_representation(self, data): | |
""" | |
List of object instances -> List of dicts of primitive datatypes. | |
""" | |
# ONLY HERE BECAUSE OF THE NATURE OF EMBEDDING DJANGO INTO A SINGLE FILE. `models` GETS REROUTED TO BE THIS | |
# FILES `models` NOT THE ONE ORIGNIALLY IMPORTED. | |
from django.db import models # noqa: E402 isort:skip | |
# Dealing with nested relationships, data can be a Manager, | |
# so, first get a queryset from the Manager if needed | |
iterable = data.all() if isinstance(data, models.Manager) else data | |
fields = list(self.child.get_fields().keys()) | |
iterable = iterable.only(*fields) | |
return [ | |
self.child.to_representation(item) for item in iterable | |
] | |
class DynamicSerializerFieldsMixin: | |
def get_fields(self): | |
fields = super().get_fields() | |
raw_fields = set(self.context['request'].GET.get('fields', '').split(',')) | |
# If querysparams ?fields= doesn't evaluate to anything, default to original | |
validated_fields = set(raw_fields) & set(fields.keys()) or set(fields.keys()) | |
return {key: value for key, value in fields.items() if key in validated_fields} | |
@classmethod | |
def many_init(cls, *args, **kwargs): | |
meta = getattr(cls, 'Meta', None) | |
if not hasattr(meta, 'list_serializer_class'): | |
meta.list_serializer_class = DynamicFieldListSerializer | |
return super().many_init(*args, **kwargs) | |
class TestSerializer(DynamicSerializerFieldsMixin, serializers.ModelSerializer): | |
class Meta: | |
model = TestData | |
fields = '__all__' | |
class TestSimpleViewSet(viewsets.ReadOnlyModelViewSet): | |
queryset = TestData.objects.all() | |
serializer_class = TestSerializer | |
urlpatterns = [ | |
path('test_simple/', TestSimpleViewSet.as_view({'get': 'list'})), | |
] | |
class APITestCase(TestCase): | |
def setUp(self): | |
# Create 20 records or test data | |
for _ in range(20): | |
TestData.objects.create(**{'field' + str(_): 'value' + str(_) for _ in range(1, 9)}) | |
@staticmethod | |
def get_sql_fields_selected(sql): | |
table_name = TestData._meta.db_table | |
# Get all the fields, everything to the left of the "from" | |
sql = sql.lower().split('from')[0] | |
# Remove the beging "SELECT" | |
sql = sql.replace('select', '') | |
# Remove the table name | |
sql = sql.replace('"%s".' % table_name, '') | |
# last bit of clean up -- sorting for predictability | |
return sorted(sql.replace('"', '').replace(' ', '').split(',')) | |
def test_normal_fields(self): | |
response = self.client.get('/test_simple/', content_type='application/json') | |
self.assertEqual(response.status_code, 200) | |
data = response.json() | |
self.assertEqual(20, len(data)) | |
self.assertEqual(['field1', 'field2', 'field3', 'field4', 'field5', 'field6', 'field7', 'field8', 'id'], | |
sorted(data[0].keys())) | |
def test_filtered_fields(self): | |
with CaptureQueriesContext(connection) as context: | |
response = self.client.get('/test_simple/?fields=field1,field2', content_type='application/json') | |
self.assertEqual(response.status_code, 200) | |
self.assertEqual(['field1', 'field2'], sorted(response.json()[0].keys())) | |
queries = context.captured_queries | |
self.assertEqual(1, len(queries)) | |
selected_fields = self.get_sql_fields_selected(queries[0]['sql']) | |
self.assertEqual(['field1', 'field2', 'id'], selected_fields) | |
def test_filtered_incorrect_fields(self): | |
with CaptureQueriesContext(connection) as context: | |
response = self.client.get('/test_simple/?fields=field1,field2,nope1', content_type='application/json') | |
self.assertEqual(response.status_code, 200) | |
# Keeps working with fields that do exist | |
self.assertEqual(['field1', 'field2'], sorted(response.json()[0].keys())) | |
queries = context.captured_queries | |
self.assertEqual(1, len(queries)) | |
selected_fields = self.get_sql_fields_selected(queries[0]['sql']) | |
self.assertEqual(['field1', 'field2', 'id'], selected_fields) | |
def test_filtered_all_incorrect_fields(self): | |
with CaptureQueriesContext(connection) as context: | |
response = self.client.get('/test_simple/?fields=nope1,nope2,nope3', content_type='application/json') | |
self.assertEqual(response.status_code, 200) | |
# Keeps working with fields that do exist | |
self.assertEqual(['field1', 'field2', 'field3', 'field4', 'field5', 'field6', 'field7', 'field8', 'id'], | |
sorted(response.json()[0].keys())) | |
queries = context.captured_queries | |
self.assertEqual(1, len(queries)) | |
selected_fields = self.get_sql_fields_selected(queries[0]['sql']) | |
self.assertEqual(['field1', 'field2', 'field3', 'field4', 'field5', 'field6', 'field7', 'field8', 'id'], | |
selected_fields) | |
# Your code above this line | |
# ############################################################################## | |
# Used so you can do 'from <name of file>.models import *' | |
models_module = ModuleType('%s.models' % (APP_LABEL)) | |
tests_module = ModuleType('%s.tests' % (APP_LABEL)) | |
urls_module = ModuleType('%s.urls' % (APP_LABEL)) | |
urls_module.urlpatterns = urlpatterns | |
for variable_name, value in list(locals().items()): | |
# We are only interested in models | |
if inspect.isclass(value) and issubclass(value, models.Model): | |
setattr(models_module, variable_name, value) | |
# We are only interested in tests | |
if inspect.isclass(value) and issubclass(value, TestCase): | |
setattr(tests_module, variable_name, value) | |
# Setup the fake modules | |
sys.modules[models_module.__name__] = models_module | |
sys.modules[tests_module.__name__] = tests_module | |
sys.modules[urls_module.__name__] = urls_module | |
sys.modules[APP_LABEL].models = models_module | |
sys.modules[APP_LABEL].tests = tests_module | |
sys.modules[APP_LABEL].urls = urls_module | |
if __name__ == "__main__": | |
# Hack to fix tests | |
argv = [arg for arg in sys.argv if not arg.startswith('-')] | |
if len(argv) == 2 and argv[1] == 'test': | |
sys.argv.append(APP_LABEL) | |
from django.core.management import execute_from_command_line | |
execute_from_command_line(sys.argv) | |
else: | |
from django.core.wsgi import get_wsgi_application | |
get_wsgi_application() |
Hi o/ really hope you're doing fine :) Here is the author of Django single file project article I would like to say thanks a lot for improving the original source code 🥇 And in my turn, I will take this gist and run with it too 🏃♂️
@LeMeteore haha, i've made more improvements btw: https://gist.github.com/kingbuzzman/ac2ada9c27196fc90c1b75f2d01a6271 mostly to geared towards tests and testing
@LeMetore ps. the disqus
integration on your site is down.
Yeah, I'm actually reading the whole thing. It's so nice 👌 And yeah, Disqus is down because I planned to move from a centralized and proprietary comments system and never actually replaced it :(
@LeMeteore Your site has been down since Aug 22 :/ Im sure you're aware, just saying...
Hi Javier, really hope you're doing fine! Thanks a lot for reaching. Yeah, I lost my domain name for a reason that is so long and so stupid, you don't want to hear it. After spending too much time (w/o success) trying to retrieve it, I'm now slowly redeploying behind nskm.xyz. Once again, thanks a lot for reaching o/
In order to run the code above run the following:
If you want the REAL copy and paste version:
Answers stackoverflow question: https://stackoverflow.com/questions/56276747/limit-django-fields-queried-in-sql-call-to-database-by-queryset