Skip to content

Instantly share code, notes, and snippets.

@ctodd
Created July 9, 2020 05:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ctodd/a411a8bb516601c5b1f6ca89ed92cb70 to your computer and use it in GitHub Desktop.
Save ctodd/a411a8bb516601c5b1f6ca89ed92cb70 to your computer and use it in GitHub Desktop.
import json
augmented_manifest_filename_output = local_manifest_dir + '/output.manifest'
with jsonlines.open(augmented_manifest_filename_output, 'r') as reader:
lines = list(reader)
# Shuffle data in place.
np.random.shuffle(lines)
dataset_size = len(lines)
num_training_samples = round(dataset_size*0.9)
train_data = lines[:num_training_samples]
validation_data = lines[num_training_samples:]
augmented_manifest_filename_train = local_manifest_dir + '/train.manifest'
with open(augmented_manifest_filename_train, 'w') as f:
for line in train_data:
f.write(json.dumps(line))
f.write('\n')
augmented_manifest_filename_validation = local_manifest_dir + '/validation.manifest'
with open(augmented_manifest_filename_validation, 'w') as f:
for line in validation_data:
f.write(json.dumps(line))
f.write('\n')
print(f'training samples: {num_training_samples}, validation samples: {len(lines)-num_training_samples}')`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment