Skip to content

Instantly share code, notes, and snippets.

@keimina
Created February 21, 2021 17:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save keimina/4894ba2670d25aa95c7263d844666455 to your computer and use it in GitHub Desktop.
Save keimina/4894ba2670d25aa95c7263d844666455 to your computer and use it in GitHub Desktop.
import numpy as np
import pandas as pd
from itertools import product
def is_not_join(i):
if i == "T":
ret = True
elif i == "F":
ret = False
elif i == "":
ret = ""
else:
ret = ""
return ret
class Lbl():
# ラベル作成用のクラス
# start_lbl : 最初のラベルの番号
# | arr[i-1] | arr[i] | lbl[i] |
# |----------+--------+---------|
# | T | T | lbl + 1 |
# | T | F | lbl + 1 |
# | T | N | 0 |
# | F | T | lbl |
# | F | F | lbl |
# | F | N | 0 |
# | N | T | lbl + 1 |
# | N | F | lbl + 1 |
# | N | N | 0 |
def __init__(self, start_lbl):
self.lbl = start_lbl
self.lbl_func = {
(True, True): self.get_lbl_plus_one,
(True, False): self.get_lbl_plus_one,
(True, ''): self.get_lbl_zero,
(False, True): self.get_lbl_plus_zero,
(False, False): self.get_lbl_plus_zero,
(False, ''): self.get_lbl_zero,
('', True): self.get_lbl_plus_one,
('', False): self.get_lbl_plus_one,
('', ''): self.get_lbl_zero
}
self.has_lbl = False
def get_lbl(self, left, right):
left = is_not_join(left)
right = is_not_join(right)
return self.lbl_func[(left, right)](left, right)
def get_lbl_plus_one(self, left, right):
if self.has_lbl:
self.lbl += 1
self.has_lbl = True
return self.lbl
def get_lbl_plus_zero(self, left, right):
self.has_lbl = True
return self.lbl
def get_lbl_zero(self, left, right):
return 0
def create_lbl_1d(lst, start_lbl):
# lst : 文字列のリスト
# start_lbl : 最初のラベル番号
# return labeled list
out_lst = [None] * len(lst)
lbl = Lbl(start_lbl)
first = True
i = 0
while i < len(lst):
if first:
left = ""
right = lst[i]
first = False
else:
left = lst[i - 1]
right = lst[i]
out_lst[i] = lbl.get_lbl(left, right)
i += 1
return out_lst
def label_2d(lst_2d):
lst_2d = np.array(lst_2d).T
arr_2d = np.zeros_like(lst_2d, dtype=int)
for n, lst in enumerate(lst_2d):
arr_2d[n] = create_lbl_1d(lst, arr_2d.max() + 1)
return arr_2d.T
################ TEST ################
lst = list(
product(["T", "F", ""], ["T", "F", ""], ["T", "F", ""], ["T", "F", ""]))
out_lst = []
out_lst_2 = []
for i in lst:
out = create_lbl_1d(i, 1)
out_2 = create_lbl_1d(i, 5)
out_lst.append(out)
out_lst_2.append(out_2)
print(i, " -> ", out, out_2)
a = np.array(out_lst)
a[a > 0] += 4
b = np.array(out_lst_2)
print((a == b).all())
################ TEST ################
a = np.array(lst)
print(label_2d(a))
################ TEST ################
a = np.array([[]])
print(label_2d(a))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment