Skip to content

Instantly share code, notes, and snippets.

@adamcik
Created March 30, 2010 08:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save adamcik/348914 to your computer and use it in GitHub Desktop.
Save adamcik/348914 to your computer and use it in GitHub Desktop.
Custom manger to allow access to proper inet/cidr lookups
from django.db.backends.postgresql_psycopg2.base import *
from django.db.backends.postgresql_psycopg2.base import DatabaseFeatures as PostgresqlDatabaseFeatures
from django.db.backends.postgresql_psycopg2.base import DatabaseOperations as PostgresqlDatabaseOperations
from django.db.backends.postgresql_psycopg2.base import DatabaseWrapper as PostgresqlDatabaseWrapper
class DatabaseFeatures(PostgresqlDatabaseFeatures):
uses_custom_query_class = True
class DatabaseOperations(PostgresqlDatabaseOperations):
def field_cast_sql(self, db_type):
return '%s'
def query_class(self, DefaultQueryClass):
from django.db.models.sql.where import WhereNode
from django.db.models.sql.constants import QUERY_TERMS
INET_TERMS = {}
INET_TERMS['inet_lt'] = '<'
INET_TERMS['inet_lte'] = '<='
INET_TERMS['inet_exact'] = '='
INET_TERMS['inet_gte'] = '>='
INET_TERMS['inet_gt'] = '>'
INET_TERMS['inet_not'] = '<>'
INET_TERMS['inet_is_contained'] = '<<'
INET_TERMS['inet_is_contained_or_equal'] = '<<='
INET_TERMS['inet_contains'] = '>>'
INET_TERMS['inet_contains'] = '>>='
ALL_TERMS = QUERY_TERMS.copy()
ALL_TERMS.update(INET_TERMS)
class InetAwareWhereNode(WhereNode):
def make_atom(self, child, qn):
table_alias, name, db_type, lookup_type, value_annot, params = child
if db_type != 'inet':
return super(InetAwareWhereNode, self).make_atom(child, qn)
if lookup_type in INET_TERMS:
return ('%s.%s %s inet %%s' % (table_alias, name, INET_TERMS[lookup_type]), params)
return super(InetAwareWhereNode, self).make_atom(child, qn)
class InetAwareQueryClass(DefaultQueryClass):
query_terms = ALL_TERMS
def __init__(self, model, connection, where=InetAwareWhereNode):
super(InetAwareQueryClass, self).__init__(model, connection, where)
return InetAwareQueryClass
class DatabaseWrapper(PostgresqlDatabaseWrapper):
def __init__(self, *args, **kwargs):
super(DatabaseWrapper, self).__init__(*args, **kwargs)
self.features = DatabaseFeatures()
self.ops = DatabaseOperations()
from django.db import models
class IPAddressField(models.IPAddressField):
def get_db_prep_lookup(self, lookup_type, value):
if lookup_type.startswith('inet_'):
return [value]
return super(IPAddressField, self).get_db_prep_lookup(lookup_type, value)
from IPy import IP
from django.core.exceptions import ValidationError
from django.db import models, connection
from django.db.models import sql, query
NET_TERMS = {
'inet_lt': '<',
'inet_lte': '<=',
'inet_exact': '=',
'inet_gte': '>=',
'inet_gt': '>',
'inet_not': '<>',
'inet_is_contained': '<<',
'inet_is_contained_or_equal': '<<=',
'inet_contains': '>>',
'inet_contains': '>>=',
}
class NetQuery(sql.Query):
query_terms = sql.Query.query_terms.copy()
query_terms.update(NET_TERMS)
def add_filter(self, (filter_string, value), *args, **kwargs):
if isinstance(value, IP):
value = unicode(value)
return super(NetQuery, self).add_filter((filter_string, value), *args, **kwargs)
class NetWhere(sql.where.WhereNode):
def make_atom(self, child, qn):
table_alias, name, db_type, lookup_type, value_annot, params = child
if db_type in ['cidr', 'inet'] and lookup_type in NET_TERMS:
return ('%s.%s %s inet %%s' % (table_alias, name, NET_TERMS[lookup_type]), params)
return super(NetWhere, self).make_atom(child, qn)
class NetManger(models.Manager):
def get_query_set(self):
q = NetQuery(self.model, connection, NetWhere)
return query.QuerySet(self.model, q)
class _NetAddressField(models.Field):
def to_python(self, value):
if not value:
return None
try:
return IP(value)
except ValueError, e:
raise ValidationError(e)
def get_db_prep_value(self, value):
return unicode(self.to_python(value))
def get_db_prep_lookup(self, lookup_type, value):
value = unicode(value)
if lookup_type in INET_TERMS:
return [value]
return super(_NetAddressField, self).get_db_prep_lookup(lookup_type, value)
class InetAddressField(_NetAddressField):
description = "Postgresql inet field"
__metaclass__ = models.SubfieldBase
def db_type(self):
return 'inet'
class CidrAddressField(_NetAddressField):
description = "Postgresql cidr field"
__metaclass__ = models.SubfieldBase
def db_type(self):
return 'cidr'
class MACAddressField(models.Field):
description = "Postgresql macaddr field"
def __init__(self, *args, **kwargs):
kwargs['max_length'] = 17
super(MACAddressField, self).__init__(*args, **kwargs)
def db_type(self):
return 'macaddr'
class Foo(models.Model):
inet = InetAddressField()
test = CidrAddressField()
mac = MACAddressField()
objects = NetManger()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment