Skip to content

Instantly share code, notes, and snippets.

@joshreini1
Created October 7, 2022 20:32
Show Gist options
  • Save joshreini1/8177f2f223aea8ab651151577d4c7e2d to your computer and use it in GitHub Desktop.
Save joshreini1/8177f2f223aea8ab651151577d4c7e2d to your computer and use it in GitHub Desktop.
project_name = 'Fire_Party'
tru.set_environment('local')
tru.add_project(project_name, score_type='probits')
extra_data_columns = ['year']
train_split_name = 'train'
burned_fraction_th = 0.01
for window_size in range(1,11):
key = f'{window_size}year_window'
print(key)
tru.add_data_collection(key)
tru.add_data_split(train_split_name, data_train_x[key], label_data=(data_train_y[key]>burned_fraction_th), extra_data_df=data_train[key][extra_data_columns])
for year in range(2001,2017):
print(year)
data_tmp = data_yearly[key][data_yearly[key]['year']==year]
x = data_tmp.drop(columns=['year', 'burned_fraction'])
y = data_tmp['burned_fraction']
extra = data_tmp[extra_data_columns]
tru.add_data_split(f'{year}', x, label_data=(y>burned_fraction_th), extra_data_df=extra)
for model_name in models:
if key not in model_name:
continue
print(model_name)
model_params = {
'model_type': type(models[model_name]).__name__,
}
model_prefix = model_name.split('_')[0]
if model_prefix == 'linear':
model_params['max_iter'] = 1000
model_params['solver'] = 'saga'
elif model_prefix == 'gb':
model_params['n_estimators'] = int(model_name.split('_')[1])
elif model_prefix == 'SVC':
model_params['c_weight'] = {0:1.0, 1:int(test.split('SVC')[1].split('_')[0])}
tru.add_python_model(model_name, models[model_name], train_split_name=train_split_name, train_parameters=model_params)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment