Created
July 1, 2015 03:32
-
-
Save haje01/d268b745acd532849722 to your computer and use it in GitHub Desktop.
Caffe 파이썬 테스트를 위해 수정할 것
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/python/caffe/io.py b/python/caffe/io.py | |
index fc96266..02b2ffb 100644 | |
--- a/python/caffe/io.py | |
+++ b/python/caffe/io.py | |
@@ -251,9 +251,13 @@ class Transformer: | |
ms = (1,) + ms | |
if len(ms) != 3: | |
raise ValueError('Mean shape invalid') | |
- if ms != self.inputs[in_][1:]: | |
- raise ValueError('Mean shape incompatible with input shape.') | |
- self.mean[in_] = mean | |
+ if ms != self.inputs[in_] : | |
+ print(self.inputs[in_]) | |
+ in_shape = self.inputs[in_][1:] | |
+ m_min, m_max = mean.min(), mean.max() | |
+ normal_mean = (mean - m_min) / (m_max - m_min) | |
+ mean = resize_image(normal_mean.transpose((1,2,0)), in_shape[1:]).transpose((2,0,1)) * (m_max - m_min) + m_min | |
+ self.mean[in_] = mean | |
def set_input_scale(self, in_, scale): | |
""" | |
diff --git a/python/classify.py b/python/classify.py | |
index 4544c51..a5b1540 100755 | |
--- a/python/classify.py | |
+++ b/python/classify.py | |
@@ -12,6 +12,7 @@ import glob | |
import time | |
import caffe | |
+import pandas as pd | |
def main(argv): | |
@@ -86,6 +87,17 @@ def main(argv): | |
help="Image file extension to take as input when a directory " + | |
"is given as the input file." | |
) | |
+ parser.add_argument( | |
+ "--print_results", | |
+ action='store_true', | |
+ help="Write output text to stdout rather than serializing to a file." | |
+ ) | |
+ parser.add_argument( | |
+ "--labels_file", | |
+ default=os.path.join(pycaffe_dir,"../data/ilsvrc12/synset_words.txt"), | |
+ help="Readable label definition file." | |
+ ) | |
+ | |
args = parser.parse_args() | |
image_dims = [int(s) for s in args.images_dim.split(',')] | |
@@ -126,13 +138,25 @@ def main(argv): | |
# Classify. | |
start = time.time() | |
+ scores = classifier.predict(inputs, not args.center_only).flatten() | |
predictions = classifier.predict(inputs, not args.center_only) | |
print("Done in %.2f s." % (time.time() - start)) | |
+ if args.print_results: | |
+ with open(args.labels_file) as f: | |
+ labels_df = pd.DataFrame([{'synset_id':l.strip().split(' ')[0], 'name': ' '.join(l.strip().split(' ')[1:]).split(',')[0]} for l in f.readlines()]) | |
+ labels = labels_df.sort('synset_id')['name'].values | |
+ | |
+ indices =(-scores).argsort()[:5] | |
+ predictions = labels[indices] | |
+ | |
+ meta = [(p, '%.5f' % scores[i]) for i,p in zip(indices, predictions)] | |
+ print meta | |
+ | |
# Save | |
print("Saving results into %s" % args.output_file) | |
np.save(args.output_file, predictions) | |
- | |
+ | |
if __name__ == '__main__': | |
main(sys.argv) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment