Skip to content

Instantly share code, notes, and snippets.

@data-hound
Last active January 22, 2021 17:55
Show Gist options
  • Save data-hound/b94ee5b9157ae547be72a8b722090524 to your computer and use it in GitHub Desktop.
Save data-hound/b94ee5b9157ae547be72a8b722090524 to your computer and use it in GitHub Desktop.
Scikeras Tutorial - 5: Wrapping the MIMO Estimator and giving a CV run
# Helper Method to get data into shape to pass to GridSearchClassifier
def get_sciki_xy(X,y):
X_sciki = np.column_stack([X.reshape((y.shape[0], np.prod(X.shape[1:]))), y])
y_sciki = np.column_stack([y,X.reshape((y.shape[0], np.prod(X.shape[1:])))])
return X_sciki,y_sciki
def do_cross_val():
(x_train_, y_train_), (x_test_, y_test_) = load_mnist() #load the dataset
# Optional - trim the data size for faster epochs
x_train, y_train = get_sciki_xy(x_train_[:1000], y_train_[:1000])
x_test, y_test = get_sciki_xy(x_test_[:100],y_test_[:100])
# Create a MIMOEstimator with get_model function
# Parameters that need to be passed to get_model function are prefixed with model__
clf = MIMOEstimator(model = get_model,
model__input_shape=x_train_.shape[1:],
model__n_class=len(np.unique(np.argmax(y_train_, 1))),
model__routings=args.routings,
model__batch_size = args.batch_size,
model__n_filters_c1=256,
# epochs=args.epochs,
# callbacks=[log, checkpoint, lr_decay],
model__model_type = 'train')
# Print the shapes of X and Y
print("X input shape = ", x_train.shape)
print("Y input shape = ", y_train.shape)
# Define the parameter grid to perform Grid-Search
params = {'model__n_filters_c1': [128,256],
'model__routings': [4,5]}
# no. of examples/cv should be completely divisible by batch_size
gs = GridSearchCV(estimator=clf, param_grid=params, cv=5 verbose=True)
gs_res = gs.fit(X=x_train,
y=y_train)
print("Grid Search Results: ")
print(gs_res)
best_est = gs_res.best_estimator_
best_score = gs_res.best_score_
best_params = gs_res.best_params_
print('Best score obtained after GridSearchCV: ', best_score)
return best_est, best_params
est,params = do_cross_val() # call the function to begin training
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment