Skip to content

Instantly share code, notes, and snippets.

@fallthrough
Last active September 17, 2018 10:40
Show Gist options
  • Save fallthrough/92399e5fd4f70fecdad5 to your computer and use it in GitHub Desktop.
Save fallthrough/92399e5fd4f70fecdad5 to your computer and use it in GitHub Desktop.
Caffe Introduction: Codes
#!/usr/bin/python
# Copyright 2014 SIG2D
# Licensed under the Apache License, Version 2.0.
import os
import shutil
import subprocess
import sys
from caffe.proto import caffe_pb2
import leveldb
import numpy as np
import PIL.Image
import random
THUMBNAIL_SIZE = 32
def make_thumbnail(image):
image = image.convert('RGB')
square_size = min(image.size)
offset_x = (image.size[0] - square_size) / 2
offset_y = (image.size[1] - square_size) / 2
image = image.crop((offset_x, offset_y,
offset_x + square_size, offset_y + square_size))
image.thumbnail((THUMBNAIL_SIZE, THUMBNAIL_SIZE), PIL.Image.ANTIALIAS)
return image
def make_datum(thumbnail, label):
return caffe_pb2.Datum(
channels=3,
width=THUMBNAIL_SIZE,
height=THUMBNAIL_SIZE,
label=label,
data=np.rollaxis(np.asarray(thumbnail), 2).tostring())
def create_leveldb(name):
path = os.path.join(os.environ['HOME'], 'caffe/examples/cifar10', name)
try:
shutil.rmtree(path)
except OSError:
pass
print 'opening', path
return leveldb.LevelDB(
path, create_if_missing=True, error_if_exists=True, paranoid_checks=True)
def main():
train_db = create_leveldb('cifar10_train_leveldb')
test_db = create_leveldb('cifar10_test_leveldb')
filepath_and_label = []
for dirpath, _, filenames in os.walk('.'):
try:
label = int(dirpath.split('/')[1])
except Exception:
continue
for filename in filenames:
if filename.endswith(('.png', '.jpg')):
filepath_and_label.append((os.path.join(dirpath, filename), label))
random.shuffle(filepath_and_label)
for seq, (filepath, label) in enumerate(filepath_and_label):
print seq, label, filepath
image = PIL.Image.open(filepath)
thumbnail = make_thumbnail(image)
datum = make_datum(thumbnail, label)
db = test_db if seq % 10 == 0 else train_db
db.Put('%08d' % seq, datum.SerializeToString())
if __name__ == '__main__':
sys.exit(main())
<!doctype html>
<html>
<head>
<meta charset="UTF-8">
<!--
Copyright 2014 SIG2D
Licensed under the Apache License, Version 2.0.
-->
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>Online Kin-patsu Image Classifier</title>
<link href="//netdna.bootstrapcdn.com/bootstrap/3.1.1/css/bootstrap.min.css" rel="stylesheet">
<!--[if lt IE 9]>
<script src="https://oss.maxcdn.com/libs/html5shiv/3.7.0/html5shiv.js"></script>
<script src="https://oss.maxcdn.com/libs/respond.js/1.4.2/respond.min.js"></script>
<![endif]-->
<style>
.result {
display: inline-block;
margin: 2px;
text-align: center;
}
</style>
</head>
<body>
<div class="container-fluid">
<div class="page-header">
<h1>Online Kin-patsu Image Classifier</h1>
</div>
<div class="row">
<div class="col-lg-12">
<div class="well well-sm">
Drag &amp; drop image files here to classify.<br>
API is available: try <code>curl -F file=@image.jpg http://kinpatsu.api.sig2d.org/</code>
</div>
</div>
</div>
<div class="row">
<div class="col-lg-12">
<div id="results">
</div>
</div>
</div>
</div>
<script src="https://ajax.googleapis.com/ajax/libs/jquery/1.11.0/jquery.min.js"></script>
<script src="//netdna.bootstrapcdn.com/bootstrap/3.1.1/js/bootstrap.min.js"></script>
<script>
var queue = [];
var processing = false;
var queueFile = function(file) {
queue.push(file);
if (!processing) {
processNext();
}
};
var processNext = function() {
if (queue.length == 0) {
processing = false;
return;
}
processing = true;
var file = queue.shift();
var formData = new FormData();
formData.append('file', file);
$.ajax({
url: '',
type: 'POST',
data: formData,
contentType: false,
processData: false
}).always(processNext).then(function(data) {
var reader = new FileReader();
reader.onload = function() {
var img = $('<img>');
img.on('load', function() {
var width = img[0].width;
var height = img[0].height;
var zoom = 128 / Math.max(width, height);
var displayWidth = Math.floor(width * zoom);
var displayHeight = Math.floor(height * zoom);
var label = data['result']['label'];
var score = data['result']['score'];
var percent = Math.floor(100 * score) + '%';
img.attr('title', score)
.css('width', displayWidth + 'px')
.css('height', displayHeight + 'px')
.css('margin',
((128 - displayHeight) / 2 + 8) + 'px ' +
((128 - displayWidth) / 2 + 8) + 'px');
var cell = $('<div class="result">');
cell.append($('<div>').append(img));
cell.append($('<div>').text(label));
cell.append($('<div>').text(percent));
if (label == 'blonde_hair') {
cell.css('background-color', '#fe0');
}
if (label == 'silver_hair') {
cell.css('background-color', '#ddd');
}
$('#results').append(cell);
});
img.attr('src', reader.result);
};
reader.readAsDataURL(file);
});
};
$(document).on('dragover', function(e) {
e.preventDefault();
return false;
}).on('dragend', function(e) {
e.preventDefault();
return false;
}).on('drop', function(e) {
e.preventDefault();
var files = e.originalEvent.dataTransfer.files;
for (var i = 0; i < files.length; ++i) {
queueFile(files[i]);
}
return false;
});
</script>
</body>
</html>
#!/usr/bin/python
# Copyright 2014 SIG2D
# Licensed under the Apache License, Version 2.0.
import sys
import bottle
import caffe
from caffe.proto import caffe_pb2
import numpy as np
@bottle.get('/')
def index_handler():
return bottle.template('frontend.html', {})
@bottle.post('/')
def classify_handler():
upload = bottle.request.files.values()[0]
image = caffe.io.load_image(upload.file)
prediction = g_classifier.predict([image], oversample=False)[0]
clazz, score = max(enumerate(prediction.tolist()), key=lambda (i, p): p)
label = g_labels[clazz]
return {'result': {'label': label, 'score': score}}
def main():
mean_blob = caffe_pb2.BlobProto()
with open('mean.binaryproto') as f:
mean_blob.ParseFromString(f.read())
mean_array = np.asarray(mean_blob.data, dtype=np.float32).reshape(
(mean_blob.channels, mean_blob.height, mean_blob.width))
global g_classifier
g_classifier = caffe.Classifier(
'cifar10_quick.prototxt',
'cifar10_quick_iter_5000.caffemodel',
mean=mean_array,
raw_scale=255)
global g_labels
with open('labels.txt') as f:
g_labels = f.read().decode('utf-8').splitlines()
bottle.debug()
bottle.run(host='0.0.0.0', port=8080)
if __name__ == '__main__':
sys.exit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment