Skip to content

Instantly share code, notes, and snippets.

@AnchorBlues
Created October 3, 2020 11:43
Show Gist options
  • Save AnchorBlues/1d92995626876ca076ae4b6b233811c9 to your computer and use it in GitHub Desktop.
Save AnchorBlues/1d92995626876ca076ae4b6b233811c9 to your computer and use it in GitHub Desktop.
LightGBMをhyperopt用いてハイパラチューニングし最良のモデルを保存するスクリプト
import os
from typing import Dict, List, Tuple, Any, Callable, Optional, Union, TypeVar
from pandas.core.frame import DataFrame as DF
from pandas.core.series import Series as S
from numpy import ndarray as ARR
from pathlib import Path
import pickle
import json
import numpy as np
import pandas as pd
from sklearn.model_selection import cross_val_score
import lightgbm as lgb
from hyperopt import hp, tpe
from hyperopt.fmin import fmin
def create_objective(
x_train: ARR,
y_train: ARR,
task_type: str,
scoring: Optional[str],
cv: int,
n_jobs: int,
history: List[Dict[str, float]]
) -> Callable[[Dict[str, float]], float]:
if task_type == "regression":
ModelClass = lgb.LGBMRegressor
elif task_type == "classification":
ModelClass = lgb.LGBMClassifier
else:
raise ValueError()
def objective(params: Dict[str, float]) -> float:
# intとして与えるべきパラメータをfloatとして与えてしまうとエラーになるので型変換
for pname in ('num_leaves', 'min_child_samples', 'subsample_freq', 'n_estimators'):
params[pname] = int(params[pname])
model = ModelClass(
**params,
random_state=0,
n_jobs=1 # cros_val_scoreの方で並列化する
)
score = cross_val_score(model, x_train, y_train,
scoring=scoring, cv=cv, n_jobs=n_jobs).mean()
print(f"params: {params}, score: {score}")
d = params.copy()
d['score'] = score
history.append(d)
return - score
return objective
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='')
parser.add_argument('train_csv_path', type=str, help='')
parser.add_argument('obj_var_name', type=str)
parser.add_argument('task_type', type=str,
choices=['regression', 'classification'])
parser.add_argument('--output_dir_path', type=str, default='../results/')
parser.add_argument('--n_jobs', type=int, default=-1)
parser.add_argument('--cv', type=int, default=5)
parser.add_argument('--scoring', type=str, default=None)
parser.add_argument('--max_evals', type=int, default=200)
args = parser.parse_args()
print(args)
obj_var_name: str = args.obj_var_name
output_dir_path: Path = Path(args.output_dir_path)
df_train: DF = pd.read_csv(args.train_csv_path)
x_train = df_train.drop(obj_var_name, axis=1).values
y_train = df_train[obj_var_name].values
space = {
'n_estimators': hp.quniform('n_estimators', 50, 1000, 50),
'num_leaves': hp.quniform('num_leaves', 4, 100, 4),
'subsample': hp.uniform('subsample', 0.5, 1.0),
'subsample_freq': hp.quniform('subsample_freq', 1, 20, 2),
'colsample_bytree': hp.uniform('colsample_bytree', 0.01, 1.0),
'min_child_samples': hp.quniform('min_child_samples', 1, 50, 1),
'min_child_weight': hp.loguniform('min_child_weight', np.log(1e-3), np.log(1e+1)),
'reg_lambda': hp.loguniform('reg_lambda', np.log(1e-2), np.log(1e+3)),
'learning_rate': hp.loguniform('learning_rate', np.log(1e-3), np.log(1e-1))
}
history: List[Dict[str, float]] = []
objective = create_objective(x_train, y_train,
args.task_type, args.scoring,
args.cv, args.n_jobs, history)
np.random.seed(0)
best_params = fmin(fn=objective,
space=space,
algo=tpe.suggest,
max_evals=args.max_evals, rstate=np.random.RandomState(0))
print(f"best_params: {best_params}")
for pname in ('num_leaves', 'min_child_samples', 'subsample_freq', 'n_estimators'):
best_params[pname] = int(best_params[pname])
if args.task_type == "regression":
ModelClass = lgb.LGBMRegressor
elif args.task_type == "classification":
ModelClass = lgb.LGBMClassifier
else:
raise ValueError()
best_model = ModelClass(random_state=0, n_jobs=-1, **best_params)
best_model.fit(x_train, y_train)
os.makedirs(output_dir_path, exist_ok=True)
with open(output_dir_path / "best_params.json", 'w') as f:
json.dump(best_params, f)
with open(output_dir_path / "best_model.pickle", mode='wb') as f:
pickle.dump(best_model, f)
pd.DataFrame(history).to_csv(output_dir_path / "history.csv", index=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment