Skip to content

Instantly share code, notes, and snippets.

@Paulescu
Last active May 5, 2022 07:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Paulescu/9b0f36702d51df61271b167854f260e4 to your computer and use it in GitHub Desktop.
Save Paulescu/9b0f36702d51df61271b167854f260e4 to your computer and use it in GitHub Desktop.
# unique agent_id to identify this run in tensorboard
from src.utils import get_agent_id
agent_id = get_agent_id(ENV_NAME)
# tensorboard logger to see training curves
from src.utils import get_logger, get_model_path
logger = get_logger(env_name=ENV_NAME, agent_id=agent_id)
# path to save policy network weights and hyperparameters
model_path = get_model_path(env_name=ENV_NAME, agent_id=agent_id)
# let's train!
agent.train(
n_policy_updates=5000,
batch_size=256,
logger=logger,
model_path=model_path,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment