Skip to content

Instantly share code, notes, and snippets.

@AruniRC
Last active April 24, 2019 19:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AruniRC/b841917fb1e5196fd9df750e52d60631 to your computer and use it in GitHub Desktop.
Save AruniRC/b841917fb1e5196fd9df750e52d60631 to your computer and use it in GitHub Desktop.
Allow shifts and scales of Poincare distance which usually lies on the unit disc
import numpy as np
import itertools
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable
from torch.autograd import Function
from scipy.spatial.distance import pdist
from core.config import cfg
import nn as mynn
import utils.net as net_utils
import numpy as np
import math
import gc
from joblib import Parallel,delayed
from modeling.sparse_activations import Sparsemax
from poincare_embeddings.hype import poincare
from hyperbolic_cones import my_poincare_model as mpm
DEBUG = False
class GradScaler(Function):
"""
Gradient scaler layer
Based off:
https://discuss.pytorch.org/t/solved-reverse-gradients-in-backward-pass/3589/4
"""
def __init__(self, scaler=0.0):
self.scaler = scaler
def forward(self, x):
return x.view_as(x)
def backward(self, grad_output):
return (grad_output * self.scaler)
def grad_scale(x):
return GradScaler()(x)
class Poincare(nn.Module):
def __init__(self):
super(Poincare, self).__init__()
self.eps = 1e-5
def forward(self, u, v):
eps = self.eps
squnorm = torch.clamp(torch.sum(u * u, dim=-1), 0, 1 - eps)
sqvnorm = torch.clamp(torch.sum(v * v, dim=-1), 0, 1 - eps)
sqdist = torch.sum(torch.pow(u - v, 2), dim=-1)
#ctx.eps = eps
#ctx.save_for_backward(u, v, squnorm, sqvnorm, sqdist)
x = sqdist / ((1 - squnorm) * (1 - sqvnorm)) * 2 + 1
# arcosh
z = torch.sqrt(torch.pow(x, 2) - 1)
return torch.log(x + z)
##### Self-Attention Relation Networks Recreation ######
class SelfAttnMat(nn.Module):
""" Visual appearance features to compute Self-attention Matrix
"""
def __init__(self, feat_dim=2048, proj_dim=256, T=1.0, use_poincare=False):
super(SelfAttnMat, self).__init__()
self.proj_dim = proj_dim
self.T = T
self.sparsemax = Sparsemax(dim=2)
self.d_k_sqrt = math.sqrt(self.proj_dim)
self.proj_w1 = nn.Conv2d(feat_dim, self.proj_dim, 1, stride=1) # [feat_dim x proj_dim]
self.proj_w2 = nn.Conv2d(feat_dim, self.proj_dim, 1, stride=1)
self._init_weights()
self.use_poincare = use_poincare
self.pm = poincare.PoincareManifold()
self.my_pc = Poincare()
# EDIT: scale and shift the distance on poincare disc
self.poinc_scale = nn.Parameter(torch.tensor([1.0]))
self.poinc_shift = nn.Parameter(torch.tensor([0.0]))
def _init_weights(self):
mynn.init.XavierFill(self.proj_w1.weight)
init.constant_(self.proj_w1.bias, 0)
mynn.init.XavierFill(self.proj_w2.weight)
init.constant_(self.proj_w2.bias, 0)
def forward(self, region_feature, num_imgs, iou_mat=[]):
""" Return adjacency matrix as scaled dot-product self-attention """
# Send scaled (or zero) gradients to rest of net
region_feature = GradScaler(scaler=cfg.TRAIN.GRL_SCALER)(region_feature)
# Project down the region features [n_img*n_region x n_dim x 1 x 1]
feat_key = self.proj_w1(region_feature)
feat_query = self.proj_w2(region_feature)
# Reshape from (n_img*n_region, n_dim, 1, 1) to (n_img, n_region, n_dim, 1, 1)
sz = feat_key.shape
feat_key = feat_key.view(num_imgs, int(sz[0]/num_imgs), sz[1], sz[2], sz[3])
feat_query = feat_query.view(num_imgs, int(sz[0]/num_imgs), sz[1], sz[2], sz[3])
use_poincare = self.use_poincare
if use_poincare:
import time;start = time.time()
device_id = feat_key.get_device()
n_img = feat_key.shape[0]
n_region = feat_key.shape[1]
R_new = []
#A = torch.zeros((n_region,n_region))
for im in range(n_img):
u = feat_key[im].squeeze(-1).squeeze(-1) # [n_region x n_dim x 1 x 1] -> [n_region x n_dim]
v = feat_query[im].squeeze(-1).squeeze(-1) # [n_region x n_dim x 1 x 1] -> [n_region x n_dim]
# normalize to unit ball
u = F.normalize(u,p=1,dim=1)
v = F.normalize(v,p=1,dim=1)
# slow version: iterate through regions
for i in range(n_region):
A[i,:] = self.pm.distance(u[i,:].unsqueeze(0).expand(n_region,u.shape[1]),v)
# slow version sped up with multiprocessing -- pytorch DataLoader threads cry
#pool = Parallel(n_jobs=2)(
# delayed(self.pm.distance)(u[i,:].unsqueeze(0).expand(n_region,u.shape[1]),v) for i in range(n_region)
# )
#A = [(i,self.pm.distance(u[i,:].unsqueeze(0).expand(n_region,u.shape[1]),v)) for i in range(n_region)]
#import pdb; pdb.set_trace();
# broadcast version -- poincare grad throws an error
#A = self.pm.distance(u.unsqueeze(1),v.unsqueeze(1).transpose(0,1))
# broadcast with my Poincare (above): directly uses torch autograd, not the poincare grad
#A = self.my_pc(u.unsqueeze(1),v.unsqueeze(1).transpose(0,1))
R_new.append(A.unsqueeze(0))
del A
R_new = torch.cat(R_new).cuda(device_id) # [n_img x n_region x n_region]
R_new = (1 - R_new)
end = time.time(); print('==',end - start)
else:
feat_key = feat_key.unsqueeze(2) # [n_img x n_region x 1 x n_dim x 1 x 1]
feat_query = feat_query.unsqueeze(2) # [n_img x n_region x 1 x n_dim x 1 x 1]
feat_query = feat_query.transpose(1, 2) # [n_img x 1 x n_region x n_dim x 1 x 1]
# broadcast: [n_img x n_region x n_region x n_dim x 1 x 1]
R_new = feat_key * feat_query
R_new = R_new.squeeze(-1).squeeze(-1) # [n_img x n_region x n_region x n_dim]
R_new = R_new.sum(3) / self.d_k_sqrt # self-attention/relation network has sqrt(d_k)
if cfg.TRAIN.ATTN_IOU_THRESH:
# mask out region pairs with IoU > TRAIN.IOU_THRESH
assert num_imgs == 1 # TODO - extend to multiple images
assert len(iou_mat) > 0
device_id = R_new.get_device()
mask = (iou_mat < cfg.TRAIN.IOU_THRESH).astype('float32')
np.fill_diagonal(mask, 1.0)
mask = Variable(torch.from_numpy(mask), requires_grad=False).cuda(device_id)
mask = mask.view(R_new.shape)
R_new = R_new * mask
R_new = R_new.contiguous()
R_new = R_new / (self.T) # softmax temperature
if cfg.TRAIN.SPARSEMAX:
out = self.sparsemax(R_new)
else:
if cfg.TRAIN.DROPOUT > 0:
R_new = F.dropout(R_new,p=cfg.TRAIN.DROPOUT,inplace=True)
if use_poincare:
#out = (R_new / R_new.sum(dim=2))
out = (R_new / R_new.sum(dim=2).unsqueeze(2))
#import pdb; pdb.set_trace();
# EDIT: If you are using "softmax" poincare
# R_new = (-self.poinc_scale * R_new) + self.poinc_shift
# Then do softmax instead of row-sum 1
else:
out = F.softmax(R_new, 2)
return out
class SelfAttn_basic(nn.Module):
""" Self Attention Network for visual context """
def __init__(self, num_A=1, feat_dim=2048, input_feat=2048, output_feat=2048,
visual_proj_dim=256, combine='add'):
"""
num_A - Number of adjacency matrices (multi attention heads)
feat_dim - Size of "appearance" features for each region (roi)
input_feat - Size of ROI-pooled features (can be different from feat_dim)
output_features - Size of the output from each attention head
"""
super(SelfAttn_basic, self).__init__()
self.num_A = num_A
self.proj_dim = visual_proj_dim
self.output_feat = int(output_feat / self.num_A)
self.combine = combine
if cfg.TRAIN.CONTEXT_BBOX:
feat_dim += 4 # bbox coords are appended to visual feat
if self.num_A >= 1:
assert cfg.TRAIN.ATTN_W # multi-heads need down-projection with W
for i in range(self.num_A):
module_AdjMat = SelfAttnMat(feat_dim=feat_dim,
proj_dim=self.proj_dim,
T=cfg.TRAIN.SOFTMAX_T[i],
use_poincare=(i in cfg.TRAIN.POINCARE))
self.add_module('compute_AdjMat{}'.format(i), module_AdjMat)
linear_out = nn.Conv2d(input_feat, self.output_feat, 1, stride=1)
self._init_weights_multi(linear_out)
self.add_module('linear_out{}'.format(i), linear_out)
else:
raise ValueError
def _init_weights_multi(self, linear_out):
mynn.init.XavierFill(linear_out.weight)
init.constant_(linear_out.bias, 0)
def forward(self, visual_feature, x, num_imgs, iou_mat=[], bboxes=[]):
"""
Returns features incorporating visual context from all other rois
visual_feature - appearance feature tensor [num_rois, feat_dim, 1, 1]
x - region (box) feature [num_rois, feat_dim, 1, 1]
num_imgs - number of images per batch
iou_mat - (optional) IoU between regions [num_rois, num_rois]
Visual feature and "x" can be from same or different CNN layers.
Image IDs are typically obtained from rpn_ret['rois'] in model_builder.py
"""
# Scale-down or zero-out gradients to rest of the (pre-trained) network
x = GradScaler(scaler=cfg.TRAIN.GRL_SCALER)(x)
num_rois = int(visual_feature.shape[0] / num_imgs)
if cfg.TRAIN.ATTN_IOU_THRESH:
assert len(iou_mat) > 0
if cfg.TRAIN.CONTEXT_BBOX:
bboxes = bboxes.unsqueeze(-1).unsqueeze(-1).cuda(visual_feature.get_device())
visual_feature = torch.cat((visual_feature, bboxes), 1)
visual_feature = GradScaler(scaler=cfg.TRAIN.GRL_SCALER)(visual_feature)
z = []
for i in range(self.num_A):
# z = A.X.W
A_i = self._modules['compute_AdjMat{}'.format(i)](visual_feature, num_imgs,
iou_mat=iou_mat)
z_i = torch.bmm(A_i, x.view(num_imgs, num_rois, -1, 1, 1).squeeze(-1).squeeze(-1))
z_i = z_i.view(num_imgs*num_rois, -1)
z_i = z_i.unsqueeze(-1).unsqueeze(-1)
z_i = self._modules['linear_out{}'.format(i)](z_i) # [n_img*n_region, output_features, 1, 1]
z.append(z_i)
z = torch.cat(z, 1) # [n_img*n_region, num_A*output_features, 1, 1]
if self.combine == 'add':
y = x + z
elif self.combine == 'concat':
y = torch.cat([x,z], 1)
else:
raise NotImplementedError
y = F.relu(y, inplace=True)
return y
##### END: Self-Attention Relation Networks Recreation ######
def _gen_timing_signal(length, channels=64, min_timescale=1.0, max_timescale=1.0e3):
"""
Generates a [1, length, channels] timing signal consisting of sinusoids
Adapted from:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/common_attention.py
"""
position = np.arange(length)
num_timescales = channels // 2
log_timescale_increment = (
math.log(float(max_timescale) / float(min_timescale)) /
(float(num_timescales) - 1))
inv_timescales = min_timescale * np.exp(
np.arange(num_timescales).astype(np.float) * -log_timescale_increment)
scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales, 0)
signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1)
signal = np.pad(signal, [[0, 0], [0, channels % 2]],
'constant', constant_values=[0.0, 0.0])
signal = signal.reshape([1, length, channels])
return torch.from_numpy(signal).type(torch.FloatTensor)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment