Last active
March 24, 2020 01:34
-
-
Save Jongbhin/07281557439df55fb3c64c31f2865102 to your computer and use it in GitHub Desktop.
[Split train dev file] #python # ml #dl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
from sklearn.model_selection import train_test_split | |
# split train, dev set and save to file | |
all_data = pd.read_csv(output_file_balanced, sep='\t') | |
all_data.sample(frac=1) | |
train_set, dev_set = train_test_split(all_data, test_size=0.2) | |
train_set.to_csv(output_file_train, sep='\t', header=False, index=False) | |
dev_set.to_csv(output_file_dev, sep='\t', header=False, index=False) | |
# function | |
def split_test_dev(all_data, train_ratio): | |
split_point = int(round(train_ratio * len(all_data))) | |
shuffled = all_data[:] | |
random.shuffle(shuffled) | |
train = shuffled[:split_point] | |
dev = shuffled[split_point:] | |
return train, dev |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment