Skip to content

Instantly share code, notes, and snippets.

@ornithos
Last active June 16, 2022 13:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ornithos/d84349c920ceb8fad04b30c519fdb3b3 to your computer and use it in GitHub Desktop.
Save ornithos/d84349c920ceb8fad04b30c519fdb3b3 to your computer and use it in GitHub Desktop.
Parsing / chunking a large XML file from StackExchange into multiple JSON files
"""
Chunk a large XML file consisting of M elements of tag `row` into N JSON blobs
which can be read directly into memory
"""
import math
import subprocess
from datetime import datetime
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict
import ujson # faster than built-in json
from lxml import etree
from tqdm import tqdm
PATH_SE = Path(...) # Define path of stack exchange data here
PATH_POSTS = PATH_SE / "Posts.xml"
# Attributes we want to store
POST_ATTRIBS = ["PostTypeId","ParentId","OwnerUserId","AcceptedAnswerId","CreationDate","Score","Title","Body","Tags"]
# We want to split the data into ~100 equally-sized json blobs so that we can read
# consecutive posts into memory without overflowing the RAM. We need two components
# 1. A streaming XML reader (lxml's etree.iterparse, n.b. the parsed DOM clean-up
# strategy is borrowed with thanks from andrekamman/stackexchangeparser, although
# I wonder in retrospect that `element.clear(keep_tail=True)` is a cleaner strategy.
# 2. A `handle_save` function that chunks up the extracted elements and periodically
# saves once the desired size is reached.
file_nlines = int(subprocess.check_output(['wc', '-l', filename]).split()[0]) # apologies to Windows folks
file_noutput = 100
@dataclass
class SaveState:
MAX_SIZE: int = 500_000
CUR_SIZE: int = 0
BATCH_IX: int = 0
CUR_DICT: Dict = field(default_factory=dict)
ID_MAP: Dict = field(default_factory=dict)
def handle_save(key, value, state):
if state.CUR_SIZE >= state.MAX_SIZE:
with open(PATH_SE / "post_blobs" / f"posts_{state.BATCH_IX}.json", "w") as f:
ujson.dump(state.CUR_DICT, f)
state.BATCH_IX += 1
state.CUR_DICT = {}
state.CUR_SIZE = 0
state.CUR_DICT[key] = value
state.ID_MAP[key] = BATCH_IX
state.CUR_SIZE += 1
state = SaveState(MAX_SIZE=math.ceil(file_nlines / file_noutput))
# Main loop: iterparse over the rows of the XML file
xml_context = etree.iterparse(str(PATH_POSTS), events=("end",), tag="row")
for _, element in tqdm(xml_context, total=file_nlines):
attributes = {column: element.attrib.get(column, "") for column in POST_ATTRIBS}
handle_save(element.attrib["Id"], attributes, state)
# dynamically remove DOM to free memory from previous rows
while element.getprevious() is not None:
del element.getparent() [0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment