Skip to content

Instantly share code, notes, and snippets.

@akirayou
Created Mar 23, 2018
Embed
What would you like to do?
Label Balanced Iteretor for chainer
# -*- coding: utf-8 -*-
"""
Created on Tue Sep 27 02:01:52 2016
@author: a-noda
"""
import numpy as np
from chainer.dataset import iterator
class BlanceIterator(iterator.Iterator):
'''
Just choice same amount of data for each label
'''
def __init__(self, dataset,ids, batch_size, repeat=True):
self.dataset = dataset
self.batch_size = batch_size
self._repeat = repeat
self._idxis=[]
for i in np.unique(ids):
self._idxis.append(np.nonzero(ids==i)[0]);
self.current_position = 0
self.epoch = 0
self.is_new_epoch = False
def __next__(self):
if not self._repeat and self.epoch > 0:
raise StopIteration
N = len(self.dataset)
index=np.empty(self.batch_size,dtype=np.int)
nofId=len(self._idxis)
for i,x in enumerate(self._idxis):
s=self.batch_size*i//nofId
e=self.batch_size*(i+1)//nofId
index[s:e]=np.random.choice(x,e-s)
#batch=self.dataset[index]
#print("index",index)
batch=[self.dataset[i] for i in index]
self.current_position+=self.batch_size
if self.current_position >= N:
if self._repeat:
self.current_position = 0
else:
self.current_position = N
self.epoch += 1
self.is_new_epoch = True
else:
self.is_new_epoch = False
return batch
next = __next__
@property
def epoch_detail(self):
return self.epoch + self.current_position / len(self.dataset)
def serialize(self, serializer):
self.current_position = serializer('current_position',
self.current_position)
self.epoch = serializer('epoch', self.epoch)
self.is_new_epoch = serializer('is_new_epoch', self.is_new_epoch)
if self._order is not None:
serializer('_order', self._order)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment