Skip to content

Instantly share code, notes, and snippets.

@uenoku
Last active March 12, 2018 16:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save uenoku/553b88f56322f389a44694f8063809af to your computer and use it in GitHub Desktop.
Save uenoku/553b88f56322f389a44694f8063809af to your computer and use it in GitHub Desktop.
import numpy as np
import numpy.random as random
from numpy.testing import assert_allclose
from keras.models import Sequential, Model
from keras.layers import Input, Dense, Flatten, Concatenate
from rl.agents.dqn import DQNAgent
from rl.memory import SequentialMemory
from rl.processors import MultiInputProcessor
from rl.core import Env
class MultiInputTestEnv(Env):
def __init__(self, observation_shape):
self.observation_shape = observation_shape
def step(self, action):
return self._get_obs(), random.choice([0, 1]), random.choice([True, False]), {}
def reset(self):
return self._get_obs()
def _get_obs(self):
if type(self.observation_shape) is list:
return [np.random.random(s) for s in self.observation_shape]
else:
return np.random.random(self.observation_shape)
def __del__(self):
pass
def test_multi_dqn_input1():
input1 = Input(shape=(2, 15, 1))
input2 = Input(shape=(2, 3))
x1 = Dense(2)(input1)
x1 = Flatten()(x1)
x2 = Flatten()(input2)
x = Concatenate()([x1, x2])
x = Dense(2)(x)
model = Model(inputs=[input1, input2], outputs=x)
memory = SequentialMemory(limit=10, window_length=2)
processor = MultiInputProcessor(nb_inputs=2)
agent = DQNAgent(model, memory=memory, nb_actions=2, nb_steps_warmup=5, batch_size=4,
processor=processor )
agent.compile('sgd')
agent.fit(MultiInputTestEnv([(15,1), (3,)]), nb_steps=10000,verbose=2)
# => ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
def test_multi_dqn_input2():
input1 = Input(shape=(2, 2, 2))
input2 = Input(shape=(2, 3))
x1 = Dense(2)(input1)
x1 = Flatten()(x1)
x2 = Flatten()(input2)
x = Concatenate()([x1, x2])
x = Dense(2)(x)
model = Model(inputs=[input1, input2], outputs=x)
memory = SequentialMemory(limit=10, window_length=2)
processor = MultiInputProcessor(nb_inputs=2)
agent = DQNAgent(model, memory=memory, nb_actions=2, nb_steps_warmup=5, batch_size=4,
processor=processor )
agent.compile('sgd')
agent.fit(MultiInputTestEnv([(2,2), (3,)]), nb_steps=10000,verbose=2)
# => ValueError: operands could not be broadcast together with shapes (2,2) (3,)
if __name__ == '__main__':
test_multi_dqn_input1()
test_multi_dqn_input2()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment