This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
PANDAS_UDF_SCHEMA = StructType([StructField("dt", StringType(), False), | |
StructField("request_time", TimestampType(), True), | |
StructField("request_hour", IntegerType(), True), | |
StructField("region", StringType(), False), | |
StructField("dma", IntegerType(), True), | |
StructField("city", StringType(), False), | |
StructField("form_factor", StringType(), False), | |
StructField("brand_name", StringType(), False), | |
StructField("advertised_device_os", StringType(), False), | |
StructField("advertised_browser", StringType(), False), | |
StructField("netspeed", StringType(), False), | |
StructField("model_name", StringType(), False), | |
StructField("pub_acnt_id", IntegerType(), True), | |
StructField("domain", StringType(), False), | |
StructField("ad_unit_id", IntegerType(), True), | |
StructField("placement_type_name", StringType(), False), | |
StructField("advertiser_category_id", StringType(), True), | |
StructField("io_number", IntegerType(), True), | |
StructField("io_line_item_number", IntegerType(), True), | |
StructField("variant_id", IntegerType(), True), | |
StructField("product_id", IntegerType(), True), | |
StructField("y_label", IntegerType(), True)]) | |
@pandas_udf(config.PANDAS_UDF_SCHEMA, PandasUDFType.GROUPED_MAP) | |
def pandas_udf_ctr_create_model(pdf_src): | |
# This code is run inside executers JVM. | |
# required files and classes that are used in the udf or in methods it uses must be import | |
# UDF code is broadcast to executer, dependcies must be supplied to excuters. | |
from binary_classification.metrics_wrapper import MetricsWrapper | |
from binary_classification.logger_manager import LoggerManager | |
from binary_classification.infra_utils import InfraUtils | |
######################################################### | |
# pdf_src is pandas dataframe with 10M records, t | |
######################################################### | |
env = command_args.value.env | |
model_feature_name = command_args.value.model_feature_name | |
# Each UDF have single value of model_feature_name, in the master node, we group by model_feature_name, so each UDF gets pandas DF | |
# with single value of model_feature_name | |
model_feature_value = str(pdf_src[model_feature_name][0]) | |
try: | |
pdf_train_test = pdf_src[list(pdf_src)] | |
############################################################################## | |
# Take in cosideration we convert spark df to pd df, Some adjusment might be needed | |
pdf_train_test = pre_processing(pdf_train_test) | |
################################################################################# | |
train_pdf, test_pdf = train_test_split(pdf_train_test, test_size=0.2) | |
params2 = {'n_estimators': 100, | |
'learning_rate': 0.5, | |
'seed': 0, | |
'subsample': 0.8, | |
'n_jobs': 50, | |
'colsample_bytree': 0.8, | |
'objective': 'binary:logistic', | |
'max_depth': 10, | |
'min_child_weight': 300, | |
'gamma': 2, | |
'max_delta_step': 6 | |
} | |
estimator = xgb.XGBClassifier(**params2) | |
mapper = DataFrameMapper([(i, None) if j != 'object' and j != 'bool' else (i, | |
[CategoricalDomain( | |
missing_value_treatment="as_value", | |
invalid_value_treatment="as_missing", | |
missing_value_replacement=train_x[ | |
i].value_counts().idxmax(), | |
invalid_value_replacement=train_x[ | |
i].value_counts().idxmax()), | |
LabelEncoder()]) | |
for i, j in | |
zip(train_x.columns.values, train_x.dtypes.values)] | |
, input_df=True, df_out=True) | |
rf_pipeline = PMMLPipeline([("mapper", mapper), ("classifier", estimator)]) | |
rf_pipeline.fit(train_x, train_y) | |
model_base_name = 'ctr_'+model_feature_name+'_"+model_feature_value | |
pmml_model_name = model_base_name + '.pmml' | |
# Save pmml file | |
sklearn2pmml(rf_pipeline, pmml_model_name, with_repr=True) | |
# You can upload the pmml to s3 | |
#.... | |
# return the same pdf - we are not using the return value. | |
# All the job is done inside this function | |
return pdf_src | |
except Exception as e: | |
print "**************** pandas_udf_ctr_create_model Exception ***************************" | |
traceback.print_exc() | |
raise Exception(e.message) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment