Created
February 14, 2020 11:51
-
-
Save heiner/4306ff6ba6e7b163f4c4e56fd186bbe9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import timeit | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
class Flags: | |
pass | |
FLAGS = Flags() | |
FLAGS.use_lstm = False # True or false. | |
FLAGS.model = "resnet" # One of ["test", "shallow", "resnet"]. | |
FLAGS.device = "cpu" | |
FLAGS.batch_size = 4 | |
OBSERVATION_SHAPE = [4, 84, 84] | |
def nest_map(f, n): | |
if isinstance(n, tuple) or isinstance(n, list): | |
return n.__class__(nest_map(f, sn) for sn in n) | |
elif isinstance(n, dict): | |
return {k: nest_map(f, v) for k, v in n.items()} | |
else: | |
return f(n) | |
class TestModel(nn.Module): | |
def __init__(self, num_actions, core_output_size=256): | |
super(TestModel, self).__init__() | |
if FLAGS.use_lstm: | |
raise ValueError("Test model cannot use LSTM.") | |
self.num_actions = num_actions | |
self.linear = nn.Linear(4 * 84 * 84, core_output_size) | |
self.policy = nn.Linear(core_output_size, num_actions) | |
self.baseline = nn.Linear(core_output_size, 1) | |
def initial_state(self, batch_size=1): | |
return tuple() | |
def forward(self, last_actions, env_outputs, core_state, unroll=False): | |
if not unroll: | |
# [T=1, B, ...]. | |
env_outputs = nest_map(lambda t: t.unsqueeze(0), env_outputs) | |
observation, reward, done = env_outputs | |
T, B, *_ = observation.shape | |
x = observation.reshape(T * B, -1) | |
x = x.float() / 255.0 | |
core_output = self.linear(x) | |
policy_logits = self.policy(core_output) | |
baseline = self.baseline(core_output) | |
action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1) | |
policy_logits = policy_logits.view(T, B, self.num_actions) | |
baseline = baseline.view(T, B) | |
action = action.view(T, B) | |
outputs = dict(action=action, policy_logits=policy_logits, baseline=baseline) | |
if not unroll: | |
for t in outputs.values(): | |
t.squeeze_(0) | |
return outputs, core_state | |
class ShallowModel(nn.Module): | |
def __init__(self, num_actions, core_output_size=256): | |
super(ShallowModel, self).__init__() | |
self.num_actions = num_actions | |
self.use_lstm = FLAGS.use_lstm | |
self.conv1 = nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4) | |
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) | |
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) | |
# Fully connected layer. | |
self.fc = nn.Linear(3136, 512) | |
# FC output size + one-hot of last action + last reward. | |
core_output_size = self.fc.out_features + num_actions + 1 | |
if self.use_lstm: | |
self.core = nn.LSTMCell(core_output_size, 256) | |
core_output_size = 256 | |
self.policy = nn.Linear(core_output_size, self.num_actions) | |
self.baseline = nn.Linear(core_output_size, 1) | |
def initial_state(self, batch_size=1): | |
if not self.use_lstm: | |
return tuple() | |
return tuple(torch.zeros(batch_size, self.core.hidden_size) for _ in range(2)) | |
def forward(self, last_actions, env_outputs, core_state, unroll=False): | |
if not unroll: | |
# [T=1, B, ...]. | |
env_outputs = nest_map(lambda t: t.unsqueeze(0), env_outputs) | |
observation, reward, done = env_outputs | |
T, B = reward.shape | |
x = torch.flatten(observation, 0, 1) # Merge time and batch. | |
x = x.float() / 255.0 | |
x = F.relu(self.conv1(x)) | |
x = F.relu(self.conv2(x)) | |
x = F.relu(self.conv3(x)) | |
x = x.view(T * B, -1) | |
x = F.relu(self.fc(x)) | |
one_hot_last_action = F.one_hot( | |
last_actions.view(T * B), self.num_actions | |
).float() | |
reward = reward.view(T * B, 1).float() | |
core_input = torch.cat([x, reward, one_hot_last_action], dim=1) | |
if self.use_lstm: | |
core_input = core_input.view(T, B, -1) | |
core_output_list = [] | |
notdone = (~done).float() | |
notdone.unsqueeze_(-1) # [T, B, H=1] for broadcasting. | |
for input_t, notdone_t in zip(core_input.unbind(), notdone.unbind()): | |
core_state = nest_map(notdone_t.mul, core_state) | |
output_t, core_state = self.core(input_t, core_state) | |
core_state = (output_t, core_state) # nn.LSTMCell is a bit weird. | |
core_output_list.append(output_t) # [[B, H], [B, H], ...]. | |
core_output = torch.cat(core_output_list) # [T * B, H]. | |
else: | |
core_output = core_input | |
policy_logits = self.policy(core_output) | |
baseline = self.baseline(core_output) | |
if self.training: | |
action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1) | |
else: | |
# Don't sample when testing. | |
action = torch.argmax(policy_logits, dim=1) | |
policy_logits = policy_logits.view(T, B, self.num_actions) | |
baseline = baseline.view(T, B) | |
action = action.view(T, B) | |
outputs = { | |
"action": action, | |
"policy_logits": policy_logits, | |
"baseline": baseline, | |
} | |
if not unroll: | |
for t in outputs.values(): | |
t.squeeze_(0) | |
return outputs, core_state | |
class Model(nn.Module): | |
def __init__(self, num_actions, use_lstm=None): | |
super(Model, self).__init__() | |
self.num_actions = num_actions | |
self.use_lstm = use_lstm | |
if use_lstm is None: | |
self.use_lstm = FLAGS.use_lstm | |
self.feat_convs = [] | |
self.resnet1 = [] | |
self.resnet2 = [] | |
self.convs = [] | |
input_channels = 4 | |
for num_ch in [16, 32, 32]: | |
feats_convs = [] | |
feats_convs.append( | |
nn.Conv2d( | |
in_channels=input_channels, | |
out_channels=num_ch, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
) | |
) | |
feats_convs.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) | |
self.feat_convs.append(nn.Sequential(*feats_convs)) | |
input_channels = num_ch | |
for i in range(2): | |
resnet_block = [] | |
resnet_block.append(nn.ReLU()) | |
resnet_block.append( | |
nn.Conv2d( | |
in_channels=input_channels, | |
out_channels=num_ch, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
) | |
) | |
resnet_block.append(nn.ReLU()) | |
resnet_block.append( | |
nn.Conv2d( | |
in_channels=input_channels, | |
out_channels=num_ch, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
) | |
) | |
if i == 0: | |
self.resnet1.append(nn.Sequential(*resnet_block)) | |
else: | |
self.resnet2.append(nn.Sequential(*resnet_block)) | |
self.feat_convs = nn.ModuleList(self.feat_convs) | |
self.resnet1 = nn.ModuleList(self.resnet1) | |
self.resnet2 = nn.ModuleList(self.resnet2) | |
self.fc = nn.Linear(3872, 256) | |
# FC output size + last reward + one-hot of last action. | |
core_output_size = self.fc.out_features + 1 + num_actions | |
if self.use_lstm: | |
self.core = nn.LSTMCell(core_output_size, 256) | |
core_output_size = 256 | |
self.policy = nn.Linear(core_output_size, self.num_actions) | |
self.baseline = nn.Linear(core_output_size, 1) | |
def initial_state(self, batch_size=1): | |
if not self.use_lstm: | |
return tuple() | |
return tuple(torch.zeros(batch_size, self.core.hidden_size) for _ in range(2)) | |
def forward(self, last_actions, env_outputs, core_state, unroll=False): | |
if not unroll: | |
# [T=1, B, ...]. | |
env_outputs = nest_map(lambda t: t.unsqueeze(0), env_outputs) | |
observation, reward, done = env_outputs | |
T, B, *_ = observation.shape | |
x = torch.flatten(observation, 0, 1) # Merge time and batch. | |
x = x.float() / 255.0 | |
res_input = None | |
for i, fconv in enumerate(self.feat_convs): | |
x = fconv(x) | |
res_input = x | |
x = self.resnet1[i](x) | |
x += res_input | |
res_input = x | |
x = self.resnet2[i](x) | |
x += res_input | |
x = F.relu(x) | |
x = x.view(T * B, -1) | |
x = F.relu(self.fc(x)) | |
one_hot_last_action = F.one_hot( | |
last_actions.view(T * B), self.num_actions | |
).float() | |
reward = reward.view(T * B, 1).float() | |
core_input = torch.cat([x, reward, one_hot_last_action], dim=1) | |
if self.use_lstm: | |
core_input = core_input.view(T, B, -1) | |
core_output_list = [] | |
notdone = (~done).float() | |
notdone.unsqueeze_(-1) # [T, B, H=1] for broadcasting. | |
for input_t, notdone_t in zip(core_input.unbind(), notdone.unbind()): | |
core_state = nest_map(notdone_t.mul, core_state) | |
output_t, core_state = self.core(input_t, core_state) | |
core_state = (output_t, core_state) # nn.LSTMCell is a bit weird. | |
core_output_list.append(output_t) # [[B, H], [B, H], ...]. | |
core_output = torch.cat(core_output_list) # [T * B, H]. | |
else: | |
core_output = core_input | |
policy_logits = self.policy(core_output) | |
baseline = self.baseline(core_output) | |
if self.training: | |
action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1) | |
else: | |
# Don't sample when testing. | |
action = torch.argmax(policy_logits, dim=1) | |
policy_logits = policy_logits.view(T, B, self.num_actions) | |
baseline = baseline.view(T, B) | |
action = action.view(T, B) | |
outputs = dict(action=action, policy_logits=policy_logits, baseline=baseline) | |
if not unroll: | |
outputs = nest_map(lambda t: t.squeeze(0), outputs) | |
return outputs, core_state | |
def create_model(num_actions): | |
if FLAGS.model == "test": | |
return TestModel(num_actions) | |
if FLAGS.model == "shallow": | |
return ShallowModel(num_actions) | |
return Model(num_actions) | |
def main(): | |
num_actions = 6 | |
model = create_model(num_actions) | |
model = model.to(device=FLAGS.device) | |
dummy_env_output = ( | |
torch.empty([1] + OBSERVATION_SHAPE, dtype=torch.uint8), | |
torch.zeros(1, dtype=torch.float64), | |
torch.zeros(1, dtype=torch.bool), | |
) | |
# T = FLAGS.unroll_length | |
B = FLAGS.batch_size | |
dummy_model_input = dict( | |
last_actions=torch.zeros([1], dtype=torch.int64), | |
env_outputs=dummy_env_output, | |
core_state=model.initial_state(1), | |
) | |
dummy_model_input = nest_map(lambda t: t.to(FLAGS.device), dummy_model_input) | |
initial_agent_state = nest_map(lambda t: t.to(FLAGS.device), model.initial_state(B)) | |
last_step = 0 | |
last_time = timeit.default_timer() | |
for step in range(10000): | |
model_output, _ = model(**dummy_model_input) | |
model_output["action"].item() # sync. | |
if step and step % 100 == 0: | |
current_time = timeit.default_timer() | |
sps = (step - last_step) / (current_time - last_time) | |
last_step = step | |
last_time = current_time | |
print("Step %i @ %.1f SPS." % (step, sps)) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment