Skip to content

Instantly share code, notes, and snippets.

@vishalp-perpetualny
Last active May 31, 2022 11:28
Show Gist options
  • Save vishalp-perpetualny/4cb478f1862bfb6d5e8e82be1233ff87 to your computer and use it in GitHub Desktop.
Save vishalp-perpetualny/4cb478f1862bfb6d5e8e82be1233ff87 to your computer and use it in GitHub Desktop.
neseted create mixin
from rest_framework.utils import model_meta
import copy
class SimplifiedNestedCreateMixin:
"""
Serializer with helper methods to handle nested crud.
usage:
declare a nesting map which is a dictionary of dictionary
- 'related_name_on_children_field': #'1-M or 1-1 models as key,
- 'serializer' : the serializer instance responsible for handling the create / update
- 'parent_key' : the field mapping the parent object to child object on the child model.
example:
AddressSerializer(Serializer.ModelSerializer):
#normal serializer
class Meta:
model = Address
fields ("id","street1","street2","pincode","user")
UserSerializer(SimplifiedNestedCreateMixin,Serializer.ModelSerializer):
#a serilalizer that needs support to create/update related objects
nesting_map = {
"addresses":{
"serializer":"AddressSerializer",
"parent_key":"user"
}
}
class Meta:
model = User
fields ("last_name","addresses","age","first_name")
"""
nesting_map = {}
@staticmethod
def get_model_pk_by_serializer(serializer):
"""
accepts serializer as a param, finds the primary key of the model from the serializer
why:
- a model can have a differnt primary key name, which is used to identify if a object needs to be created or updated
"""
meta,model,model_pk = getattr(serializer,"Meta",False),False,False
if meta:
model = getattr(meta,"model",False)
del meta
if model:
model_pk = getattr(model._meta,"pk",False)
if model_pk:
return model,model_pk.name
return False,False
def remove_children(self, validated_data):
nested_data, children, many_to_many = validated_data, {}, {}
info = model_meta.get_field_info(self.Meta.model)
for field_name, relation_info in info.relations.items():
if relation_info.to_many and (field_name in nested_data):
if field_name in self.nesting_map.keys():
children[field_name] = nested_data.pop(field_name)
else:
many_to_many[field_name] = validated_data.pop(field_name)
return nested_data, children, many_to_many
def nested_create(self, instance, children):
valid_relations = self.nesting_map.keys()
for key, data in children.items():
if key not in valid_relations:
continue
config = self.nesting_map.get(key, {})
serializer, param_key = config.get('serializer', None), config.get('parent_key', None)
# if serializer:
for obj_data in data:
self.create_child(serializer=serializer, object_data=obj_data,
param_key=param_key, instance=instance)
def create_child(self, serializer, object_data, param_key, instance):
serializer.create(**{'validated_data': object_data, param_key: instance})
def self_create(self, data):
current_model = self.Meta.model
instance = current_model(**data)
return instance
def to_representation(self, instance):
remove_empty = getattr(self.Meta, 'repr_remove_empty', True)
if remove_empty:
from collections import OrderedDict
ret = super().to_representation(instance)
data = OrderedDict()
for key, val in ret.items():
if val == []:
pass
elif val is not None:
data[key] = val
return data
return super().to_representation(instance)
def create(self, validated_data, **kwargs,):
validate_model = getattr(self.Meta, 'model_validation', True)
data, children, many_to_many = self.remove_children(validated_data)
data.update(kwargs)
instance = self.self_create(data)
if validate_model:
instance.clean()
instance.save()
if many_to_many:
for field_name, value in many_to_many.items():
field = getattr(instance, field_name)
field.set(value)
print("ID:", instance.pk, "Type:", type(instance))
self.nested_create(instance, children)
return instance
def update(self, instance, validated_data,**kwargs):
validate_model = getattr(self.Meta, 'model_validation', True)
data, children, many_to_many = self.remove_children(validated_data)
data.update(kwargs)
instance = self.self_update(instance,data)
if validate_model:
instance.clean()
instance.save()
if many_to_many:
for field_name, value in many_to_many.items():
print("updating many2many",field_name)
field = getattr(instance, field_name)
field.set(value)
print("ID Updated:", instance.pk, "Type:", type(instance))
self.nested_update(instance, children)
return instance
def self_update(self,instance,data):
allowed_fields = getattr(self.Meta,"edit_only_fields",[])
if len(allowed_fields) > 0:
for attr,val in data.items():
if attr in allowed_fields:
setattr(instance,attr,val)
else:
list(setattr(instance,attr,val) for attr,val in data.items())
return instance
def nested_update(self, instance, children):
valid_relations = self.nesting_map.keys()
for key, data in children.items():
if key not in valid_relations:
continue
self.nested_multi_update(instance,key,data)
def nested_multi_update(self,instance,key,data):
config = self.nesting_map.get(key, {})
serializer, param_key = config.get('serializer', None), config.get('parent_key', None)
model,pk_field = self.get_model_pk_by_serializer(serializer)
relation_cursor = getattr(instance,key,False)
if relation_cursor:
updatable = relation_cursor.all().values_list(*(pk_field,),flat=True)
to_delete = list(copy.copy(updatable))
create_items , update_items , rataining_ids=[],[],[]
for item in data:
if pk_field and pk_field in item:
obj_id = item[pk_field]
if obj_id in updatable:
try:
to_delete.remove(obj_id)
except ValueError:
pass
update_items.append(item)
rataining_ids.append(obj_id)
else:
create_items.append(item)
for item in to_delete:
self.delete_child(serializer,item,key)
for item in update_items:
self.update_child(serializer,item)
for item in create_items:
self.create_child(serializer=serializer, object_data=item,
param_key=param_key, instance=instance)
def update_child(self, serializer, object_data,):
serializer.pre_update(object_data)
def pre_update(self,data):
model,pk_field = self.get_model_pk_by_serializer(self)
obj = model.objects.get(**{pk_field:data[pk_field]})
self.update(obj,validated_data=data)
def delete_child(self, serializer, object_data,key):
serializer.delete_obj(object_data)
def delete_obj(self,id):
model,pk_field = self.get_model_pk_by_serializer(self)
obj = model.objects.get(**{pk_field:id})
obj.delete()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment