Created
October 4, 2017 23:43
-
-
Save d0p3t/c5b5e9c5b3e7ec8db7b39613b3ea1960 to your computer and use it in GitHub Desktop.
Version 1 DDQN Cloney
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ============================= | |
# -----------DQN TODO --------- | |
# ============================= | |
def setup_play_ddqn(self): | |
self.plugin_path = offshoot.config["file_paths"]["plugins"] | |
context_classifier_path = f"{self.plugin_path}/SerpentCloneyGameAgentPlugin/files/ml_models/cloney_context_classifier.model" | |
context_classifier = CNNInceptionV3ContextClassifier(input_shape=(288, 512, 3)) | |
context_classifier.prepare_generators() | |
context_classifier.load_classifier(context_classifier_path) | |
self.machine_learning_models["context_classifier"] = context_classifier | |
self._reset_game_state() | |
input_mapping = { | |
"W": [KeyboardKey.KEY_W], | |
"S": [KeyboardKey.KEY_S] | |
} | |
# "A": [self.input_controller.tap_key(KeyboardKey.KEY_A)], | |
# "S": [self.input_controller.tap_key(KeyboardKey.KEY_S)], | |
# "D": [self.input_controller.tap_key(KeyboardKey.KEY_D)], | |
# "ENTER": [self.input_controller.tap_key(KeyboardKey.KEY_ENTER)], | |
# "L_CLICK": [self.input_controller.tap_key(MouseButton.LEFT)] | |
action_space = KeyboardMouseActionSpace( | |
default_keys=[None, "W", "S"] | |
) | |
# model_file_path = '' | |
# model_file_path if os.path.isfile(model_file_path) else | |
self.dqn_movement = DDQN( | |
model_file_path=None, | |
input_shape=(72, 128, 4), | |
input_mapping=input_mapping, | |
action_space=action_space, | |
replay_memory_size=5000, | |
max_steps=100000, | |
observe_steps=500, | |
batch_size=32, | |
initial_epsilon=0.1, | |
final_epsilon=0.0001, | |
override_epsilon=False | |
) | |
def handle_play_ddqn(self, game_frame): | |
gc.disable() | |
context = self.machine_learning_models["context_classifier"].predict(game_frame.frame) | |
if self.dqn_movement.first_run: | |
if context == "GAME_OVER": | |
#self.input_controller.tap_key(KeyboardKey.KEY_ENTER) | |
self.input_controller.click_screen_region(screen_region="GAME_OVER_PLAY") | |
self.dqn_movement.first_run = False | |
return None | |
if context == "GAME_WORLD_1": | |
self.game_state['alive'] = True | |
else: | |
self.game_state['alive'] = False | |
if self.dqn_movement.frame_stack is None: | |
self.dqn_movement.build_frame_stack(game_frame.eighth_resolution_grayscale_frame) | |
else: | |
if self.dqn_movement.mode == "TRAIN": | |
reward = self._calculate_dragon_train_reward(context=context) | |
self.game_state["run_reward"] += reward | |
self.dqn_movement.append_to_replay_memory( | |
self.game_frame_buffer, | |
reward, | |
terminal=self.game_state['alive'] == 0 | |
) | |
# Every 2000 steps, save latest weights to disk | |
if self.dqn_movement.current_step % 2000 == 0: | |
self.dqn_movement.save_model_weights( | |
file_path_prefix=f"datasets/cloney_weights_movement" | |
) | |
# Every 20000 steps, save weights checkpoint to disk | |
if self.dqn_movement.current_step % 20000 == 0: | |
self.dqn_movement.save_model_weights( | |
file_path_prefix=f"datasets/cloney_weights_movement", | |
is_checkpoint=True | |
) | |
elif self.dqn_movement.mode == "RUN": | |
self.dqn_movement.update_frame_stack(self.game_frame_buffer) | |
if self.game_state['alive'] is False: | |
# print("\033c") | |
# timestamp = datetime.utcnow() | |
gc.enable() | |
gc.collect() | |
gc.disable() | |
# Set display stuff TODO | |
# timestamp_delta = timestamp - self.game_state["run_timestamp"] | |
# self.game_state["last_run_duration"] = timestamp_delta.seconds | |
# if self.dqn_movement.mode in ["TRAIN", "RUN"]: | |
# # Check for Records | |
# if self.game_state["last_run_duration"] > self.game_state["record_time_alive"].get("value", 0): | |
# self.game_state["record_time_alive"] = { | |
# "value": self.game_state["last_run_duration"], | |
# "run": self.game_state["current_run"], | |
# "predicted": self.dqn_movement.mode == "RUN", | |
# "boss_hp": self.game_state["boss_health"][0] | |
# } | |
# | |
# if self.game_state["boss_health"][0] < self.game_state["record_boss_hp"].get("value", 1000): | |
# self.game_state["record_boss_hp"] = { | |
# "value": self.game_state["boss_health"][0], | |
# "run": self.game_state["current_run"], | |
# "predicted": self.dqn_movement.mode == "RUN", | |
# "time_alive": self.game_state["last_run_duration"] | |
# } | |
# else: | |
# self.game_state["random_time_alives"].append(self.game_state["last_run_duration"]) | |
# self.game_state["random_boss_hps"].append(self.game_state["boss_health"][0]) | |
# | |
# self.game_state["random_time_alive"] = np.mean(self.game_state["random_time_alives"]) | |
# self.game_state["random_boss_hp"] = np.mean(self.game_state["random_boss_hps"]) | |
# Compute APS | |
self.game_state["average_aps"] = self.game_state["current_run_steps"] / self.game_state["last_run_duration"] | |
self.game_state["current_run_steps"] = 0 | |
# self.input_controller.release_key(KeyboardKey.KEY_W) | |
#self.input_controller.tap_key(KeyboardKey.KEY_ENTER, duration=3.5) | |
self.input_controller.click_screen_region(screen_region="GAME_OVER_PLAY") | |
if self.dqn_movement.mode == "TRAIN": | |
for i in range(32): | |
print("\033c") | |
print(f"TRAINING ON MINI-BATCHES: {i + 1}/32") | |
print(f"NEXT RUN: {self.game_state['current_run'] + 1} {'- AI RUN' if (self.game_state['current_run'] + 1) % 20 == 0 else ''}") | |
self.dqn_movement.train_on_mini_batch() | |
self.game_state["run_timestamp"] = datetime.utcnow() | |
self.game_state["current_run"] += 1 | |
self.game_state["run_reward_movement"] = 0 | |
self.game_state["run_predicted_actions"] = 0 | |
self.game_state["alive"] = True | |
if self.dqn_movement.mode in ["TRAIN", "RUN"]: | |
if self.game_state["current_run"] > 0 and self.game_state["current_run"] % 100 == 0: | |
if self.dqn_movement.type == "DDQN": | |
self.dqn_movement.update_target_model() | |
if self.game_state["current_run"] > 0 and self.game_state["current_run"] % 20 == 0: | |
self.dqn_movement.enter_run_mode() | |
else: | |
self.dqn_movement.enter_train_mode() | |
self.input_controller.tap_key(KeyboardKey.KEY_ENTER) | |
return None | |
self.dqn_movement.pick_action() | |
self.dqn_movement.generate_action() | |
self.input_controller.handle_keys(self.dqn_movement.get_input_values()) | |
if self.dqn_movement.current_action_type == "PREDICTED": | |
self.game_state["run_predicted_actions"] += 1 | |
self.dqn_movement.erode_epsilon(factor=1) | |
self.dqn_movement.next_step() | |
self.game_state["current_run_steps"] += 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment