Skip to content

Instantly share code, notes, and snippets.

@deep-diver
Last active May 18, 2021 15:12
Show Gist options
  • Save deep-diver/97be5e5ff9f6a7a00e0579043c165ec6 to your computer and use it in GitHub Desktop.
Save deep-diver/97be5e5ff9f6a7a00e0579043c165ec6 to your computer and use it in GitHub Desktop.
MNIST dataset Split
import pandas as pd
df = pd.read_csv('train.csv')
total_length = len(df)
keep_ratio = 0.5
keep_idx = (int)total_length*keep_ratio
keep_df = df[:keep_idx]
keep_df.to_csv('train.csv', index=False)
import pandas as pd
import numpy as np
# train_ratio 읽어오기 from YAML
# data_path = 데이터패스 읽어오기 from YAML
# train_data_path = 데이터패스 읽어오기 from YAML
# valid_data_path = 데이터패스 읽어오기 from YAML
def split(df, train_ratio):
total_data = len(df)
train_index = total_data//train_ratio
return df[:train_index], df[train_index:]
def main():
# 1. 데이터 읽어오기
df = pd.read_csv(data_path)
# 2. 데이터 분할
train_df, valid_df = split(df, train_ratio)
# 3. 데이터 저장
train_df.to_csv(train_data_path)
valid_df.to_csv(valid_data_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment