-
-
Save Furffico/7ce3f2ef1dc0bc42536d2a178c5c5a92 to your computer and use it in GitHub Desktop.
Simple example of Deep Q-Learning for CartPole-v1
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import numpy as np | |
import gymnasium as gym | |
import torch | |
import torch.nn as nn | |
from tqdm import tqdm | |
from typing import NamedTuple, Union | |
n_actions = 2 | |
n_states = 4 | |
lr = 3e-4 | |
discount = 0.95 | |
batch_size = 128 | |
epochs = 5000 | |
class Net(nn.Sequential): | |
norm_vector = torch.tensor([0.5, 1.0, 0.21, 0.5]) | |
def __init__(self, in_feats = n_states, out_feats = n_actions, hidden = 32): | |
super().__init__( | |
nn.Linear(in_feats, hidden), | |
nn.SiLU(), | |
nn.Linear(hidden, hidden), | |
nn.SiLU(), | |
nn.Linear(hidden, out_feats), | |
) | |
def forward(self, state): | |
x = state/self.norm_vector # 归一化 | |
y = super().forward(x) | |
return y | |
class Experience(NamedTuple): | |
'''experience四元组''' | |
state: np.ndarray | |
action: int | |
reward: float | |
next_state: np.ndarray | |
done: bool | |
class Memory(object): | |
'''存储固定数量记录的队列''' | |
def __init__(self, buffer_size: int): | |
self.buffer_size = buffer_size | |
self.buffer: list[Union[Experience, None]] | |
= [None for _ in range(self.buffer_size)] | |
self.count = 0 | |
def append(self, exp: Experience): | |
'''增加记录,如果buffer已满则替换最早的记录''' | |
self.buffer[self.count%self.buffer_size] = exp | |
self.count += 1 | |
def sample(self, k: int): | |
'''随机选取k个experience,打包好返回''' | |
if self.count < self.buffer_size: | |
pool = self.buffer[:self.count] | |
else: | |
pool = self.buffer | |
exp: list[Experience] = random.choices(pool, k=k) # type: ignore | |
# 打包成 Tensor | |
states = torch.from_numpy(np.array([e.state for e in exp])) | |
actions = torch.tensor([e.action for e in exp]) | |
rewards = torch.tensor([e.reward for e in exp]) | |
next_states = torch.from_numpy(np.array([e.next_state for e in exp])) | |
dones = torch.tensor([e.done for e in exp], dtype=torch.float) | |
return states, actions, rewards, next_states, dones | |
memory = Memory(batch_size*10) | |
net = Net() | |
optimizer = torch.optim.AdamW(net.parameters(), lr=lr, amsgrad=True) | |
criterion = nn.SmoothL1Loss() # 发现L1的效果比L2要好 | |
env = gym.make("CartPole-v1") | |
state, info = env.reset() | |
batch_index = torch.arange(batch_size) | |
for t in tqdm(range(epochs)): | |
# 采样 ========================================= | |
epsilon = 1 - t / epochs # 动态调整epsilon | |
for _ in range(batch_size//4 if t >= 1 else batch_size): | |
if random.random() < epsilon: # exploration | |
action = random.randint(0, n_actions-1) | |
else: | |
action = net(state).squeeze().argmax().item() | |
org_state = state | |
state, reward, terminated, truncated, info = env.step(action) | |
# 使用自定义的reward | |
reward = -20.0 if terminated else 0 | |
if abs(state[2]) > 0.1: # 限制倾角 | |
reward += -1.0 | |
if abs(state[0]) > 0.3: # 限制位置 | |
reward += -2.0 | |
if reward >= 0: | |
reward = 1.0 | |
# 加入记忆 | |
memory.append(Experience(org_state, action, reward, state, terminated)) | |
if terminated or truncated: | |
state, info = env.reset() | |
# 训练 ========================================= | |
states, actions, rewards, next_states, dones = memory.sample(batch_size) | |
# 前向传播 | |
pred_q = net(states)[batch_index, actions] | |
target_q = ((1-dones)*net(next_states).max(dim=-1).values * discount + rewards).detach() | |
loss = criterion(pred_q, target_q) | |
# 反向传播 | |
optimizer.zero_grad() | |
loss.backward() | |
torch.nn.utils.clip_grad_value_(net.parameters(), 1) | |
optimizer.step() | |
env.close() | |
# 保存checkpoint | |
torch.save(net.state_dict(), "cartpole-replay.ckpt") | |
env = gym.make("CartPole-v1", render_mode="human") | |
state, info = env.reset() | |
with torch.no_grad(): | |
for t in range(2000): | |
row = net(state).squeeze() | |
action = row.argmax().item() | |
print(row) | |
state, reward, terminated, truncated, info = env.step(action) | |
if terminated or truncated: | |
state, info = env.reset() | |
env.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from random import random, randint | |
import gymnasium as gym | |
import torch | |
import torch.nn as nn | |
from tqdm import tqdm | |
n_actions = 2 | |
n_states = 4 | |
lr = 3e-4 | |
discount = 0.95 | |
batch_size = 128 | |
epochs = 5000 | |
class Net(nn.Sequential): | |
norm_vector = torch.tensor([0.5, 1.0, 0.21, 0.5]) | |
def __init__(self, in_feats = n_states, out_feats = n_actions, hidden = 32): | |
super().__init__( | |
nn.Linear(in_feats, hidden), | |
nn.SiLU(), | |
nn.Linear(hidden, hidden), | |
nn.SiLU(), | |
nn.Linear(hidden, out_feats), | |
) | |
def forward(self, state): | |
x = state/self.norm_vector # 归一化 | |
y = super().forward(x) | |
return y | |
net = Net() | |
optimizer = torch.optim.AdamW(net.parameters(), lr=lr, amsgrad=True) | |
criterion = nn.SmoothL1Loss() # 发现L1的效果比L2要好 | |
env = gym.make("CartPole-v1") | |
state, info = env.reset() | |
for t in tqdm(range(epochs)): | |
# 前向传播 | |
loss = 0.0 | |
epsilon = 1 - t / epochs # 动态调整epsilon | |
for _ in range(batch_size): | |
# 选择action ========================= | |
row = net(state) | |
if random() < epsilon: # exploration | |
action = randint(0, n_actions-1) | |
else: | |
action = row.squeeze().argmax().item() | |
# 执行action ========================= | |
state, reward, terminated, truncated, info = env.step(action) | |
# 使用自定义的reward | |
reward = -20.0 if terminated else 0 | |
if abs(state[2]) > 0.1: # 限制倾角 | |
reward += -1.0 | |
if abs(state[0]) > 0.3: # 限制位置 | |
reward += -2.0 | |
if reward >= 0: | |
reward = 1.0 | |
# 计算loss ========================= | |
with torch.no_grad(): | |
if terminated: | |
curr_q = torch.tensor(reward) | |
else: | |
curr_q = net(state).max() * discount + reward | |
loss += criterion(row[action], curr_q) | |
if terminated or truncated: | |
state, info = env.reset() | |
# 反向传播 | |
optimizer.zero_grad() | |
(loss/batch_size).backward() | |
torch.nn.utils.clip_grad_value_(net.parameters(), 1) | |
optimizer.step() | |
env.close() | |
# 保存checkpoint | |
torch.save(net.state_dict(), "cartpole.ckpt") | |
env = gym.make("CartPole-v1", render_mode="human") | |
state, info = env.reset() | |
with torch.no_grad(): | |
for t in range(2000): | |
row = net(state).squeeze() | |
action = row.argmax().item() | |
print(row) | |
state, reward, terminated, truncated, info = env.step(action) | |
if terminated or truncated: | |
state, info = env.reset() | |
env.close() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment