Skip to content

Instantly share code, notes, and snippets.

@kuenishi
Created September 3, 2019 09:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kuenishi/3c9420ee01af86edcd77d73524fb1435 to your computer and use it in GitHub Desktop.
Save kuenishi/3c9420ee01af86edcd77d73524fb1435 to your computer and use it in GitHub Desktop.
import argparse
import copy
import chainer
from chainer import iterators
from chainer import function
import chainermn
from chainermn.extensions import GenericMultiNodeEvaluator
from chainercv.utils import ProgressHook
from chainercv.utils.iterator.unzip import unzip
from eval_detection import models
from eval_detection import setup
chainer.config.cv_resize_backend = "cv2"
class DetMNEvaluator(GenericMultiNodeEvaluator):
def __init__(self, eval, *args, **kwargs):
super().__init__(*args, **kwargs)
self.eval = eval
def calc_local(self, batch):
# print(self.comm.rank, self.counter)
in_values = []
rest_values = []
for sample in batch:
in_values.append(sample[0:1])
rest_values.append(sample[1:])
in_values = tuple(list(v) for v in zip(*in_values))
rest_values = tuple(list(v) for v in zip(*rest_values))
model = self._targets['main']
out_values = model.predict(*in_values)
return (out_values, rest_values)
def aggregate(self, results):
out_values, rest_values = unzip(results)
out_values = tuple(map(_flatten, unzip(out_values)))
rest_values = tuple(map(_flatten, unzip(rest_values)))
self.eval(out_values, rest_values)
def _flatten(iterator):
return (sample for batch in iterator for sample in batch)
def _noop_convert(batch, device):
return batch
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', choices=('voc', 'coco'))
parser.add_argument('--model', choices=sorted(models.keys()))
parser.add_argument('--pretrained-model')
parser.add_argument('--batchsize', type=int)
args = parser.parse_args()
comm = chainermn.create_communicator('pure_nccl')
device = comm.intra_rank
dataset, eval_, model, batchsize = setup(
args.dataset, args.model, args.pretrained_model, args.batchsize)
chainer.cuda.get_device_from_id(device).use()
model.to_gpu()
model.use_preset('evaluate')
dataset = chainermn.scatter_dataset(dataset, comm, force_equal_length=False)
_hook = ProgressHook(len(dataset))
def hook(batch):
_hook([batch], None, None)
iterator = iterators.MultithreadIterator(
dataset, batchsize, repeat=False, shuffle=False)
evaluator = DetMNEvaluator(eval_, comm, iterator, model,
converter=_noop_convert,
progress_hook=hook)
evaluator.initialize()
evaluator(None)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment