Skip to content

Instantly share code, notes, and snippets.

@numb3r3
Created June 17, 2020 12:30
Show Gist options
  • Save numb3r3/e365967a7d4cbef3d5d562bf408e91cd to your computer and use it in GitHub Desktop.
Save numb3r3/e365967a7d4cbef3d5d562bf408e91cd to your computer and use it in GitHub Desktop.
textrank++ demo
#! /usr/bin/env python
# -*- coding: utf-8 -*-
"""
"""
import os
import sys
import argparse
import json
import numpy as np
from numpy import dot
from numpy.linalg import norm
from tqdm import tqdm
from utils.data_helper import split_sentences, tokenize
from text_utils.segmenter import SentenceSegmenter
from text_utils.tokenizer import Tokenizer
from text_utils.stopwords import is_stopword
from utils.simhash import SimHash
from utils.metrics import evaluate, summary_metrics
from utils.client import Client
from graph import DocGraph
def main(args):
# print('load stopwords')
# load_stopwords(args.stopwords)
print('loading text graph ...')
segmenter = SentenceSegmenter(token_limits=args.max_sent_len)
tokenizer = Tokenizer()
select_client = None
if args.select_ip:
select_client = Client(args.select_ip, args.select_port)
encode_client = None
if args.encode_ip:
encode_client = Client(args.encode_ip, args.encode_port)
topn_results_1 = []
topn_results_2 = []
# sample_count = 0
with open(args.input, 'r') as fin:
for line in tqdm(fin):
items = line.strip().split('\t')
doc_id = items[0]
title = items[1]
content = items[2]
# labels = items[3]
gt_rank = items[3]
gt_rank_list = gt_rank.split("\001")
important_term_set = set([item.split("_")[0] for item in gt_rank_list if item.split("_")[1]=="0"])
middle_term_set = set([item.split("_")[0] for item in gt_rank_list if item.split("_")[1]=="1"])
unimportant_term_set = set([item.split("_")[0] for item in gt_rank_list if item.split("_")[1]=="2"])
title_tokenized_str = ' '.join(tokenizer.tokenize(title))
sent_tokenlized_strs = []
for i, sent in enumerate(segmenter.segment(content)):
tokens = tokenizer.tokenize(sent)
if len(tokens) >= 2:
sent_tokenlized_strs.append(' '.join(tokens))
weight_preds = None
if select_client:
weight_preds = select_client.predict([title_tokenized_str] + sent_tokenlized_strs)
sent_encodes = None
if encode_client:
sent_encodes = encode_client.predict([title_tokenized_str] + sent_tokenlized_strs)
doc_graph = DocGraph(doc_id=doc_id,
title=title_tokenized_str,
sentences=sent_tokenlized_strs,
term_weights=weight_preds,
sent_encodes=sent_encodes,
filter_term_func=lambda t: not is_stopword(t),
sim_measure=args.sim_measure)
ranks = doc_graph.rank(iters=args.maxiters, normalize=True)
def _valid_token(token):
is_valid = doc_graph.get_vertex(token).meta['vertex_type'] == 'word'
if is_valid and args.only_title:
is_valid = (token in title)
return is_valid
sorted_pred_ranks = [key for key in sorted(ranks, key=ranks.get, reverse=True) if _valid_token(key)]
result1 = evaluate([3, 5, 10], sorted_pred_ranks, important_term_set)
result2 = evaluate([3, 5, 10], sorted_pred_ranks, important_term_set | middle_term_set)
topn_results_1.append(result1)
topn_results_2.append(result2)
if args.manual_check:
print('标题: %s' % title)
print('正文: %s' % content)
print('TOP-5 重要句子:')
sorted_sents = [k for k in sorted(ranks, key=ranks.get, reverse=True) if doc_graph.get_vertex(k).meta['vertex_type'] == 'sentence']
summary_sents = {}
for t, k in enumerate(sorted_sents):
vertex = doc_graph.get_vertex(k)
idx = vertex.meta['vertex_index']
x = vertex.meta['vertex_str']
print('[%d] (%.2f) %s' % (idx, ranks[k], x))
summary_sents[idx] = x
if t >= 4:
break
print('静态摘要:')
print(' '.join([summary_sents[k] for k in sorted(summary_sents)]))
print('TOP-10 重要词:')
key_words = sorted_pred_ranks[:10]
print(' '.join(['%s - %.2f' % (k, ranks[k]) for k in key_words]))
print('最重要词: %s' % (' '.join(important_term_set)))
print('次重要词: %s' % (' '.join(middle_term_set)))
for n, p, r, f in result1:
print('TOP-%d: P: %.2f, R: %.2f, F1: %.2f' % (n, p, r, f))
print('#############################')
precision_1, recall_1, f1_score_1 = summary_metrics(topn_results_1, verbose=True)
precision_2, recall_2, f1_score_2 = summary_metrics(topn_results_2, verbose=True)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("-input", type=str, help="the input file")
parser.add_argument("-algo", type=str, help="the rank algorithm name")
parser.add_argument("-sim_measure", type=str, help="the similarity measure", default="euclidean_sim")
parser.add_argument("-stopwords", type=str, help="the stopwords file")
parser.add_argument("-maxiters", type=int, help="the maximize iterations", default=150)
parser.add_argument("-only_title", type=bool, help="only consider the title words")
parser.add_argument("-select_port", type=int, help="the selecte word service port", default=30311)
parser.add_argument("-select_ip", type=str, help="the select word service ip")
parser.add_argument("-encode_port", type=int, help="the sentence encode service port", default=30310)
parser.add_argument("-encode_ip", type=str, help="the sentence encode service ip")
parser.add_argument("-manual_check", type=bool, help="whether to print detailed result", default=False)
parser.add_argument("-max_sent_len", type=int, help="the maximize sentence length", default=20)
return parser.parse_args()
def usage():
print("Usage: %s" % (sys.argv[0]), file=sys.stderr)
if __name__ == "__main__":
args = parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment