Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def bootstrap(df, n, to_df=True): | |
""" | |
generate n bootstraped samples from a DataFrame | |
""" | |
assert isinstance(df, type(pd.DataFrame())),\ | |
f"Expected pandas.DataFrame, got type: {type(df)}" | |
sample = {column: np.random.choice(df[column], size=int(n)) for column in df.columns} # column: bootstrap sample | |
if to_df: sample = pd.DataFrame.from_dict(sample) # convert to DataFrame |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test_model(model, sample, k=5): | |
""" | |
get the mean time it takes to train the model over k-folds | |
""" | |
x = sample.copy() # set our predictors | |
y = x.pop('y') # set our response | |
mean_fit_time = np.mean(cross_validate(model, x, y, cv=k)['fit_time']) * 100 # get mean run time | |
return mean_fit_time |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def evaluate_cores(model, data, n_samples, stride=1000, printer=False): | |
""" | |
evaluate multicore model training time | |
""" | |
N = [n for n in range(stride, (n_samples+stride), stride)] # the number of samples we're using, stride of 1000 | |
cores = os.cpu_count() # get the number of cores | |
model_dict = {f"model_{n}": model(n_jobs=n) for n in range(1, cores+1)} # contruct model:instance pairs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def visualize(data, save=False): | |
""" | |
plot the results | |
""" | |
y = data.copy() # get y-values | |
x = y.pop('N') # get x-values | |
plt.plot(x, y) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gym | |
import gym_pull | |
gym_pull.pull('github.com/ppaquette/gym-super-mario') | |
env = gym.make('ppaquette/SuperMarioBros-1-1-v0') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
register( | |
id='SuperMarioBros-1-1-v0', | |
entry_point='gym.envs.ppaquette_gym_super_mario:MetaSuperMarioBrosEnv' | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
add_group( | |
id='ppaquette_gym_super_mario', | |
name='ppaquette_gym_super_mario', | |
description='super_mario' | |
) | |
# mario bros | |
add_task( | |
id='ppaquette/SuperMarioBros-1-1-v0', | |
group='ppaquette_gym_super_mario', |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Super Mario | |
register( | |
id='SuperMarioBros-1-1-v0', | |
entry_point='gym.envs.ppaquette_gym_super_mario:MetaSuperMarioBrosEnv', | |
) | |
register( | |
id='SuperMarioBros-1-2-v0', | |
entry_point='gym.envs.ppaquette_gym_super_mario:MetaSuperMarioBrosEnv', |
OlderNewer