Skip to content

Instantly share code, notes, and snippets.

@enijkamp
Created June 28, 2021 02:45
Show Gist options
  • Save enijkamp/549236d67eb14844b1828b17bb88aec2 to your computer and use it in GitHub Desktop.
Save enijkamp/549236d67eb14844b1828b17bb88aec2 to your computer and use it in GitHub Desktop.
tfrecordresumableloader.py
import argparse
import numpy as np
import tensorflow as tf
class TFRecordResumableLoader:
def __init__(self, files, batch_size, batch_prefetch, parse_fn, map_fn=lambda x: x):
self.files = files
self.batch_size = batch_size
self.batch_prefetch = batch_prefetch
self.parse_fn = parse_fn
self.map_fn = map_fn
self.state_files_used = []
self.state_file_current = None
self.state_batch_index = 0
self.state_restore = False
def set_state(self, state):
self.state_files_used = state['state_files_used']
self.state_file_current = state['state_file_current']
self.state_batch_index = state['state_batch_index']
self.state_restore = True
def get_state(self):
return { 'state_files_used': list(self.state_files_used), 'state_file_current': self.state_file_current, 'state_batch_index': self.state_batch_index }
def reset_state(self):
self.state_files_used = []
self.state_file_current = None
self.state_batch_index = 0
def sample(self):
def unused_files():
return [f for f in self.files if f not in self.state_files_used]
for f in unused_files():
if self.state_restore:
assert self.state_file_current == f
else:
self.state_file_current = f
ds = tf.data.TFRecordDataset(f)
ds = ds.map(self.parse_fn, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.apply(tf.data.experimental.dense_to_ragged_batch(np.prod(self.batch_size), drop_remainder=True))
ds = ds.prefetch(self.batch_prefetch)
for batch_index, batch in enumerate(ds):
if self.state_restore:
if batch_index <= self.state_batch_index:
continue
else:
self.state_restore = False
self.state_batch_index = batch_index
yield batch
self.state_files_used.append(f)
self.reset_state()
def samples(self):
while True:
for sample in self.sample():
yield sample
def create_args(args=argparse.Namespace()):
args.ds_batch_size = 2
args.ds_prefetch = 10
return args
def tf_parse(example_proto):
features = { 'text': tf.io.VarLenFeature(tf.int64) }
parsed_features = tf.io.parse_single_example(example_proto, features)
return tf.cast(tf.sparse.to_dense(tf.sparse.reorder(parsed_features['text'])), tf.uint32)
def load_records(args):
files = ['/export/home/gptc/bigquery_bpe/c/data_000000000000.json']
loader = TFRecordResumableLoader(files=files, batch_size=args.ds_batch_size, batch_prefetch=args.ds_prefetch, parse_fn=tf_parse)
print(next(loader.samples()))
def test_restore_state(args):
files = ['/export/home/gptc/bigquery_bpe/c/data_000000000000.json']
loader = TFRecordResumableLoader(files=files, batch_size=args.ds_batch_size, batch_prefetch=args.ds_prefetch, parse_fn=tf_parse)
# (1) create state
ds_iter = loader.samples()
for _ in range(8):
next(ds_iter)
# (2) store state
state = loader.get_state()
print(state)
# (3) move forward
check_sample = next(ds_iter)
for _ in range(8):
next(ds_iter)
# (4) restore state
loader.set_state(state)
ds_iter = loader.samples()
# (5) assert
check_sample2 = next(ds_iter)
assert tf.reduce_sum(tf.cast(tf.not_equal(check_sample, check_sample2), tf.uint32)).numpy() == 0
def main():
args = create_args()
# load_records(args)
test_restore_state(args)
print('done.')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment