Last active April 24, 2019 19:07
Allow shifts and scales of Poincare distance which usually lies on the unit disc
import numpy as np
import itertools
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable
from torch.autograd import Function
from scipy.spatial.distance import pdist
from core.config import cfg
import nn as mynn
import as net_utils
import numpy as np
import math
import gc
from joblib import Parallel,delayed
from modeling.sparse_activations import Sparsemax
from poincare_embeddings.hype import poincare
from hyperbolic_cones import my_poincare_model as mpm
DEBUG = False
class GradScaler(Function):
Gradient scaler layer
Based off:
def __init__(self, scaler=0.0):
self.scaler = scaler
def forward(self, x):
return x.view_as(x)
def backward(self, grad_output):
return (grad_output * self.scaler)
def grad_scale(x):
return GradScaler()(x)
class Poincare(nn.Module):
def __init__(self):
super(Poincare, self).__init__()
self.eps = 1e-5
def forward(self, u, v):
eps = self.eps
squnorm = torch.clamp(torch.sum(u * u, dim=-1), 0, 1 - eps)
sqvnorm = torch.clamp(torch.sum(v * v, dim=-1), 0, 1 - eps)
sqdist = torch.sum(torch.pow(u - v, 2), dim=-1)
#ctx.eps = eps
#ctx.save_for_backward(u, v, squnorm, sqvnorm, sqdist)
x = sqdist / ((1 - squnorm) * (1 - sqvnorm)) * 2 + 1
# arcosh
z = torch.sqrt(torch.pow(x, 2) - 1)
return torch.log(x + z)
##### Self-Attention Relation Networks Recreation ######
class SelfAttnMat(nn.Module):
""" Visual appearance features to compute Self-attention Matrix
def __init__(self, feat_dim=2048, proj_dim=256, T=1.0, use_poincare=False):
super(SelfAttnMat, self).__init__()
self.proj_dim = proj_dim
self.T = T
self.sparsemax = Sparsemax(dim=2)
self.d_k_sqrt = math.sqrt(self.proj_dim)
self.proj_w1 = nn.Conv2d(feat_dim, self.proj_dim, 1, stride=1) # [feat_dim x proj_dim]
self.proj_w2 = nn.Conv2d(feat_dim, self.proj_dim, 1, stride=1)
self.use_poincare = use_poincare = poincare.PoincareManifold()
self.my_pc = Poincare()
# EDIT: scale and shift the distance on poincare disc
self.poinc_scale = nn.Parameter(torch.tensor([1.0]))
self.poinc_shift = nn.Parameter(torch.tensor([0.0]))
def _init_weights(self):
init.constant_(self.proj_w1.bias, 0)
init.constant_(self.proj_w2.bias, 0)
def forward(self, region_feature, num_imgs, iou_mat=[]):
""" Return adjacency matrix as scaled dot-product self-attention """
# Send scaled (or zero) gradients to rest of net
region_feature = GradScaler(scaler=cfg.TRAIN.GRL_SCALER)(region_feature)
# Project down the region features [n_img*n_region x n_dim x 1 x 1]
feat_key = self.proj_w1(region_feature)
feat_query = self.proj_w2(region_feature)
# Reshape from (n_img*n_region, n_dim, 1, 1) to (n_img, n_region, n_dim, 1, 1)
sz = feat_key.shape
feat_key = feat_key.view(num_imgs, int(sz[0]/num_imgs), sz[1], sz[2], sz[3])
feat_query = feat_query.view(num_imgs, int(sz[0]/num_imgs), sz[1], sz[2], sz[3])
use_poincare = self.use_poincare
if use_poincare:
import time;start = time.time()
device_id = feat_key.get_device()
n_img = feat_key.shape[0]
n_region = feat_key.shape[1]
R_new = []
#A = torch.zeros((n_region,n_region))
for im in range(n_img):
u = feat_key[im].squeeze(-1).squeeze(-1) # [n_region x n_dim x 1 x 1] -> [n_region x n_dim]
v = feat_query[im].squeeze(-1).squeeze(-1) # [n_region x n_dim x 1 x 1] -> [n_region x n_dim]
# normalize to unit ball
u = F.normalize(u,p=1,dim=1)
v = F.normalize(v,p=1,dim=1)
# slow version: iterate through regions
for i in range(n_region):
A[i,:] =[i,:].unsqueeze(0).expand(n_region,u.shape[1]),v)
# slow version sped up with multiprocessing -- pytorch DataLoader threads cry
#pool = Parallel(n_jobs=2)(
# delayed([i,:].unsqueeze(0).expand(n_region,u.shape[1]),v) for i in range(n_region)
# )
#A = [(i,[i,:].unsqueeze(0).expand(n_region,u.shape[1]),v)) for i in range(n_region)]
#import pdb; pdb.set_trace();
# broadcast version -- poincare grad throws an error
#A =,v.unsqueeze(1).transpose(0,1))
# broadcast with my Poincare (above): directly uses torch autograd, not the poincare grad
#A = self.my_pc(u.unsqueeze(1),v.unsqueeze(1).transpose(0,1))
del A
R_new = # [n_img x n_region x n_region]
R_new = (1 - R_new)
end = time.time(); print('==',end - start)
feat_key = feat_key.unsqueeze(2) # [n_img x n_region x 1 x n_dim x 1 x 1]
feat_query = feat_query.unsqueeze(2) # [n_img x n_region x 1 x n_dim x 1 x 1]
feat_query = feat_query.transpose(1, 2) # [n_img x 1 x n_region x n_dim x 1 x 1]
# broadcast: [n_img x n_region x n_region x n_dim x 1 x 1]
R_new = feat_key * feat_query
R_new = R_new.squeeze(-1).squeeze(-1) # [n_img x n_region x n_region x n_dim]
R_new = R_new.sum(3) / self.d_k_sqrt # self-attention/relation network has sqrt(d_k)
# mask out region pairs with IoU > TRAIN.IOU_THRESH
assert num_imgs == 1 # TODO - extend to multiple images
assert len(iou_mat) > 0
device_id = R_new.get_device()
mask = (iou_mat < cfg.TRAIN.IOU_THRESH).astype('float32')
np.fill_diagonal(mask, 1.0)
mask = Variable(torch.from_numpy(mask), requires_grad=False).cuda(device_id)
mask = mask.view(R_new.shape)
R_new = R_new * mask
R_new = R_new.contiguous()
R_new = R_new / (self.T) # softmax temperature
out = self.sparsemax(R_new)
if cfg.TRAIN.DROPOUT > 0:
R_new = F.dropout(R_new,p=cfg.TRAIN.DROPOUT,inplace=True)
if use_poincare:
#out = (R_new / R_new.sum(dim=2))
out = (R_new / R_new.sum(dim=2).unsqueeze(2))
#import pdb; pdb.set_trace();
# EDIT: If you are using "softmax" poincare
# R_new = (-self.poinc_scale * R_new) + self.poinc_shift
# Then do softmax instead of row-sum 1
out = F.softmax(R_new, 2)
return out
class SelfAttn_basic(nn.Module):
""" Self Attention Network for visual context """
def __init__(self, num_A=1, feat_dim=2048, input_feat=2048, output_feat=2048,
visual_proj_dim=256, combine='add'):
num_A - Number of adjacency matrices (multi attention heads)
feat_dim - Size of "appearance" features for each region (roi)
input_feat - Size of ROI-pooled features (can be different from feat_dim)
output_features - Size of the output from each attention head
super(SelfAttn_basic, self).__init__()
self.num_A = num_A
self.proj_dim = visual_proj_dim
self.output_feat = int(output_feat / self.num_A)
self.combine = combine
feat_dim += 4 # bbox coords are appended to visual feat
if self.num_A >= 1:
assert cfg.TRAIN.ATTN_W # multi-heads need down-projection with W
for i in range(self.num_A):
module_AdjMat = SelfAttnMat(feat_dim=feat_dim,
use_poincare=(i in cfg.TRAIN.POINCARE))
self.add_module('compute_AdjMat{}'.format(i), module_AdjMat)
linear_out = nn.Conv2d(input_feat, self.output_feat, 1, stride=1)
self.add_module('linear_out{}'.format(i), linear_out)
raise ValueError
def _init_weights_multi(self, linear_out):
init.constant_(linear_out.bias, 0)
def forward(self, visual_feature, x, num_imgs, iou_mat=[], bboxes=[]):
Returns features incorporating visual context from all other rois
visual_feature - appearance feature tensor [num_rois, feat_dim, 1, 1]
x - region (box) feature [num_rois, feat_dim, 1, 1]
num_imgs - number of images per batch
iou_mat - (optional) IoU between regions [num_rois, num_rois]
Visual feature and "x" can be from same or different CNN layers.
Image IDs are typically obtained from rpn_ret['rois'] in
# Scale-down or zero-out gradients to rest of the (pre-trained) network
x = GradScaler(scaler=cfg.TRAIN.GRL_SCALER)(x)
num_rois = int(visual_feature.shape[0] / num_imgs)
assert len(iou_mat) > 0
bboxes = bboxes.unsqueeze(-1).unsqueeze(-1).cuda(visual_feature.get_device())
visual_feature =, bboxes), 1)
visual_feature = GradScaler(scaler=cfg.TRAIN.GRL_SCALER)(visual_feature)
z = []
for i in range(self.num_A):
# z = A.X.W
A_i = self._modules['compute_AdjMat{}'.format(i)](visual_feature, num_imgs,
z_i = torch.bmm(A_i, x.view(num_imgs, num_rois, -1, 1, 1).squeeze(-1).squeeze(-1))
z_i = z_i.view(num_imgs*num_rois, -1)
z_i = z_i.unsqueeze(-1).unsqueeze(-1)
z_i = self._modules['linear_out{}'.format(i)](z_i) # [n_img*n_region, output_features, 1, 1]
z =, 1) # [n_img*n_region, num_A*output_features, 1, 1]
if self.combine == 'add':
y = x + z
elif self.combine == 'concat':
y =[x,z], 1)
raise NotImplementedError
y = F.relu(y, inplace=True)
return y
##### END: Self-Attention Relation Networks Recreation ######
def _gen_timing_signal(length, channels=64, min_timescale=1.0, max_timescale=1.0e3):
Generates a [1, length, channels] timing signal consisting of sinusoids
Adapted from:
position = np.arange(length)
num_timescales = channels // 2
log_timescale_increment = (
math.log(float(max_timescale) / float(min_timescale)) /
(float(num_timescales) - 1))
inv_timescales = min_timescale * np.exp(
np.arange(num_timescales).astype(np.float) * -log_timescale_increment)
scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales, 0)
signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1)
signal = np.pad(signal, [[0, 0], [0, channels % 2]],
'constant', constant_values=[0.0, 0.0])
signal = signal.reshape([1, length, channels])
return torch.from_numpy(signal).type(torch.FloatTensor)
