Skip to content

Instantly share code, notes, and snippets.

@ebsaral
Created July 12, 2017 12:29
Show Gist options
  • Save ebsaral/51ce560c5ef98f42259348db213e693e to your computer and use it in GitHub Desktop.
Save ebsaral/51ce560c5ef98f42259348db213e693e to your computer and use it in GitHub Desktop.
Converting a model to multi-table inheritance in Django - Migration Script
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from django.db import migrations, models
class CopyFieldsBetweenTables(migrations.operations.base.Operation):
reversible = False
def __init__(self, model_from_name, model_to_name, columns):
self.model_from_name = model_from_name
self.model_to_name = model_to_name
self.columns = columns
def state_forwards(self, app_label, state):
pass
def database_forwards(self, app_label, schema_editor, from_state, to_state):
columns = ", ".join(self.columns)
base_query = """
INSERT INTO {app_label}_{model_to}
(temp_obj_name, temp_id, {insert_columns})
SELECT '{model_from}', id, {select_columns}
FROM {app_label}_{model_from};
UPDATE {app_label}_{subobj}
SET lineitem_ptr_id = (
SELECT id
FROM {app_label}_{mainobj}
WHERE temp_id={app_label}_{subobj}.id AND
temp_obj_name='{subobj}'
LIMIT 1
);
""".format(
app_label=app_label,
model_to=self.model_to_name,
insert_columns=columns,
select_columns=columns,
model_from=self.model_from_name,
subobj=self.model_from_name,
mainobj=self.model_to_name
)
schema_editor.execute(base_query)
def database_backwards(self, app_label, schema_editor, from_state,
to_state):
pass
def describe(self):
return "Copies between two tables for %s" % self.name
class Migration(migrations.Migration):
dependencies = [
('accounting', '0026_migrate_existing_expenses'),
]
operations = [
# Create the base class (convert abstract to real model)
migrations.CreateModel(
name='Address',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('name', models.CharField(max_length=20)),
('owner', models.ForeignKey(to='another_app.Owner')),
],
),
# Add temporary holder fields
migrations.AddField(
model_name='address',
name='temp_id',
field=models.IntegerField(default=0, null=True, blank=True),
),
migrations.AddField(
model_name='address',
name='temp_obj_name',
field=models.CharField(max_length=255, null=True, blank=True),
),
# Add pointer field as a foreign key first (then it will be converted to real reference).
# By doing this, you prevent address_ptr being null issue
migrations.AddField(
model_name='homeaddress',
name='address_ptr',
field=models.ForeignKey('my_app.Address', null=True, blank=True),
),
migrations.AddField(
model_name='workaddress',
name='address_ptr',
field=models.ForeignKey('my_app.Address', null=True, blank=True),
),
# Copy data from HomeAddress and WorkAddress to Address, and update
# relation fields
CopyFieldsBetweenTables(
model_from_name='homeaddress',
model_to_name='address',
columns=['name', 'owner_id'],
),
CopyFieldsBetweenTables(
model_from_name='workaddress',
model_to_name='address',
columns=['name', 'owner_id'],
),
# Remove temporary fields in Address
migrations.RemoveField(
model_name='address',
name='temp_id'
),
migrations.RemoveField(
model_name='address',
name='temp_obj_name'
),
# Remove id fields from WorkAddress and HomeAdress,
# since we won't be using them anymore
migrations.RemoveField(
model_name='homeaddress',
name='id',
),
migrations.RemoveField(
model_name='workaddress',
name='id',
),
# Convert pointer field initially created to a OneToOneField
migrations.AlterField(
model_name='homeaddress',
name='address_ptr',
field=models.OneToOneField(parent_link=True,
on_delete=models.deletion.CASCADE,
auto_created=True, primary_key=True,
serialize=False,
to='my_app.Address'),
preserve_default=False,
),
migrations.AlterField(
model_name='workaddress',
name='address_ptr',
field=models.OneToOneField(parent_link=True,
on_delete=models.deletion.CASCADE,
auto_created=True, primary_key=True,
serialize=False,
to='my_app.Address'),
preserve_default=False,
),
# Remove copied fields which came from abstract base class
migrations.RemoveField(
model_name='homeaddress',
name='name',
),
migrations.RemoveField(
model_name='homeaddress',
name='owner',
),
migrations.RemoveField(
model_name='workaddress',
name='name',
),
migrations.RemoveField(
model_name='workaddress',
name='owner',
),
]
@kasaiee
Copy link

kasaiee commented Dec 4, 2020

Hello, first I should say thank you for the greate solution. This code, cannot execute with sqlite3 and raised this error:

sqlite3.Warning: You can only execute one statement at a time

althogh lineitem_ptr_id should change to address_ptr_id and you forgotten to put on_delete=django.db.models.deletion.CASCADE. I think this version of your code could works better:

# Generated by Django 3.1.4 on 2020-12-04 14:05

from django.db import migrations, models
import django.db.models.deletion


class CopyFieldsBetweenTables(migrations.operations.base.Operation):

    reversible = False

    def __init__(self, model_from_name, model_to_name, columns):
        self.model_from_name = model_from_name
        self.model_to_name = model_to_name
        self.columns = columns

    def state_forwards(self, app_label, state):
        pass

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        columns = ", ".join(self.columns)

        base_query = """
            INSERT INTO {app_label}_{model_to}
            (temp_obj_name, temp_id, {insert_columns})
            SELECT '{model_from}', id, {select_columns}
            FROM {app_label}_{model_from};
        """.format(
            app_label=app_label,
            model_to=self.model_to_name,
            insert_columns=columns,
            select_columns=columns,
            model_from=self.model_from_name,
            subobj=self.model_from_name,
            mainobj=self.model_to_name
        )

        schema_editor.execute(base_query)

        base_query = """
            UPDATE {app_label}_{subobj}
              SET lineitem_ptr_id = (
                SELECT id
                FROM {app_label}_{mainobj} 
                WHERE temp_id={app_label}_{subobj}.id AND
                  temp_obj_name='{subobj}'  
                LIMIT 1
              );
        """.format(
            app_label=app_label,
            model_to=self.model_to_name,
            insert_columns=columns,
            select_columns=columns,
            model_from=self.model_from_name,
            subobj=self.model_from_name,
            mainobj=self.model_to_name
        )

        schema_editor.execute(base_query)

    def database_backwards(self, app_label, schema_editor, from_state,
                           to_state):
        pass

    def describe(self):
        return "Copies between two tables for %s" % self.name

class Migration(migrations.Migration):

    dependencies = [
        ('accounting', '0026_migrate_existing_expenses'),
    ]

    operations = [
        # Create the base class (convert abstract to real model)
        migrations.CreateModel(
            name='Address',
            fields=[
                ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
                ('name', models.CharField(max_length=20)),
                ('owner', models.ForeignKey(to='another_app.Owner')),
            ],
        ),
        # Add temporary holder fields
        migrations.AddField(
            model_name='address',
            name='temp_id',
            field=models.IntegerField(default=0, null=True, blank=True),
        ),
        migrations.AddField(
            model_name='address',
            name='temp_obj_name',
            field=models.CharField(max_length=255, null=True, blank=True),
        ),
        # Add pointer field as a foreign key first (then it will be converted to real reference).
        # By doing this, you prevent address_ptr being null issue
        migrations.AddField(
            model_name='homeaddress',
            name='address_ptr',
            field=models.ForeignKey('my_app.Address', null=True, blank=True, on_delete=django.db.models.deletion.CASCADE),
        ),
        migrations.AddField(
            model_name='workaddress',
            name='address_ptr',
            field=models.ForeignKey('my_app.Address', null=True, blank=True, on_delete=django.db.models.deletion.CASCADE),
        ),
        # Copy data from HomeAddress and WorkAddress to Address, and update
        # relation fields
        CopyFieldsBetweenTables(
            model_from_name='homeaddress',
            model_to_name='address',
            columns=['name', 'owner_id'],
        ),
        CopyFieldsBetweenTables(
            model_from_name='workaddress',
            model_to_name='address',
            columns=['name', 'owner_id'],
        ),
        # Remove temporary fields in Address
        migrations.RemoveField(
            model_name='address',
            name='temp_id'
        ),
        migrations.RemoveField(
            model_name='address',
            name='temp_obj_name'
        ),
        # Remove id fields from WorkAddress and HomeAdress,
        # since we won't be using them anymore
        migrations.RemoveField(
            model_name='homeaddress',
            name='id',
        ),
        migrations.RemoveField(
            model_name='workaddress',
            name='id',
        ),
        # Convert pointer field initially created to a OneToOneField
        migrations.AlterField(
            model_name='homeaddress',
            name='address_ptr',
            field=models.OneToOneField(parent_link=True,
                                       on_delete=models.deletion.CASCADE,
                                       auto_created=True, primary_key=True,
                                       serialize=False,
                                       to='my_app.Address'),
            preserve_default=False,
        ),
        migrations.AlterField(
            model_name='workaddress',
            name='address_ptr',
            field=models.OneToOneField(parent_link=True,
                                       on_delete=models.deletion.CASCADE,
                                       auto_created=True, primary_key=True,
                                       serialize=False,
                                       to='my_app.Address'),
            preserve_default=False,
        ),
        # Remove copied fields which came from abstract base class
        migrations.RemoveField(
            model_name='homeaddress',
            name='name',
        ),
        migrations.RemoveField(
            model_name='homeaddress',
            name='owner',
        ),
        migrations.RemoveField(
            model_name='workaddress',
            name='name',
        ),
        migrations.RemoveField(
            model_name='workaddress',
            name='owner',
        ),
    ]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment