This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def build_model(): | |
model = Sequential() | |
model.add(Conv2D(32, kernel_size=(8,8),strides=(4, 4), padding='same', | |
activation='relu', input_shape=(img_rows,img_cols,img_channels), | |
kernel_initializer=initializers.glorot_normal(seed=31))) | |
model.add(Conv2D(64,kernel_size=(4,4),strides=(2,2),padding='same',activation='relu', | |
kernel_initializer=initializers.glorot_normal(seed=31))) | |
model.add(Conv2D(64,kernel_size=(3,3),strides=(1,1),padding='same',activation='relu', | |
kernel_initializer=initializers.glorot_normal(seed=31))) | |
model.add(Flatten()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
states_batch, action_batch, reward_batch, next_states_batch, done_batch = map(np.array, zip(*minibatch)) | |
q_values_next = target_model.predict(next_states_batch,batch_size=BATCH) | |
targets = np.zeros((BATCH,ACTIONS)) #BATCHxACTIONS | |
targets[ti_tuple,action_batch] = reward_batch + done_batch * GAMMA * np.amax(q_values_next,axis=1) | |
loss += model.train_on_batch(states_batch, targets) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
prev_step_max = 0 | |
-- reward that allows backtracking, only rewarded by getting max trajectory | |
function reward_by_max_trajectory() | |
frame_count = frame_count + 1 | |
local level_done = calc_progress(data) | |
local temp_progress = calc_trajectory_progress(data) | |
--local reward = reward_by_ring(data) --optional, modify reward by ring behavior | |
local reward = 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
prev_max_distance_reward = 0 | |
--only get reward for going over prior max | |
--calculates reward based on where you are in terms of waypoints and progress to next waypoint | |
function reward_by_max_waypoint() | |
frame_count = frame_count + 1 | |
local level_done = calc_progress(data) | |
--local reward = reward_by_ring(data) | |
local reward = 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ray | |
import gym | |
from ray.rllib.agents import ppo, dqn | |
from ray.tune.registry import register_env | |
env_name = "multienv" | |
class MultiEnv(gym.Env): | |
def __init__(self, env_config): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MultiEnv(gym.Env): | |
def __init__(self, env_config): | |
# pick actual env based on worker and env indexes | |
#print("worker index is {}".format(env_config.worker_index)) | |
#print("testing vector_index {}".format(env_config.vector_index)) | |
#BustAMove.Challengeplay0 | |
challenge_level = env_config.worker_index % 5 | |
self.env = sonic_on_ray.make(game='BustAMove-Snes', state='BustAMove.Challengeplay{}'.format(challenge_level)) #BustAMove.1pplay.Level10 | |
self.action_space = self.env.action_space | |
self.observation_space = self.env.observation_space |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from ray.rllib.models.misc import get_activation_fn, flatten, AddCoords | |
... | |
def _build_layers(self, inputs, num_outputs, options): | |
if options.get('custom_options', {}).get('add_coordinates'): | |
with_r = False | |
if options.get('custom_options', {}).get('add_coords_with_r'): | |
with_r = True | |
addcoords = AddCoords(x_dim=int(np.shape(inputs)[1]), y_dim=int(np.shape(inputs)[1]),with_r=with_r) | |
inputs = addcoords(inputs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Critic(Model): | |
def __init__(self, name='critic', td3_variant=False, network='mlp', **network_kwargs): | |
super().__init__(name=name, network=network, **network_kwargs) | |
self.layer_norm = True | |
self.td3_variant = td3_variant | |
def __call__(self, obs, action, reuse=False): | |
with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE): | |
if self.td3_variant: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Critic(Model): | |
def __init__(self, name='critic', td3_variant=False, network='mlp', **network_kwargs): | |
super().__init__(name=name, network=network, **network_kwargs) | |
self.layer_norm = True | |
self.td3_variant = td3_variant | |
def __call__(self, obs, action, reuse=False): | |
with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE): | |
if self.td3_variant: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if self.td3_variant: | |
logger.info('using TD3 variant model') | |
self.normalized_critic_tf, self.normalized_critic_tf2 = critic(normalized_obs0, self.actions) | |
self.critic_tf = denormalize( | |
tf.clip_by_value(self.normalized_critic_tf, self.return_range[0], self.return_range[1]), self.ret_rms) | |
self.normalized_critic_with_actor_tf, _ = critic(normalized_obs0, self.actor_tf, reuse=True) | |
self.critic_with_actor_tf = denormalize( | |
tf.clip_by_value(self.normalized_critic_with_actor_tf, self.return_range[0], self.return_range[1]), | |
self.ret_rms) | |
out_q1, out_q2 = target_critic(normalized_obs1, target_actor(normalized_obs1)) |
OlderNewer