Skip to content

Instantly share code, notes, and snippets.

@andrey-khropov
Last active March 22, 2019 12:43
Show Gist options
  • Save andrey-khropov/ae83131549229c45b887e96196ef885e to your computer and use it in GitHub Desktop.
Save andrey-khropov/ae83131549229c45b887e96196ef885e to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import os.path
import sys
import catboost
from catboost import Pool, CatBoost, CatBoostRegressor
def data_file(*args):
return os.path.join('/home/akhropov/src/trunk-cb-4/arcadia/catboost/pytest/data', *args)
def test_output_path(path):
return path
QUERYWISE_TRAIN_FILE = data_file('querywise', 'train')
QUERYWISE_TEST_FILE = data_file('querywise', 'test')
QUERYWISE_CD_FILE = data_file('querywise', 'train.cd')
OUTPUT_CBM_MODEL_PATH = 'model.cbm'
OUTPUT_COREML_MODEL_PATH = 'model.mlmodel'
OUTPUT_JSON_MODEL_PATH = 'model.json'
def test_coreml_import_export(task_type):
print (QUERYWISE_TRAIN_FILE)
train_pool = Pool(QUERYWISE_TRAIN_FILE, column_description=QUERYWISE_CD_FILE)
test_pool = Pool(QUERYWISE_TEST_FILE, column_description=QUERYWISE_CD_FILE)
model = CatBoost(params={'loss_function': 'RMSE', 'iterations': 20, 'thread_count': 8, 'task_type': task_type, 'devices': '0'})
model.fit(train_pool)
output_cbm_model_path = test_output_path(OUTPUT_CBM_MODEL_PATH)
model.save_model(output_cbm_model_path, format="cbm")
cbm_loaded_model = CatBoostRegressor()
cbm_loaded_model.load_model(output_cbm_model_path, format="cbm")
output_json_model_path = test_output_path(OUTPUT_JSON_MODEL_PATH)
cbm_loaded_model.save_model(output_json_model_path, format="json")
#output_coreml_model_path = test_output_path(OUTPUT_COREML_MODEL_PATH)
#model.save_model(output_coreml_model_path, format="coreml")
#canon_pred = model.predict(test_pool)
#coreml_loaded_model = CatBoostRegressor()
#coreml_loaded_model.load_model(output_coreml_model_path, format="coreml")
#assert all(canon_pred == coreml_loaded_model.predict(test_pool))
#return compare_canonical_models(output_coreml_model_path)
def main():
print ('python version=', sys.version)
print ('catboost.version.VERSION', catboost.version.VERSION)
test_coreml_import_export(task_type='CPU')
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment