Created
December 29, 2018 09:15
-
-
Save kbrx93/fc5237531571985a2a27e1d95c75cb79 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
from sklearn.neighbors import KNeighborsClassifier | |
from sklearn.model_selection import train_test_split | |
from sklearn.preprocessing import StandardScaler | |
def knn(): | |
""" | |
用 KNN 算法预测用户入住位置 | |
:return: None | |
""" | |
# 1. 读取数据 | |
data = pd.read_csv("./train.csv") | |
# 2. 对数据进行简单的处理 | |
# 缩小范围 | |
data = data.query("x > 1.0 & x < 1.25 & y > 2.5 & y < 2.75") | |
print("data.size: ", data.size) | |
# 时间数据处理 | |
time = pd.to_datetime(data['time'], unit='s') | |
# 将时间格式化并加入数据集中,同时删除原来时间特征 | |
time = pd.DatetimeIndex(time) | |
# 关闭难看的警告 | |
pd.set_option('chained_assignment', None) | |
data.loc[:, 'day'] = time.day | |
data.loc[:, 'hour'] = time.hour | |
data.loc[:, 'weekday'] = time.weekday | |
data = data.drop(['time'], axis=1) | |
# 样本数量处理 | |
# 先进行分组 | |
place_group_count = data.groupby('place_id').count() | |
# print(place_group_count) | |
print("---------------分组成功------------------") | |
# 筛选符合条件的分组,并重置列表 | |
reset_group = place_group_count[place_group_count.row_id > 200].reset_index() | |
print("---------------重置列表成功------------------") | |
# 取出 data 和 reset_group 具有相同数据的样本 | |
data = data[data['place_id'].isin(reset_group.place_id)] | |
print("---------------取出相同数据样本成功------------------") | |
# 取特征值 | |
y = data['place_id'] | |
x = data.drop(['place_id'], axis=1) | |
print("---------------取出特征值和目标值成功------------------") | |
# 3. 数据集分割训练 | |
# 分割 | |
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.25) | |
print("---------------分割数据集成功------------------") | |
# 数据标准化,添加 astype 去警告 | |
std = StandardScaler() | |
x_train = std.fit_transform(x_train.astype(float)) | |
x_test = std.transform(x_test.astype(float)) | |
print("---------------数据标准化成功------------------") | |
# KNN 核心 | |
knn = KNeighborsClassifier(n_neighbors=5) | |
# 数据填充 | |
knn.fit(x_train, y_train) | |
# 训练 | |
# 得出预测结果 | |
y_predict = knn.predict(x_test) | |
print("预测的目标签到位置为:", y_predict) | |
# 得出准确率 | |
print("预测的准确率为:", knn.score(x_test, y_test)) | |
return None | |
if __name__ == '__main__': | |
knn() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment