Skip to content

Instantly share code, notes, and snippets.

@cnk
Created October 14, 2022 21:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cnk/11fc729eb0fb6a4964a750cdecc0b893 to your computer and use it in GitHub Desktop.
Save cnk/11fc729eb0fb6a4964a750cdecc0b893 to your computer and use it in GitHub Desktop.
Import code for changing references in rich text from their html versions into draftail references
import json
import re
import uuid
from django.db import connection
from core.logging import logger
from .utils import (
get_document_by_import_id,
get_image_by_import_id,
load_page_by_import_id,
)
class RichTextFieldMigrator:
EMBED_REGEX = re.compile(r'<embed.*?id="(?P<image_id>\d+)"/>', re.M)
PAGELINK_REGEX = re.compile(r'<a.*?id="(?P<page_id>\d+)" linktype="page">', re.M)
DOCUMENT_REGEX = re.compile(r'<a.*?id="(?P<document_id>\d+)" linktype="document">', re.M)
def process(self, site_helper, text):
text = self.update_image_ids(site_helper, text)
text = self.update_page_ids(site_helper, text)
text = self.update_document_ids(site_helper, text)
return text
def update_image_ids(self, site_helper, text):
m = self.EMBED_REGEX.findall(text)
if m:
for import_id in m:
image = get_image_by_import_id(site_helper, import_id)
if image:
text = re.sub(
'<embed(.*?)id="{}"/>'.format(import_id),
r'<embed\1id="{}"/>'.format(image.id),
text
)
return text
def update_page_ids(self, site_helper, text):
m = self.PAGELINK_REGEX.findall(text)
if m:
for import_id in m:
try:
page = load_page_by_import_id(site_helper, import_id)
text = re.sub(
'<a id="{}"'.format(import_id),
r'<a id="{}"'.format(page.id),
text
)
except KeyError:
logger.warning(
'importer.flexpage.update_page_references.no_such_page',
import_id=site_helper.import_id(import_id)
)
return text
def update_document_ids(self, site_helper, text):
m = self.DOCUMENT_REGEX.findall(text)
if m:
for import_id in m:
document = get_document_by_import_id(site_helper, import_id)
if document:
text = re.sub(
'<a id="{}"'.format(import_id),
r'<a id="{}"'.format(document.id),
text
)
else:
logger.warning(
'importer.flexpage.update_document_references.no_such_document',
import_id=site_helper.import_id(import_id)
)
return text
class BaseStreamFieldMigrator:
def __init__(self, site_helper, data):
"""
Manage the json data for a StreamField in the database.
:param site_helper: a SiteHelper object that we can use to get the actual import_id from ``page_import_id``
:type site_helper: a ``core.utils.SiteHelper`` object
:param data: the StreamField data
:type data: list of dicts
"""
self.site_helper = site_helper
self.data = data
def _recurse(self, data, lookup_function, singulars, plurals, attribute='id'):
"""
In actual StreamField data, Django model references are foreign keys
into the appropriate lookup table. In our site data to be imported, we
don't know what those ids will be until we actually import the relevant
lookup table data, so in the body field data in the site.yml file, we
set the id of the record to be some database agnostic unique key that
record, like "name"
Thus before we can stuff our body field data into the database, we need
to update the record references in the data to point to the actual id of
the record in the database.
This method recurses through the tree of streamfield JSON data and
updates the appropriate fields.
"""
if isinstance(data, dict):
for k, v in list(data.items()):
if (isinstance(v, dict) or isinstance(v, tuple)):
self._recurse(v, lookup_function, singulars, plurals, attribute=attribute)
elif isinstance(v, list):
if k in plurals:
if v:
if not isinstance(v[0], int) and not isinstance(v[0], str):
# this is a list of structs, so recurse into each
for item in v:
self._recurse(item, lookup_function, singulars, plurals, attribute=attribute)
else:
# This is a list of tags
resolved = []
for datum in v:
record = lookup_function(self.site_helper, datum)
if record:
if attribute:
resolved.append(getattr(record, attribute))
else:
resolved.append(record)
data[k] = resolved
else:
self._recurse(v, lookup_function, singulars, plurals, attribute=attribute)
else:
# We just have to know the name of the field names because
# at this point we don't know the field type
if k in singulars:
if v is not None:
record = lookup_function(self.site_helper, v)
if record:
if attribute:
data[k] = getattr(record, attribute)
else:
data[k] = record
elif isinstance(data, list) or isinstance(data, tuple):
for item in data:
self._recurse(item, lookup_function, singulars, plurals, attribute=attribute)
def _update_page_references_and_rich_text(self, data):
"""
In actual StreamField data, page references are foreign keys into the
pages table. In our site data to be imported, we don't know what those
ids will be until we actually import all the pages, so in the StreamField
data in the site.yml file, we set the id of the page to be that of the
import_id of the page.
Thus before we can say we are finished inserting our JSON data into the
database, we need to update the page references in the data to point to
the actual id of the page in the database.
"""
if isinstance(data, dict):
for k, v in data.items():
if isinstance(v, dict) or isinstance(v, list) or isinstance(v, tuple):
self._update_page_references_and_rich_text(v)
else:
if k in ('page', 'root_page'):
if v is not None:
try:
page = load_page_by_import_id(self.site_helper, v)
data[k] = page.page_ptr_id
except KeyError:
logger.warning(
'importer.streamfield.update_page_references.no_such_page',
page_id=v,
import_id=self.site_helper.import_id(v)
)
elif k in ['text', 'content']:
if v is not None:
data[k] = RichTextFieldMigrator().process(self.site_helper, v)
elif isinstance(data, list) or isinstance(data, tuple):
for item in data:
self._update_page_references_and_rich_text(item)
def _maybe_add_block_ids(self, data):
"""
In actual FlexPage body field data, blocks have block ids. Our source
data might not have them, so we need to add them.
"""
if isinstance(data, dict):
if 'type' in data and 'value' in data:
if 'id' not in data:
data['id'] = str(uuid.uuid4())
self._maybe_add_block_ids(data['value'])
else:
for k, v in list(data.items()):
self._maybe_add_block_ids(v)
elif isinstance(data, list) or isinstance(data, tuple):
for item in data:
self._maybe_add_block_ids(item)
def save(self):
raise NotImplementedError
def update(self):
raise NotImplementedError
class StreamFieldMigrator(BaseStreamFieldMigrator):
def __init__(self, page_import_id, site_helper, data, table, column):
"""
Manage the json data for a StreamField in the database.
:param page_import_id: the import_id for the page that this StreamField data should belong to
:type page_import_id: string
:param site_helper: a SiteHelper object that we can use to get the actual import_id from ``page_import_id``
:type site_helper: a ``core.utils.SiteHelper`` object
:param data: the StreamField data
:type data: list of dicts
:param table: the name of the table which owns the StreamField column we want to work with
:type table: string
:param column: the name of column on ``table`` that represents the StreamField
:type column: string
"""
super().__init__(site_helper, data)
self.import_id = page_import_id
self.table = table
self.column = column
def update_page_references_and_rich_text(self):
"""
Correct any page references in the StreamField JSON to point to actual
page ids in the database instead of import ids; we have them set up to
refer to import ids in our export file.
"""
with connection.cursor() as cursor:
cursor.execute(
'select {} from {} where import_id = %s'.format(self.column, self.table),
[self.site_helper.import_id(self.import_id)]
)
body = json.loads(cursor.fetchone()[0])
self._update_page_references_and_rich_text(body)
with connection.cursor() as cursor:
cursor.execute(
'UPDATE {} SET {} = %s where import_id = %s'.format(self.table, self.column),
[json.dumps(body), self.site_helper.import_id(self.import_id)]
)
def update_image_references(self):
"""
Correct any image references in the StreamField JSON to point to actual
image ids in the database instead of import ids; we have them set up to
refer to import ids in our export file.
"""
self._recurse(
self.data,
get_image_by_import_id,
['image', 'background_image'],
['images'],
)
def save(self):
"""
Write the streamfield data as a JSON blob to the appropriate column on
the appropriate table.
"""
with connection.cursor() as cursor:
cursor.execute(
'UPDATE {} SET {} = %s where import_id = %s'.format(self.table, self.column),
[json.dumps(self.data), self.site_helper.import_id(self.import_id)]
)
def maybe_add_block_ids(self):
self._maybe_add_block_ids(self.data)
def update(self):
"""
Take the streamfield data from our export file, update the import_id
references with actual ids of the referenced records.
"""
self.maybe_add_block_ids()
self.update_image_references()
self.save()
class HomePageMigrator(BaseSpecificPageMigrator):
# This version of the HomePage doesn't have header, so can't use BaseStreamFieldPageMigrator
def load(self, page_model, site, page_data, page=None):
if not page:
page = site.home_page
return super(HomePageMigrator, self).load(page_model, site, page_data, page=page)
def create(self, site, page_data, page=None, dry_run=False):
page = super().create(HomePage, site, page_data, page=page, dry_run=dry_run)
if not dry_run:
# Body
body = StreamFieldMigrator(
page_data['id'],
site,
page_data['body'],
HomePage._meta.db_table,
'body'
)
# StreamField Tag and Display Location replacement must be done
# before we write to the database, otherwise we run into Block
# validation problems: tag and display location references need to
# be ints in the database, but we have them as strings in the yaml
body.update()
return page
def after_page_import(self, site, page_data):
# Body
body = StreamFieldMigrator(
page_data['id'],
site,
page_data['body'],
HomePage._meta.db_table,
'body'
)
body.update_page_references_and_rich_text()
def export(self, page, site):
output = super(HomePageMigrator, self).export(page, site)
if page.depth == 2:
# if this is the current home page, remove slug and parent
output.pop('slug', None)
output['parent'] = None
# now export page data
output['body'] = remove_block_ids(page.body.raw_data)
return output
class BaseSpecificPageMigrator:
"""
2021-01-04 CNK
We are not using the translations system so I am omitting the following fields from import/export:
alias_of_id
locale (exporting but not importing)
"""
export_fields = [
'title',
'subtitle',
'nav_title',
'breadcrumb_title',
'seo_title',
'draft_title',
'slug',
'search_description',
'teaser_video',
'go_live_at',
'expire_at',
'latest_revision_created_at',
'first_published_at',
'last_published_at',
'locked_at',
'teaser_title',
]
export_boolean_fields = [
'expired',
'show_title',
'show_in_menus',
'hide_from_search_engines',
'has_unpublished_changes',
'locked',
]
def load(self, page_model, site_helper, page_data, page=None, dry_run=False):
"""
Create or get a page whose import_id matches
``site_helper.import_id(page_data['import_id'])`` and update that page's fields from the
``page_data`` dict. Specifically, update these fields:
* ``title``
* ``import_id``
* ``slug`` (optional)
* ``seo_title`` (optional, set to ``title`` if ``seo_title`` is not defined
* ``live`` (optional, defaults to ``True``)
* ``show_in_menus`` (optional; Wagtail will default this to ``False``)
* ``search_description`` (optional)
:param page_model: the class object for the page model to create/update
:type page_model: a class
:param site_helper: a SiteHelper object for the site we're working on
:type site_helper: a `core.utils.SiteHelper` object
:param page_data: a dict of page data
:type page_data: dict
:param page: (optional) use this page object instead of getting one from the database
:type page: a Page type object
:rtype: 2-tuple: (page object, boolean)
"""
if not page:
if not dry_run:
if page_data.get('alias'):
import_rec = get_import_information_model().find_by_url(
site_helper.site,
page_data['alias']).first()
if import_rec:
page, created = get_or_generate(page_model, id=import_rec.page_id)
else:
page = page_model()
created = True
else:
page = page_model()
created = True
else:
page = page_model()
created = True
else:
created = False
page.import_id = site_helper.import_id(page_data['id'])
# set some of our boolean fields
page.show_in_menus = page_data.get('show_in_menus', True)
page.live = page_data.get('published', True)
page.locked = page_data.get('locked', False)
page.expired = page_data.get('expired', False)
page.has_unpublished_changes = page_data.get('has_unpublished_changes', False)
if getattr(page, 'hide_from_search_engines', None):
page.hide_from_search_engines = page_data.get('hide_from_search_engines', False)
if getattr(page, 'show_title', None):
page.show_title = page_data.get('show_title', True)
date_fields = ['first_published_at', 'last_published_at', 'latest_revision_created_at',
'go_live_at', 'expire_at', 'locked_at']
# This is the main set of imports
for field in [f for f in self.export_fields if f not in date_fields]:
# If this model has this field, set it to the value in page_data
if hasattr(page, field):
try:
setattr(page, field, page_data.get(field, ''))
except AttributeError:
# NewsPage has a nav_title property so hasattr but can't setattr; log and move on
logger.warning('importer.page.cant_set_attribute', field=field, page=page)
# pyyaml creates naive datettimes. We know we output UTC, so just bash the dates into UTC
# at least until this PR is accepted https://github.com/yaml/pyyaml/pull/113
for field in date_fields:
naive_timestamp = page_data.get(field, None)
if naive_timestamp:
setattr(page, field, naive_timestamp.replace(tzinfo=pytz.utc))
if page_data.get('locked_by', None):
# The user should exist - but if it doesn't I am just going to let this be null
page.locked_by_id = User.objects.filter(username=page_data['locked_by']).values_list('id', flat=True).first() # noqa
return page, created
def add_to_parent(self, page, site_helper, parent, dry_run):
# Add newly-created Pages to the page tree.
# If we got an integer for 'parent', look up the parent page
# also make sure site is not a Stuff obj.
if isinstance(parent, int) and not isinstance(site_helper.site, Stuff):
try:
parent = load_page_by_import_id(site_helper, parent)
except KeyError:
logger.error(
'importer.page.add_to_parent.no_such_import_id',
parent_id=parent,
import_id=site_helper.import_id(parent)
)
raise
if not dry_run:
page = parent.add_child(instance=page)
else:
logger.info(
'importer.base_page_importer.create.dry_run',
page_model=type(page),
import_id=site_helper.import_id(page),
parent_id=getattr(parent, 'id', None)
)
return page
def move_to_parent(self, page, site_helper, parent):
# If we got an integer for 'parent', look up the parent page
if isinstance(parent, int):
try:
parent = load_page_by_import_id(site_helper, parent)
except KeyError:
logger.error(
'importer.page.move_to_parent.no_such_import_id',
parent_id=parent,
import_id=site_helper.import_id(parent)
)
raise
if page.get_parent().id != parent.id:
if page.can_move_to(parent):
page.move(parent, pos="last-child")
def create(self, page_model, site_helper, page_data, page=None, dry_run=False):
loaded_page, created = self.load(page_model, site_helper, page_data, page=page, dry_run=dry_run)
if created:
loaded_page = self.add_to_parent(loaded_page, site_helper, page_data['parent'], dry_run)
else:
if not dry_run:
loaded_page.save()
if page_data.get('parent', None):
self.move_to_parent(loaded_page, site_helper, page_data['parent'])
else:
logger.info(
'importer.base_page_importer.update.dry_run',
page_model=type(page),
import_id=site_helper.import_id(page),
)
if not dry_run:
if loaded_page.has_unpublished_changes and page_data.get('draft', None):
user = User.objects.filter(username=page_data['draft']['username']).first()
user_id = user.id if user else None
go_live = page_data['draft'].get('approved_go_live_at', None)
go_live_at = go_live.replace(tzinfo=pytz.utc) if go_live else None
created = page_data['draft'].get('created_at', None)
created_at = created.replace(tzinfo=pytz.utc) if created else None
loaded_page.revisions.create(
content_json=page_data['draft']['content_json'],
user_id=user_id,
submitted_for_moderation=page_data['draft']['submitted_for_moderation'],
approved_go_live_at=go_live_at,
created_at=created_at,
)
return loaded_page
def export(self, page, site):
d = OrderedDict([
('id', page.id),
('treebeard_path', page.path),
('parent', page.get_parent().id),
('type', page.__class__.__name__),
('alias', page.relative_url()[1:]),
('published', page.live),
('locale', page.locale.language_code),
('translation_key', str(page.translation_key)),
])
for field in self.export_fields:
if getattr(page, field, None):
d[field] = getattr(page, field)
if getattr(page, 'locked_by_id', None):
d['locked_by'] = page.locked_by.username
if getattr(page, 'teaser_image_id', None):
d['teaser_image'] = page.teaser_image_id
for field in self.export_boolean_fields:
d[field] = getattr(page, field, False)
if page.has_unpublished_changes:
d['has_unpublished_changes'] = True
revision = page.get_latest_revision()
if revision:
d['draft'] = OrderedDict([
('id', revision.id),
('page_id', revision.page_id),
('username', revision.user.username if revision.user else ''),
('submitted_for_moderation', revision.submitted_for_moderation),
('created_at', revision.created_at),
('approved_go_live_at', revision.approved_go_live_at),
('content_json', revision.content_json),
])
return d
def after_page_import(self, site_helper, page_data):
"""
This is something the our subclasses can implement to do work after all the pages have been imported.
For instance, they can use this to resolve page import_ids to actual database ids for page references.
"""
pass
import os
import requests
import traceback
from io import BytesIO
from collections import OrderedDict
from django.core.files.images import ImageFile
from wagtail.core.models import Collection
from wagtail.images import get_image_model
from djunk.utils import get_or_generate
from core.logging import logger
class ImageMigrator(object):
def __get_title(self, yml):
if yml.get('title_text'):
return yml['title_text']
if yml.get('file', None):
# This import is dealing with files in the local file system
return yml.get('file')
elif yml.get('url', None):
# This import is getting urls from which we can upload the images, retrieve the actual image content
return yml['url'].split('/')[-1]
def __get_image_data(self, yml):
if yml.get('file', None):
# This import is dealing with files in the local file system
return (open(os.path.join(os.getcwd(), 'images', yml['file']), 'rb'), yml['file'])
elif yml.get('url', None):
# This import is getting urls from which we can upload the images, retrieve the actual image content
response = requests.get(yml['url'])
return (BytesIO(response.content), yml['url'].split('/')[-1])
else:
raise RuntimeError('We need either a local file or a url from which we can retrieve the file.')
def __file_needs_update(self, image, yml):
# figure out if we need to create new file object or not
# If there isn't a current file, we definitely need to update
if not image.file:
return True
# Otherwise, check if the exported file hash equals the one recorded in the database
return yml.get('file_hash', None) != image.get_file_hash()
def create(self, site_helper, yml, dry_run=False):
"""
images:
- id: 1
title_text: foobar
photo_credit: credit
caption: caption
alt_text: alt
file: foobar.jpg
tags:
- tag1
- tag2
"""
image, created = get_or_generate(get_image_model(), import_id=site_helper.import_id(yml['id']))
if not dry_run:
image.title = self.__get_title(yml).replace('%20', ' ')
image.photo_credit = yml.get('photo_credit', '')
image.caption = yml.get('caption', '')
image.alt = yml.get('alt_text', yml.get('title_text', image.title))
image.focal_point_x = yml.get('focal_point_x')
image.focal_point_y = yml.get('focal_point_y')
image.focal_point_width = yml.get('focal_point_width')
image.focal_point_height = yml.get('focal_point_height')
if yml.get('collection_name', None):
other_collection = Collection.objects.descendant_of(site_helper.collection).filter(name=yml['collection_name']).first()
image.collection = other_collection
else:
image.collection = site_helper.collection
if created or self.__file_needs_update(image, yml):
image_data, filename = self.__get_image_data(yml)
image.file = ImageFile(image_data, name=filename.replace('%20', ' '))
image.file_size = len(image.file)
image.file_hash = yml.get('file_hash', '')
try:
image.save()
image.get_file_hash()
logger.info(
'importer.image.{}'.format('create' if created else 'update'),
file=yml.get('file', yml.get('url', 'NO FILE?!'))
)
except Exception as err:
# If the image file isn't valid requests still gets a 200 status code so I can't test before saving
logger.warning(
'importer.image.invalid_image',
file=yml.get('file', yml.get('url', 'NO FILE?!')),
error=traceback.print_tb(err.__traceback__),
)
# now save the file hash so we can later use it in __file_needs_update
image.get_file_hash()
if yml.get('tags', None):
image.tags.set(yml['tags'])
else:
logger.info(
f"importer.image.{'create' if created else 'update'}.dry-run",
file=yml.get('file', yml.get('url', 'NO FILE?!'))
)
return image
def export(self, img):
img_data = OrderedDict([('id', img.id),
('file', img.filename),
('title_text', img.title),
])
# if this is in a collection other than the default collection, export collection name
if img.collection.depth > 2:
img_data['collection_name'] = img.collection.name
if getattr(img, 'alt', None) and not img.alt == 'Replace this default':
img_data['alt_text'] = img.alt
for field in ['caption', 'photo_credit', 'focal_point_height', 'focal_point_width',
'focal_point_x', 'focal_point_y', 'height', 'width', 'file_hash']:
if getattr(img, field, None):
img_data[field] = getattr(img, field)
# There is a name mismatch for this attribute but keeping it until Associates is gone
if getattr(img, 'file_size', None):
img_data['filesize'] = img.file_size
if img.tags.exists():
img_data['tags'] = [tag.name for tag in img.tags.all()]
return img_data
from collections import OrderedDict
from collections.abc import MutableSequence
from django.db.models import Q
from wagtail.core.models import Site
from wagtail.images import get_image_model
from wagtail.documents import get_document_model
from core.logging import logger
from core.models import (
NewsArticleTag,
NewsCategory,
DisplayLocation,
SpotlightType,
)
from core.utils import get_alias_and_hostname_validators
def validate_hostname(hostname):
for validator in get_alias_and_hostname_validators(site_creator=True):
validator(hostname)
def get_image_by_import_id(site_helper, import_id):
if not import_id:
return None
try:
return get_image_model().objects.get(import_id=site_helper.import_id(import_id))
except get_image_model().DoesNotExist:
logger.warning(
'importer.update_image_references.no_such_image',
import_id=site_helper.import_id(import_id)
)
return None
def get_document_by_import_id(site_helper, import_id):
if not import_id:
return None
try:
return get_document_model().objects.get(import_id=site_helper.import_id(import_id))
except get_document_model().DoesNotExist:
logger.warning(
'importer.update_document_references.no_such_document',
import_id=site_helper.import_id(import_id)
)
return None
def load_page_by_import_id(site_helper, import_id):
"""
Look through the site_helper.linkable_page_types list of Page models for
pages whose ``import_id`` field matches
``site_helper.import_id(import_id)``. If such a page exists, return it,
otherwise raise KeyError.
:param site_helper: a SiteHelper object for the site we're working on
:type site_helper: a SiteHelper object
:param import_id: we'll look for objects with their import_id column set to ``site_helper.import_id(import_id)``
:type import_id: string
:rtype: a page object of the correct subtype
"""
page = None
for page_type in site_helper.page_types:
try:
page = page_type.objects.get(import_id=site_helper.import_id(import_id)).specific
except page_type.DoesNotExist:
pass
if page is None:
logger.warning('importer.load_page_by_import_id.not_found', import_id=site_helper.import_id(import_id))
raise KeyError('No page matching import_id {} found.'.format(site_helper.import_id(import_id)))
else:
return page
def remove_block_ids(data):
"""
If block ids are really supposed to be unique, then we don't want to keep the ids and
keep importing more and more copies of them if we import a site template.
As of our upgrade to Wagtail 2.12, this also converts OrderedDicts to regular ones to
avoid errors when dumping nested OrderedDicts
"""
if isinstance(data, (OrderedDict, dict)):
return {k: remove_block_ids(v) for k, v in data.items() if not k == 'id'}
elif isinstance(data, (MutableSequence, tuple, list)):
return [remove_block_ids(v) for v in data]
else:
return data
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment