-
-
Save sld/5cb766398471c63903f3 to your computer and use it in GitHub Desktop.
Caffe-Additionals
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import numpy as np | |
import os | |
import sys | |
import cv2 | |
caffe_python = '/opt/caffe/python' | |
sys.path.insert(0, caffe_python) | |
import caffe | |
def calculate(basepath, shape=(3, 32, 32), labels = range(0, 10)): | |
mean_image = np.zeros(shape) | |
count = 0 | |
for label in labels: | |
path = "{0}/{1}".format(basepath, label) | |
for filename in os.listdir(path): | |
if filename[0] != '.': | |
filepath = "{0}/{1}".format(path, filename) | |
print(filepath) | |
count += 1 | |
bgr_img = cv2.imread(filepath) | |
rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB) | |
mean_image += rgb_img.T | |
mean_image = mean_image / count | |
mean_image = mean_image.reshape(1, 3, 32, 32) | |
return mean_image | |
def convert_to_binaryproto(mean, filename): | |
blob = caffe.io.array_to_blobproto(mean).SerializeToString() | |
with open(filename, 'wb') as f: | |
f.write(blob) | |
if __name__ == '__main__': | |
if len(sys.argv) < 3: | |
sys.exit("Pass basepath to train images and binaryproto save filename") | |
mean = calculate(sys.argv[1]) | |
sys.exit(convert_to_binaryproto(mean, sys.argv[2])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import sys | |
import csv | |
import os | |
from sys import argv | |
from os.path import basename | |
caffe_python = '/opt/caffe/python' | |
sys.path.insert(0, caffe_python) | |
import caffe | |
class Net: | |
def __init__(self, model_file, pretrained, mean, cpu=True, | |
image_dims=(32, 32), raw_scale=255, channel_swap=(2, 1, 0)): | |
self.net = caffe.Classifier(model_file, pretrained, | |
image_dims=image_dims, | |
raw_scale=raw_scale, | |
channel_swap=channel_swap, | |
mean=mean) | |
if cpu: | |
self.net.set_mode_cpu() | |
else: | |
self.net.set_mode_gpu() | |
self.net.set_phase_test() | |
def predict(self, caffe_images): | |
predictions = self.net.predict(caffe_images) | |
return [predict.argmax() for predict in predictions] | |
def predictions(self, images): | |
predicts = [] | |
count = len(images) | |
i = 1 | |
for img_batch in chunks(images, 10): | |
caffe_imgs = [caffe.io.load_image(img) for img in img_batch] | |
labels = self.predict(caffe_imgs) | |
print("{0}/{1}".format(i*10, count)) | |
i += 1 | |
img_names = [basename(img).split('.')[0] for img in img_batch] | |
predicts += zip(img_names, labels) | |
return sorted(predicts, reverse=False, key=lambda x: int(x[0])) | |
def chunks(l, n): | |
for i in xrange(0, len(l), n): | |
yield l[i:i+n] | |
def save_to_file(filename, arr): | |
with open(filename, 'w') as f: | |
writer = csv.writer(f, delimiter=',') | |
writer.writerows(arr) | |
def kaggle(model_file, pretrained, mean, imgs_basepath, save_filename): | |
mean = np.load(mean)[0] | |
net = Net(model_file, pretrained, mean) | |
basepath = "{0}/test_32x32".format(imgs_basepath) | |
image_files = ["{0}/{1}".format(basepath, filename) for filename in os.listdir(basepath)] | |
result = [['Image_Name', 'Digit']] + net.predictions(image_files) | |
save_to_file(save_filename, result) | |
# 40000 | |
# model_file='winny_deploy.prototxt', pretrained='iter_51000.caffemodel', | |
# mean='mean.npy', imgs_basepath=/opt/SVHN | |
if __name__ == '__main__': | |
model_file = argv[1] | |
pretrained = argv[2] | |
mean = argv[3] | |
imgs_basepath = argv[4] | |
save_filename = argv[5] | |
kaggle(model_file, pretrained, mean, imgs_basepath, save_filename) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/home/deploy/train_32x32/0/10005.png 0 | |
/home/deploy/train_32x32/1/10024.png 1 | |
/home/deploy/train_32x32/2/10026.png 2 | |
/home/deploy/train_32x32/1/1003.png 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment