Skip to content

Instantly share code, notes, and snippets.

@catalystfrank
Created February 3, 2017 01:10
Show Gist options
  • Save catalystfrank/782bab5e51356b70fa4aae361cbd47cc to your computer and use it in GitHub Desktop.
Save catalystfrank/782bab5e51356b70fa4aae361cbd47cc to your computer and use it in GitHub Desktop.
Yet Another Accuracy Func For LSTM+CTC (in mxnet/example/warpctc)
# When processing Sequence Tagging problems,
# An accuracy func which is too strict is not convergence-friendly.
# This Func is the solution of a naive leetcode problem LCS (Largest Common Subsequence),
# Original All-Correct-Or-Nothing accuracy function takes 7x times long to achieve certain accuracy.
def LCS(p,l):
if len(p)==0:
return 0
P = np.array(list(p)).reshape((1,len(p)))
L = np.array(list(l)).reshape((len(l),1))
M = np.int32(P==L)
for i in range(M.shape[0]):
for j in range(M.shape[1]):
up = 0 if i==0 else M[i-1,j]
left = 0 if j==0 else M[i,j-1]
M[i,j] = max(up,left, M[i,j] if (i==0 or j==0) else M[i,j]+M[i-1,j-1])
return M.max()
def Accuracy_LCS(label, pred):
global BATCH_SIZE
global SEQ_LENGTH
#global CTX_LENGTH
hit = 0.
total = 0.
for i in range(BATCH_SIZE):
l = remove_blank(label[i])
p = []
for k in range(SEQ_LENGTH):
p.append(np.argmax(pred[k * BATCH_SIZE + i]))
p = ctc_label(p)
## Dynamic Programming Finding LCS
hit += LCS(p,l)*1.0/len(l)
total += 1.0
return hit / total
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment