Skip to content

Instantly share code, notes, and snippets.

Created May 27, 2019 19:44
Show Gist options
  • Save jorgeramirez/9ef253e65505cc04e4b561417f03feff to your computer and use it in GitHub Desktop.
Save jorgeramirez/9ef253e65505cc04e4b561417f03feff to your computer and use it in GitHub Desktop.
import sys
import itertools
import gc
import math
import datetime
import os
import threading
import multiprocessing
import concurrent.futures
import time
from code.cnndm_acl18.PyRouge.Rouge.Rouge import Rouge
from code.cnndm_acl18.Document import Document
rouge = Rouge(use_ngram_buf=True)
MAX_COMB_NUM = 100000
def c_n_x(n, x):
if x > (n >> 2):
x = n - x
res = 1
for i in range(n, n - x, -1):
res *= i
for i in range(x, 0, -1):
res = res // i
return res
def solve_one(document):
if document.doc_len == 0 or document.summary_len == 0:
return None, 0
sentence_bigram_recall = [0] * document.doc_len
for idx, sent in enumerate(document.doc_sents):
scores = rouge.compute_rouge([document.summary_sents], [sent])
recall = scores['rouge-2']['r'][0]
sentence_bigram_recall[idx] = recall
candidates = []
for idx, recall in enumerate(sentence_bigram_recall):
if recall > 0:
all_best_l = 1
all_best_score = 0
all_best_comb = None
for l in range(1, len(candidates)):
if l > MAX_COMB_L:
print('Exceed MAX_COMB_L')
comb_num = c_n_x(len(candidates), l)
if math.isnan(comb_num) or math.isinf(comb_num) or comb_num > MAX_COMB_NUM:
print('Exceed MAX_COMB_NUM')
combs = itertools.combinations(candidates, l)
l_best_score = 0
l_best_choice = None
for comb in combs:
c_string = [document.doc_sents[idx] for idx in comb]
rouge_scores = rouge.compute_rouge(
[document.summary_sents], [c_string])
rouge_bigram_f1 = rouge_scores['rouge-2']['f'][0]
if rouge_bigram_f1 > l_best_score:
l_best_score = rouge_bigram_f1
l_best_choice = comb
if l_best_score > all_best_score:
all_best_l = l
all_best_score = l_best_score
all_best_comb = l_best_choice
if l > all_best_l:
return all_best_comb, all_best_score
def solve(documents, output_file):
writer = open(output_file, 'w', encoding='utf-8', buffering=1)
for idx, doc in enumerate(documents):
if idx % 50 == 0:
rouge.ngram_buf = {}
comb = solve_one(doc)
writer.write('{0}\t {1}'.format(comb[0], comb[1]) + '\n')
def load_data(src_file, tgt_file):
docs = []
with open(src_file, 'r', encoding='utf-8') as src_reader, \
open(tgt_file, 'r', encoding='utf-8') as tgt_reader:
for src_line, tgt_line in zip(src_reader, tgt_reader):
src_line = src_line.strip()
tgt_line = tgt_line.strip()
if src_line == "" or tgt_line == "":
src_sents = src_line.split('##SENT##')
tgt_sents = tgt_line.strip().split('##SENT##')
docs.append(Document(src_sents, tgt_sents))
return docs
def main(src_file, tgt_file, outfile_name):
docs = load_data(src_file, tgt_file)
solve(docs, outfile_name)
def main_chunks(in_file, out_file, init, limit):
for i in range(init, limit):
src_file = "%s_%03d.src.txt" % (in_file, i)
tgt_file = "%s_%03d.tgt.txt" % (in_file, i)
fout = "" % (out_file, i)
print("src_file/tgt_file %s fout %s" % (src_file, fout))
if not os.path.isfile(src_file):
print("input %s not found" % src_file)
main(src_file, tgt_file, fout)
def _main_multi(in_file, out_file, total_chunks):
starttime = time.time()
n_cpu = multiprocessing.cpu_count()
steps = 1
N = n_cpu
if n_cpu < total_chunks:
steps = (total_chunks // n_cpu) + 1
processes = []
for i in range(N + 1):
p = multiprocessing.Process(target=main_chunks,
args=(in_file, out_file, i * steps, (i + 1) * steps))
for process in processes:
# with concurrent.futures.ThreadPoolExecutor(max_workers=N) as executor:
# for i in range(N + 1):
# print(i)
# executor.submit(main_chunks, in_file, out_file,
# i * steps, (i + 1) * steps)
print("%s done!" % in_file)
if __name__ == "__main__":
_main_multi(sys.argv[1], sys.argv[2], int(sys.argv[3]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment