Skip to content

Instantly share code, notes, and snippets.

@mokemokechicken
Created January 24, 2016 04:55
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save mokemokechicken/7e512a0a97193bd7420c to your computer and use it in GitHub Desktop.
Save mokemokechicken/7e512a0a97193bd7420c to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# coding: utf-8
"""Create Sample Sequence Data"""
import random
from cPickle import dump, HIGHEST_PROTOCOL
import tensorflow as tf
import json
flags = tf.flags
flags.DEFINE_string("dataset", "dataset.pkl", "path to output data")
flags.DEFINE_integer("n", 10000, "number of train_data")
FLAGS = flags.FLAGS
def generate_seq():
"""
Generate Sequence
:return: sequence
"""
ret = [1]
while True:
rand = random.random()
if rand < 0.1:
ret.append(10) # End Of Sequence
return ret
elif rand < 0.6:
ret.append((ret[-1] + 1) % 10)
else:
n = (ret[-1] % 5) + 1 # 1~5
ref = ret[-min(n, len(ret))] # 直近のn番目の数字
ret.append((ref + 1) % 10)
def main(unused_args):
random.seed()
num_train = FLAGS.n
data_path = FLAGS.dataset
train_data = [generate_seq() for _ in range(num_train)]
validation_data = [generate_seq() for _ in range(int(num_train*0.2))]
test_data = [generate_seq() for _ in range(num_train)]
with open(data_path, 'w') as f:
if data_path.endswith('.pkl'):
dump([train_data, validation_data, test_data], f, protocol=HIGHEST_PROTOCOL)
else:
json.dump([train_data, validation_data, test_data], f)
if __name__ == '__main__':
tf.app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment