This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import tensorflow as tf | |
def generator_1(): | |
while True: | |
time.sleep(0.1) | |
yield [1] | |
def generator_2(): | |
while True: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import tensorflow as tf | |
def generator_1(): | |
while True: | |
time.sleep(1) | |
yield [1] | |
def generator_2(): | |
while True: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import tensorflow as tf | |
def generator(): | |
while True: | |
time.sleep(1) | |
yield [1] | |
N_parallel = 10 | |
output_signature=tf.TensorSpec(shape=(1,), dtype=tf.uint8) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import tensorflow as tf | |
# tf.__version__ == 2.4.4 | |
def generator(): | |
while True: | |
time.sleep(1) | |
yield [1] | |
dataset = tf.data.Dataset.from_generator( |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
states, actions, rewards, next_states, dones = minibatch | |
# TQ = reward + γ * max_a[Q(s, a)] | |
target_quantile_qvalues = self.make_target_distribution(rewards, next_states, dones) | |
with tf.GradientTape() as tape: | |
quantile_qvalues_all = self.qnetwork(states) # (B, A, N_ATOMS) | |
actions_onehot = tf.expand_dims( | |
tf.one_hot(actions, self.action_space, on_value=1., off_value=0.), axis=2) # (B, A, 1) | |
quantile_qvalues = tf.reduce_sum( |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow_probability import distributions as tfd | |
""" | |
MPO: M-step | |
- πをexp(Q/τ)に似るように更新する | |
- 制約項 λ×KL(π_old || π) | |
- λはKL制約違反が閾値εを超えたら大きくする | |
i.e. (ε - KL) < 0 ならλを大きくする | |
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow_probability import distributions as tfd | |
""" | |
MPO: E-step | |
- exp(Q(s, a)/η) の算出 | |
- g(η)を最小化する方向に温度パラメータηの更新 | |
""" | |
# バッチサイズ、状態sごとにサンプリングするアクション数 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DreamerV2Agent: | |
""" 省略 """ | |
def rollout(self, training: bool): | |
env = gym.make(self.env_id) | |
#: グレスケ化 -> 解像度を(64, 64)へ | |
obs = self.preprocess_func(env.reset()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import tensorflow.keras.layers as kl | |
from tensorflow_probability import distributions as tfd | |
class WorldModel(tf.keras.Model): | |
def __init__(self, config): | |
super(WorldModel, self).__init__() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# つづき | |
--- | |
apiVersion: v1 | |
kind: ConfigMap | |
metadata: | |
name: worker-config | |
namespace: default | |
data: | |
conncheck.py : | | |
import time |