Skip to content

Instantly share code, notes, and snippets.

@devvspaces
Created October 24, 2022 07:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save devvspaces/94b5331a903fb7d9aa41105a2d256beb to your computer and use it in GitHub Desktop.
Save devvspaces/94b5331a903fb7d9aa41105a2d256beb to your computer and use it in GitHub Desktop.
Mixin to monitor changes on model fields, and call a function when a change is detected
from typing import Callable, Dict, List
from django.db import models
class ModelChangeFunc(models.Model):
"""
Mixin to monitor changes on model fields,
and call a function when a change is detected
Setup:
1. Add a dict to the model with the fields to monitor
2. Add a function to the dict with the field name as key
Example:
```python
class Model(ModelChangeFunc):
field = models.CharField(max_length=100)
check = None
def check_field(self):
self.check = True
monitor_change = {
'field': check_field,
}
```
Every time `field` is changed, the function `check_field` will be called
"""
class Meta:
abstract = True
# Key and Update function to run when something changes
monitor_change: Dict[str, Callable[..., None]] = None
@property
def monitor_change_fields(self) -> List[str]:
"""
Get all fields to monitor
"""
if self.monitor_change:
return list(self.monitor_change.keys())
return []
@property
def monitor_change_funcs(self) -> List[Callable[..., None]]:
"""
Get all functions to run when a field is changed
"""
if self.monitor_change:
return list(set(self.monitor_change.values()))
return []
def get_clone_field(self, name: str) -> str:
"""
Get the clone field name for a field
"""
return f"__{name}"
def get_attr(self, field: str):
"""
Get the value of a field
"""
return getattr(self, field, None)
def call_updates(self):
"""Forcefully call all update functions"""
for function in self.monitor_change_funcs:
function(self)
def save(self, force_insert=False, force_update=False, *args, **kwargs):
"""
Save the model and call update functions if needed
"""
if self.pk:
for field in self.monitor_change_fields:
clone_field = self.get_clone_field(field)
if self.get_attr(field) != self.get_attr(clone_field):
self.monitor_change[field](self)
super().save(force_insert, force_update, *args, **kwargs)
for field in self.monitor_change_fields:
clone_field = self.get_clone_field(field)
default_value = self.get_attr(field)
setattr(self, clone_field, default_value)
@devvspaces
Copy link
Author

tests for it

import pytest
from django.db import models
from utils.mixins import ModelChangeFunc


@pytest.mark.django_db
class TestModelChangeMixin:

    class _Model(ModelChangeFunc):
        field = models.CharField(max_length=100)
        other = models.CharField(max_length=100)
        check = None

        def check_field(self):
            self.check = True

        monitor_change = {
            'field': check_field,
        }

    class NoModelCheck(ModelChangeFunc):
        field = models.CharField(max_length=100)
        other = models.CharField(max_length=100)

    class SimilarModelCheck(ModelChangeFunc):
        field = models.CharField(max_length=100)
        other = models.CharField(max_length=100)
        check = None

        def check_field(self):
            self.check = True

        monitor_change = {
            'field': check_field,
            'other': check_field,
        }

    class MultiModel(ModelChangeFunc):
        field = models.CharField(max_length=100)
        other = models.CharField(max_length=100)
        check = []

        def check_field(self):
            self.check.append()

        monitor_change = {
            'field': check_field,
            'other': check_field,
        }


    def test_monitor_change_fields(self):
        assert self._Model().monitor_change_fields == ['field']

    def test_monitor_change_fields_no_model_check(self):
        assert self.NoModelCheck().monitor_change_fields == []

    def test_monitor_change_funcs(self):
        assert self._Model().monitor_change_funcs == [
            self._Model.check_field,
        ]

    def test_monitor_change_funcs_no_model_check(self):
        assert self.NoModelCheck().monitor_change_funcs == []

    def test_monitor_change_funcs_similar_model_check(self):
        assert self.SimilarModelCheck().monitor_change_funcs == [
            self.SimilarModelCheck.check_field,
        ]

    def test_get_clone_field(self):
        assert self._Model().get_clone_field('field') == '__field'

    def test_get_attr(self):
        model = self._Model(
            field='test1',
            other='test2',
        )
        model.save()
        assert model.get_attr('field') == 'test1'
        assert model.get_attr('other') == 'test2'

    def test_call_updates(self):
        model = self._Model(
            field='test1',
            other='test2',
        )
        model.save()
        assert model.check is None
        assert model.field == 'test1'

        model.call_updates()
        assert model.check is True
        assert model.field == 'test1'

    def test_model_change_func_valid_change(self):
        model = self._Model(
            field='test1',
            other='test2',
        )
        model.save()
        assert model.check is None
        assert model.field == 'test1'

        model.field = 'test'
        model.save()
        assert model.check is True
        assert model.field == 'test'

    def test_model_change_func_no_change(self):
        model = self._Model(
            field='test1',
            other='test2',
        )
        model.save()
        assert model.check is None
        assert model.field == 'test1'

        model.field = 'test1'
        model.save()
        assert model.check is None
        assert model.field == 'test1'

    def test_model_change_func_invalid_change(self):
        model = self._Model(
            field='test1',
            other='test2',
        )
        model.save()
        assert model.check is None
        assert model.field == 'test1'

        model.other = "error"
        model.save()
        assert model.check is None
        assert model.field == 'test1'
        assert model.other == 'error'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment