|
def PoweredRelatedFactory(related_factory_class, related_name, pre_save=None): |
|
"""Improvement of factory.RelatedFactory. |
|
|
|
:param related_factory_class: (class) factory class of the related field |
|
:param related_name: (str) string defining the name of the attribute |
|
of the related factory |
|
:param pre_save: (function) function called just before calling |
|
the related factory class. Takes: |
|
- `related_kwargs` (dict) ATTR__SUBATTR dict of the related object to create |
|
- `related_n` (int) index of the list of related objects, if |
|
it's passed a `list`. (default: 0) |
|
- `related_tot` (int) total length of list of related objects |
|
if present. None if it's not a `list`. |
|
|
|
Example of usage: |
|
``` |
|
import factory |
|
from django.db import models |
|
class ProductFactory(factory.django.DjangoModelFactory): |
|
class Meta: |
|
model = models.Product |
|
prices = PoweredRelatedFactory(PriceFactory, 'product', pre_save=set_quantities) |
|
``` |
|
Where the model of `PriceFactory` has a foreign key for model `Product`: |
|
``` |
|
from django.db import models |
|
class Price(models.Model): |
|
product = models.ForeignKey(Product) |
|
``` |
|
|
|
This way we can call the factory like this: |
|
``` |
|
ProductFactory(prices__USD=15, prices__EUR=12, prices__GBP=10) |
|
ProductFactory(prices={'USD': 102, 'GBP': 90, 'EUR': 81}) |
|
ProductFactory(prices=[{'USD': 102, 'GBP': 90, 'EUR': 81}, {'USD': 12, 'GBP': 8, 'EUR': 10}]) |
|
``` |
|
""" |
|
def post_generation_function(obj, create, extracted, **kwargs): |
|
"""Function to be decorated by factory.PostGeneration |
|
|
|
:param extracted: (None|dict|list of ATTR__SUBATTR dicts) used to |
|
populate the related factory |
|
e.g.: |
|
- None |
|
- {'USD': 102, 'GBP': 90, 'EUR': 81} |
|
- [{'USD': 102, 'GBP': 90, 'EUR': 81}, {'USD': 12, 'GBP': 8, 'EUR': 10}] |
|
:param kwargs: ATTR__SUBATTR dict used to populate the related factory |
|
e.g.: |
|
- None |
|
- {'USD': 102, 'GBP': 90, 'EUR': 81} |
|
|
|
Note: |
|
ATTR__SUBATTR defines the django ORM syntax. |
|
e.g.: product_type__name |
|
""" |
|
def create_related_obj(related_kwargs, related_n=0, related_tot=None): |
|
if pre_save: |
|
pre_save(related_kwargs, related_n, related_tot) |
|
related_kwargs[related_name] = obj |
|
related_factory_class(**related_kwargs) |
|
|
|
if isinstance(extracted, dict): |
|
create_related_obj(extracted) |
|
elif hasattr(extracted, '__iter__'): |
|
tot_extracted = len(extracted) |
|
for related_obj_n, related_obj in enumerate(extracted): |
|
create_related_obj(related_obj, related_obj_n, tot_extracted) |
|
else: |
|
create_related_obj(kwargs) |
|
|
|
return factory.PostGeneration(post_generation_function) |