Skip to content

Instantly share code, notes, and snippets.

import time
import tensorflow as tf
def generator_1():
while True:
time.sleep(0.1)
yield [1]
def generator_2():
while True:
import time
import tensorflow as tf
def generator_1():
while True:
time.sleep(1)
yield [1]
def generator_2():
while True:
@horoiwa
horoiwa / gen2.py
Last active November 9, 2022 14:06
import time
import tensorflow as tf
def generator():
while True:
time.sleep(1)
yield [1]
N_parallel = 10
output_signature=tf.TensorSpec(shape=(1,), dtype=tf.uint8)
@horoiwa
horoiwa / gen_1.py
Last active November 9, 2022 13:28
import time
import tensorflow as tf
# tf.__version__ == 2.4.4
def generator():
while True:
time.sleep(1)
yield [1]
dataset = tf.data.Dataset.from_generator(
states, actions, rewards, next_states, dones = minibatch
# TQ = reward + γ * max_a[Q(s, a)]
target_quantile_qvalues = self.make_target_distribution(rewards, next_states, dones)
with tf.GradientTape() as tape:
quantile_qvalues_all = self.qnetwork(states) # (B, A, N_ATOMS)
actions_onehot = tf.expand_dims(
tf.one_hot(actions, self.action_space, on_value=1., off_value=0.), axis=2) # (B, A, 1)
quantile_qvalues = tf.reduce_sum(
import tensorflow as tf
from tensorflow_probability import distributions as tfd
"""
MPO: M-step
- πをexp(Q/τ)に似るように更新する
- 制約項 λ×KL(π_old || π)
- λはKL制約違反が閾値εを超えたら大きくする
  i.e. (ε - KL) < 0 ならλを大きくする
"""
import tensorflow as tf
from tensorflow_probability import distributions as tfd
"""
MPO: E-step
- exp(Q(s, a)/η) の算出
- g(η)を最小化する方向に温度パラメータηの更新
"""
# バッチサイズ、状態sごとにサンプリングするアクション数
class DreamerV2Agent:
""" 省略 """
def rollout(self, training: bool):
env = gym.make(self.env_id)
#: グレスケ化 -> 解像度を(64, 64)へ
obs = self.preprocess_func(env.reset())
import tensorflow as tf
import tensorflow.keras.layers as kl
from tensorflow_probability import distributions as tfd
class WorldModel(tf.keras.Model):
def __init__(self, config):
super(WorldModel, self).__init__()
# つづき
---
apiVersion: v1
kind: ConfigMap
metadata:
name: worker-config
namespace: default
data:
conncheck.py : |
import time