Skip to content

Instantly share code, notes, and snippets.

@kuenishi
Last active August 1, 2019 05:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kuenishi/fa20a88881beffd3cccbd6c15d62522e to your computer and use it in GitHub Desktop.
Save kuenishi/fa20a88881beffd3cccbd6c15d62522e to your computer and use it in GitHub Desktop.
import argparse
import copy
import chainer
from chainer import iterators
from chainer import function
import chainermn
from chainermn.extensions.multi_node_evaluator import GatherEvaluator
from chainercv.utils import ProgressHook
from chainercv.utils.iterator.unzip import unzip
from eval_detection import models
from eval_detection import setup
chainer.config.cv_resize_backend = "cv2"
def _flatten(iterator):
return (sample for batch in iterator for sample in batch)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', choices=('voc', 'coco'))
parser.add_argument('--model', choices=sorted(models.keys()))
parser.add_argument('--pretrained-model')
parser.add_argument('--batchsize', type=int)
args = parser.parse_args()
comm = chainermn.create_communicator('pure_nccl')
device = comm.intra_rank
dataset, eval_, model, batchsize = setup(
args.dataset, args.model, args.pretrained_model, args.batchsize)
chainer.cuda.get_device_from_id(device).use()
model.to_gpu()
model.use_preset('evaluate')
dataset = chainermn.scatter_dataset(dataset, comm, force_equal_length=False)
hook = ProgressHook(len(dataset))
iterator = iterators.MultithreadIterator(
dataset, batchsize, repeat=False, shuffle=False)
def eval_func(batch):
in_values = []
rest_values = []
for sample in batch:
in_values.append(sample[0:1])
rest_values.append(sample[1:])
in_values = tuple(list(v) for v in zip(*in_values))
rest_values = tuple(list(v) for v in zip(*rest_values))
out_values = model.predict(*in_values)
if comm.rank == 0:
hook(in_values, out_values, rest_values)
return (out_values, rest_values)
def aggregate_func(results):
out_values, rest_values = unzip(results)
out_values = tuple(map(_flatten, unzip(out_values)))
rest_values = tuple(map(_flatten, unzip(rest_values)))
eval_(out_values, rest_values)
evaluator = GatherEvaluator(comm, iterator,
model, aggregate_func,
eval_func=eval_func)
evaluator(None)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment