Skip to content

Instantly share code, notes, and snippets.

@literadix
Created January 21, 2020 21:18
Show Gist options
  • Save literadix/90d77b56221261f0fe34e66e1295794a to your computer and use it in GitHub Desktop.
Save literadix/90d77b56221261f0fe34e66e1295794a to your computer and use it in GitHub Desktop.
xgboost example
# https://www.kdnuggets.com/2017/03/simple-xgboost-tutorial-iris-dataset.html
# https://joblib.readthedocs.io/en/latest/
import pandas as pd
import numpy as np
import xgboost as xgb
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
import joblib
# First you load the dataset from sklearn, where X will be the data, y – the class labels:
iris = datasets.load_iris()
X = iris.data
y = iris.target
# Then you split the data into train and test sets with 80-20% split:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Next you need to create the Xgboost specific DMatrix data format from
# the numpy array. Xgboost can work with numpy arrays directly, load data
# from svmlignt files and other formats. Here is how to work with numpy arrays:
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test, label=y_test)
# Now for the Xgboost to work you need to set the parameters:
param = {
'max_depth': 3, # the maximum depth of each tree
'eta': 0.3, # the training step for each iteration
'silent': 1, # logging mode - quiet
'objective': 'multi:softprob', # error evaluation for multiclass training
'num_class': 3} # the number of classes that exist in this datset
num_round = 20 # the number of training iterations
# train
bst = xgb.train(param, dtrain, num_round)
# predict
preds = bst.predict(dtest)
best_preds = np.asarray([np.argmax(line) for line in preds])
print(preds)
print(best_preds)
from sklearn.metrics import precision_score
print (precision_score(y_test, best_preds, average='macro'))
joblib.dump(bst, 'bst_model.pkl', compress=True)
bst = joblib.load('bst_model.pkl') # load it later
# s = pd.Series([1, 3, 5, np.nan, 6, 8])
# dates = pd.date_range('20130101', periods=6)
# df = pd.DataFrame(np.random.randn(6, 4), index=dates, columns=list('ABCD'))
# print(df)
#name = input('What is your name? ')
#greeting = f'Hello {name}'
#print (greeting)
print(y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment