Skip to content

Instantly share code, notes, and snippets.

@Tjorriemorrie
Last active July 28, 2017 03:06
Show Gist options
  • Save Tjorriemorrie/40980a4f6c7a96f790a2184cc1714580 to your computer and use it in GitHub Desktop.
Save Tjorriemorrie/40980a4f6c7a96f790a2184cc1714580 to your computer and use it in GitHub Desktop.
tensorforce
import os
from collections import deque
import numpy as np
from tensorforce import Configuration
from tensorforce.agents.random_agent import RandomAgent
from tensorforce.agents import TRPOAgent
from tensorforce.core.networks import layered_network_builder
state = np.array([0] * 9, dtype='float')
actions = [
('foo',),
('bar',),
('baz',),
('bez',),
('boz',),
]
# random_config = Configuration(
# states=dict(shape=state.shape, type=state.dtype),
# actions=dict(continuous=False, num_actions=len(actions)),
# )
# random_agent = RandomAgent(random_config)
# a = random_agent.act(state)
# print(f'random action {a}')
# random_agent.observe(a, True)
trpo_config = Configuration(
states=dict(shape=state.shape, type=state.dtype),
actions=dict(continuous=False, num_actions=len(actions)),
network=layered_network_builder([]),
)
trpo_agent = TRPOAgent(trpo_config)
model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'trpo')
model_path = os.path.join(model_dir, 'model')
print(model_path)
if os.path.isdir(model_dir):
trpo_agent.load_model(model_path)
rews = deque(maxlen=10)
for _ in range(10):
a = trpo_agent.act(state)
trpo_agent.observe(reward=abs(3 - a), terminal=True)
rews.append(a)
print('past reward {}'.format(sum(rews)))
trpo_agent.save_model(model_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment