Created
June 5, 2019 18:47
-
-
Save prigoyal/a042f5247a00951bee9f1c1f285ff2f9 to your computer and use it in GitHub Desktop.
Taskonomy architectures.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn as nn | |
from torch.nn import Parameter, ModuleList | |
import torch.nn.functional as F | |
import torch | |
import multiprocessing | |
import numpy as np | |
import os | |
from gym import spaces | |
from torchvision.models import resnet18 | |
from teas.models.taskonomyencoder import TaskonomyEncoder | |
from teas.baselines.utils import init, init_normc_ | |
from teas.preprocess import transforms | |
import torchvision as vision | |
from teas.models.taskonomyencoder import TaskonomyEncoder, TASKONOMY_TASKS | |
from .sparsely_gated_mixture_of_experts import SparselyGatedMoELayer | |
init_ = lambda m: init(m, | |
nn.init.orthogonal_, | |
lambda x: nn.init.constant_(x, 0), | |
nn.init.calculate_gain('relu')) | |
DEFAULT_ENCODER_PATHS = ['/home/bradleyemi/taskonomy_data/{}_encoder.dat'.format(task) | |
for task in TASKONOMY_TASKS] | |
class FrameStacked(nn.Module): | |
def __init__(self, net, n_stack, parallel=False, max_parallel=50): | |
super().__init__() | |
self.net = net | |
self.n_stack = n_stack | |
self.parallel = parallel | |
self.max_parallel = max_parallel | |
def forward(self, x): | |
xs = torch.chunk(x, self.n_stack, dim=1) | |
if self.parallel and len(x) <= self.max_parallel: | |
xs = torch.cat(xs, dim=0) | |
res = self.net(xs) | |
res = torch.chunk(res, self.n_stack, dim=0) | |
else: | |
res = [self.net(x) for x in xs] | |
# if is_cuda(self.net): | |
# res = torch.nn.parallel.parallel_apply([self.net for _ in xs], | |
# [(xi,) for xi in xs]) | |
# else: | |
# res = [self.net(x) for x in xs] | |
out = torch.cat(res, dim=1) | |
return out | |
class AtariNatureEncoder(nn.Module): # pylint: disable=too-many-instance-attributes | |
""" VAE encoder """ | |
def __init__(self, img_channels, latent_size): | |
super().__init__() | |
self.latent_size = latent_size | |
#self.img_size = img_size | |
self.img_channels = img_channels | |
init_ = lambda m: init(m, | |
nn.init.orthogonal_, | |
lambda x: nn.init.constant_(x, 0), | |
nn.init.calculate_gain('relu')) | |
self.conv1 = init_(nn.Conv2d(img_channels, 32, 8, stride=4)) | |
self.conv2 = init_(nn.Conv2d(32, 64, 4, stride=2)) | |
self.conv3 = init_(nn.Conv2d(64, 32, 3, stride=1)) | |
self.flatten = Flatten() | |
self.fc1 = init_(nn.Linear(32*7*7, latent_size)) | |
def forward(self, x): # pylint: disable=arguments-differ | |
x = F.relu(self.conv1(x)) | |
x = F.relu(self.conv2(x)) | |
x = F.relu(self.conv3(x)) | |
x = self.flatten(x) | |
x = F.relu(self.fc1(x)) | |
return x | |
class TaskonomyMoE(nn.Module): | |
def __init__(self, img_channels, latent_size, use_map=False, encoder_paths=DEFAULT_ENCODER_PATHS, n_gating_features=128, k=1, temperature=1.0): | |
super().__init__() | |
self.img_channels = img_channels | |
self.gate_features = AtariNatureEncoder(img_channels, n_gating_features) | |
print("LOADING EXPERTS") | |
print(encoder_paths) | |
self.experts = ModuleList(self._load_encoders_parallel(encoder_paths)) | |
self.use_map = use_map | |
print("experts loaded") | |
print("using k = ", k) | |
self.mixture_of_experts = SparselyGatedMoELayer(n_gating_features, self.experts, k=k, temperature=temperature) | |
self.postprocessing = TaskonomyFeaturesOnlyNet(img_channels // 3, use_map=self.use_map) | |
self.accepts_sensordict = True | |
def _load_encoder(self, encoder_path): | |
net = TaskonomyEncoder() #.cuda() | |
net.eval() | |
checkpoint = torch.load(encoder_path) | |
net.load_state_dict(checkpoint['state_dict']) | |
net = FrameStacked(net, n_stack=int(self.img_channels/3)) | |
for p in net.parameters(): | |
p.requires_grad = False | |
# net = Compose(nn.GroupNorm(32, 32, affine=False), net) | |
return net | |
def _load_encoders_parallel(self, encoder_paths, n_processes=None): | |
''' | |
n_processes = len(encoder_paths) if n_processes is None else min(len(encoder_paths), n_processes) | |
n_parallel = min(multiprocessing.cpu_count(), n_processes) | |
pool = multiprocessing.Pool(min(n_parallel, n_processes)) | |
experts = pool.map(self._load_encoder, encoder_paths) | |
pool.close() | |
pool.join() | |
experts = [e.cuda() for e in experts] | |
''' | |
experts = [self._load_encoder(path).cuda() for path in encoder_paths] | |
# experts = [self._load_encoder(encoder_paths[0]).cuda()] | |
return experts | |
def forward(self, sensors): | |
x_gate = self.gate_features(sensors['rgb_filled']) | |
expert_output = self.mixture_of_experts(x_gate, sensors['taskonomy']) | |
x = {} | |
x['taskonomy'] = expert_output | |
if self.use_map: | |
x['map'] = sensors['map'] | |
return self.postprocessing(x) | |
class TaskonomyUnconditionedMoE(TaskonomyMoE): | |
def __init__(self, img_channels, latent_size, encoder_paths=DEFAULT_ENCODER_PATHS, n_gating_features=128, k=1, temperature=1.0): | |
super().__init__() | |
self.img_channels = img_channels | |
self.experts = ModuleList(self._load_encoders_parallel(encoder_paths)) | |
print("experts loaded") | |
self.gate_features = IgnoreInput(n_gating_features) | |
self.mixture_of_experts = SparselyGatedMoELayer(n_gating_features, self.experts, k=k) | |
self.postprocessing = TaskonomyFrozenFeaturesRelu(img_channels, latent_size) | |
self.accepts_sensordict = True | |
class IgnoreInput(nn.Module): | |
def __init__(self, n_experts): | |
super().__init__() | |
self.weights = Parameter(torch.Tensor(n_experts)) | |
def forward(self, x): | |
sft = F.softmax(self.weights, dim=0) | |
return torch.stack([sft for _ in range(x.shape[0])], dim=0) | |
class Flatten(nn.Module): | |
def forward(self, x): | |
return x.view(x.size(0), -1) | |
def atari_nature(num_inputs, num_outputs=512): | |
init_ = lambda m: init(m, | |
nn.init.orthogonal_, | |
lambda x: nn.init.constant_(x, 0), | |
nn.init.calculate_gain('relu')) | |
return nn.Sequential( | |
init_(nn.Conv2d(num_inputs, 32, 8, stride=4)), | |
nn.ReLU(), | |
init_(nn.Conv2d(32, 64, 4, stride=2)), | |
nn.ReLU(), | |
init_(nn.Conv2d(64, 32, 3, stride=1)), | |
nn.ReLU(), | |
Flatten(), | |
init_(nn.Linear(32 * 7 * 7, num_outputs)), # 512 original outputs | |
nn.ReLU() | |
) | |
def atari_conv(num_inputs): | |
init_ = lambda m: init(m, | |
nn.init.orthogonal_, | |
lambda x: nn.init.constant_(x, 0), | |
nn.init.calculate_gain('relu')) | |
return nn.Sequential( | |
init_(nn.Conv2d(num_inputs, 32, 8, stride=4)), | |
nn.ReLU(), | |
init_(nn.Conv2d(32, 64, 4, stride=2)), | |
nn.ReLU(), | |
init_(nn.Conv2d(64, 32, 3, stride=1)), | |
nn.ReLU()) | |
def atari_small_conv(num_inputs): | |
init_ = lambda m: init(m, | |
nn.init.orthogonal_, | |
lambda x: nn.init.constant_(x, 0), | |
nn.init.calculate_gain('relu')) | |
return nn.Sequential( | |
init_(nn.Conv2d(num_inputs, 32, 8, stride=4)), | |
nn.ReLU(), | |
init_(nn.Conv2d(32, 64, 4, stride=2)), | |
nn.ReLU()) | |
def atari_nature_vae(num_inputs, num_outputs=512): | |
init_ = lambda m: init(m, | |
nn.init.orthogonal_, | |
lambda x: nn.init.constant_(x, 0), | |
nn.init.calculate_gain('relu')) | |
nn.Sequential( | |
init_(nn.Conv2d(num_inputs, 32, 8, stride=4)), | |
nn.ReLU(), | |
init_(nn.Conv2d(32, 64, 4, stride=2)), | |
nn.ReLU(), | |
init_(nn.Conv2d(64, 32, 3, stride=1)), | |
nn.ReLU(), | |
Flatten(), | |
init_(nn.Linear(32 * 7 * 7, num_outputs)), | |
nn.ReLU() | |
) | |
class FramestackResnet(nn.Module): | |
def __init__(self, n_frames): | |
super(FramestackResnet, self).__init__() | |
self.n_frames = n_frames | |
self.resnet = resnet18(pretrained=True) | |
def forward(self, x): | |
assert x.shape[1] / 3 == self.n_frames, "Dimensionality mismatch of input, is n_frames set right?" | |
num_observations = x.shape[0] | |
reshaped = x.reshape((x.shape[0] * self.n_frames, 3, x.shape[2], x.shape[3])) | |
features = self.resnet(reshaped) | |
return features.reshape((num_observations, features.shape[0] * features.shape[1] // num_observations)) | |
def is_cuda(model): | |
return next(model.parameters()).is_cuda | |
def task_encoder(checkpoint_path): | |
net = TaskonomyEncoder() | |
net.eval() | |
print(checkpoint_path) | |
if checkpoint_path != None: | |
path_pth_ckpt = os.path.join(checkpoint_path) | |
checkpoint = torch.load(path_pth_ckpt) | |
net.load_state_dict(checkpoint['state_dict']) | |
return net | |
class AtariNet(nn.Module): | |
def __init__(self, n_frames, use_map=False, use_target=True, | |
output_size=512): | |
super(AtariNet, self).__init__() | |
self.n_frames = n_frames | |
self.use_map = use_map | |
self.output_size = output_size | |
self.use_target = use_target | |
if self.use_map: | |
self.map_tower = atari_conv(num_inputs=self.n_frames) | |
self.map_channels = 1 | |
else: | |
self.map_channels = 0 | |
if self.use_target: | |
self.target_channels = 3 | |
else: | |
self.target_channels = 0 | |
self.image_tower = atari_small_conv(num_inputs=self.n_frames*3) | |
self.conv1 = nn.Conv2d(64 + (self.n_frames * self.target_channels), 32, 3, stride=1) | |
self.flatten = Flatten() | |
self.fc1 = init_(nn.Linear(32 * 7 * 7 * (self.map_channels + 1), 1024)) | |
self.fc2 = init_(nn.Linear(1024, self.output_size)) | |
def forward(self, x): | |
x_rgb = x['rgb_filled'] | |
x_rgb = self.image_tower(x_rgb) | |
if self.use_target: | |
x_rgb = torch.cat([x_rgb, x["target"]], dim=1) | |
x_rgb = F.relu(self.conv1(x_rgb)) | |
if self.use_map: | |
x_map = x['map'] | |
x_map = self.map_tower(x_map) | |
x_rgb = torch.cat([x_rgb, x_map], dim=1) | |
x = self.flatten(x_rgb) | |
x = F.relu(self.fc1(x)) | |
x = F.relu(self.fc2(x)) | |
return x | |
class TaskonomyFeaturesOnlySingleSensorNet(nn.Module): | |
# Taskonomy features only, taskonomy encoder frozen | |
def __init__(self, n_frames, use_map=False, | |
output_size=512): | |
super(TaskonomyFeaturesOnlyNet, self).__init__() | |
self.n_frames = n_frames | |
self.output_size = output_size | |
self.use_map = use_map | |
if self.use_map: | |
self.map_tower = atari_conv(self.n_frames) | |
self.map_channels = 1 | |
else: | |
self.map_channels = 0 | |
self.conv1 = nn.Conv2d(self.n_frames * 8, 32, 4, stride=2) | |
self.flatten = Flatten() | |
self.fc1 = init_(nn.Linear(32 * 7 * 7 * (self.map_channels + 1), 1024)) | |
self.fc2 = init_(nn.Linear(1024, self.output_size)) | |
def forward(self, x): | |
x_taskonomy = x['taskonomy'] | |
x_taskonomy = F.relu(self.conv1(x_taskonomy)) | |
if self.use_map: | |
x_map = x['map'] | |
x_map = self.map_tower(x_map) | |
x_taskonomy = torch.cat([x_map, x_taskonomy], dim=1) | |
x = self.flatten(x_taskonomy) | |
x = F.relu(self.fc1(x)) | |
x = F.relu(self.fc2(x)) | |
return x | |
class TaskonomyFeaturesOnlyNet(nn.Module): | |
# Taskonomy features only, taskonomy encoder frozen | |
def __init__(self, n_frames, use_map=False, use_target=True, | |
output_size=512): | |
super(TaskonomyFeaturesOnlyNet, self).__init__() | |
self.n_frames = n_frames | |
self.output_size = output_size | |
self.use_map = use_map | |
self.use_target = use_target | |
if self.use_map: | |
self.map_tower = atari_conv(self.n_frames) | |
self.map_channels = 1 | |
else: | |
self.map_channels = 0 | |
if self.use_target: | |
self.target_channels = 3 | |
else: | |
self.target_channels = 0 | |
self.conv1 = nn.Conv2d(self.n_frames * (8 + self.target_channels), 32, 4, stride=2) | |
self.flatten = Flatten() | |
self.fc1 = init_(nn.Linear(32 * 7 * 7 * (self.map_channels + 1), 1024)) | |
self.fc2 = init_(nn.Linear(1024, self.output_size)) | |
def forward(self, x): | |
x_taskonomy = x['taskonomy'] | |
if self.use_target: | |
x_taskonomy = torch.cat([x_taskonomy, x["target"]], dim=1) | |
x_taskonomy = F.relu(self.conv1(x_taskonomy)) | |
if self.use_map: | |
x_map = x['map'] | |
x_map = self.map_tower(x_map) | |
x_taskonomy = torch.cat([x_map, x_taskonomy], dim=1) | |
x = self.flatten(x_taskonomy) | |
x = F.relu(self.fc1(x)) | |
x = F.relu(self.fc2(x)) | |
return x | |
class TaskonomyFeaturesPixelsNet(nn.Module): | |
# Taskonomy features and pixels, taskonomy encoder frozen | |
def __init__(self, n_frames, use_map=False, | |
output_size=512): | |
super(TaskonomyFeaturesPixelsNet, self).__init__() | |
self.n_frames = n_frames | |
self.output_size = output_size | |
self.use_map = use_map | |
if self.use_map: | |
self.map_tower = atari_conv(self.n_frames) | |
self.map_channels = 1 | |
else: | |
self.map_channels = 0 | |
self.image_tower = atari_conv(self.n_frames * 3) | |
self.conv1 = nn.Conv2d(self.n_frames * 8, 32, 4, stride=2) | |
self.flatten = Flatten() | |
self.fc1 = init_(nn.Linear(32 * 7 * 7 * (self.map_channels + 2), 1024)) | |
self.fc2 = init_(nn.Linear(1024, self.output_size)) | |
def forward(self, x): | |
x_rgb = x['rgb_filled'] | |
x_rgb = self.image_tower(x_rgb) | |
x_taskonomy = x['taskonomy'] | |
x_taskonomy = F.relu(self.conv1(x_taskonomy)) | |
cat_image = torch.cat([x_rgb, x_taskonomy], dim=1) | |
if self.use_map: | |
x_map = x['map'] | |
x_map = self.map_tower(x_map) | |
cat_image = torch.cat([cat_image, x_map], dim=1) | |
x = self.flatten(cat_image) | |
x = F.relu(self.fc1(x)) | |
x = F.relu(self.fc2(x)) | |
return x | |
RESCALE_0_1_NEG1_POS1 = vision.transforms.Normalize([0.5,0.5,0.5], [0.5, 0.5, 0.5]) | |
RESCALE_0_255_NEG1_POS1 = vision.transforms.Normalize([127.5,127.5,127.5], [255, 255, 255]) | |
def pixels_as_state(output_size, dtype=np.float32): | |
''' rescale_centercrop_resize | |
Args: | |
output_size: A tuple CxWxH | |
dtype: of the output (must be np, not torch) | |
Returns: | |
a function which returns takes 'env' and returns transform, output_size, dtype | |
''' | |
def _thunk(obs_space): | |
obs_shape = obs_space.shape | |
obs_min_wh = min(obs_shape[:2]) | |
output_wh = output_size[-2:] # The out | |
processed_env_shape = output_size | |
base_pipeline = vision.transforms.Compose([ | |
vision.transforms.ToPILImage(), | |
vision.transforms.CenterCrop([obs_min_wh, obs_min_wh]), | |
vision.transforms.Resize(output_wh)]) | |
grayscale_pipeline = vision.transforms.Compose([ | |
vision.transforms.Grayscale(), | |
vision.transforms.ToTensor(), | |
RESCALE_0_1_NEG1_POS1, | |
]) | |
rgb_pipeline = vision.transforms.Compose([ | |
vision.transforms.ToTensor(), | |
RESCALE_0_1_NEG1_POS1, | |
]) | |
def pipeline(x): | |
base = base_pipeline(x) | |
rgb = rgb_pipeline(base) | |
gray = grayscale_pipeline(base) | |
n_rgb = output_size[0] // 3 | |
n_gray = output_size[0] % 3 | |
return torch.cat([rgb] * n_rgb + [gray] * n_gray) | |
return pipeline, spaces.Box(-1, 1, output_size, dtype) | |
return _thunk | |
def taskonomy_features_transform(task_path, dtype=np.float32): | |
''' rescale_centercrop_resize | |
Args: | |
output_size: A tuple CxWxH | |
dtype: of the output (must be np, not torch) | |
Returns: | |
a function which returns takes 'env' and returns transform, output_size, dtype | |
''' | |
_rescale_thunk = transforms.rescale_centercrop_resize((3, 256, 256)) | |
_pixels_as_state_thunk = pixels_as_state((8, 16, 16)) | |
if task_path != 'pixels_as_state': | |
net = TaskonomyEncoder().cuda() | |
net.eval() | |
if task_path != 'None': | |
checkpoint = torch.load(task_path) | |
# net.load_state_dict(checkpoint['state_dict']) | |
net.load_state_dict(checkpoint['state_dict'], strict=False) | |
def encode(x): | |
if task_path == 'pixels_as_state': | |
return x | |
with torch.no_grad(): | |
return net(x) | |
def _taskonomy_features_transform_thunk(obs_space): | |
rescale, _ = _rescale_thunk(obs_space) | |
pixels_as_state, _ = _pixels_as_state_thunk(obs_space) | |
def pipeline(x): | |
x = rescale(x).view(1, 3, 256, 256) | |
x = torch.Tensor(x).cuda() | |
x = encode(x) | |
return x.cpu() | |
def pixels_as_state_pipeline(x): | |
return pixels_as_state(x).cpu() | |
if task_path == 'pixels_as_state': | |
return pixels_as_state_pipeline, spaces.Box(-1, 1, (8, 16, 16), dtype) | |
else: | |
return pipeline, spaces.Box(-1, 1, (8, 16, 16), dtype) | |
return _taskonomy_features_transform_thunk |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment