Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save serihiro/7e20215d079045c076951841cb0dae3e to your computer and use it in GitHub Desktop.
Save serihiro/7e20215d079045c076951841cb0dae3e to your computer and use it in GitHub Desktop.
chainer iterator_statemachine examination
from chainer.iterators._statemachine import IteratorState
from chainer.iterators._statemachine import iterator_statemachine
from chainer.iterators.order_samplers import ShuffleOrderSampler
import numpy
dataset_length = 100
batch_size = 32
repeat = True
order_sampler = ShuffleOrderSampler()
initial_order = order_sampler(numpy.arange(dataset_length), 0)
current_state = IteratorState(0, 0, False, initial_order)
for _ in range(10):
current_state, indices = iterator_statemachine(
current_state,
batch_size,
repeat,
order_sampler,
dataset_length
)
print(current_state)
print(len(indices))
'''
IteratorState(current_position=32, epoch=0, is_new_epoch=False, order=array([43, 50, 30, 95, 39, 23, 41, 51, 21, 96, 12, 40, 24, 1, 9, 49, 64,
19, 83, 58, 48, 76, 74, 72, 73, 32, 52, 65, 11, 97, 67, 18, 68, 82,
16, 62, 26, 84, 88, 70, 91, 60, 2, 59, 94, 75, 92, 56, 35, 99, 69,
47, 13, 8, 17, 55, 78, 33, 0, 25, 29, 89, 38, 44, 71, 28, 7, 6,
63, 80, 81, 20, 27, 36, 54, 87, 37, 22, 4, 34, 14, 31, 86, 90, 3,
85, 10, 98, 42, 46, 61, 45, 77, 79, 66, 93, 5, 57, 53, 15]))
32
IteratorState(current_position=64, epoch=0, is_new_epoch=False, order=array([43, 50, 30, 95, 39, 23, 41, 51, 21, 96, 12, 40, 24, 1, 9, 49, 64,
19, 83, 58, 48, 76, 74, 72, 73, 32, 52, 65, 11, 97, 67, 18, 68, 82,
16, 62, 26, 84, 88, 70, 91, 60, 2, 59, 94, 75, 92, 56, 35, 99, 69,
47, 13, 8, 17, 55, 78, 33, 0, 25, 29, 89, 38, 44, 71, 28, 7, 6,
63, 80, 81, 20, 27, 36, 54, 87, 37, 22, 4, 34, 14, 31, 86, 90, 3,
85, 10, 98, 42, 46, 61, 45, 77, 79, 66, 93, 5, 57, 53, 15]))
32
IteratorState(current_position=96, epoch=0, is_new_epoch=False, order=array([43, 50, 30, 95, 39, 23, 41, 51, 21, 96, 12, 40, 24, 1, 9, 49, 64,
19, 83, 58, 48, 76, 74, 72, 73, 32, 52, 65, 11, 97, 67, 18, 68, 82,
16, 62, 26, 84, 88, 70, 91, 60, 2, 59, 94, 75, 92, 56, 35, 99, 69,
47, 13, 8, 17, 55, 78, 33, 0, 25, 29, 89, 38, 44, 71, 28, 7, 6,
63, 80, 81, 20, 27, 36, 54, 87, 37, 22, 4, 34, 14, 31, 86, 90, 3,
85, 10, 98, 42, 46, 61, 45, 77, 79, 66, 93, 5, 57, 53, 15]))
32
IteratorState(current_position=28, epoch=1, is_new_epoch=True, order=array([28, 9, 24, 49, 48, 2, 84, 15, 90, 13, 32, 38, 95, 74, 76, 70, 77,
87, 46, 80, 52, 33, 12, 96, 31, 18, 86, 50, 75, 23, 44, 81, 1, 91,
94, 51, 68, 39, 66, 7, 89, 25, 69, 71, 30, 22, 29, 4, 42, 79, 65,
21, 99, 82, 6, 40, 43, 83, 26, 54, 20, 67, 64, 57, 58, 55, 37, 98,
85, 5, 10, 56, 93, 88, 17, 73, 36, 14, 45, 63, 59, 19, 60, 3, 27,
47, 62, 16, 8, 0, 92, 53, 41, 61, 35, 72, 11, 97, 78, 34]))
32
IteratorState(current_position=60, epoch=1, is_new_epoch=False, order=array([28, 9, 24, 49, 48, 2, 84, 15, 90, 13, 32, 38, 95, 74, 76, 70, 77,
87, 46, 80, 52, 33, 12, 96, 31, 18, 86, 50, 75, 23, 44, 81, 1, 91,
94, 51, 68, 39, 66, 7, 89, 25, 69, 71, 30, 22, 29, 4, 42, 79, 65,
21, 99, 82, 6, 40, 43, 83, 26, 54, 20, 67, 64, 57, 58, 55, 37, 98,
85, 5, 10, 56, 93, 88, 17, 73, 36, 14, 45, 63, 59, 19, 60, 3, 27,
47, 62, 16, 8, 0, 92, 53, 41, 61, 35, 72, 11, 97, 78, 34]))
32
IteratorState(current_position=92, epoch=1, is_new_epoch=False, order=array([28, 9, 24, 49, 48, 2, 84, 15, 90, 13, 32, 38, 95, 74, 76, 70, 77,
87, 46, 80, 52, 33, 12, 96, 31, 18, 86, 50, 75, 23, 44, 81, 1, 91,
94, 51, 68, 39, 66, 7, 89, 25, 69, 71, 30, 22, 29, 4, 42, 79, 65,
21, 99, 82, 6, 40, 43, 83, 26, 54, 20, 67, 64, 57, 58, 55, 37, 98,
85, 5, 10, 56, 93, 88, 17, 73, 36, 14, 45, 63, 59, 19, 60, 3, 27,
47, 62, 16, 8, 0, 92, 53, 41, 61, 35, 72, 11, 97, 78, 34]))
32
IteratorState(current_position=24, epoch=2, is_new_epoch=True, order=array([81, 92, 7, 97, 34, 70, 30, 22, 0, 52, 82, 72, 71, 10, 85, 68, 8,
36, 93, 60, 62, 47, 19, 75, 37, 28, 21, 14, 48, 78, 39, 95, 18, 43,
55, 86, 53, 77, 24, 73, 67, 79, 45, 46, 63, 98, 1, 49, 94, 96, 99,
54, 84, 90, 64, 26, 23, 66, 13, 91, 61, 5, 3, 57, 11, 83, 56, 65,
88, 59, 6, 44, 89, 80, 76, 69, 31, 12, 15, 74, 25, 42, 58, 38, 51,
35, 9, 29, 87, 2, 4, 33, 32, 27, 17, 41, 50, 40, 16, 20]))
32
IteratorState(current_position=56, epoch=2, is_new_epoch=False, order=array([81, 92, 7, 97, 34, 70, 30, 22, 0, 52, 82, 72, 71, 10, 85, 68, 8,
36, 93, 60, 62, 47, 19, 75, 37, 28, 21, 14, 48, 78, 39, 95, 18, 43,
55, 86, 53, 77, 24, 73, 67, 79, 45, 46, 63, 98, 1, 49, 94, 96, 99,
54, 84, 90, 64, 26, 23, 66, 13, 91, 61, 5, 3, 57, 11, 83, 56, 65,
88, 59, 6, 44, 89, 80, 76, 69, 31, 12, 15, 74, 25, 42, 58, 38, 51,
35, 9, 29, 87, 2, 4, 33, 32, 27, 17, 41, 50, 40, 16, 20]))
32
IteratorState(current_position=88, epoch=2, is_new_epoch=False, order=array([81, 92, 7, 97, 34, 70, 30, 22, 0, 52, 82, 72, 71, 10, 85, 68, 8,
36, 93, 60, 62, 47, 19, 75, 37, 28, 21, 14, 48, 78, 39, 95, 18, 43,
55, 86, 53, 77, 24, 73, 67, 79, 45, 46, 63, 98, 1, 49, 94, 96, 99,
54, 84, 90, 64, 26, 23, 66, 13, 91, 61, 5, 3, 57, 11, 83, 56, 65,
88, 59, 6, 44, 89, 80, 76, 69, 31, 12, 15, 74, 25, 42, 58, 38, 51,
35, 9, 29, 87, 2, 4, 33, 32, 27, 17, 41, 50, 40, 16, 20]))
32
IteratorState(current_position=20, epoch=3, is_new_epoch=True, order=array([50, 80, 23, 20, 97, 10, 24, 63, 67, 54, 73, 47, 38, 28, 66, 45, 30,
18, 49, 52, 78, 44, 31, 41, 85, 34, 58, 32, 21, 91, 43, 62, 53, 99,
69, 40, 26, 68, 65, 59, 89, 42, 4, 8, 77, 39, 70, 88, 2, 22, 16,
72, 9, 29, 12, 1, 75, 55, 82, 11, 51, 14, 96, 33, 98, 79, 71, 93,
76, 74, 46, 81, 94, 17, 60, 19, 5, 86, 57, 13, 37, 6, 90, 56, 35,
7, 84, 25, 87, 36, 0, 48, 27, 64, 92, 3, 83, 61, 15, 95]))
32
'''
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment