Skip to content

Instantly share code, notes, and snippets.

@ppwwyyxx
Last active June 23, 2017 00:03
Show Gist options
  • Save ppwwyyxx/62d5723dea411a515ee1a52b1a87a637 to your computer and use it in GitHub Desktop.
Save ppwwyyxx/62d5723dea411a515ee1a52b1a87a637 to your computer and use it in GitHub Desktop.
Patch f1 (https://github.com/ebonyclock/vizdoom_cig2017) to work under latest TensorFlow + latest tensorpack
diff --git i/f1/F1_track1/agent.py w/f1/F1_track1/agent.py
index a717842..4ae8405 100644
--- i/f1/F1_track1/agent.py
+++ w/f1/F1_track1/agent.py
@@ -140,15 +140,15 @@ class FinalEnv(RLEnvironment):
class Model(ModelDesc):
- def _get_input_vars(self):
- return [InputVar(tf.float32, (None,) + IMAGE_SHAPE3, 'image'),
- InputVar(tf.float32, (None, 5), 'vars')]
+ def _get_inputs(self):
+ return [InputDesc(tf.float32, (None,) + IMAGE_SHAPE3, 'image'),
+ InputDesc(tf.float32, (None, 5), 'vars')]
- def _get_NN_prediction(self, state, is_training):
+ def _get_NN_prediction(self, state):
""" image: [0,255]"""
image, vars = state
image = image / 255.0
- with argscope(Conv2D, nl=PReLU.f):
+ with argscope(Conv2D, nl=PReLU.symbolic_function):
l = (LinearWrap(image)
.Conv2D('conv0', out_channel=32, kernel_shape=7, stride=2)
# 60
@@ -160,16 +160,16 @@ class Model(ModelDesc):
# 7
.Conv2D('conv4', out_channel=192, kernel_shape=3, padding='VALID')
# 5
- .FullyConnected('fcimage', 1024, nl=PReLU.f)())
+ .FullyConnected('fcimage', 1024, nl=PReLU.symbolic_function)())
vars = tf.tile(vars, [1, 10], name='tiled_vars')
- feat = tf.concat(1, [l, vars])
+ feat = tf.concat([l, vars], 1)
policy = FullyConnected('fc-pi-m', feat, out_dim=NUM_ACTIONS, nl=tf.identity) * 0.1
return policy
- def _build_graph(self, inputs, is_training):
- policy = self._get_NN_prediction(inputs, is_training)
+ def _build_graph(self, inputs):
+ policy = self._get_NN_prediction(inputs)
self.logits = tf.nn.softmax(policy, name='logits')
class Model2(ModelDesc):
@@ -213,9 +213,9 @@ class Runner(object):
cfg = PredictConfig(
model=Model(),
session_init=SaverRestore('model.tfmodel'),
- input_var_names=['image', 'vars'],
- output_var_names=['logits'])
- self._pred_func = get_predict_func(cfg)
+ input_names=['image', 'vars'],
+ output_names=['logits'])
+ self._pred_func = OfflinePredictor(cfg)
def action_func(self, inputs):
f = self._pred_func
diff --git i/f1/F1_track1/history.py w/f1/F1_track1/history.py
index 5049412..511cef6 100644
--- i/f1/F1_track1/history.py
+++ w/f1/F1_track1/history.py
@@ -11,12 +11,12 @@ __all__ = ['HistoryPlayerWithVar']
class HistoryPlayerWithVar(HistoryFramePlayer):
def current_state(self):
assert len(self.history) != 0
- assert len(self.history[0]) == 2, "state needs to be like [img, vars]"
- diff_len = self.history.maxlen - len(self.history)
- zeros = [np.zeros_like(self.history[0][0]) for k in range(diff_len)]
- for k in self.history:
+ assert len(self.history.buf[0]) == 2, "state needs to be a 2-list like [img, vars]"
+ diff_len = self.history.buf.maxlen - len(self.history)
+ zeros = [np.zeros_like(self.history.buf[0][0]) for k in range(diff_len)]
+ for k in self.history.buf:
zeros.append(k[0])
img = np.concatenate(zeros, axis=2)
- gvar = self.history[-1][1]
+ gvar = self.history.buf[-1][1]
return img, gvar
@ppwwyyxx
Copy link
Author

Thanks. The diff is updated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment