Created
February 3, 2017 01:10
-
-
Save catalystfrank/782bab5e51356b70fa4aae361cbd47cc to your computer and use it in GitHub Desktop.
Yet Another Accuracy Func For LSTM+CTC (in mxnet/example/warpctc)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# When processing Sequence Tagging problems, | |
# An accuracy func which is too strict is not convergence-friendly. | |
# This Func is the solution of a naive leetcode problem LCS (Largest Common Subsequence), | |
# Original All-Correct-Or-Nothing accuracy function takes 7x times long to achieve certain accuracy. | |
def LCS(p,l): | |
if len(p)==0: | |
return 0 | |
P = np.array(list(p)).reshape((1,len(p))) | |
L = np.array(list(l)).reshape((len(l),1)) | |
M = np.int32(P==L) | |
for i in range(M.shape[0]): | |
for j in range(M.shape[1]): | |
up = 0 if i==0 else M[i-1,j] | |
left = 0 if j==0 else M[i,j-1] | |
M[i,j] = max(up,left, M[i,j] if (i==0 or j==0) else M[i,j]+M[i-1,j-1]) | |
return M.max() | |
def Accuracy_LCS(label, pred): | |
global BATCH_SIZE | |
global SEQ_LENGTH | |
#global CTX_LENGTH | |
hit = 0. | |
total = 0. | |
for i in range(BATCH_SIZE): | |
l = remove_blank(label[i]) | |
p = [] | |
for k in range(SEQ_LENGTH): | |
p.append(np.argmax(pred[k * BATCH_SIZE + i])) | |
p = ctc_label(p) | |
## Dynamic Programming Finding LCS | |
hit += LCS(p,l)*1.0/len(l) | |
total += 1.0 | |
return hit / total |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment