Skip to content

Instantly share code, notes, and snippets.

@jorgeramirez
Created May 27, 2019 19:43
Show Gist options
  • Save jorgeramirez/052285d70fe131c3601b94c91788000d to your computer and use it in GitHub Desktop.
Save jorgeramirez/052285d70fe131c3601b94c91788000d to your computer and use it in GitHub Desktop.
get_mmr_regression_gain_multi.py
from ast import literal_eval as make_tuple
import random
import math
import sys
import itertools
import gc
import math
import datetime
import os
import threading
import multiprocessing
import concurrent.futures
import time
from code.cnndm_acl18.PyRouge.Rouge.Rouge import Rouge
from code.cnndm_acl18.Document import Document
rouge = Rouge(use_ngram_buf=True)
def load_data(src_file, tgt_file):
docs = []
with open(src_file, 'r', encoding='utf-8') as src_reader, \
open(tgt_file, 'r', encoding='utf-8') as tgt_reader:
for src_line, tgt_line in zip(src_reader, tgt_reader):
src_line = src_line.strip()
tgt_line = tgt_line.strip()
if src_line == "" or tgt_line == "":
docs.append(None)
continue
src_sents = src_line.split('##SENT##')
tgt_sents = tgt_line.strip().split('##SENT##')
docs.append(Document(src_sents, tgt_sents))
return docs
def load_upperbound(filepath):
res = []
with open(filepath, 'r', encoding='utf-8') as reader:
for line in reader:
line = line.strip()
sp = line.split('\t')
if 'None' in sp[0]:
comb = None
else:
comb = make_tuple(sp[0])
score = float(sp[1])
res.append((comb, score))
return res
def get_mmr_order(oracle, doc):
scores = [(rouge.compute_rouge([doc.summary_sents], [[doc.doc_sents[idx]]])[
'rouge-2']['f'][0]) for idx in oracle[0]]
comb = zip(oracle[0], scores)
comb = sorted(comb, key=lambda x: -x[1])
selected = []
left = [x[0] for x in comb[1:]]
selected.append(comb[0][0])
while len(left) > 0:
candidates = [(selected + [x]) for x in left]
scores = [(rouge.compute_rouge([doc.summary_sents], [[doc.doc_sents[idx] for idx in can]])['rouge-2']['f'][0])
for can in
candidates]
tmp = zip(list(range(len(candidates))), scores)
sorted_tmp = sorted(tmp, key=lambda x: -x[1])
best_sent = left[sorted_tmp[0][0]]
best_score = sorted_tmp[0][1]
selected.append(best_sent)
del left[sorted_tmp[0][0]]
mmr_comb = tuple(selected)
return mmr_comb
def get_mmr_regression(oracle, doc):
selected = []
selected_id = []
prev_rouge = 0
res_buf = []
for sent_id in oracle:
candidates = [(selected + [x]) for x in doc.doc_sents]
cur_rouge = [(rouge.compute_rouge([doc.summary_sents], [can])[
'rouge-2']['f'][0]) for can in candidates]
selected.append(doc.doc_sents[sent_id])
selected_id.append(sent_id)
out_rouge = [(x - prev_rouge) for x in cur_rouge]
out_string = ' '.join([str(x) for x in out_rouge])
res_buf.append(out_string)
prev_rouge = max(cur_rouge)
return tuple(selected_id), '\t'.join(res_buf)
def main(src_file, tgt_file, oracle_file, output_file):
docs = load_data(src_file, tgt_file)
oracles = load_upperbound(oracle_file)
acc = 0
count = 0
for item in oracles:
if item[0] is not None:
acc += item[1]
count += 1
print('upper bound: {0}'.format(acc / count))
count = 0
with open(output_file, 'w', encoding='utf-8') as writer:
for doc, oracle in zip(docs, oracles):
count += 1
if count % 100 == 0:
print(count)
rouge.ngram_buf = {}
if oracle[0] is None:
writer.write('None\t0' + '\n')
continue
oracle_with_order = get_mmr_order(oracle, doc)
oracle_with_order, rouge_scores = get_mmr_regression(
oracle_with_order, doc)
writer.write('{0}\t{1}'.format(
oracle_with_order, rouge_scores) + '\n')
def main_chunks(in_file, out_file, init, limit):
for i in range(init, limit):
src_file = "%s_%03d.src.txt" % (in_file, i)
tgt_file = "%s_%03d.tgt.txt" % (in_file, i)
oracle_file = "%s_%03d.oracle.txt" % (in_file, i)
fout = "%s_%03d.regain.txt" % (out_file, i)
print("src_file/tgt_file/oracle %s fout %s" % (src_file, fout))
if not os.path.isfile(src_file):
print("input %s not found" % src_file)
break
main(src_file, tgt_file, oracle_file, fout)
def _main_multi(in_file, out_file, total_chunks):
if total_chunks == 0:
# we do not multiprocess the file
src_file = "%s.src.txt" % in_file
tgt_file = "%s.tgt.txt" % in_file
oracle_file = "%s.oracle.txt" % in_file
fout = "%s.regain.txt" % out_file
main(src_file, tgt_file, oracle_file, fout)
starttime = time.time()
n_cpu = multiprocessing.cpu_count()
steps = 1
N = n_cpu
if n_cpu < total_chunks:
steps = (total_chunks // n_cpu) + 1
processes = []
for i in range(N + 1):
print(i)
p = multiprocessing.Process(target=main_chunks,
args=(in_file, out_file, i * steps, (i + 1) * steps))
processes.append(p)
p.start()
for process in processes:
process.join()
print("%s done!" % in_file)
if __name__ == "__main__":
# src_file = r"sample_data/train.txt.src.100"
# tgt_file = r"sample_data/train.txt.tgt.100"
# oracle_file = r"sample_data/train.rouge_bigram_F1.oracle.100"
# output_file = r"sample_data/train.rouge_bigram_F1.oracle.100.regGain"
# main(src_file, tgt_file, oracle_file, output_file)
_main_multi(sys.argv[1], sys.argv[2], int(sys.argv[3]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment