Created
February 21, 2021 17:01
-
-
Save keimina/4894ba2670d25aa95c7263d844666455 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
from itertools import product | |
def is_not_join(i): | |
if i == "T": | |
ret = True | |
elif i == "F": | |
ret = False | |
elif i == "": | |
ret = "" | |
else: | |
ret = "" | |
return ret | |
class Lbl(): | |
# ラベル作成用のクラス | |
# start_lbl : 最初のラベルの番号 | |
# | arr[i-1] | arr[i] | lbl[i] | | |
# |----------+--------+---------| | |
# | T | T | lbl + 1 | | |
# | T | F | lbl + 1 | | |
# | T | N | 0 | | |
# | F | T | lbl | | |
# | F | F | lbl | | |
# | F | N | 0 | | |
# | N | T | lbl + 1 | | |
# | N | F | lbl + 1 | | |
# | N | N | 0 | | |
def __init__(self, start_lbl): | |
self.lbl = start_lbl | |
self.lbl_func = { | |
(True, True): self.get_lbl_plus_one, | |
(True, False): self.get_lbl_plus_one, | |
(True, ''): self.get_lbl_zero, | |
(False, True): self.get_lbl_plus_zero, | |
(False, False): self.get_lbl_plus_zero, | |
(False, ''): self.get_lbl_zero, | |
('', True): self.get_lbl_plus_one, | |
('', False): self.get_lbl_plus_one, | |
('', ''): self.get_lbl_zero | |
} | |
self.has_lbl = False | |
def get_lbl(self, left, right): | |
left = is_not_join(left) | |
right = is_not_join(right) | |
return self.lbl_func[(left, right)](left, right) | |
def get_lbl_plus_one(self, left, right): | |
if self.has_lbl: | |
self.lbl += 1 | |
self.has_lbl = True | |
return self.lbl | |
def get_lbl_plus_zero(self, left, right): | |
self.has_lbl = True | |
return self.lbl | |
def get_lbl_zero(self, left, right): | |
return 0 | |
def create_lbl_1d(lst, start_lbl): | |
# lst : 文字列のリスト | |
# start_lbl : 最初のラベル番号 | |
# return labeled list | |
out_lst = [None] * len(lst) | |
lbl = Lbl(start_lbl) | |
first = True | |
i = 0 | |
while i < len(lst): | |
if first: | |
left = "" | |
right = lst[i] | |
first = False | |
else: | |
left = lst[i - 1] | |
right = lst[i] | |
out_lst[i] = lbl.get_lbl(left, right) | |
i += 1 | |
return out_lst | |
def label_2d(lst_2d): | |
lst_2d = np.array(lst_2d).T | |
arr_2d = np.zeros_like(lst_2d, dtype=int) | |
for n, lst in enumerate(lst_2d): | |
arr_2d[n] = create_lbl_1d(lst, arr_2d.max() + 1) | |
return arr_2d.T | |
################ TEST ################ | |
lst = list( | |
product(["T", "F", ""], ["T", "F", ""], ["T", "F", ""], ["T", "F", ""])) | |
out_lst = [] | |
out_lst_2 = [] | |
for i in lst: | |
out = create_lbl_1d(i, 1) | |
out_2 = create_lbl_1d(i, 5) | |
out_lst.append(out) | |
out_lst_2.append(out_2) | |
print(i, " -> ", out, out_2) | |
a = np.array(out_lst) | |
a[a > 0] += 4 | |
b = np.array(out_lst_2) | |
print((a == b).all()) | |
################ TEST ################ | |
a = np.array(lst) | |
print(label_2d(a)) | |
################ TEST ################ | |
a = np.array([[]]) | |
print(label_2d(a)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment