Last active
August 29, 2019 08:55
-
-
Save shijieyao/1a747bd928f0aad5c15581ff167f70f9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict | |
from ipywidgets import widgets, Output, HBox, VBox, interact, interactive, fixed, Layout | |
from IPython.display import display, clear_output | |
class CheckerRouter(object): | |
def __init__(self): | |
def display_buttons(): | |
label_checker_button = widgets.Button(description="标签质检") | |
anno_inconsistency_checker_button = widgets.Button( | |
description="一致性质检") | |
reset_button = widgets.Button(description="Reset") | |
display(label_checker_button) | |
display(anno_inconsistency_checker_button) | |
display(reset_button) | |
return label_checker_button, anno_inconsistency_checker_button, reset_button | |
def on_label_checker_button_clicked(b): | |
clear_output() | |
self.__init__() | |
checker = DataAnnoCheckerApp() | |
def on_anno_inconsistency_checker_button_clicked(b): | |
clear_output() | |
self.__init__() | |
checker = AnnoConsistencyCheckerApp() | |
def on_reset_button_clicked(b): | |
clear_output() | |
self.__init__() | |
label_checker_button, anno_inconsistency_checker_button, reset_button = \ | |
display_buttons() | |
label_checker_button.on_click(on_label_checker_button_clicked) | |
anno_inconsistency_checker_button.on_click(on_anno_inconsistency_checker_button_clicked) | |
reset_button.on_click(on_reset_button_clicked) | |
class DataAnnoCheckerApp(object): | |
def __init__(self, in_fp=None, label_col=None, domain=None): | |
self.in_fp = in_fp | |
self.label_col = label_col | |
self.domain = domain | |
confirm_button, reset_button, in_fp_Text, in_domain_button, label_col_IntText = \ | |
self.display_label_checker_user_interface() | |
self.get_msg_from_buttons( | |
confirm_button, reset_button, in_fp_Text, in_domain_button, label_col_IntText) | |
@staticmethod | |
def _get_label2lines_from_in_fp(in_fp=None, label_col=None): | |
label2lines = defaultdict(list) | |
with open(in_fp, "r", encoding="UTF-8") as fin: | |
for index, line in enumerate(fin): | |
line = line.strip() | |
fields = line.split("\t") | |
label = fields[label_col - 1] | |
index = str(index) | |
label2lines[label].append(index) | |
return label2lines | |
@staticmethod | |
def _get_valid_labels(domain): | |
if domain == "新护肤": | |
return ["a"] | |
elif domain == "洗护": | |
return ["b"] | |
elif domain == "新彩妆": | |
return ["c"] | |
else: | |
return [] | |
@staticmethod | |
def _get_customised_labelset(fp): | |
return [i.strip() for i in | |
open(fp, "r", encoding="UTF-8").readlines()] | |
@staticmethod | |
def _check_if_invalid_label(valid_labels: set, label2lines: defaultdict): | |
invalid_label2lines = defaultdict() | |
for label, lines in label2lines.items(): | |
if label in valid_labels: | |
continue | |
invalid_label2lines[label] = " ".join(lines) | |
return invalid_label2lines | |
@staticmethod | |
def _display_invalid_label2lines(invalid_label2lines): | |
display_html("<table stype='width: 100%'> <tr>") | |
for label, lines in invalid_label2lines.items(): | |
display_html("<td><font color='red'>{}</td>".format(label), "h4") | |
display_html("<tr>{}</tr>".format("\t".join(lines.split(" ")))) | |
def display_label_checker_user_interface(self): | |
display_html("1. 选择domain(domain决定了label set)", "h3") | |
in_domain_button = widgets.RadioButtons( | |
options=["新护肤", "洗护", "新彩妆", "待开放/自定义待添加"], | |
style={"description_width": "initial", | |
"description_height": "initial"}) | |
display(in_domain_button) | |
display_html("", "h4") | |
display_html("2. 输入待检查文件的绝对路径", "h3") | |
in_fp_Text = widgets.Text( | |
description="File path", | |
placeholder="Path to file for checking", | |
style={"description_width": "initial", | |
"description_height": "initial"}) | |
display(in_fp_Text) | |
display_html("", "h4") | |
display_html("3. 输入待检查文件中 待检查标签 的列数(从1开始为第一列)", "h3") | |
label_col_IntText = widgets.IntText( | |
value=1, | |
style={"description_width": "initial", | |
"description_height": "initial"}) | |
display(label_col_IntText) | |
box_layout = widgets.Layout(display='flex', | |
flex_flow='column', | |
align_items='center', | |
width='100%') | |
confirm_button = widgets.Button(description="Start checking") | |
reset_button = widgets.Button(description="Reset") | |
box = widgets.HBox(children=[confirm_button, reset_button], layout=box_layout) | |
display(box) | |
return confirm_button, reset_button, in_fp_Text, in_domain_button, label_col_IntText | |
def get_msg_from_buttons(self, confirm_button, reset_button, in_fp_Text, in_domain_button, label_col_IntText): | |
def _get_label_col(sender): | |
if sender["name"] == "value": | |
label_col = sender["new"] | |
else: | |
label_col = None | |
return label_col | |
def _get_input_fp(sender): | |
in_fp = sender.value | |
return in_fp | |
def _get_in_domain(sender): | |
options_tuple = list(sender["owner"].options) | |
try: | |
index = sender["new"] | |
in_domain = options_tuple[index] | |
except TypeError: | |
in_domain = None | |
pass | |
return in_domain | |
in_domain_button.observe(_get_in_domain) | |
in_fp_Text.on_submit(_get_input_fp) | |
label_col_IntText.observe(_get_label_col) | |
self.confirm(confirm_button, in_fp_Text, in_domain_button, | |
label_col_IntText) | |
self.reset(reset_button) | |
def check_label_pipeline_w_buttons(self, confirm_button, reset_button, in_fp_Text, in_domain_button, label_col_IntText): | |
self.confirm(confirm_button, in_fp_Text, in_domain_button, label_col_IntText) | |
self.reset(reset_button) | |
def confirm(self, confirm_button, in_fp_Text, in_domain_button, label_col_IntText): | |
def get_invalid_label2lines(in_fp_Text, in_domain_button, label_col_IntText): | |
if self.in_fp and self.domain and self.label_col: | |
in_fp = self.in_fp | |
in_domain = self.domain | |
label_col = self.label_col | |
else: | |
in_fp = in_fp_Text.value | |
in_domain = in_domain_button.value | |
label_col = label_col_IntText.value | |
valid_labels = self._get_valid_labels(in_domain) | |
label2lines = self._get_label2lines_from_in_fp(in_fp, label_col) | |
invalid_label2lines = self._check_if_invalid_label(valid_labels, | |
label2lines) | |
return invalid_label2lines | |
def display_invalid_label2lines(invalid_label2lines): | |
display_html("", "h4") | |
display_html(">>>>>>>>>>>>>>> RESULT <<<<<<<<<<<<<<<", "h3") | |
self._display_invalid_label2lines(invalid_label2lines) | |
def on_confirm_button_clicked(b): | |
invalid_label2lines = get_invalid_label2lines(in_fp_Text, in_domain_button, label_col_IntText) | |
display_invalid_label2lines(invalid_label2lines) | |
confirm_button.on_click(on_confirm_button_clicked) | |
def reset(self, reset_button): | |
def on_reset_button_clicked(b): | |
clear_output() | |
self.__init__() | |
reset_button.on_click(on_reset_button_clicked) | |
class AnnoConsistencyCheckerApp(object): | |
def __init__(self): | |
confirm_button, reset_button, in_fp_Text, query_col_IntText, label_col_IntText = \ | |
self.display_user_interface() | |
self._get_msg_from_buttons(confirm_button, reset_button, | |
in_fp_Text, query_col_IntText, label_col_IntText) | |
def display_user_interface(self): | |
display_html("1. 输入待检查文件的绝对路径", "h3") | |
in_fp_Text = widgets.Text(description="File path", | |
placeholder="Path to file for checking", | |
style={"description_width": "initial", | |
"description_height": "initial"}) | |
display(in_fp_Text) | |
display_html("", "h4") | |
display_html("2. 输入待检查文件中 待检查query 和 待检查标签 的列数(第一列为1)", "h3") | |
query_col_IntText = widgets.IntText(description="query col", | |
value=1, | |
style={ | |
"description_width": "initial", | |
"description_height": "initial"}) | |
label_col_IntText = widgets.IntText(description="label col", | |
value=2, | |
style={ | |
"description_width": "initial", | |
"description_height": "initial"}) | |
display(query_col_IntText) | |
display(label_col_IntText) | |
box_layout = widgets.Layout(display='flex', | |
flex_flow='column', | |
align_items='center', | |
width='100%') | |
confirm_button = widgets.Button(description="Start checking") | |
reset_button = widgets.Button(description="Reset") | |
box = widgets.HBox(children=[confirm_button, reset_button], | |
layout=box_layout) | |
display(box) | |
return confirm_button, reset_button, in_fp_Text, query_col_IntText, label_col_IntText | |
def confirm(self, confirm_button, in_fp_Text, query_col_IntText, label_col_IntText): | |
def display_anno_inconsistency(query2inconsistent_label_idx): | |
display_html("", "h4") | |
display_html(">>>>>>>>>>>>>>> RESULT <<<<<<<<<<<<<<<", "h3") | |
for query, label_idx in query2inconsistent_label_idx.items(): | |
html_string = '<table style="width:100%"> <tr> <th>Anno</th> <th>Lines</th> </tr>' | |
display_html("<font color='red'>Query: {}".format(query), "h4") | |
for label, idx in label_idx.items(): | |
idx = [str(i) for i in idx] | |
html_string += "<tr> <td>{}</td> <td>{}</td> </tr>".format(label, " ".join(idx)) | |
html_string += "</table>" | |
display_html(html_string) | |
def on_confirm_button_clicked(b): | |
query2inconsistent_label_idx = \ | |
self._check_anno_inconsistency(in_fp_Text, query_col_IntText, label_col_IntText) | |
display_anno_inconsistency(query2inconsistent_label_idx) | |
confirm_button.on_click(on_confirm_button_clicked) | |
def reset(self, reset_button): | |
def on_reset_button_clicked(b): | |
clear_output() | |
self.__init__() | |
reset_button.on_click(on_reset_button_clicked) | |
def _get_msg_from_buttons(self, confirm_button, reset_button, in_fp_Text, query_col_IntText, label_col_IntText): | |
def _get_in_fp(sender): | |
if sender["name"] == "value": | |
in_fp = sender["new"] | |
else: | |
in_fp = None | |
return in_fp | |
def _get_query_col(sender): | |
if sender["name"] == "value": | |
query_col = sender["new"] | |
else: | |
query_col = None | |
return query_col | |
def _get_label_col(sender): | |
if sender["name"] == "value": | |
label_col = sender["new"] | |
else: | |
label_col = None | |
return label_col | |
in_fp = in_fp_Text.on_submit(_get_in_fp) | |
query_col = query_col_IntText.observe(_get_query_col) | |
label_col = label_col_IntText.observe(_get_label_col) | |
self.confirm(confirm_button, in_fp_Text, query_col_IntText, label_col_IntText) | |
self.reset(reset_button) | |
@staticmethod | |
def _check_anno_inconsistency(in_fp_Text, query_col_IntText, label_col_IntText): | |
in_fp = in_fp_Text.value | |
query_col = query_col_IntText.value | |
label_col = label_col_IntText.value | |
query2label_idx = defaultdict(dict) | |
with open(in_fp, "r", encoding="UTF-8") as fin: | |
for idx, line in enumerate(fin): | |
line_index = idx + 1 | |
line = line.strip() | |
fields = line.split("\t") | |
query, label = fields[query_col - 1], fields[label_col - 1] | |
try: | |
query2label_idx[query][label].add(line_index) | |
except: | |
query2label_idx[query][label] = set() | |
query2label_idx[query][label].add(line_index) | |
query2inconsistent_label_idx = defaultdict(dict) | |
for query, label_idx, in query2label_idx.items(): | |
if len(label_idx) <= 1: | |
continue | |
query2inconsistent_label_idx[query] = label_idx | |
return query2inconsistent_label_idx | |
# helper funcs | |
def display_html(content, h=None): | |
if h: | |
return display(widgets.HTML("<{}>{}</{}>".format(h, content, h))) | |
else: | |
return display(widgets.HTML(content)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "007e1b5f5110478b86ccaf58bafe911b", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Button(description='标签质检', style=ButtonStyle())" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "176f404a4f894bfd9ee550821201536a", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Button(description='一致性质检', style=ButtonStyle())" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "f6db48b3dc014e368567c410a7812e90", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Button(description='Reset', style=ButtonStyle())" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"if __name__ == \"__main__\":\n", | |
" import checker_app as app\n", | |
" checker = app.CheckerRouter()" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
aaaa |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment