Skip to content

Instantly share code, notes, and snippets.

@ijan10
Last active March 9, 2023 11:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save ijan10/d9ce2cbbe23250545a4bb8443c508187 to your computer and use it in GitHub Desktop.
Save ijan10/d9ce2cbbe23250545a4bb8443c508187 to your computer and use it in GitHub Desktop.
PANDAS_UDF_SCHEMA = StructType([StructField("dt", StringType(), False),
StructField("request_time", TimestampType(), True),
StructField("request_hour", IntegerType(), True),
StructField("region", StringType(), False),
StructField("dma", IntegerType(), True),
StructField("city", StringType(), False),
StructField("form_factor", StringType(), False),
StructField("brand_name", StringType(), False),
StructField("advertised_device_os", StringType(), False),
StructField("advertised_browser", StringType(), False),
StructField("netspeed", StringType(), False),
StructField("model_name", StringType(), False),
StructField("pub_acnt_id", IntegerType(), True),
StructField("domain", StringType(), False),
StructField("ad_unit_id", IntegerType(), True),
StructField("placement_type_name", StringType(), False),
StructField("advertiser_category_id", StringType(), True),
StructField("io_number", IntegerType(), True),
StructField("io_line_item_number", IntegerType(), True),
StructField("variant_id", IntegerType(), True),
StructField("product_id", IntegerType(), True),
StructField("y_label", IntegerType(), True)])
@pandas_udf(config.PANDAS_UDF_SCHEMA, PandasUDFType.GROUPED_MAP)
def pandas_udf_ctr_create_model(pdf_src):
# This code is run inside executers JVM.
# required files and classes that are used in the udf or in methods it uses must be import
# UDF code is broadcast to executer, dependcies must be supplied to excuters.
from binary_classification.metrics_wrapper import MetricsWrapper
from binary_classification.logger_manager import LoggerManager
from binary_classification.infra_utils import InfraUtils
#########################################################
# pdf_src is pandas dataframe with 10M records, t
#########################################################
env = command_args.value.env
model_feature_name = command_args.value.model_feature_name
# Each UDF have single value of model_feature_name, in the master node, we group by model_feature_name, so each UDF gets pandas DF
# with single value of model_feature_name
model_feature_value = str(pdf_src[model_feature_name][0])
try:
pdf_train_test = pdf_src[list(pdf_src)]
##############################################################################
# Take in cosideration we convert spark df to pd df, Some adjusment might be needed
pdf_train_test = pre_processing(pdf_train_test)
#################################################################################
train_pdf, test_pdf = train_test_split(pdf_train_test, test_size=0.2)
params2 = {'n_estimators': 100,
'learning_rate': 0.5,
'seed': 0,
'subsample': 0.8,
'n_jobs': 50,
'colsample_bytree': 0.8,
'objective': 'binary:logistic',
'max_depth': 10,
'min_child_weight': 300,
'gamma': 2,
'max_delta_step': 6
}
estimator = xgb.XGBClassifier(**params2)
mapper = DataFrameMapper([(i, None) if j != 'object' and j != 'bool' else (i,
[CategoricalDomain(
missing_value_treatment="as_value",
invalid_value_treatment="as_missing",
missing_value_replacement=train_x[
i].value_counts().idxmax(),
invalid_value_replacement=train_x[
i].value_counts().idxmax()),
LabelEncoder()])
for i, j in
zip(train_x.columns.values, train_x.dtypes.values)]
, input_df=True, df_out=True)
rf_pipeline = PMMLPipeline([("mapper", mapper), ("classifier", estimator)])
rf_pipeline.fit(train_x, train_y)
model_base_name = 'ctr_'+model_feature_name+'_"+model_feature_value
pmml_model_name = model_base_name + '.pmml'
# Save pmml file
sklearn2pmml(rf_pipeline, pmml_model_name, with_repr=True)
# You can upload the pmml to s3
#....
# return the same pdf - we are not using the return value.
# All the job is done inside this function
return pdf_src
except Exception as e:
print "**************** pandas_udf_ctr_create_model Exception ***************************"
traceback.print_exc()
raise Exception(e.message)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment