Skip to content

Instantly share code, notes, and snippets.

@heiner
Created February 14, 2020 11:51
Show Gist options
  • Save heiner/4306ff6ba6e7b163f4c4e56fd186bbe9 to your computer and use it in GitHub Desktop.
Save heiner/4306ff6ba6e7b163f4c4e56fd186bbe9 to your computer and use it in GitHub Desktop.
import timeit
import torch
from torch import nn
from torch.nn import functional as F
class Flags:
pass
FLAGS = Flags()
FLAGS.use_lstm = False # True or false.
FLAGS.model = "resnet" # One of ["test", "shallow", "resnet"].
FLAGS.device = "cpu"
FLAGS.batch_size = 4
OBSERVATION_SHAPE = [4, 84, 84]
def nest_map(f, n):
if isinstance(n, tuple) or isinstance(n, list):
return n.__class__(nest_map(f, sn) for sn in n)
elif isinstance(n, dict):
return {k: nest_map(f, v) for k, v in n.items()}
else:
return f(n)
class TestModel(nn.Module):
def __init__(self, num_actions, core_output_size=256):
super(TestModel, self).__init__()
if FLAGS.use_lstm:
raise ValueError("Test model cannot use LSTM.")
self.num_actions = num_actions
self.linear = nn.Linear(4 * 84 * 84, core_output_size)
self.policy = nn.Linear(core_output_size, num_actions)
self.baseline = nn.Linear(core_output_size, 1)
def initial_state(self, batch_size=1):
return tuple()
def forward(self, last_actions, env_outputs, core_state, unroll=False):
if not unroll:
# [T=1, B, ...].
env_outputs = nest_map(lambda t: t.unsqueeze(0), env_outputs)
observation, reward, done = env_outputs
T, B, *_ = observation.shape
x = observation.reshape(T * B, -1)
x = x.float() / 255.0
core_output = self.linear(x)
policy_logits = self.policy(core_output)
baseline = self.baseline(core_output)
action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1)
policy_logits = policy_logits.view(T, B, self.num_actions)
baseline = baseline.view(T, B)
action = action.view(T, B)
outputs = dict(action=action, policy_logits=policy_logits, baseline=baseline)
if not unroll:
for t in outputs.values():
t.squeeze_(0)
return outputs, core_state
class ShallowModel(nn.Module):
def __init__(self, num_actions, core_output_size=256):
super(ShallowModel, self).__init__()
self.num_actions = num_actions
self.use_lstm = FLAGS.use_lstm
self.conv1 = nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4)
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
# Fully connected layer.
self.fc = nn.Linear(3136, 512)
# FC output size + one-hot of last action + last reward.
core_output_size = self.fc.out_features + num_actions + 1
if self.use_lstm:
self.core = nn.LSTMCell(core_output_size, 256)
core_output_size = 256
self.policy = nn.Linear(core_output_size, self.num_actions)
self.baseline = nn.Linear(core_output_size, 1)
def initial_state(self, batch_size=1):
if not self.use_lstm:
return tuple()
return tuple(torch.zeros(batch_size, self.core.hidden_size) for _ in range(2))
def forward(self, last_actions, env_outputs, core_state, unroll=False):
if not unroll:
# [T=1, B, ...].
env_outputs = nest_map(lambda t: t.unsqueeze(0), env_outputs)
observation, reward, done = env_outputs
T, B = reward.shape
x = torch.flatten(observation, 0, 1) # Merge time and batch.
x = x.float() / 255.0
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = x.view(T * B, -1)
x = F.relu(self.fc(x))
one_hot_last_action = F.one_hot(
last_actions.view(T * B), self.num_actions
).float()
reward = reward.view(T * B, 1).float()
core_input = torch.cat([x, reward, one_hot_last_action], dim=1)
if self.use_lstm:
core_input = core_input.view(T, B, -1)
core_output_list = []
notdone = (~done).float()
notdone.unsqueeze_(-1) # [T, B, H=1] for broadcasting.
for input_t, notdone_t in zip(core_input.unbind(), notdone.unbind()):
core_state = nest_map(notdone_t.mul, core_state)
output_t, core_state = self.core(input_t, core_state)
core_state = (output_t, core_state) # nn.LSTMCell is a bit weird.
core_output_list.append(output_t) # [[B, H], [B, H], ...].
core_output = torch.cat(core_output_list) # [T * B, H].
else:
core_output = core_input
policy_logits = self.policy(core_output)
baseline = self.baseline(core_output)
if self.training:
action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1)
else:
# Don't sample when testing.
action = torch.argmax(policy_logits, dim=1)
policy_logits = policy_logits.view(T, B, self.num_actions)
baseline = baseline.view(T, B)
action = action.view(T, B)
outputs = {
"action": action,
"policy_logits": policy_logits,
"baseline": baseline,
}
if not unroll:
for t in outputs.values():
t.squeeze_(0)
return outputs, core_state
class Model(nn.Module):
def __init__(self, num_actions, use_lstm=None):
super(Model, self).__init__()
self.num_actions = num_actions
self.use_lstm = use_lstm
if use_lstm is None:
self.use_lstm = FLAGS.use_lstm
self.feat_convs = []
self.resnet1 = []
self.resnet2 = []
self.convs = []
input_channels = 4
for num_ch in [16, 32, 32]:
feats_convs = []
feats_convs.append(
nn.Conv2d(
in_channels=input_channels,
out_channels=num_ch,
kernel_size=3,
stride=1,
padding=1,
)
)
feats_convs.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
self.feat_convs.append(nn.Sequential(*feats_convs))
input_channels = num_ch
for i in range(2):
resnet_block = []
resnet_block.append(nn.ReLU())
resnet_block.append(
nn.Conv2d(
in_channels=input_channels,
out_channels=num_ch,
kernel_size=3,
stride=1,
padding=1,
)
)
resnet_block.append(nn.ReLU())
resnet_block.append(
nn.Conv2d(
in_channels=input_channels,
out_channels=num_ch,
kernel_size=3,
stride=1,
padding=1,
)
)
if i == 0:
self.resnet1.append(nn.Sequential(*resnet_block))
else:
self.resnet2.append(nn.Sequential(*resnet_block))
self.feat_convs = nn.ModuleList(self.feat_convs)
self.resnet1 = nn.ModuleList(self.resnet1)
self.resnet2 = nn.ModuleList(self.resnet2)
self.fc = nn.Linear(3872, 256)
# FC output size + last reward + one-hot of last action.
core_output_size = self.fc.out_features + 1 + num_actions
if self.use_lstm:
self.core = nn.LSTMCell(core_output_size, 256)
core_output_size = 256
self.policy = nn.Linear(core_output_size, self.num_actions)
self.baseline = nn.Linear(core_output_size, 1)
def initial_state(self, batch_size=1):
if not self.use_lstm:
return tuple()
return tuple(torch.zeros(batch_size, self.core.hidden_size) for _ in range(2))
def forward(self, last_actions, env_outputs, core_state, unroll=False):
if not unroll:
# [T=1, B, ...].
env_outputs = nest_map(lambda t: t.unsqueeze(0), env_outputs)
observation, reward, done = env_outputs
T, B, *_ = observation.shape
x = torch.flatten(observation, 0, 1) # Merge time and batch.
x = x.float() / 255.0
res_input = None
for i, fconv in enumerate(self.feat_convs):
x = fconv(x)
res_input = x
x = self.resnet1[i](x)
x += res_input
res_input = x
x = self.resnet2[i](x)
x += res_input
x = F.relu(x)
x = x.view(T * B, -1)
x = F.relu(self.fc(x))
one_hot_last_action = F.one_hot(
last_actions.view(T * B), self.num_actions
).float()
reward = reward.view(T * B, 1).float()
core_input = torch.cat([x, reward, one_hot_last_action], dim=1)
if self.use_lstm:
core_input = core_input.view(T, B, -1)
core_output_list = []
notdone = (~done).float()
notdone.unsqueeze_(-1) # [T, B, H=1] for broadcasting.
for input_t, notdone_t in zip(core_input.unbind(), notdone.unbind()):
core_state = nest_map(notdone_t.mul, core_state)
output_t, core_state = self.core(input_t, core_state)
core_state = (output_t, core_state) # nn.LSTMCell is a bit weird.
core_output_list.append(output_t) # [[B, H], [B, H], ...].
core_output = torch.cat(core_output_list) # [T * B, H].
else:
core_output = core_input
policy_logits = self.policy(core_output)
baseline = self.baseline(core_output)
if self.training:
action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1)
else:
# Don't sample when testing.
action = torch.argmax(policy_logits, dim=1)
policy_logits = policy_logits.view(T, B, self.num_actions)
baseline = baseline.view(T, B)
action = action.view(T, B)
outputs = dict(action=action, policy_logits=policy_logits, baseline=baseline)
if not unroll:
outputs = nest_map(lambda t: t.squeeze(0), outputs)
return outputs, core_state
def create_model(num_actions):
if FLAGS.model == "test":
return TestModel(num_actions)
if FLAGS.model == "shallow":
return ShallowModel(num_actions)
return Model(num_actions)
def main():
num_actions = 6
model = create_model(num_actions)
model = model.to(device=FLAGS.device)
dummy_env_output = (
torch.empty([1] + OBSERVATION_SHAPE, dtype=torch.uint8),
torch.zeros(1, dtype=torch.float64),
torch.zeros(1, dtype=torch.bool),
)
# T = FLAGS.unroll_length
B = FLAGS.batch_size
dummy_model_input = dict(
last_actions=torch.zeros([1], dtype=torch.int64),
env_outputs=dummy_env_output,
core_state=model.initial_state(1),
)
dummy_model_input = nest_map(lambda t: t.to(FLAGS.device), dummy_model_input)
initial_agent_state = nest_map(lambda t: t.to(FLAGS.device), model.initial_state(B))
last_step = 0
last_time = timeit.default_timer()
for step in range(10000):
model_output, _ = model(**dummy_model_input)
model_output["action"].item() # sync.
if step and step % 100 == 0:
current_time = timeit.default_timer()
sps = (step - last_step) / (current_time - last_time)
last_step = step
last_time = current_time
print("Step %i @ %.1f SPS." % (step, sps))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment