Skip to content

Instantly share code, notes, and snippets.

@hamukazu
Created August 8, 2014 03:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hamukazu/17c9a22cca0ad2d2a579 to your computer and use it in GitHub Desktop.
Save hamukazu/17c9a22cca0ad2d2a579 to your computer and use it in GitHub Desktop.
diff -u -r cnn_orig/convolutional_mlp.py cnn/convolutional_mlp.py
--- cnn_orig/convolutional_mlp.py 2014-08-08 12:15:47.412611962 +0900
+++ cnn/convolutional_mlp.py 2014-08-08 12:18:02.684614501 +0900
@@ -37,6 +37,18 @@
from logistic_sgd import LogisticRegression, load_data
from mlp import HiddenLayer
+from PIL import Image
+
+def pack(m,n,a):
+ r=numpy.zeros((n*28,m*28))
+ print a.shape,n*m
+ for k in xrange(min(n*m,a.shape[0])):
+ ii=28*(k//m)
+ jj=28*(k%m)
+ for i in xrange(28):
+ r[ii+i,jj:jj+28]=a[k,i*28:i*28+28]
+ return r.reshape(-1)
+
class LeNetConvPoolLayer(object):
"""Pool Layer of a convolutional network """
@@ -129,6 +141,7 @@
train_set_x, train_set_y = datasets[0]
valid_set_x, valid_set_y = datasets[1]
test_set_x, test_set_y = datasets[2]
+ valid_set = datasets[3]
# compute number of minibatches for training, validation and testing
n_train_batches = train_set_x.get_value(borrow=True).shape[0]
@@ -286,6 +299,21 @@
done_looping = True
break
+ validate_model2=theano.function(
+ [index],layer3.failures(y),
+ givens={
+ x: valid_set_x[index * batch_size: (index + 1) * batch_size],
+ y: valid_set_y[index * batch_size: (index + 1) * batch_size]})
+ nonzeros=[]
+ for i in xrange(n_valid_batches):
+ failures=validate_model2(i)
+ nz,=numpy.nonzero(failures)
+ nonzeros+=[j+i*batch_size for j in nz]
+ im=Image.new("1",(28*10,28*10))
+ sampled=nonzeros[:100]
+ im.putdata(1-pack(10,10,valid_set[0][sampled,:]),255.,0.)
+ im.show()
+
end_time = time.clock()
print('Optimization complete.')
print('Best validation score of %f %% obtained at iteration %i,'\
diff -u -r cnn_orig/logistic_sgd.py cnn/logistic_sgd.py
--- cnn_orig/logistic_sgd.py 2014-08-08 12:15:47.412611962 +0900
+++ cnn/logistic_sgd.py 2014-08-08 12:18:52.308615433 +0900
@@ -143,6 +143,8 @@
else:
raise NotImplementedError()
+ def failures(self,y):
+ return T.neq(self.y_pred, y)
def load_data(dataset):
''' Loads the dataset
@@ -212,7 +214,7 @@
train_set_x, train_set_y = shared_dataset(train_set)
rval = [(train_set_x, train_set_y), (valid_set_x, valid_set_y),
- (test_set_x, test_set_y)]
+ (test_set_x, test_set_y), valid_set]
return rval
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment