Skip to content

Instantly share code, notes, and snippets.

@zacharysyoung
Last active June 30, 2023 03:18
Show Gist options
  • Save zacharysyoung/79c29599e1ec8ae9ab6df9c9ea61cc71 to your computer and use it in GitHub Desktop.
Save zacharysyoung/79c29599e1ec8ae9ab6df9c9ea61cc71 to your computer and use it in GitHub Desktop.
SO-76508000

Making it run not so slow

I mocked up a 60 MB XML by taking all the small samples in your original ZIP archive and just copying them all 200 times, which ended up with over 425k tok elements.

I then profiled your code and found a really bad culprit for chewing up time.

To process that XML took about 35 seconds:

Thu Jun 29 10:50:59 2023    profile.stats

         1521023 function calls (1520459 primitive calls) in 36.464 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      9/1    0.000    0.000   36.464   36.464 {built-in method builtins.exec}
        1    0.323    0.323   36.464   36.464 /Users/zyoung/develop/StackOverflow/main.py:1(<module>)
        1    0.030    0.030   36.130   36.130 /Users/zyoung/develop/StackOverflow/main.py:232(run)
        1    0.482    0.482   36.100   36.100 /Users/zyoung/develop/StackOverflow/main.py:114(op_extract)
    18600   35.098    0.002   35.098    0.002 {method 'index' of 'list' objects}

and you can see that the call to index took 96% of the total runtime, yikes!!

You call index to get the original position in the list of toks/dtoks:

...
for el in matching_toks:
    ...
    pos = all_toks.index(el)   # <-- right here!

    RelevantPrecedingElements = all_toks[max(pos - 6, 0) : pos]
    ...

Why does the call to index make this so slow? As the matching_toks loop progresses, the call to all_toks.index(...) must look through more and more of all_toks to find the matching element. In Big-O notation it takes your for el in matching_toks loop from O(n) to O(n*n), "O of n-squared".

We can see the same behavior in this simple example:

def test_index(size: int):
    list_ = [x for x in range(size)]
    for x in list_:
        list_.index(x)

When I time this for sizes 10_000, 50_000, 100_000 I see the following run times:

size 10000: 0.29s
size 50000: 7.2s
size 100000: 29s

50_000 takes 25 times as long as 10_000 (despite being only 5 times bigger, because 5*5), and 100_000 takes 100 times longer (despite being only 10 times bigger, because 10*10).

We can easily fix this by just giving an index a meaningful and helpful start index:

def test_index_with_start(size: int):
    list_ = [x for x in range(size)]
    last_x = 0
    for x in list_:
        list_.index(x, last_x)
        last_x = x

Now, for say x = 100_000, index doesn't have to start looking at position 0 to find 100_000, it can start looking at position 99_999. The times now look linear:

size 10000: 0.00053s
size 50000: 0.0025s
size 100000: 0.0051s

The same can be applied to your loop and index(...) call:

prev_pos = 0
for el in matching_toks:
    ...
    pos = all_toks.index(el, prev_pos)
    prev_pos = pos
    ...

With those changes, processing that big XML only took .7 seconds (down from 35).

You can view my complete code down in main.py, I made some other changes to appease the type hinter: mostly around not mixing None and str... if a variable should end up a string, always treat it as a string, even if empty (not None).

I also changed both your and my functions to return the (CSV) rows for each document/file. This allowed me to compare the results and make sure that my changes didn't affect the accuracy.

concat.py takes all the small XMLs from the ZIP archive you shared and creates big.xml, if you want to see what I tested against.

import glob
import xml.etree.ElementTree as ET
group = ET.Element("group")
for xml_file in glob.glob("xmls/*.xml"):
if "big.xml" in xml_file:
continue
root = ET.parse(xml_file).getroot()
el_text = root.find("text")
if el_text is not None:
el_text.set("file", xml_file)
group.append(el_text)
# Serialize to make source for copies, later
group_s = ET.tostring(group, encoding="unicode")
root = ET.Element("root")
for i in range(1, 201):
group = ET.fromstring(group_s)
group.set("group_no", str(i))
root.append(group)
tree = ET.ElementTree(root)
ET.indent(tree)
tree.write("big.xml", encoding="utf-8")
from lxml import etree as et
from lxml.etree import _Element # type: ignore - ignore using private ("_") things
Row = list[str]
toks_xpath = et.XPath("//tok|//dtok")
def my_extract_index(root: _Element) -> list[Row]:
llistas: list[Row] = []
all_toks = toks_xpath(root)
matching_toks = filter(
lambda tok: tok.get("xpos") is not None
and tok.get("xpos", "").startswith("A")
and not (tok.get("xpos", "").startswith("AX")),
all_toks,
)
prev_pos = 0
for el in matching_toks:
# preceding_tok = el.xpath("./preceding-sibling::tok[1][@lemma and @xpos]")
# preceding_tok_with_dtoks = el.xpath("./preceding-sibling::tok[1][not(@lemma) and not(@xpos)]")
# following_dtok_of_dtok = el.xpath("./preceding-sibling::dtok[1]")
if el.tag == "tok":
tok_dtok = "tok"
Adj = "".join(el.itertext())
Adj_lemma = el.get("lemma")
Adj_xpos = el.get("xpos")
else:
tok_dtok = "dtok"
Adj = el.get("form")
Adj_lemma = el.get("lemma")
Adj_xpos = el.get("xpos")
pos = all_toks.index(el, prev_pos)
RelevantPrecedingElements = all_toks[max(pos - 6, 0) : pos]
RelevantFollowingElements = all_toks[pos + 1 : max(pos + 6, 1)]
if RelevantPrecedingElements:
prec1 = RelevantPrecedingElements[-1]
else:
prec1 = None
if RelevantFollowingElements:
foll1 = RelevantFollowingElements[0]
else:
foll1 = None
ElementsContext = all_toks[max(pos - 6, 0) : pos + 1]
context_list: list[str] = []
if ElementsContext:
for elem in ElementsContext:
elem_text = "".join(elem.itertext())
# assert elem_text != ""
context_list.append(elem_text)
Adj = f"<{Adj}>"
for elem in RelevantFollowingElements:
elem_text = "".join(elem.itertext())
# assert elem_text != ""
context_list.append(elem_text)
fol_lem = foll1.get("lemma") if foll1 is not None else ""
prec_lem = prec1.get("lemma") if prec1 is not None else ""
fol_xpos = foll1.get("xpos") if foll1 is not None else ""
prec_xpos = prec1.get("xpos") if prec1 is not None else ""
fol_form = ""
if foll1 is not None:
if foll1.tag == "tok":
fol_form = foll1.text
elif foll1.tag == "dtok":
fol_form = foll1.get("form")
prec_form = ""
if prec1 is not None:
if prec1.tag == "tok":
prec_form = prec1.text
elif prec1.tag == "dtok":
prec_form = prec1.get("form")
context = " ".join(context_list).replace(" ,", ",").replace(" .", ".").replace(" ", " ").replace(" ", " ")
# print(f"Context is: {context}")
llistas.append(
[
context,
prec_form,
Adj,
fol_form,
prec_lem,
Adj_lemma,
fol_lem,
prec_xpos,
Adj_xpos,
fol_xpos,
tok_dtok,
]
)
prev_pos = pos
return llistas
def op_extract(root: _Element) -> list[Row]:
llistas: list[Row] = []
all_toks = toks_xpath(root)
matching_toks = filter(
lambda tok: tok.get("xpos") is not None
and tok.get("xpos", "").startswith("A")
and not (tok.get("xpos", "").startswith("AX")),
all_toks,
)
for el in matching_toks:
# preceding_tok = el.xpath("./preceding-sibling::tok[1][@lemma and @xpos]")
# preceding_tok_with_dtoks = el.xpath("./preceding-sibling::tok[1][not(@lemma) and not(@xpos)]")
# following_dtok_of_dtok = el.xpath("./preceding-sibling::dtok[1]")
if el.tag == "tok":
tok_dtok = "tok"
Adj = "".join(el.itertext())
Adj_lemma = el.get("lemma")
Adj_xpos = el.get("xpos")
else:
tok_dtok = "dtok"
Adj = el.get("form")
Adj_lemma = el.get("lemma")
Adj_xpos = el.get("xpos")
pos = all_toks.index(el)
RelevantPrecedingElements = all_toks[max(pos - 6, 0) : pos]
RelevantFollowingElements = all_toks[pos + 1 : max(pos + 6, 1)]
if RelevantPrecedingElements:
prec1 = RelevantPrecedingElements[-1]
else:
prec1 = None
if RelevantFollowingElements:
foll1 = RelevantFollowingElements[0]
else:
foll1 = None
ElementsContext = all_toks[max(pos - 6, 0) : pos + 1]
context_list: list[str] = []
if ElementsContext:
for elem in ElementsContext:
elem_text = "".join(elem.itertext())
# assert elem_text != ""
context_list.append(elem_text)
Adj = f"<{Adj}>"
for elem in RelevantFollowingElements:
elem_text = "".join(elem.itertext())
# assert elem_text != ""
context_list.append(elem_text)
fol_lem = foll1.get("lemma") if foll1 is not None else ""
prec_lem = prec1.get("lemma") if prec1 is not None else ""
fol_xpos = foll1.get("xpos") if foll1 is not None else ""
prec_xpos = prec1.get("xpos") if prec1 is not None else ""
fol_form = ""
if foll1 is not None:
if foll1.tag == "tok":
fol_form = foll1.text
elif foll1.tag == "dtok":
fol_form = foll1.get("form")
prec_form = ""
if prec1 is not None:
if prec1.tag == "tok":
prec_form = prec1.text
elif prec1.tag == "dtok":
prec_form = prec1.get("form")
context = " ".join(context_list).replace(" ,", ",").replace(" .", ".").replace(" ", " ").replace(" ", " ")
# print(f"Context is: {context}")
llistas.append(
[
context,
prec_form,
Adj,
fol_form,
prec_lem,
Adj_lemma,
fol_lem,
prec_xpos,
Adj_xpos,
fol_xpos,
tok_dtok,
]
)
return llistas
fname = "big.xml"
# fname = "xmls/A-02_potineig.xml"
root = et.parse(fname).getroot()
def verify():
my_rows = my_extract_index(root)
op_rows = op_extract(root)
assert my_rows == op_rows
def run():
import time
for func in [op_extract, my_extract_index]:
ts = time.monotonic()
func(root)
elapsed = time.monotonic() - ts
print(f"ran {func.__name__} in {elapsed:.2g}s")
def test_index(size: int):
list_ = [x for x in range(size)]
for x in list_:
list_.index(x)
def test_index_with_start(size: int):
list_ = [x for x in range(size)]
last_x = 0
for x in list_:
list_.index(x, last_x)
last_x = x
def time_it():
import timeit
for x in [10_000, 50_000, 100_000]:
t = timeit.timeit(f"test_index({x})", "from __main__ import test_index", number=1)
print(f"size {x}: {t:.2g}s")
for x in [10_000, 50_000, 100_000]:
t = timeit.timeit(f"test_index({x})", "from __main__ import test_index_with_start as test_index", number=1)
print(f"size {x}: {t:.2g}s")
run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment