Skip to content

Instantly share code, notes, and snippets.

@jorgeramirez
Created May 27, 2019 19:44
Show Gist options
  • Save jorgeramirez/9ef253e65505cc04e4b561417f03feff to your computer and use it in GitHub Desktop.
Save jorgeramirez/9ef253e65505cc04e4b561417f03feff to your computer and use it in GitHub Desktop.
find_oracle_multi.py
import sys
import itertools
import gc
import math
import datetime
import os
import threading
import multiprocessing
import concurrent.futures
import time
from code.cnndm_acl18.PyRouge.Rouge.Rouge import Rouge
from code.cnndm_acl18.Document import Document
rouge = Rouge(use_ngram_buf=True)
MAX_COMB_L = 5
MAX_COMB_NUM = 100000
def c_n_x(n, x):
if x > (n >> 2):
x = n - x
res = 1
for i in range(n, n - x, -1):
res *= i
for i in range(x, 0, -1):
res = res // i
return res
def solve_one(document):
if document.doc_len == 0 or document.summary_len == 0:
return None, 0
sentence_bigram_recall = [0] * document.doc_len
for idx, sent in enumerate(document.doc_sents):
scores = rouge.compute_rouge([document.summary_sents], [sent])
recall = scores['rouge-2']['r'][0]
sentence_bigram_recall[idx] = recall
candidates = []
for idx, recall in enumerate(sentence_bigram_recall):
if recall > 0:
candidates.append(idx)
all_best_l = 1
all_best_score = 0
all_best_comb = None
for l in range(1, len(candidates)):
if l > MAX_COMB_L:
print('Exceed MAX_COMB_L')
break
comb_num = c_n_x(len(candidates), l)
if math.isnan(comb_num) or math.isinf(comb_num) or comb_num > MAX_COMB_NUM:
print('Exceed MAX_COMB_NUM')
break
combs = itertools.combinations(candidates, l)
l_best_score = 0
l_best_choice = None
for comb in combs:
c_string = [document.doc_sents[idx] for idx in comb]
rouge_scores = rouge.compute_rouge(
[document.summary_sents], [c_string])
rouge_bigram_f1 = rouge_scores['rouge-2']['f'][0]
if rouge_bigram_f1 > l_best_score:
l_best_score = rouge_bigram_f1
l_best_choice = comb
if l_best_score > all_best_score:
all_best_l = l
all_best_score = l_best_score
all_best_comb = l_best_choice
else:
if l > all_best_l:
break
return all_best_comb, all_best_score
def solve(documents, output_file):
writer = open(output_file, 'w', encoding='utf-8', buffering=1)
for idx, doc in enumerate(documents):
if idx % 50 == 0:
print(datetime.datetime.now())
rouge.ngram_buf = {}
gc.collect()
comb = solve_one(doc)
writer.write('{0}\t {1}'.format(comb[0], comb[1]) + '\n')
writer.close()
def load_data(src_file, tgt_file):
docs = []
with open(src_file, 'r', encoding='utf-8') as src_reader, \
open(tgt_file, 'r', encoding='utf-8') as tgt_reader:
for src_line, tgt_line in zip(src_reader, tgt_reader):
src_line = src_line.strip()
tgt_line = tgt_line.strip()
if src_line == "" or tgt_line == "":
docs.append(None)
continue
src_sents = src_line.split('##SENT##')
tgt_sents = tgt_line.strip().split('##SENT##')
docs.append(Document(src_sents, tgt_sents))
return docs
def main(src_file, tgt_file, outfile_name):
docs = load_data(src_file, tgt_file)
solve(docs, outfile_name)
def main_chunks(in_file, out_file, init, limit):
for i in range(init, limit):
src_file = "%s_%03d.src.txt" % (in_file, i)
tgt_file = "%s_%03d.tgt.txt" % (in_file, i)
fout = "%s_%03d.oracle.txt" % (out_file, i)
print("src_file/tgt_file %s fout %s" % (src_file, fout))
if not os.path.isfile(src_file):
print("input %s not found" % src_file)
break
main(src_file, tgt_file, fout)
def _main_multi(in_file, out_file, total_chunks):
starttime = time.time()
n_cpu = multiprocessing.cpu_count()
steps = 1
N = n_cpu
if n_cpu < total_chunks:
steps = (total_chunks // n_cpu) + 1
processes = []
for i in range(N + 1):
print(i)
p = multiprocessing.Process(target=main_chunks,
args=(in_file, out_file, i * steps, (i + 1) * steps))
processes.append(p)
p.start()
for process in processes:
process.join()
# with concurrent.futures.ThreadPoolExecutor(max_workers=N) as executor:
# for i in range(N + 1):
# print(i)
# executor.submit(main_chunks, in_file, out_file,
# i * steps, (i + 1) * steps)
print("%s done!" % in_file)
if __name__ == "__main__":
_main_multi(sys.argv[1], sys.argv[2], int(sys.argv[3]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment