Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active June 30, 2017 15:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wassname/1d3b675e7aa6326fb4e0983c434e2d11 to your computer and use it in GitHub Desktop.
Save wassname/1d3b675e7aa6326fb4e0983c434e2d11 to your computer and use it in GitHub Desktop.
keras-rl's TrainIntervalLogger but using tqdm for jupyter notebook
from rl.callbacks import TrainIntervalLogger
from tqdm import tqdm_notebook
import timeit
class TrainIntervalLoggerTQDMNotebook(TrainIntervalLogger):
"""TrainIntervalLogger using tqdm_notebook for jupyter-notebook."""
def reset(self):
self.interval_start = timeit.default_timer()
self.metrics = []
self.infos = []
self.info_names = None
self.episode_rewards = []
def on_train_begin(self, logs):
self.progbar = tqdm_notebook(desc='', total=self.params['nb_steps'], leave=True)
self.train_start = timeit.default_timer()
self.metrics_names = self.model.metrics_names
print('Training for {} steps ...'.format(self.params['nb_steps']))
def on_step_end(self, step, logs):
if self.info_names is None:
self.info_names = logs['info'].keys()
values = [('reward', logs['reward'])]
self.progbar.desc = 'reward={reward: 2.6f}'.format(
reward=logs['reward'])
self.progbar.update(1) # update
self.step += 1
self.metrics.append(logs['metrics'])
if len(self.info_names) > 0:
self.infos.append([logs['info'][k] for k in self.info_names])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment