Skip to content

Instantly share code, notes, and snippets.

@haje01
Created July 1, 2015 03:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save haje01/d268b745acd532849722 to your computer and use it in GitHub Desktop.
Save haje01/d268b745acd532849722 to your computer and use it in GitHub Desktop.
Caffe 파이썬 테스트를 위해 수정할 것
diff --git a/python/caffe/io.py b/python/caffe/io.py
index fc96266..02b2ffb 100644
--- a/python/caffe/io.py
+++ b/python/caffe/io.py
@@ -251,9 +251,13 @@ class Transformer:
ms = (1,) + ms
if len(ms) != 3:
raise ValueError('Mean shape invalid')
- if ms != self.inputs[in_][1:]:
- raise ValueError('Mean shape incompatible with input shape.')
- self.mean[in_] = mean
+ if ms != self.inputs[in_] :
+ print(self.inputs[in_])
+ in_shape = self.inputs[in_][1:]
+ m_min, m_max = mean.min(), mean.max()
+ normal_mean = (mean - m_min) / (m_max - m_min)
+ mean = resize_image(normal_mean.transpose((1,2,0)), in_shape[1:]).transpose((2,0,1)) * (m_max - m_min) + m_min
+ self.mean[in_] = mean
def set_input_scale(self, in_, scale):
"""
diff --git a/python/classify.py b/python/classify.py
index 4544c51..a5b1540 100755
--- a/python/classify.py
+++ b/python/classify.py
@@ -12,6 +12,7 @@ import glob
import time
import caffe
+import pandas as pd
def main(argv):
@@ -86,6 +87,17 @@ def main(argv):
help="Image file extension to take as input when a directory " +
"is given as the input file."
)
+ parser.add_argument(
+ "--print_results",
+ action='store_true',
+ help="Write output text to stdout rather than serializing to a file."
+ )
+ parser.add_argument(
+ "--labels_file",
+ default=os.path.join(pycaffe_dir,"../data/ilsvrc12/synset_words.txt"),
+ help="Readable label definition file."
+ )
+
args = parser.parse_args()
image_dims = [int(s) for s in args.images_dim.split(',')]
@@ -126,13 +138,25 @@ def main(argv):
# Classify.
start = time.time()
+ scores = classifier.predict(inputs, not args.center_only).flatten()
predictions = classifier.predict(inputs, not args.center_only)
print("Done in %.2f s." % (time.time() - start))
+ if args.print_results:
+ with open(args.labels_file) as f:
+ labels_df = pd.DataFrame([{'synset_id':l.strip().split(' ')[0], 'name': ' '.join(l.strip().split(' ')[1:]).split(',')[0]} for l in f.readlines()])
+ labels = labels_df.sort('synset_id')['name'].values
+
+ indices =(-scores).argsort()[:5]
+ predictions = labels[indices]
+
+ meta = [(p, '%.5f' % scores[i]) for i,p in zip(indices, predictions)]
+ print meta
+
# Save
print("Saving results into %s" % args.output_file)
np.save(args.output_file, predictions)
-
+
if __name__ == '__main__':
main(sys.argv)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment