Skip to content

Instantly share code, notes, and snippets.

@prigoyal
Created June 5, 2019 18:49
Show Gist options
  • Save prigoyal/c48d805b8edafb3ffb87344ceba8da34 to your computer and use it in GitHub Desktop.
Save prigoyal/c48d805b8edafb3ffb87344ceba8da34 to your computer and use it in GitHub Desktop.
Taskonomy architectures.py
import torch.nn as nn
from torch.nn import Parameter, ModuleList
import torch.nn.functional as F
import torch
import multiprocessing
import numpy as np
import os
from gym import spaces
from torchvision.models import resnet18
from teas.models.taskonomyencoder import TaskonomyEncoder
from teas.baselines.utils import init, init_normc_
from teas.preprocess import transforms
import torchvision as vision
from teas.models.taskonomyencoder import TaskonomyEncoder, TASKONOMY_TASKS
from .sparsely_gated_mixture_of_experts import SparselyGatedMoELayer
init_ = lambda m: init(m,
nn.init.orthogonal_,
lambda x: nn.init.constant_(x, 0),
nn.init.calculate_gain('relu'))
DEFAULT_ENCODER_PATHS = ['/home/bradleyemi/taskonomy_data/{}_encoder.dat'.format(task)
for task in TASKONOMY_TASKS]
class FrameStacked(nn.Module):
def __init__(self, net, n_stack, parallel=False, max_parallel=50):
super().__init__()
self.net = net
self.n_stack = n_stack
self.parallel = parallel
self.max_parallel = max_parallel
def forward(self, x):
xs = torch.chunk(x, self.n_stack, dim=1)
if self.parallel and len(x) <= self.max_parallel:
xs = torch.cat(xs, dim=0)
res = self.net(xs)
res = torch.chunk(res, self.n_stack, dim=0)
else:
res = [self.net(x) for x in xs]
# if is_cuda(self.net):
# res = torch.nn.parallel.parallel_apply([self.net for _ in xs],
# [(xi,) for xi in xs])
# else:
# res = [self.net(x) for x in xs]
out = torch.cat(res, dim=1)
return out
class AtariNatureEncoder(nn.Module): # pylint: disable=too-many-instance-attributes
""" VAE encoder """
def __init__(self, img_channels, latent_size):
super().__init__()
self.latent_size = latent_size
#self.img_size = img_size
self.img_channels = img_channels
init_ = lambda m: init(m,
nn.init.orthogonal_,
lambda x: nn.init.constant_(x, 0),
nn.init.calculate_gain('relu'))
self.conv1 = init_(nn.Conv2d(img_channels, 32, 8, stride=4))
self.conv2 = init_(nn.Conv2d(32, 64, 4, stride=2))
self.conv3 = init_(nn.Conv2d(64, 32, 3, stride=1))
self.flatten = Flatten()
self.fc1 = init_(nn.Linear(32*7*7, latent_size))
def forward(self, x): # pylint: disable=arguments-differ
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = self.flatten(x)
x = F.relu(self.fc1(x))
return x
class TaskonomyMoE(nn.Module):
def __init__(self, img_channels, latent_size, use_map=False, encoder_paths=DEFAULT_ENCODER_PATHS, n_gating_features=128, k=1, temperature=1.0):
super().__init__()
self.img_channels = img_channels
self.gate_features = AtariNatureEncoder(img_channels, n_gating_features)
print("LOADING EXPERTS")
print(encoder_paths)
self.experts = ModuleList(self._load_encoders_parallel(encoder_paths))
self.use_map = use_map
print("experts loaded")
print("using k = ", k)
self.mixture_of_experts = SparselyGatedMoELayer(n_gating_features, self.experts, k=k, temperature=temperature)
self.postprocessing = TaskonomyFeaturesOnlyNet(img_channels // 3, use_map=self.use_map)
self.accepts_sensordict = True
def _load_encoder(self, encoder_path):
net = TaskonomyEncoder() #.cuda()
net.eval()
checkpoint = torch.load(encoder_path)
net.load_state_dict(checkpoint['state_dict'])
net = FrameStacked(net, n_stack=int(self.img_channels/3))
for p in net.parameters():
p.requires_grad = False
# net = Compose(nn.GroupNorm(32, 32, affine=False), net)
return net
def _load_encoders_parallel(self, encoder_paths, n_processes=None):
'''
n_processes = len(encoder_paths) if n_processes is None else min(len(encoder_paths), n_processes)
n_parallel = min(multiprocessing.cpu_count(), n_processes)
pool = multiprocessing.Pool(min(n_parallel, n_processes))
experts = pool.map(self._load_encoder, encoder_paths)
pool.close()
pool.join()
experts = [e.cuda() for e in experts]
'''
experts = [self._load_encoder(path).cuda() for path in encoder_paths]
# experts = [self._load_encoder(encoder_paths[0]).cuda()]
return experts
def forward(self, sensors):
x_gate = self.gate_features(sensors['rgb_filled'])
expert_output = self.mixture_of_experts(x_gate, sensors['taskonomy'])
x = {}
x['taskonomy'] = expert_output
if self.use_map:
x['map'] = sensors['map']
return self.postprocessing(x)
class TaskonomyUnconditionedMoE(TaskonomyMoE):
def __init__(self, img_channels, latent_size, encoder_paths=DEFAULT_ENCODER_PATHS, n_gating_features=128, k=1, temperature=1.0):
super().__init__()
self.img_channels = img_channels
self.experts = ModuleList(self._load_encoders_parallel(encoder_paths))
print("experts loaded")
self.gate_features = IgnoreInput(n_gating_features)
self.mixture_of_experts = SparselyGatedMoELayer(n_gating_features, self.experts, k=k)
self.postprocessing = TaskonomyFrozenFeaturesRelu(img_channels, latent_size)
self.accepts_sensordict = True
class IgnoreInput(nn.Module):
def __init__(self, n_experts):
super().__init__()
self.weights = Parameter(torch.Tensor(n_experts))
def forward(self, x):
sft = F.softmax(self.weights, dim=0)
return torch.stack([sft for _ in range(x.shape[0])], dim=0)
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
def atari_nature(num_inputs, num_outputs=512):
init_ = lambda m: init(m,
nn.init.orthogonal_,
lambda x: nn.init.constant_(x, 0),
nn.init.calculate_gain('relu'))
return nn.Sequential(
init_(nn.Conv2d(num_inputs, 32, 8, stride=4)),
nn.ReLU(),
init_(nn.Conv2d(32, 64, 4, stride=2)),
nn.ReLU(),
init_(nn.Conv2d(64, 32, 3, stride=1)),
nn.ReLU(),
Flatten(),
init_(nn.Linear(32 * 7 * 7, num_outputs)), # 512 original outputs
nn.ReLU()
)
def atari_conv(num_inputs):
init_ = lambda m: init(m,
nn.init.orthogonal_,
lambda x: nn.init.constant_(x, 0),
nn.init.calculate_gain('relu'))
return nn.Sequential(
init_(nn.Conv2d(num_inputs, 32, 8, stride=4)),
nn.ReLU(),
init_(nn.Conv2d(32, 64, 4, stride=2)),
nn.ReLU(),
init_(nn.Conv2d(64, 32, 3, stride=1)),
nn.ReLU())
def atari_small_conv(num_inputs):
init_ = lambda m: init(m,
nn.init.orthogonal_,
lambda x: nn.init.constant_(x, 0),
nn.init.calculate_gain('relu'))
return nn.Sequential(
init_(nn.Conv2d(num_inputs, 32, 8, stride=4)),
nn.ReLU(),
init_(nn.Conv2d(32, 64, 4, stride=2)),
nn.ReLU())
def atari_nature_vae(num_inputs, num_outputs=512):
init_ = lambda m: init(m,
nn.init.orthogonal_,
lambda x: nn.init.constant_(x, 0),
nn.init.calculate_gain('relu'))
nn.Sequential(
init_(nn.Conv2d(num_inputs, 32, 8, stride=4)),
nn.ReLU(),
init_(nn.Conv2d(32, 64, 4, stride=2)),
nn.ReLU(),
init_(nn.Conv2d(64, 32, 3, stride=1)),
nn.ReLU(),
Flatten(),
init_(nn.Linear(32 * 7 * 7, num_outputs)),
nn.ReLU()
)
class FramestackResnet(nn.Module):
def __init__(self, n_frames):
super(FramestackResnet, self).__init__()
self.n_frames = n_frames
self.resnet = resnet18(pretrained=True)
def forward(self, x):
assert x.shape[1] / 3 == self.n_frames, "Dimensionality mismatch of input, is n_frames set right?"
num_observations = x.shape[0]
reshaped = x.reshape((x.shape[0] * self.n_frames, 3, x.shape[2], x.shape[3]))
features = self.resnet(reshaped)
return features.reshape((num_observations, features.shape[0] * features.shape[1] // num_observations))
def is_cuda(model):
return next(model.parameters()).is_cuda
def task_encoder(checkpoint_path):
net = TaskonomyEncoder()
net.eval()
print(checkpoint_path)
if checkpoint_path != None:
path_pth_ckpt = os.path.join(checkpoint_path)
checkpoint = torch.load(path_pth_ckpt)
net.load_state_dict(checkpoint['state_dict'])
return net
class AtariNet(nn.Module):
def __init__(self, n_frames, use_map=False, use_target=True,
output_size=512):
super(AtariNet, self).__init__()
self.n_frames = n_frames
self.use_map = use_map
self.output_size = output_size
self.use_target = use_target
if self.use_map:
self.map_tower = atari_conv(num_inputs=self.n_frames)
self.map_channels = 1
else:
self.map_channels = 0
if self.use_target:
self.target_channels = 3
else:
self.target_channels = 0
self.image_tower = atari_small_conv(num_inputs=self.n_frames*3)
self.conv1 = nn.Conv2d(64 + (self.n_frames * self.target_channels), 32, 3, stride=1)
self.flatten = Flatten()
self.fc1 = init_(nn.Linear(32 * 7 * 7 * (self.map_channels + 1), 1024))
self.fc2 = init_(nn.Linear(1024, self.output_size))
def forward(self, x):
x_rgb = x['rgb_filled']
x_rgb = self.image_tower(x_rgb)
if self.use_target:
x_rgb = torch.cat([x_rgb, x["target"]], dim=1)
x_rgb = F.relu(self.conv1(x_rgb))
if self.use_map:
x_map = x['map']
x_map = self.map_tower(x_map)
x_rgb = torch.cat([x_rgb, x_map], dim=1)
x = self.flatten(x_rgb)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return x
class TaskonomyFeaturesOnlySingleSensorNet(nn.Module):
# Taskonomy features only, taskonomy encoder frozen
def __init__(self, n_frames, use_map=False,
output_size=512):
super(TaskonomyFeaturesOnlyNet, self).__init__()
self.n_frames = n_frames
self.output_size = output_size
self.use_map = use_map
if self.use_map:
self.map_tower = atari_conv(self.n_frames)
self.map_channels = 1
else:
self.map_channels = 0
self.conv1 = nn.Conv2d(self.n_frames * 8, 32, 4, stride=2)
self.flatten = Flatten()
self.fc1 = init_(nn.Linear(32 * 7 * 7 * (self.map_channels + 1), 1024))
self.fc2 = init_(nn.Linear(1024, self.output_size))
def forward(self, x):
x_taskonomy = x['taskonomy']
x_taskonomy = F.relu(self.conv1(x_taskonomy))
if self.use_map:
x_map = x['map']
x_map = self.map_tower(x_map)
x_taskonomy = torch.cat([x_map, x_taskonomy], dim=1)
x = self.flatten(x_taskonomy)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return x
class TaskonomyFeaturesOnlyNet(nn.Module):
# Taskonomy features only, taskonomy encoder frozen
def __init__(self, n_frames, use_map=False, use_target=True,
output_size=512):
super(TaskonomyFeaturesOnlyNet, self).__init__()
self.n_frames = n_frames
self.output_size = output_size
self.use_map = use_map
self.use_target = use_target
if self.use_map:
self.map_tower = atari_conv(self.n_frames)
self.map_channels = 1
else:
self.map_channels = 0
if self.use_target:
self.target_channels = 3
else:
self.target_channels = 0
self.conv1 = nn.Conv2d(self.n_frames * (8 + self.target_channels), 32, 4, stride=2)
self.flatten = Flatten()
self.fc1 = init_(nn.Linear(32 * 7 * 7 * (self.map_channels + 1), 1024))
self.fc2 = init_(nn.Linear(1024, self.output_size))
def forward(self, x):
x_taskonomy = x['taskonomy']
if self.use_target:
x_taskonomy = torch.cat([x_taskonomy, x["target"]], dim=1)
x_taskonomy = F.relu(self.conv1(x_taskonomy))
if self.use_map:
x_map = x['map']
x_map = self.map_tower(x_map)
x_taskonomy = torch.cat([x_map, x_taskonomy], dim=1)
x = self.flatten(x_taskonomy)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return x
class TaskonomyFeaturesPixelsNet(nn.Module):
# Taskonomy features and pixels, taskonomy encoder frozen
def __init__(self, n_frames, use_map=False,
output_size=512):
super(TaskonomyFeaturesPixelsNet, self).__init__()
self.n_frames = n_frames
self.output_size = output_size
self.use_map = use_map
if self.use_map:
self.map_tower = atari_conv(self.n_frames)
self.map_channels = 1
else:
self.map_channels = 0
self.image_tower = atari_conv(self.n_frames * 3)
self.conv1 = nn.Conv2d(self.n_frames * 8, 32, 4, stride=2)
self.flatten = Flatten()
self.fc1 = init_(nn.Linear(32 * 7 * 7 * (self.map_channels + 2), 1024))
self.fc2 = init_(nn.Linear(1024, self.output_size))
def forward(self, x):
x_rgb = x['rgb_filled']
x_rgb = self.image_tower(x_rgb)
x_taskonomy = x['taskonomy']
x_taskonomy = F.relu(self.conv1(x_taskonomy))
cat_image = torch.cat([x_rgb, x_taskonomy], dim=1)
if self.use_map:
x_map = x['map']
x_map = self.map_tower(x_map)
cat_image = torch.cat([cat_image, x_map], dim=1)
x = self.flatten(cat_image)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return x
RESCALE_0_1_NEG1_POS1 = vision.transforms.Normalize([0.5,0.5,0.5], [0.5, 0.5, 0.5])
RESCALE_0_255_NEG1_POS1 = vision.transforms.Normalize([127.5,127.5,127.5], [255, 255, 255])
def pixels_as_state(output_size, dtype=np.float32):
''' rescale_centercrop_resize
Args:
output_size: A tuple CxWxH
dtype: of the output (must be np, not torch)
Returns:
a function which returns takes 'env' and returns transform, output_size, dtype
'''
def _thunk(obs_space):
obs_shape = obs_space.shape
obs_min_wh = min(obs_shape[:2])
output_wh = output_size[-2:] # The out
processed_env_shape = output_size
base_pipeline = vision.transforms.Compose([
vision.transforms.ToPILImage(),
vision.transforms.CenterCrop([obs_min_wh, obs_min_wh]),
vision.transforms.Resize(output_wh)])
grayscale_pipeline = vision.transforms.Compose([
vision.transforms.Grayscale(),
vision.transforms.ToTensor(),
RESCALE_0_1_NEG1_POS1,
])
rgb_pipeline = vision.transforms.Compose([
vision.transforms.ToTensor(),
RESCALE_0_1_NEG1_POS1,
])
def pipeline(x):
base = base_pipeline(x)
rgb = rgb_pipeline(base)
gray = grayscale_pipeline(base)
n_rgb = output_size[0] // 3
n_gray = output_size[0] % 3
return torch.cat([rgb] * n_rgb + [gray] * n_gray)
return pipeline, spaces.Box(-1, 1, output_size, dtype)
return _thunk
def taskonomy_features_transform(task_path, dtype=np.float32):
''' rescale_centercrop_resize
Args:
output_size: A tuple CxWxH
dtype: of the output (must be np, not torch)
Returns:
a function which returns takes 'env' and returns transform, output_size, dtype
'''
_rescale_thunk = transforms.rescale_centercrop_resize((3, 256, 256))
_pixels_as_state_thunk = pixels_as_state((8, 16, 16))
if task_path != 'pixels_as_state':
net = TaskonomyEncoder().cuda()
net.eval()
if task_path != 'None':
checkpoint = torch.load(task_path)
# net.load_state_dict(checkpoint['state_dict'])
net.load_state_dict(checkpoint['state_dict'], strict=False)
def encode(x):
if task_path == 'pixels_as_state':
return x
with torch.no_grad():
return net(x)
def _taskonomy_features_transform_thunk(obs_space):
rescale, _ = _rescale_thunk(obs_space)
pixels_as_state, _ = _pixels_as_state_thunk(obs_space)
def pipeline(x):
x = rescale(x).view(1, 3, 256, 256)
x = torch.Tensor(x).cuda()
x = encode(x)
return x.cpu()
def pixels_as_state_pipeline(x):
return pixels_as_state(x).cpu()
if task_path == 'pixels_as_state':
return pixels_as_state_pipeline, spaces.Box(-1, 1, (8, 16, 16), dtype)
else:
return pipeline, spaces.Box(-1, 1, (8, 16, 16), dtype)
return _taskonomy_features_transform_thunk
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment