Skip to content

Instantly share code, notes, and snippets.

@jarshwah
Last active August 29, 2015 14:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jarshwah/76dbc87577b7fec05807 to your computer and use it in GitHub Desktop.
Save jarshwah/76dbc87577b7fec05807 to your computer and use it in GitHub Desktop.
Fixes problems with oracle, cast issues with postgres, and removes depending on the type of compiler
diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py
index 64b5ba0..8b69406 100644
--- a/django/db/models/expressions.py
+++ b/django/db/models/expressions.py
@@ -487,11 +487,7 @@ class Value(ExpressionNode):
def as_sql(self, compiler, connection):
val = self.value
if self._output_field_or_none is not None:
- from django.db.models.sql.compiler import SQLUpdateCompiler
- if isinstance(compiler, SQLUpdateCompiler):
- val = self.output_field.get_db_prep_save(val, connection=connection)
- else:
- val = self.output_field.get_db_prep_value(val, connection=connection)
+ val = self.output_field.get_db_prep_value(val, connection=connection)
return '%s', [val]
@@ -635,10 +631,12 @@ class BaseCaseExpression(ExpressionNode):
def as_postgresql(self, compiler, connection):
sql, params = self.as_sql(compiler, connection)
if self._output_field_or_none is not None:
- from django.db.models.sql.compiler import SQLUpdateCompiler
- if isinstance(compiler, SQLUpdateCompiler):
- # cast expression for postgres
- return 'CAST(%s AS %s)' % (sql, self.output_field.db_type(connection)), params
+ # cast expression for postgres - removing components of the type
+ # within brackets: varchar(255) -> varchar. Required for values
+ # that look like strings but are more specific types like uuid or
+ # inet.
+ cast_type = self.output_field.db_type(connection).split('(')[0]
+ return 'CAST(%s AS %s)' % (sql, cast_type), params
return sql, params
diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py
index cef0c97..d697496 100644
--- a/django/db/models/fields/__init__.py
+++ b/django/db/models/fields/__init__.py
@@ -1874,7 +1874,7 @@ class IPAddressField(Field):
class GenericIPAddressField(Field):
- empty_strings_allowed = True
+ empty_strings_allowed = False
description = _("IP address")
default_error_messages = {}
diff --git a/tests/expressions_case/models.py b/tests/expressions_case/models.py
index ff1e927..2a63e40 100644
--- a/tests/expressions_case/models.py
+++ b/tests/expressions_case/models.py
@@ -11,14 +11,14 @@ class CaseTestModel(models.Model):
string = models.CharField(max_length=100)
big_integer = models.BigIntegerField(null=True)
- binary = models.BinaryField(null=True)
+ binary = models.BinaryField(default=b'')
boolean = models.BooleanField(default=False)
- comma_separated_integer = models.CommaSeparatedIntegerField(max_length=100, null=True)
+ comma_separated_integer = models.CommaSeparatedIntegerField(max_length=100, default='')
date = models.DateField(null=True)
date_time = models.DateTimeField(null=True)
decimal = models.DecimalField(max_digits=2, decimal_places=1, null=True)
duration = models.DurationField(null=True)
- email = models.EmailField(null=True)
+ email = models.EmailField(default='')
file = models.FileField(null=True)
file_path = models.FilePathField(null=True)
float = models.FloatField(null=True)
@@ -28,11 +28,11 @@ class CaseTestModel(models.Model):
null_boolean = models.NullBooleanField()
positive_integer = models.PositiveIntegerField(null=True)
positive_small_integer = models.PositiveSmallIntegerField(null=True)
- slug = models.SlugField(null=True)
+ slug = models.SlugField(default='')
small_integer = models.SmallIntegerField(null=True)
- text = models.TextField(null=True)
+ text = models.TextField(default='')
time = models.TimeField(null=True)
- url = models.URLField(null=True)
+ url = models.URLField(default='')
uuid = models.UUIDField(null=True)
fk = models.ForeignKey('self', null=True)
diff --git a/tests/expressions_case/tests.py b/tests/expressions_case/tests.py
index 1bf4422..d34f318 100644
--- a/tests/expressions_case/tests.py
+++ b/tests/expressions_case/tests.py
@@ -10,7 +10,7 @@ from django.db import models
from django.db.models import F, Q, Value
from django.db.models.expressions import SearchedCase, SimpleCase
from django.test import TestCase
-from django.utils.six import binary_type
+from django.utils.six import binary_type, text_type
from .models import CaseTestModel, FKCaseTestModel
@@ -236,12 +236,13 @@ class BaseCaseExpressionTests(TestCase):
# set explicitly
[(Value(1), Value(b'one')),
(Value(2), Value(b'two'))],
+ default=Value(b''),
output_field=models.BinaryField()))
self.assertQuerysetEqual(
CaseTestModel.objects.all().order_by('pk'),
- [(1, b'one'), (2, b'two'), (3, None), (2, b'two'), (3, None), (3, None), (4, None)],
- transform=lambda o: (o.integer, None if o.binary is None else binary_type(o.binary)))
+ [(1, b'one'), (2, b'two'), (3, b''), (2, b'two'), (3, b''), (3, b''), (4, b'')],
+ transform=lambda o: (o.integer, binary_type(o.binary)))
def test_update_boolean(self):
CaseTestModel.objects.update(
@@ -261,11 +262,11 @@ class BaseCaseExpressionTests(TestCase):
comma_separated_integer=self.create_expression(
'integer',
[(Value(1), Value('1')),
- (Value(2), Value('2,2'))]))
+ (Value(2), Value('2,2'))], default=Value('')))
self.assertQuerysetEqual(
CaseTestModel.objects.all().order_by('pk'),
- [(1, '1'), (2, '2,2'), (3, None), (2, '2,2'), (3, None), (3, None), (4, None)],
+ [(1, '1'), (2, '2,2'), (3, ''), (2, '2,2'), (3, ''), (3, ''), (4, '')],
transform=attrgetter('integer', 'comma_separated_integer'))
def test_update_date(self):
@@ -325,12 +326,12 @@ class BaseCaseExpressionTests(TestCase):
email=self.create_expression(
'integer',
[(Value(1), Value('1@example.com')),
- (Value(2), Value('2@example.com'))]))
+ (Value(2), Value('2@example.com'))], default=Value('')))
self.assertQuerysetEqual(
CaseTestModel.objects.all().order_by('pk'),
- [(1, '1@example.com'), (2, '2@example.com'), (3, None), (2, '2@example.com'), (3, None), (3, None),
- (4, None)],
+ [(1, '1@example.com'), (2, '2@example.com'), (3, ''), (2, '2@example.com'), (3, ''), (3, ''),
+ (4, '')],
transform=attrgetter('integer', 'email'))
def test_update_file(self):
@@ -342,19 +343,19 @@ class BaseCaseExpressionTests(TestCase):
self.assertQuerysetEqual(
CaseTestModel.objects.all().order_by('pk'),
- [(1, '~/1'), (2, '~/2'), (3, None), (2, '~/2'), (3, None), (3, None), (4, None)],
- transform=attrgetter('integer', 'file'))
+ [(1, '~/1'), (2, '~/2'), (3, ''), (2, '~/2'), (3, ''), (3, ''), (4, '')],
+ transform=lambda o: (o.integer, text_type(o.file)))
def test_update_file_path(self):
CaseTestModel.objects.update(
file_path=self.create_expression(
'integer',
[(Value(1), Value('~/1')),
- (Value(2), Value('~/2'))]))
+ (Value(2), Value('~/2'))], default=Value('')))
self.assertQuerysetEqual(
CaseTestModel.objects.all().order_by('pk'),
- [(1, '~/1'), (2, '~/2'), (3, None), (2, '~/2'), (3, None), (3, None), (4, None)],
+ [(1, '~/1'), (2, '~/2'), (3, ''), (2, '~/2'), (3, ''), (3, ''), (4, '')],
transform=attrgetter('integer', 'file_path'))
def test_update_float(self):
@@ -375,11 +376,10 @@ class BaseCaseExpressionTests(TestCase):
'integer',
[(Value(1), Value('~/1')),
(Value(2), Value('~/2'))]))
-
self.assertQuerysetEqual(
CaseTestModel.objects.all().order_by('pk'),
- [(1, '~/1'), (2, '~/2'), (3, None), (2, '~/2'), (3, None), (3, None), (4, None)],
- transform=attrgetter('integer', 'image'))
+ [(1, '~/1'), (2, '~/2'), (3, ''), (2, '~/2'), (3, ''), (3, ''), (4, '')],
+ transform=lambda o: (o.integer, text_type(o.image)))
def test_update_ip_address(self):
CaseTestModel.objects.update(
@@ -450,11 +450,11 @@ class BaseCaseExpressionTests(TestCase):
slug=self.create_expression(
'integer',
[(Value(1), Value('1')),
- (Value(2), Value('2'))]))
+ (Value(2), Value('2'))], default=Value('')))
self.assertQuerysetEqual(
CaseTestModel.objects.all().order_by('pk'),
- [(1, '1'), (2, '2'), (3, None), (2, '2'), (3, None), (3, None), (4, None)],
+ [(1, '1'), (2, '2'), (3, ''), (2, '2'), (3, ''), (3, ''), (4, '')],
transform=attrgetter('integer', 'slug'))
def test_update_small_integer(self):
@@ -469,16 +469,28 @@ class BaseCaseExpressionTests(TestCase):
[(1, 1), (2, 2), (3, None), (2, 2), (3, None), (3, None), (4, None)],
transform=attrgetter('integer', 'small_integer'))
+ def test_update_string(self):
+ CaseTestModel.objects.filter(string__in=['1', '2']).update(
+ string=self.create_expression(
+ 'integer',
+ [(Value(1), Value('1', output_field=models.CharField())),
+ (Value(2), Value('2', output_field=models.CharField()))]))
+
+ self.assertQuerysetEqual(
+ CaseTestModel.objects.filter(string__in=['1', '2']).order_by('pk'),
+ [(1, '1'), (2, '2'), (2, '2')],
+ transform=attrgetter('integer', 'string'))
+
def test_update_text(self):
CaseTestModel.objects.update(
text=self.create_expression(
'integer',
[(Value(1), Value('1')),
- (Value(2), Value('2'))]))
+ (Value(2), Value('2'))], default=Value('')))
self.assertQuerysetEqual(
CaseTestModel.objects.all().order_by('pk'),
- [(1, '1'), (2, '2'), (3, None), (2, '2'), (3, None), (3, None), (4, None)],
+ [(1, '1'), (2, '2'), (3, ''), (2, '2'), (3, ''), (3, ''), (4, '')],
transform=attrgetter('integer', 'text'))
def test_update_time(self):
@@ -500,12 +512,12 @@ class BaseCaseExpressionTests(TestCase):
url=self.create_expression(
'integer',
[(Value(1), Value('http://1.example.com/')),
- (Value(2), Value('http://2.example.com/'))]))
+ (Value(2), Value('http://2.example.com/'))], default=Value('')))
self.assertQuerysetEqual(
CaseTestModel.objects.all().order_by('pk'),
- [(1, 'http://1.example.com/'), (2, 'http://2.example.com/'), (3, None), (2, 'http://2.example.com/'),
- (3, None), (3, None), (4, None)],
+ [(1, 'http://1.example.com/'), (2, 'http://2.example.com/'), (3, ''), (2, 'http://2.example.com/'),
+ (3, ''), (3, ''), (4, '')],
transform=attrgetter('integer', 'url'))
def test_update_uuid(self):
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment