Last active
December 27, 2019 16:41
-
-
Save edenau/8d338a1c4e631f5d80960fe695dd7ca4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class layer: | |
def __init__(self, layer_index, is_output, input_dim, output_dim, activation): | |
self.layer_index = layer_index # zero indicates input layer | |
self.is_output = is_output # true indicates output layer, false otherwise | |
self.input_dim = input_dim | |
self.output_dim = output_dim | |
self.activation = activation | |
# the multiplication constant is sorta arbitrary | |
if layer_index != 0: | |
self.W = np.random.randn(output_dim, input_dim) * np.sqrt(2/input_dim) | |
self.b = np.random.randn(output_dim, 1) * np.sqrt(2/input_dim) | |
# Change layers_dim to configure your own neural net! | |
layers_dim = [X_num_row, 4, 4, y_num_row] # input layer --- hidden layers --- output layers | |
neural_net = [] | |
# Construct the net layer by layer | |
for layer_index in range(len(layers_dim)): | |
if layer_index == 0: # if input layer | |
neural_net.append(layer(layer_index, False, 0, layers_dim[layer_index], 'irrelevant')) | |
elif layer_index+1 == len(layers_dim): # if output layer | |
neural_net.append(layer(layer_index, True, layers_dim[layer_index-1], layers_dim[layer_index], activation='linear')) | |
else: | |
neural_net.append(layer(layer_index, False, layers_dim[layer_index-1], layers_dim[layer_index], activation='relu')) | |
# Simple check on overfitting | |
pred_n_param = sum([(layers_dim[layer_index]+1)*layers_dim[layer_index+1] for layer_index in range(len(layers_dim)-1)]) | |
act_n_param = sum([neural_net[layer_index].W.size + neural_net[layer_index].b.size for layer_index in range(1,len(layers_dim))]) | |
print(f'Predicted number of hyperparameters: {pred_n_param}') | |
print(f'Actual number of hyperparameters: {act_n_param}') | |
print(f'Number of data: {X_num_col}') | |
if act_n_param >= X_num_col: | |
raise Exception('It will overfit.') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment