Skip to content

Instantly share code, notes, and snippets.

Created January 21, 2020 21:18
Show Gist options
  • Save literadix/90d77b56221261f0fe34e66e1295794a to your computer and use it in GitHub Desktop.
Save literadix/90d77b56221261f0fe34e66e1295794a to your computer and use it in GitHub Desktop.
xgboost example
import pandas as pd
import numpy as np
import xgboost as xgb
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
import joblib
# First you load the dataset from sklearn, where X will be the data, y – the class labels:
iris = datasets.load_iris()
X =
y =
# Then you split the data into train and test sets with 80-20% split:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Next you need to create the Xgboost specific DMatrix data format from
# the numpy array. Xgboost can work with numpy arrays directly, load data
# from svmlignt files and other formats. Here is how to work with numpy arrays:
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test, label=y_test)
# Now for the Xgboost to work you need to set the parameters:
param = {
'max_depth': 3, # the maximum depth of each tree
'eta': 0.3, # the training step for each iteration
'silent': 1, # logging mode - quiet
'objective': 'multi:softprob', # error evaluation for multiclass training
'num_class': 3} # the number of classes that exist in this datset
num_round = 20 # the number of training iterations
# train
bst = xgb.train(param, dtrain, num_round)
# predict
preds = bst.predict(dtest)
best_preds = np.asarray([np.argmax(line) for line in preds])
from sklearn.metrics import precision_score
print (precision_score(y_test, best_preds, average='macro'))
joblib.dump(bst, 'bst_model.pkl', compress=True)
bst = joblib.load('bst_model.pkl') # load it later
# s = pd.Series([1, 3, 5, np.nan, 6, 8])
# dates = pd.date_range('20130101', periods=6)
# df = pd.DataFrame(np.random.randn(6, 4), index=dates, columns=list('ABCD'))
# print(df)
#name = input('What is your name? ')
#greeting = f'Hello {name}'
#print (greeting)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment