Created
March 23, 2018 00:33
-
-
Save akirayou/11e23365b27a12cefa355c72361d9d3c to your computer and use it in GitHub Desktop.
Label Balanced Iteretor for chainer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
Created on Tue Sep 27 02:01:52 2016 | |
@author: a-noda | |
""" | |
import numpy as np | |
from chainer.dataset import iterator | |
class BlanceIterator(iterator.Iterator): | |
''' | |
Just choice same amount of data for each label | |
''' | |
def __init__(self, dataset,ids, batch_size, repeat=True): | |
self.dataset = dataset | |
self.batch_size = batch_size | |
self._repeat = repeat | |
self._idxis=[] | |
for i in np.unique(ids): | |
self._idxis.append(np.nonzero(ids==i)[0]); | |
self.current_position = 0 | |
self.epoch = 0 | |
self.is_new_epoch = False | |
def __next__(self): | |
if not self._repeat and self.epoch > 0: | |
raise StopIteration | |
N = len(self.dataset) | |
index=np.empty(self.batch_size,dtype=np.int) | |
nofId=len(self._idxis) | |
for i,x in enumerate(self._idxis): | |
s=self.batch_size*i//nofId | |
e=self.batch_size*(i+1)//nofId | |
index[s:e]=np.random.choice(x,e-s) | |
#batch=self.dataset[index] | |
#print("index",index) | |
batch=[self.dataset[i] for i in index] | |
self.current_position+=self.batch_size | |
if self.current_position >= N: | |
if self._repeat: | |
self.current_position = 0 | |
else: | |
self.current_position = N | |
self.epoch += 1 | |
self.is_new_epoch = True | |
else: | |
self.is_new_epoch = False | |
return batch | |
next = __next__ | |
@property | |
def epoch_detail(self): | |
return self.epoch + self.current_position / len(self.dataset) | |
def serialize(self, serializer): | |
self.current_position = serializer('current_position', | |
self.current_position) | |
self.epoch = serializer('epoch', self.epoch) | |
self.is_new_epoch = serializer('is_new_epoch', self.is_new_epoch) | |
if self._order is not None: | |
serializer('_order', self._order) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment