Skip to content

Instantly share code, notes, and snippets.

@moaminsharifi
Created December 25, 2020 14:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save moaminsharifi/d8a706a123fc5691fdb34e74aefad667 to your computer and use it in GitHub Desktop.
Save moaminsharifi/d8a706a123fc5691fdb34e74aefad667 to your computer and use it in GitHub Desktop.
Dataset Train and Test split
def separate(X, y, train_percent = 70):
"""Separate Function: separate data set to train and test part
Which Each dataset have fair part of each class(or lable)
Parameters
----------
X : numpy array or list
features of dataset
y : numpy array or list
label of dataset
train_percent : int
The second parameter and is number beween 1 - 99
Returns
-------
tupple
X_train , y_train ,X_test , y_test
where X_train and y_train about {train_percent}% of dataset
where y_train and y_test is {100 - train_percent}% of dataset
"""
assert train_percent >= 1 and train_percent <= 99 , "at least train_percent must be one and maximum 99"
count_of_data = len(X)
unique_class = np.unique(y)
count_of_unique_class = len(unique_class)
"""
create key_value list which have each class as key
and indexes of class as value
"""
indexs_of_diffrent_class = {class_name:np.where(y == class_name)[0] for class_name in unique_class}
assert count_of_unique_class >= 1, "must be atleast two diffrent class"
train_set_index = []
test_set_index = []
for y_class in unique_class:
lenght_class = len(indexs_of_diffrent_class[y_class])
print(f"count of {y_class} label is : {lenght_class}")
count_of_train = int((lenght_class / 100) * train_percent)
count_of_test = lenght_class - count_of_train
assert count_of_train >= 1 and count_of_test >= 1, "one of the sets is zero member"
train_indexes = list(indexs_of_diffrent_class[y_class][:count_of_train])
test_indexes = list(indexs_of_diffrent_class[y_class][count_of_train:])
train_set_index.extend(train_indexes)
test_set_index.extend(test_indexes)
X_train , y_train = X[train_set_index] , y[train_set_index]
X_test , y_test = X[test_set_index] , y[test_set_index]
print(f"train set is {len(train_set_index)} \ntesting set is {len(test_set_index)}")
return (X_train , y_train ,X_test , y_test)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment