Created
March 3, 2022 09:49
-
-
Save Keycatowo/eb042f1fdd5dd323e0a81a0670249bfb to your computer and use it in GitHub Desktop.
查看特徵的重要程度(以RF為例)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.ensemble import RandomForestClassifier | |
import matplotlib.pyplot as plt | |
feature_labels = list(train_X.columns) # 欄位名稱存下來等下顯示用 | |
forest = RandomForestClassifier().fit(train_X, train_Y) # 送進模型fit | |
# 用`.feature_importances_`取得重要性 | |
importances = forest.feature_importances_ | |
# 取得對應的index,等下顯示用 | |
indices = np.argsort(importances)[::-1] | |
# 印出所有重要性特徵 | |
for f in range(train_X.shape[1]): | |
print("%2d) %-*s %f" % (f + 1, 30, | |
feature_labels[indices[f]], | |
importances[indices[f]])) | |
# 圖形顯示 | |
plt.title('Feature Importance') | |
plt.bar(range(train_X.shape[1]), | |
importances[indices], | |
align='center') | |
plt.xticks(range(train_X.shape[1]), | |
[feature_labels[x] for x in indices], rotation=45) | |
plt.xlim([-1, train_X.shape[1]]) | |
plt.tight_layout() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
欄位名稱:
![image](https://user-images.githubusercontent.com/47293699/156539912-61ef5b11-e96a-45ac-9f66-eccd289b33d5.png)
重要程度:
![image](https://user-images.githubusercontent.com/47293699/156539953-44fba811-1b07-4322-9fb2-afcf91efe30a.png)
print出來顯示
![image](https://user-images.githubusercontent.com/47293699/156540009-fa89a53e-7e4c-4c7e-9ff1-e2daca1aaaa7.png)
圖形顯示
![image](https://user-images.githubusercontent.com/47293699/156540048-2d414baa-d525-4bd6-905b-1bb0e2d89c9e.png)