Skip to content

Instantly share code, notes, and snippets.

@kuenishi
Last active July 31, 2019 09:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kuenishi/46e3527b6c4d749c7a8fc47d30d7cce3 to your computer and use it in GitHub Desktop.
Save kuenishi/46e3527b6c4d749c7a8fc47d30d7cce3 to your computer and use it in GitHub Desktop.
import argparse
import chainer
from chainer import iterators
import chainermn
from chainermn.extensions.multi_node_evaluator import MultiNodeAggregationEvaluator
from chainercv.utils import ProgressHook
from chainercv.utils.iterator.unzip import unzip
from eval_detection import models
from eval_detection import setup
chainer.config.cv_resize_backend = "cv2"
def _flatten(iterator):
return (sample for batch in iterator for sample in batch)
class DetEvaluator(MultiNodeAggregationEvaluator):
def __init__(self, comm, iterator, target, eval_func_all,
hook=None):
super(DetEvaluator, self).__init__(comm, iterator, target,
device=comm.intra_rank,
eval_func=target.predict)
self.hook = hook
self.eval_func_all=eval_func_all
def preprocess(self, batch):
in_values = []
rest_values = []
for sample in batch:
in_values.append(sample[0:1])
rest_values.append(sample[1:])
in_values = tuple(list(v) for v in zip(*in_values))
rest_values = tuple(list(v) for v in zip(*rest_values))
return in_values, rest_values
def postprocess(self, in_values, results, rest_values):
# ProgressBar can be put here
if self.hook:
self.hook(in_values, None, rest_values)
return (results, rest_values)
def aggregate(self, results):
out_values, rest_values = unzip(results)
out_values = tuple(map(_flatten, unzip(out_values)))
rest_values = tuple(map(_flatten, unzip(rest_values)))
self.eval_func_all(out_values, rest_values)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', choices=('voc', 'coco'))
parser.add_argument('--model', choices=sorted(models.keys()))
parser.add_argument('--pretrained-model')
parser.add_argument('--batchsize', type=int)
args = parser.parse_args()
comm = chainermn.create_communicator('pure_nccl')
device = comm.intra_rank
dataset, eval_, model, batchsize = setup(
args.dataset, args.model, args.pretrained_model, args.batchsize)
chainer.cuda.get_device_from_id(device).use()
model.to_gpu()
model.use_preset('evaluate')
hook = ProgressHook(len(dataset))
dataset = chainermn.scatter_dataset(dataset, comm, force_equal_length=False)
print('after scatter:', len(dataset), comm.rank)
iterator = iterators.MultithreadIterator(
dataset, batchsize, repeat=False, shuffle=False)
evaluator = DetEvaluator(comm, iterator, model, eval_, hook=hook)
evaluator(None)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment