Last active
October 14, 2020 00:55
-
-
Save kota7/0526f967cda2b1533398eb1206d1b2d7 to your computer and use it in GitHub Desktop.
Apply JUMAN++ to many texts in parallel.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# coding: utf-8 | |
""" | |
Apply JUMAN++ to many texts in parallel. | |
Requirements. | |
JUMAN++: http://nlp.ist.i.kyoto-u.ac.jp/index.php?JUMAN++ | |
Python library: | |
mojimoji (https://github.com/studio-ousia/mojimoji) or | |
jaconv (https://github.com/ikegami-yukino/jaconv). | |
Outputs. | |
Text files containing JUMAN++ results. | |
If IDs are given, each result start with "#{id}". | |
""" | |
import csv | |
import math | |
import subprocess | |
import os | |
import re | |
from argparse import ArgumentParser | |
from logging import getLogger, basicConfig | |
from multiprocessing import Pool, cpu_count | |
try: | |
from mojimoji import han_to_zen | |
except: | |
from jaconv import h2z as h2z | |
def han_to_zen(text, ascii=True, digit=True, kana=True): | |
return h2z(text, ascii=ascii, digit=digit, kana=kana) | |
logger = getLogger(__file__) | |
basicConfig(level=20, | |
format="[%(levelname).1s|%(asctime).19s] %(message)s") | |
def _preprocess(text, id_=None): | |
text = re.sub(r"[\r\n]+", " ", text) | |
text = re.sub("\s+", " ", text) | |
text = text.strip() | |
# add a space to avoid error due to empty input | |
text = text + " " | |
text = han_to_zen(text) | |
if id_ is not None: | |
text = "#{}\n{}".format(id_, text) | |
return text | |
def do_jumanpp(texts, outfile, ids=None, encoding="utf8"): | |
#print("Start. Outfile: `{}`".format(outfile)) | |
logger.info("Start jumanpp job: len(texts) = %s, outfile = `%s`", | |
len(texts), outfile) | |
if ids is None: | |
texts = "\n".join(_preprocess(t) for t in texts).encode(encoding) | |
else: | |
texts = "\n".join(_preprocess(t, i) for t,i in zip(texts,ids)).encode(encoding) | |
with open(outfile, "wb") as f: | |
p = subprocess.Popen(["jumanpp"], stdin=subprocess.PIPE, stdout=f) | |
o, e = p.communicate(texts) | |
logger.info("Done. Output file: `%s`", outfile) | |
#print("End. Outfile: `{}`".format(outfile)) | |
return outfile | |
def do_jumanpp_batch(texts, ids=None, encoding="utf8", | |
processes=1, batchsize=1000, | |
outfile_base="_jumanpp-result_%d.txt"): | |
os.makedirs(os.path.abspath(os.path.dirname(outfile_base)), exist_ok=True) | |
if processes is None or processes < 1: | |
processes = cpu_count() | |
logger.info("Number of processes: %s", processes) | |
# split texts into batches | |
def _generate_args(): | |
nbatch = math.ceil(len(texts) / batchsize) | |
for i in range(nbatch): | |
i1 = i*batchsize | |
i2 = min(i1+batchsize, len(texts)) | |
yield (texts[i1:i2], | |
outfile_base % i, | |
None if ids is None else ids[i1:i2], | |
encoding) | |
args = list(_generate_args()) | |
#print(args[0:5]) | |
logger.info("Results will be split into %d files", len(args)) | |
with Pool(processes) as p: | |
p.starmap(do_jumanpp, args) | |
def main(): | |
parser = ArgumentParser("jumanpp parallel minibatch") | |
parser.add_argument("input_file", type=str, | |
help="Input CSV file. " + \ | |
"Each line should contain either one element: text, " + \ | |
"or two elements: (it, text).") | |
parser.add_argument("-s", "--batchsize", default=500, type=int, | |
help="Minibatch size. Default: 500") | |
parser.add_argument("-p", "--num-process", default=-1, type=int, | |
help="Number of processes. " + \ | |
"Non-positive indicates all available CPUs. " + \ | |
"Default: -1") | |
parser.add_argument("-o", "--outfile-pattern", default="_juman-results/%d.txt", type=str, | |
help="Base output file name. Default: '_juman-results/%%d.txt'") | |
parser.add_argument("-e", "--encoding", default="utf8", type=str, | |
help="String encoding. Default: 'utf8'") | |
args = parser.parse_args() | |
texts = [] | |
ids = None | |
with open(args.input_file) as f: | |
reader = csv.reader(f) | |
first_row = True | |
has_id = False | |
for row in reader: | |
if len(row) == 0: | |
continue | |
if first_row: | |
has_id = (len(row) == 2) | |
if has_id: | |
ids = [] | |
first_row = False | |
if has_id: | |
ids.append(row[0]) | |
texts.append(row[1]) | |
logger.info("Number of inputs: %s", len(texts)) | |
do_jumanpp_batch(texts=texts, | |
ids=ids, | |
encoding=args.encoding, | |
outfile_base=args.outfile_pattern, | |
batchsize=args.batchsize, | |
processes=args.num_process) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment