Created
June 3, 2018 17:15
-
-
Save swerwath/1cd44afa565bffc4503c3ef642cdbb40 to your computer and use it in GitHub Desktop.
Split a CSV into training and testing sets
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import csv | |
from random import shuffle | |
parser = argparse.ArgumentParser(description='Randomly splits a CSV\'s rows into two sets') | |
parser.add_argument('input_path', help='input path of CSV to be split') | |
parser.add_argument('portion', type=float, help='portion of CSV for the first output') | |
args = parser.parse_args() | |
assert(args.portion >= 0 and args.portion <= 1) | |
rows = [] | |
with open(args.input_path) as csv_in: | |
reader = csv.DictReader(csv_in) | |
rows = [r for r in reader] | |
keys = rows[0].keys() | |
shuffle(rows) | |
split_point = int(args.portion * len(rows)) | |
first = rows[:split_point] | |
second = rows[split_point:] | |
del(rows) # I deal with some big CSVs | |
with open(args.input_path + "1", 'w') as first_out: | |
writer = csv.DictWriter(first_out, fieldnames=keys) | |
writer.writeheader() | |
for r in first: | |
writer.writerow(r) | |
del(first) | |
with open(args.input_path + "2", 'w') as second_out: | |
writer = csv.DictWriter(second_out, fieldnames=keys) | |
writer.writeheader() | |
for r in second: | |
writer.writerow(r) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment