Skip to content

Instantly share code, notes, and snippets.

@Abirdcfly
Created July 9, 2018 09:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Abirdcfly/cdc296f8aab76613aa905d454da233ba to your computer and use it in GitHub Desktop.
Save Abirdcfly/cdc296f8aab76613aa905d454da233ba to your computer and use it in GitHub Desktop.
# coding=utf-8
prob_start = []
prob_trans = []
prob_emit = []
MIN_FLOAT = -3.14e100
STATE = {0: "B", 1: "E", 2: "M", 3: "S"}
def load_dicts(filename="hmm_model.utf8"):
global prob_start, prob_trans
with open(filename, "rb") as f:
lines = [i.strip() for i in f if i.strip() and not i.startswith("#")]
prob_start = [float(i) for i in lines[0].split()]
for k in range(1, 5):
prob_trans.append([float(i) for i in lines[k].split()])
for k in range(5, 9):
prob_emit.append({k: float(v) for k, v in [item.split(":") for item in lines[k].split(",")]})
def viterbi(obs):
weight = [[0] * len(obs) for _ in range(4)]
path = [[-1] * len(obs) for _ in range(4)]
for i in range(4):
weight[i][0] = prob_start[i] + prob_emit[i].get(obs[0].encode("utf8"), MIN_FLOAT)
for i in range(1, len(obs)):
for j in range(4):
weight[j][i], path[j][i] = max(
[(weight[k][i - 1] + prob_trans[k][j] + prob_emit[j].get(obs[i].encode("utf8"), MIN_FLOAT), k) for k in
range(4)])
paths = []
t = 3 if weight[1][-1] < weight[3][-1] else 1
paths.append(3)
for i in range(len(obs) - 1, -1, -1):
sy = path[t][i]
paths.append(sy)
t = sy
return "".join([STATE[i] for i in paths[-2::-1]])
def add_slash(sentence, pos):
print(pos)
return "".join(
[i + "/" if (pos[index] == "E" or pos[index] == "S") and index + 1 != len(sentence) else i for index, i in
enumerate(sentence)])
if __name__ == '__main__':
load_dicts(filename="./hmm_model.utf8") # https://github.com/yanyiwu/cppjieba/blob/master/dict/hmm_model.utf8
s = u"小明硕士毕业于中国科学院计算所,研究人工智能"
import re
pattern = r',|\.|/|;|\'|`|\[|\]|<|>|\?|:|"|\{|\}|\~|!|@|#|\$|%|\^|&|\(|\)|-|=|\_|\+|,|。|、|;|‘|’|【|】|·|!| |…|(|)'
for s1 in re.split(pattern, s):
print(add_slash(s1, viterbi(s1)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment