Created
September 8, 2021 20:32
-
-
Save soldni/2902c22d4aedb7133a0e8a39492f210e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# coding: utf-8 | |
import os | |
# Configuration settings here | |
TERRIER_DESTINATION = '/home/ubuntu/wikipedia/pyterrier/ruwiki' | |
PROCESSED_DATA_PATH = '/home/ubuntu/wikipedia/extracted/ruwiki' | |
#NUM_RESULTS = 100 | |
NUM_RESULTS = int(os.environ['NUM_RESULTS']) | |
LANGUAGE_JOIN_CHAR = ' ' | |
CONTROLS = {"wmodel": "BM25", "qemodel": "Bo1", "qe": "on"} | |
import argparse | |
import json | |
import os | |
import bz2 | |
import shutil | |
import hashlib | |
import sys | |
from contextlib import ExitStack | |
import unidecode | |
import pathlib | |
import pyterrier as pt | |
import pandas as pd | |
import unqlite | |
import tqdm | |
import spacy | |
import html | |
class Tokenizer: | |
def __init__(self): | |
from spacy.lang import ru | |
self._tokenizer = ru.Russian() | |
def __call__(self, text): | |
return LANGUAGE_JOIN_CHAR.join( | |
str(t) for t in self._tokenizer(text.strip()) | |
if not (t.is_stop or t.is_punct or t.is_space) | |
) | |
def hash_fn(s): | |
return hashlib.md5(s.strip().encode('utf-8')).hexdigest() | |
def clean_up_string(s): | |
s = unidecode.unidecode(s) | |
s = html.unescape(s) | |
return s | |
if len(sys.argv) > 1: | |
queries_path = sys.argv[1] | |
with open(queries_path) as f: | |
queries = [ln.strip() for ln in f if ln.strip()] | |
else: | |
queries = [ln.strip() for ln in sys.stdin if ln.strip()] | |
if len(queries) < 0: | |
raise ValueError('No queries!') | |
qid_mapping = {hash_fn(q): q for q in queries} | |
tok = Tokenizer() | |
queries = pd.DataFrame([{'qid': h, 'query': tok(q)} | |
for h, q in qid_mapping.items()]) | |
# start the terrier java service | |
if not pt.started(): | |
pt.init(boot_packages=["com.github.terrierteam:terrier-prf:-SNAPSHOT"]) | |
index = pt.IndexFactory.of(f'{TERRIER_DESTINATION}/data.properties') | |
pipeline = pt.BatchRetrieve( | |
index, | |
controls=CONTROLS, | |
metadata=["docno", "wikiId", "url"], | |
num_results=NUM_RESULTS, | |
properties={"termpipelines": "", | |
"tokeniser": "UTFTokeniser"} | |
) | |
pipeline.compile() | |
results = pipeline.transform(queries) | |
docs_db = unqlite.UnQLite(os.path.join(TERRIER_DESTINATION, 'docs.unqlite')) | |
title_db = unqlite.UnQLite(os.path.join(TERRIER_DESTINATION, 'title.unqlite')) | |
try: | |
for index, row in results.iterrows(): | |
row.pop('docid') | |
docid = row.pop('docno') | |
row['org_query'] = qid_mapping[row.pop('qid')] | |
row['doc'] = docs_db.fetch(docid).decode('utf-8') | |
row['title'] = title_db.fetch(docid).decode('utf-8') | |
sys.stdout.write(json.dumps(row.to_dict(), sort_keys=True) + '\n') | |
except BrokenPipeError: | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment