Skip to content

Instantly share code, notes, and snippets.

@wonderbeyond
Last active February 15, 2023 13:13
Show Gist options
  • Save wonderbeyond/0e1f402b7595e0b4f9f653260fed029c to your computer and use it in GitHub Desktop.
Save wonderbeyond/0e1f402b7595e0b4f9f653260fed029c to your computer and use it in GitHub Desktop.
Generic foreign key field for peewee based on postgresql's jsonb type, with Inter-Model-Identifier(IMID) support.
import six
from peewee import Model, FieldAccessor
from playhouse.postgres_ext import BinaryJSONField
class GForeignKeyAccessor(FieldAccessor):
def get_rel_instance(self, instance):
value = instance.__data__.get(self.name)
model_name = value['model']
rel_model = self.field.allowed_types[model_name]
rel_field_name = value.get('rel_field') or self.field.get_ref_field(model_name).name
rel_field = getattr(rel_model, rel_field_name)
if value is not None or self.name in instance.__rel__:
if self.name not in instance.__rel__:
obj = rel_model.get(rel_field == value.get('pk', value.get('id')))
instance.__rel__[self.name] = obj
return instance.__rel__[self.name]
elif not self.field.null:
raise rel_model.DoesNotExist
return value
def __get__(self, instance, instance_type=None):
if instance is not None:
return self.get_rel_instance(instance)
return self.field
def __set__(self, instance, obj):
if isinstance(obj, tuple(self.field.allowed_types.values())):
# e.g. fm.subject = compound
instance.__data__[self.name] = self.field.db_value(obj).adapted
instance.__rel__[self.name] = obj
elif isinstance(obj, six.string_types):
# e.g. fm.subject = 'Compound/7'
data = self.field.db_value(obj).adapted
prev_value = instance.__data__.get(self.name)
instance.__data__[self.name] = data
if data != prev_value and self.name in instance.__rel__:
del instance.__rel__[self.name]
else:
# while loading from db
# reference: http://initd.org/psycopg/docs/extras.html#json-adaptation
prev_value = instance.__data__.get(self.name)
instance.__data__[self.name] = obj
if obj != prev_value and self.name in instance.__rel__:
del instance.__rel__[self.name]
instance._dirty.add(self.name)
class GForeignKeyField(BinaryJSONField):
accessor_class = GForeignKeyAccessor
def __init__(self, allowed_types=None, *args, **kwargs):
super(GForeignKeyField, self).__init__(*args, **kwargs)
self.allowed_types = {m.__name__: m for m in allowed_types or []}
def get_ref_field(self, model):
if isinstance(model, six.string_types):
if model not in self.allowed_types:
raise TypeError('{} not in in {}'.format(model, self.allowed_types))
model = self.allowed_types[model]
return getattr(model, 'id')
def db_value(self, value):
if isinstance(value, Model):
model = value._meta.model
ref_field = self.get_ref_field(model)
value = {
'model': model.__name__,
'pk': getattr(value, ref_field.name),
}
elif isinstance(value, six.string_types):
model_name, pk = value.split('/')
ref_field = self.get_ref_field(model_name)
value = {
'model': model_name,
'pk': ref_field.db_value(pk),
}
return super(GForeignKeyField, self).db_value(value)
def __eq__(self, rhs):
return self.contains(self.db_value(rhs).adapted)
def __ne__(self, rhs):
return ~(self.__eq__(rhs))

HOW TO USE

Sample Models

class Post(Model):
    pass


class News(Model):
    pass


class Comment(Model):
    of = GForeignKeyField(allowed_types={Post, News})

Create Comments

p1 = Post(); p1.save()
n1 = News(); n1.save()

Comment(of=p1).save()
Comment(of='News/1').save()  # By IMID notation

Retrieve/Query Comments by target

c1 = Comment.get(Comment.of == p1)
c2 = Comment.get(Comment.of == 'News/1)  # By IMID notation
print(Comment.select().where(Comment.of == n1).sql())
print(Comment.select().where(Comment.of == 'News/1').sql())

Get/Set Comment's target object

print(c1.of)
c1.of = n1
c1.of = 'News/1'  # or By IMID notation
c1.save()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment