This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// create world request initializes the environment | |
if( rqo.requestType == DmEnvRpc.V1.EnvironmentRequest.CreateWorldFieldNumber){ | |
// CreateWorld comes in with config settings | |
// in this case just the reset position of the agent | |
int startPosition = 10; | |
// unpack the Tensor format sent in from the grpc | |
if (rqo.unpackedTensorDict.ContainsKey("start_position") ){ | |
if(rqo.unpackedTensorDict["start_position"] != null && rqo.unpackedTensorDict["start_position"].Count > 0) | |
startPosition = rqo.unpackedTensorDict["start_position"][0]; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//requests and responses are a bidirection gRPC stream. | |
while (await requestStream.MoveNext()) | |
{ | |
// A request is received | |
EnvironmentRequest envRequest = requestStream.Current; | |
// add the request to the RequestQUeu | |
this.requestQueue.AddRequestQueueObject(envRequest); | |
// ask the session layer to do its work and generate a reponse to send to the client | |
var envResponseList = await this.agentSession.HandleEnvironmentRequest(); | |
foreach (DmEnvRpc.V1.EnvironmentResponse ero in envResponseList) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# create the container | |
container = docker.from_env().containers.run( image="basic_example", detach=True, ports={30051:30051}) | |
# connect to the container | |
connection = dm_env_rpc_connection.create_secure_channel_and_connect("localhost:30051") | |
# create the dm_env | |
env, _ = dm_env_adaptor.create_and_join_world( connection, create_world_settings={}, join_world_settings={}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class UnityEnvWrapper(gym.Env): | |
def __init__(self, env_config): | |
self.vector_index = env_config.vector_index | |
self.worker_index = env_config.worker_index | |
self.worker_id = env_config["unity_worker_id"] + env_config.worker_index | |
# Name of the Unity environment binary to launch | |
env_name = '/home/jim/projects/unity_ray/basic_env_linux/basic_env_linux' | |
self.env = UnityEnv(env_name, worker_id=self.worker_id, use_visual=False, multiagent=False, no_graphics=True) # | |
self.action_space = self.env.action_space | |
self.observation_space = self.env.observation_space |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def make_env(env_id, log_dir, rank): | |
def _init(): | |
env = UnityEnv(env_id, worker_id=rank, use_visual=False) | |
env = Monitor(env, log_dir, allow_early_resets=True) | |
return _init | |
env_id = "unity_ray/basic_env_linux/basic_env_linux" | |
num_env = 2 | |
worker_id = 9 | |
env = SubprocVecEnv([make_env(env_id, log_dir, i+worker_id) for i in range(num_env)]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# create unity env | |
from gym_unity.envs import UnityEnv | |
env_id = "unity_ray/basic_env_linux/basic_env_linux" | |
env = UnityEnv(env_id, worker_id=2, use_visual=False, no_graphics=False) | |
# run stable baselines | |
env = DummyVecEnv([lambda: env]) # The algorithms require a vectorized environment to run | |
model = PPO2(MlpPolicy, env, verbose=1) | |
model.learn(total_timesteps=10000) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
self.fc1_actor_ = tf.contrib.layers.fully_connected(tf.concat([self.state_, self.action_], axis=1), actor_hs_list[0], activation_fn=tf.nn.relu) #400 | |
self.fc2_actor_ = tf.contrib.layers.fully_connected(self.fc1_actor_, actor_hs_list[1], activation_fn=tf.nn.relu) | |
self.fc3_actor_ = tf.contrib.layers.fully_connected(self.fc2_actor_, action_dim, activation_fn=tf.nn.tanh) * 0.05 * max_action | |
self.actor_clip_ = tf.clip_by_value((self.fc3_actor_ + self.action_),-max_action, max_action) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#copies the mainQN values to the targetQN | |
#from Denny Britz's excellent RL repo | |
#https://github.com/dennybritz/reinforcement-learning/blob/master/DQN/Double%20DQN%20Solution.ipynb | |
def copy_model_parameters(sess, estimator1, estimator2): | |
""" | |
Copies the model parameters of one estimator to another. | |
Args: | |
sess: Tensorflow session instance |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#in graph | |
target_network_update_ops = trfl.periodic_target_update(targetQN.get_qnetwork_variables(),mainQN.get_qnetwork_variables(), | |
update_period=2000,tau=1.0) | |
#in session | |
with tf.Session() as sess: | |
#.... | |
for ep in range(1, train_episodes): | |
#... | |
#update target q network |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#in graph | |
target_network_update_ops = trfl.update_target_variables(targetQN.get_qnetwork_variables(), | |
mainQN.get_qnetwork_variables(),tau=1.0/2000) | |
#in session | |
with tf.Session() as sess: | |
#.... | |
for ep in range(1, train_episodes): | |
#... | |
#update target q network |
NewerOlder