Skip to content

Instantly share code, notes, and snippets.

@oiehot
Last active February 19, 2017 04:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save oiehot/3da4da64caa07142eb41301ea197eed3 to your computer and use it in GitHub Desktop.
Save oiehot/3da4da64caa07142eb41301ea197eed3 to your computer and use it in GitHub Desktop.
평가함수, 평균제곱오차 & 교차엔트로피오차
import sys, os
import numpy as np
# 정답표, One hot label
t = np.array([
[0,0,1,0,0,0,0,0,0,0], # 2
[0,0,0,1,0,0,0,0,0,0], # 3
[0,0,0,0,0,1,0,0,0,0], # 5
[1,0,0,0,0,0,0,0,0,0], # 0
[0,1,0,0,0,0,0,0,0,0] # 1
])
# 정답일 확률
y = np.array([
[.0, .0, .8, .0, .0, .0, .0, .0, .0, .0],
[.0, .0, .0, .5, .0, .0, .0, .0, .0, .0],
[.0, .0, .0, .0, .0, .3, .0, .0, .0, .0],
[.7, .0, .0, .0, .0, .0, .0, .0, .0, .0],
[.0, .2, .0, .0, .0, .0, .0, .0, .0, .0]
])
# 평균 제곱 오차mean squared error, MSE
def meanSquaredError(y, t):
return 0.5 * np.sum( (y-t)**2 )
# 교차 엔트로피 오차cross entropy error, CEE
def crossEntropyError(y, t):
return -np.sum( t * np.log(y+1e-7) )
def crossEntropyError_total(y, t):
total_size = y.shape[0]
if y.ndim == 1: # 항목이 하나인 1차원 배열인 경우, 2차원 배열로 변경.
y.reshape(1, total_size)
t.reshape(1, total_size)
return -np.sum( t * np.log(y+1e-7) ) / total_size
def crossEntropyError_batch(y, t, batch_size):
total_size = y.shape[0]
if batch_size > total_size:
batch_size = total_size
if y.ndim == 1:
y.reshape(1, total_size)
t.reshape(1, total_size)
batch_y = y[ np.arange(batch_size) ]
batch_t = t[ np.arange(batch_size) ]
return -np.sum( batch_t * np.log(batch_y+1e-7) ) / batch_size
print( crossEntropyError(y, t) )
print( crossEntropyError_total(y, t) )
print( crossEntropyError_batch(y, t, 3) )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment