Skip to content

Instantly share code, notes, and snippets.

@zh3389
Last active February 7, 2021 09:01
Show Gist options
  • Save zh3389/150881af0ba9c30403bb319582547355 to your computer and use it in GitHub Desktop.
Save zh3389/150881af0ba9c30403bb319582547355 to your computer and use it in GitHub Desktop.
用于对训练模型的数据 进行数据均衡操作.
# -*- encoding:utf-8 -*-
"""
@作者:Mr.zhang
@文件名:Data_replication_equalization.py
@时间:20-6-03 上午10:54
@文档说明:
1. 使用pandas包
2. 对csv文件的数据做数据均衡
3. 通过复制的方式均衡数据
"""
import pandas as pd
class Data_equalization_initialization:
def __init__(self, data, column_names):
self.data_column_name = column_names[0]
self.class_column_name = column_names[1]
self.data = data[[self.data_column_name, self.class_column_name]]
self.supplementary_dict = self.get_replenishment_quantity()
self.final_data = pd.DataFrame()
def group_equalization(self):
'''
:return: Balance data by group
'''
grouped = self.data.groupby(self.class_column_name)
for class_name, dataframe in grouped:
temp_copy_data = self.data_copy(dataframe, self.supplementary_dict[class_name])
self.final_data = pd.concat([self.final_ata, temp_copy_data])
return self.final_data
def get_replenishment_quantity(self):
'''
:return: Return the multiples and remainders of supplementary data required for each category
'''
data_class_count = dict(self.data[self.class_column_name].value_counts())
max_num = max(data_class_count.values())
supplementary_dict = {}
for item in data_class_count.items():
temp_dict = {}
temp_dict["multiple"] = max_num // item[1] - 1
temp_dict["remainder"] = max_num % item[1]
supplementary_dict[item[0]] = temp_dict
return supplementary_dict
def data_copy(self, dataframe, sup_dict):
'''
:param dataframe: data
:param sup_dict: copy data is dict
:return: final data
'''
temp_df = dataframe
for i in range(sup_dict["multiple"]):
dataframe = pd.concat([dataframe, temp_df])
dataframe = pd.concat([dataframe, dataframe.sample(sup_dict["remainder"])])
return dataframe
if __name__ == '__main__':
data = pd.read_csv("./data/train_data.csv")
column_names = ["data", "class"]
print("数据均衡前:")
print(data[column_names[1]].value_counts())
equalization = Data_equalization_initialization(data, column_names)
data = equalization.group_equalization()
print("数据均衡后:")
print(data[column_names[1]].value_counts())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment