Skip to content

Instantly share code, notes, and snippets.

@ShoufaChen
Created August 12, 2022 11:03
Show Gist options
  • Save ShoufaChen/263eaf55599c6e884584d7fce445af45 to your computer and use it in GitHub Desktop.
Save ShoufaChen/263eaf55599c6e884584d7fce445af45 to your computer and use it in GitHub Desktop.
# Modified by Shoufa Chen,
import math
import random
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from fvcore.nn import sigmoid_focal_loss_jit
from slowfast.models.losses import focal_loss_wo_logits_jit
from detectron2.modeling.poolers import ROIPooler
from detectron2.structures import Boxes
from slowfast.datasets.cv2_transform import clip_boxes_tensor
_DEFAULT_SCALE_CLAMP = math.log(100000.0 / 16)
class ResNetRoIHead(nn.Module):
"""
ResNe(X)t RoI head.
"""
def __init__(
self,
cfg,
dim_in,
num_classes,
pool_size,
resolution,
scale_factor,
dropout_rate=0.0,
act_func="softmax",
aligned=True,
dim_before_proj=2048,
use_fpn=False,
):
"""
The `__init__` method of any subclass should also contain these
arguments.
ResNetRoIHead takes p pathways as input where p in [1, infty].
Args:
dim_in (list): the list of channel dimensions of the p inputs to the
ResNetHead.
num_classes (int): the channel dimensions of the p outputs to the
ResNetHead.
pool_size (list): the list of kernel sizes of p spatial temporal
poolings, temporal pool kernel size, spatial pool kernel size,
spatial pool kernel size in order.
resolution (list): the list of spatial output size from the ROIAlign.
scale_factor (list): the list of ratio to the input boxes by this
number.
dropout_rate (float): dropout rate. If equal to 0.0, perform no
dropout.
act_func (string): activation function to use. 'softmax': applies
softmax on the output. 'sigmoid': applies sigmoid on the output.
aligned (bool): if False, use the legacy implementation. If True,
align the results more perfectly.
Note:
Given a continuous coordinate c, its two neighboring pixel indices
(in our pixel model) are computed by floor (c - 0.5) and ceil
(c - 0.5). For example, c=1.3 has pixel neighbors with discrete
indices [0] and [1] (which are sampled from the underlying signal at
continuous coordinates 0.5 and 1.5). But the original roi_align
(aligned=False) does not subtract the 0.5 when computing neighboring
pixel indices and therefore it uses pixels with a slightly incorrect
alignment (relative to our pixel model) when performing bilinear
interpolation.
With `aligned=True`, we first appropriately scale the ROI and then
shift it by -0.5 prior to calling roi_align. This produces the
correct neighbors; It makes negligible differences to the model's
performance if ROIAlign is used together with conv layers.
"""
super(ResNetRoIHead, self).__init__()
assert (
len({len(pool_size), len(dim_in)}) == 1
), "pathway dimensions are not consistent."
self.cfg = cfg
self.use_fpn = use_fpn
self.gt_boxes_prob = cfg.MODEL.SparseRCNN.GT_BOXES_PROB
self.num_pathways = len(pool_size)
self.device = torch.device(cfg.MODEL.DEVICE)
self.use_action_heads = cfg.MODEL.SparseRCNN.NUM_ACT_HEADS > 0
# move conv1x1 dim_before_proj-->256 from backbone to head
if self.use_action_heads:
self.proj_to_256 = nn.Conv3d(dim_before_proj, 256, kernel_size=1)
for pathway in range(self.num_pathways):
temporal_pool = nn.AvgPool3d(
[pool_size[pathway][0], 1, 1], stride=1
)
self.add_module("s{}_tpool".format(pathway), temporal_pool)
pooler = ROIPooler(
output_size=resolution[pathway],
scales=[1.0 / scale_factor[pathway]],
sampling_ratio=2,
pooler_type=cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE
)
self.add_module("s{}_roi".format(pathway), pooler)
if self.use_fpn:
keyframe_pooler = self._init_box_pooler(cfg)
self.add_module("s{}_keyroi".format(pathway), keyframe_pooler)
if pathway == 0:
rcnn_head = RCNNHead(cfg)
head_series = _get_clones(rcnn_head, cfg.MODEL.SparseRCNN.NUM_HEADS)
self.add_module("s{}_headseries".format(pathway), head_series)
if self.use_action_heads:
act_rcnn_head = RCNNHead(cfg, origin=False)
act_head_series = _get_clones(act_rcnn_head, cfg.MODEL.SparseRCNN.NUM_ACT_HEADS)
self.add_module("s{}_actheadseries".format(pathway), act_head_series)
temp_head = RCNNHead3D(cfg)
temp_head_series = _get_clones(temp_head, cfg.MODEL.SparseRCNN.NUM_ACT_HEADS)
self.add_module("s{}_tempheadseries".format(pathway), temp_head_series)
spatial_pool = nn.MaxPool2d(resolution[pathway], stride=1)
self.add_module("s{}_spool".format(pathway), spatial_pool)
if self.num_pathways == 2 and self.use_action_heads:
proj = nn.Conv2d(512, 256, kernel_size=1)
self.add_module("concat_proj", proj)
self.return_intermediate = cfg.MODEL.SparseRCNN.DEEP_SUPERVISION
self.use_focal = cfg.MODEL.SparseRCNN.USE_FOCAL
self.num_classes = num_classes
if not self.use_action_heads:
if dropout_rate > 0:
self.dropout = nn.Dropout(dropout_rate)
self.ori_projection = nn.Linear(sum(dim_in), num_classes, bias=True)
if self.use_focal:
prior_prob = cfg.MODEL.SparseRCNN.PRIOR_PROB
self.bias_value = -math.log((1 - prior_prob) / prior_prob)
self._reset_parameters()
@staticmethod
def _init_box_pooler(cfg):
pooler_resolution = 7
pooler_scales = tuple([1.0 / x for x in [4, 8, 16, 32]])
sampling_ratio = 2
pooler_type = cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE
box_pooler = ROIPooler(
output_size=pooler_resolution,
scales=pooler_scales,
sampling_ratio=sampling_ratio,
pooler_type=pooler_type,
)
return box_pooler
def _reset_parameters(self):
# init all parameters.
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, images_whwh, fpn_features, inputs, init_bboxes, init_features, act_init_features, temp_init_features, criterion=None, targets=None):
assert (
len(inputs) == self.num_pathways
), "Input tensor does not contain {} pathway".format(self.num_pathways)
inter_class_logits = []
inter_action_logits = []
inter_pred_bboxes = []
# reduce
feat_pre_reduce = [feat for feat in inputs]
if self.use_action_heads:
inputs[0] = self.proj_to_256(inputs[0])
bs = len(feat_pre_reduce[0]) # one pathway first dim
bboxes = init_bboxes
# (100, 256) -> (1, 100 * bs, 256)
init_features = init_features[None].repeat(1, bs, 1)
proposal_features = init_features.clone()
if self.use_action_heads:
act_init_features = act_init_features[None].repeat(1, bs, 1)
act_proposal_features = act_init_features.clone()
temp_init_features = temp_init_features[None].repeat(1, bs, 1)
temp_proposal_features = temp_init_features.clone()
if self.use_fpn:
keyframe = fpn_features
else:
# we only consider keyframe from Slow pathway
assert fpn_features is None, "Check Logic"
num_frame = inputs[0].shape[2]
if self.cfg.MODEL.SparseRCNN.KEYWAY:
raise ValueError("Use FPN feature, or below KEYWAY maybe bug")
keyframe = [inputs[0][:, :, -1]]
inputs[0] = inputs[0][:, :, :-1]
feat_pre_reduce[0] = feat_pre_reduce[0][:, :, :-1]
else:
keyframe = [inputs[0][:, :, num_frame//2]]
pool_out = []
for pathway in range(self.num_pathways):
t_pool = getattr(self, "s{}_tpool".format(pathway))
out = t_pool(inputs[pathway] if self.use_action_heads else feat_pre_reduce[pathway])
assert out.shape[2] == 1
out = torch.squeeze(out, 2)
pool_out.append(out)
if self.use_action_heads:
out = torch.cat(pool_out, dim=1)
if self.num_pathways == 2:
out = self.concat_proj(out)
# code below this line, we assume pathway is 0. For SlowFast, we have concat the two pathway into a single way
pathway = 0
key_roi_align = getattr(self, "s{}_keyroi".format(pathway)) if self.use_fpn else getattr(self, "s{}_roi".format(pathway)) # noqa
for rcnn_head in getattr(self, "s{}_headseries".format(pathway)):
class_logits, pred_bboxes, proposal_features, jitter_pred_bboxes = rcnn_head(keyframe, bboxes, proposal_features, key_roi_align, images_whwh=images_whwh)
if self.return_intermediate:
inter_class_logits.append(class_logits)
inter_pred_bboxes.append(pred_bboxes)
bboxes = pred_bboxes.detach()
if self.cfg.MODEL.SparseRCNN.JITTER_BOX:
ava_box = jitter_pred_bboxes.detach()
else:
ava_box = bboxes
roi_align = getattr(self, "s{}_roi".format(pathway))
if self.training:
# fork person detector loss, matching indices, idx
losses, indices, idx = self.person_detector_loss(inter_class_logits, inter_pred_bboxes, criterion, targets)
# Use GT boxes to replace the corresponding position predicted box, with probability self.gt_boxes_prob
if random.random() < self.gt_boxes_prob: # random.random() uniform ( 0 inclusive, 1 exclusive)
ava_box = ava_box.clone()
ava_box[idx] = torch.cat([t['boxes_xyxy'][i] for t, (_, i) in zip(targets, indices)], dim=0)
if self.use_action_heads:
for act_rcnn_head, temp_head in zip(getattr(self, "s{}_actheadseries".format(pathway)),
getattr(self, "s{}_tempheadseries".format(pathway))):
# for two pathway, we use the fast pathway (inputs[-1]) as the source of temporal feature
temp_helper, temp_proposal_features = temp_head(inputs[-1], ava_box, temp_proposal_features, roi_align)
action_logits, act_proposal_features = act_rcnn_head([out], ava_box, act_proposal_features, roi_align, temp_helper)
if self.return_intermediate:
inter_action_logits.append(action_logits)
else:
N, nr_boxes = bboxes.shape[:2]
s_pool_out = []
proposal_boxes = [Boxes(b) for b in bboxes]
for i, po in enumerate(pool_out):
roi_align = getattr(self, "s{}_roi".format(i))
out = roi_align([po], proposal_boxes)
s_pool_out.append(F.adaptive_max_pool2d(out, output_size=(1, 1)))
x = torch.cat(s_pool_out, 1)
if hasattr(self, "dropout"):
x = self.dropout(x)
x = x.view(N, nr_boxes, -1)
action_logits = self.ori_projection(x)
inter_action_logits.append(action_logits)
if self.training:
if self.cfg.MODEL.SparseRCNN.JHMDB_LOSS:
act_loss = self.jhmdb_act_loss(inter_action_logits, targets, indices, idx)
else:
act_loss = self.action_cls_loss(inter_action_logits, targets, indices, idx)
losses.update(act_loss)
return losses
# eval
return dict(pred_logits=inter_class_logits[-1],
pred_boxes=inter_pred_bboxes[-1],
pred_actions=inter_action_logits[-1])
def person_detector_loss(self, outputs_class, outputs_coord, criterion, targets):
if self.return_intermediate:
output = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1],
'aux_outputs': [{'pred_logits': a, 'pred_boxes': b}
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]}
else:
raise NotImplementedError
loss_dict, indices, idx = criterion(output, targets)
return loss_dict, indices, idx
def action_cls_loss(self, output_action, targets, indices, idx):
losses = {}
target_actions_o = torch.cat([t["actions"][J] for t, (_, J) in zip(targets, indices)])
for i, action_logits in enumerate(output_action):
action = action_logits[idx]
if not self.cfg.MODEL.SparseRCNN.SOFTMAX_POSE:
if self.cfg.MODEL.LOSS_FUNC == 'focal_action':
act_loss = sigmoid_focal_loss_jit(action, target_actions_o, alpha=0.25, reduction='mean')
else:
act_loss = F.binary_cross_entropy_with_logits(action, target_actions_o) # remove Sigmoid in model
else:
pose_pred = F.softmax(action[:, :14], dim=-1) # first 14 is pose label
other_pred = F.sigmoid(action[:, 14:])
action = torch.cat([pose_pred, other_pred], dim=-1)
if self.cfg.MODEL.LOSS_FUNC == 'focal_action':
act_loss = focal_loss_wo_logits_jit(action, target_actions_o, alpha=0.25, reduction='mean')
else:
act_loss = F.binary_cross_entropy(action, target_actions_o)
losses.update({'loss_bce' + f'_{i}': act_loss})
losses['loss_bce'] = losses.pop('loss_bce' + f'_{i}') # modify the last loss key, making AVA meter happy
return losses
def jhmdb_act_loss(self, output_action, targets, indices, idx):
losses = {}
target_actions_o = torch.cat([t["actions"][J] for t, (_, J) in zip(targets, indices)])
for i, action_logits in enumerate(output_action):
action = action_logits[idx]
label = target_actions_o.argmax(dim=1)
act_loss = F.cross_entropy(action, label)
losses.update({'loss_bce' + f'_{i}': act_loss})
losses['loss_bce'] = losses.pop('loss_bce' + f'_{i}') # modify the last loss key, making AVA meter happy
return losses
class RCNNHead(nn.Module):
def __init__(self, cfg, scale_clamp: float = _DEFAULT_SCALE_CLAMP,
bbox_weights=(2.0, 2.0, 1.0, 1.0), origin=True):
super().__init__()
d_model = cfg.MODEL.SparseRCNN.HIDDEN_DIM
num_classes = cfg.MODEL.SparseRCNN.NUM_CLASSES
num_actions = cfg.MODEL.NUM_CLASSES
dim_feedforward = cfg.MODEL.SparseRCNN.DIM_FEEDFORWARD
nhead = cfg.MODEL.SparseRCNN.NHEADS
dropout = cfg.MODEL.SparseRCNN.DROPOUT
activation = cfg.MODEL.SparseRCNN.ACTIVATION
self.d_model = d_model
self.jitter_box = cfg.MODEL.SparseRCNN.JITTER_BOX
# dynamic.
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.inst_interact = DynamicConv(cfg)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.is_origin = origin # origin = False for action_proposal_feature head
self.combine_mode = cfg.MODEL.SparseRCNN.ST_COMBINE
# cls.
if self.is_origin:
num_cls = cfg.MODEL.SparseRCNN.NUM_CLS
cls_module = list()
for _ in range(num_cls):
cls_module.append(nn.Linear(d_model, d_model, False))
cls_module.append(nn.LayerNorm(d_model))
cls_module.append(nn.ReLU(inplace=True))
self.cls_module = nn.ModuleList(cls_module)
# reg.
num_reg = cfg.MODEL.SparseRCNN.NUM_REG
reg_module = list()
for _ in range(num_reg):
reg_module.append(nn.Linear(d_model, d_model, False))
reg_module.append(nn.LayerNorm(d_model))
reg_module.append(nn.ReLU(inplace=True))
self.reg_module = nn.ModuleList(reg_module)
else:
# act.
if self.combine_mode == 'MHA':
self.st_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
if self.combine_mode == 'concat':
fc_concat = 2
else:
fc_concat = 1
num_act = cfg.MODEL.SparseRCNN.NUM_ACT
assert num_act > 0, "at least 1, but got num_act {}".format(num_act)
act_dim = cfg.MODEL.SparseRCNN.ACT_FC_DIM
act_module = list()
act_module.append(nn.Linear(d_model * fc_concat, act_dim, False))
act_module.append(nn.LayerNorm(act_dim))
act_module.append(nn.ReLU(inplace=True))
for _ in range(num_act - 1):
act_module.append(nn.Linear(act_dim, act_dim, False))
act_module.append(nn.LayerNorm(act_dim))
act_module.append(nn.ReLU(inplace=True))
self.act_module = nn.ModuleList(act_module)
# pred.
self.use_focal = cfg.MODEL.SparseRCNN.USE_FOCAL
if self.use_focal:
self.class_logits = nn.Linear(d_model, num_classes)
raise NotImplementedError
else:
assert num_classes == 1, "Check Person Detector num_classes {}".format(num_classes)
if self.is_origin:
self.class_logits = nn.Linear(d_model, num_classes + 1)
else:
self.action_logits = nn.Linear(act_dim, num_actions)
# self.act = nn.Sigmoid()
if self.is_origin:
self.bboxes_delta = nn.Linear(d_model, 4)
self.scale_clamp = scale_clamp
self.bbox_weights = bbox_weights
def forward(self, features, bboxes, pro_features, pooler, temp_helper=None, images_whwh=None):
"""
:param bboxes: (N, nr_boxes, 4)
:param pro_features: (N, nr_boxes, d_model)
"""
N, nr_boxes = bboxes.shape[:2]
# roi_feature.
proposal_boxes = list()
for b in range(N):
proposal_boxes.append(Boxes(bboxes[b]))
# proposal_boxes: List[Boxes], Boxes(100); features: List[Tensor] (N, d_model, H', W')
""" where M is the total number of boxes aggregated over all N batch images.
batch first because every image may have different boxes
"""
roi_features = pooler(features, proposal_boxes) # roi_features (M=N*nr_boxes, d_model, 7, 7)
roi_features = roi_features.view(N * nr_boxes, self.d_model, -1).permute(2, 0, 1)
# self_att.
pro_features = pro_features.view(N, nr_boxes, self.d_model).permute(1, 0, 2) # (nr_boxes, N, d_model)
pro_features2 = self.self_attn(pro_features, pro_features, value=pro_features)[0]
pro_features = pro_features + self.dropout1(pro_features2)
pro_features = self.norm1(pro_features)
# inst_interact. (nr_boxes, N, d_model) => (N, nr_boxes, d_model) => (1, N*nr_boxes, d_model)
pro_features = pro_features.view(nr_boxes, N, self.d_model).permute(1, 0, 2).reshape(1, N * nr_boxes, self.d_model)
pro_features2 = self.inst_interact(pro_features, roi_features) # (N*nr_boxes, d_model)
pro_features = pro_features + self.dropout2(pro_features2) # broadcast (1, N*nr_boxes, d_model)
obj_features = self.norm2(pro_features)
# obj_feature.
obj_features2 = self.linear2(self.dropout(self.activation(self.linear1(obj_features))))
obj_features = obj_features + self.dropout3(obj_features2)
obj_features = self.norm3(obj_features) # (1, N*nr_boxes, d_model)
# (N*nr_boxes, 1, d_model) => (N*nr_boxes, d_model)
fc_feature = obj_features.transpose(0, 1).reshape(N * nr_boxes, -1)
if self.is_origin:
cls_feature = fc_feature.clone()
reg_feature = fc_feature.clone()
for cls_layer in self.cls_module:
cls_feature = cls_layer(cls_feature)
for reg_layer in self.reg_module:
reg_feature = reg_layer(reg_feature)
class_logits = self.class_logits(cls_feature)
bboxes_deltas = self.bboxes_delta(reg_feature)
pred_bboxes, jitter_pred_bboxes = self.apply_deltas(bboxes_deltas, bboxes.view(-1, 4), with_jitter=self.jitter_box, images_whwh=images_whwh, N=N, nr_boxes=nr_boxes)
if jitter_pred_bboxes is not None:
return class_logits.view(N, nr_boxes, -1), pred_bboxes.view(N, nr_boxes, -1), obj_features, jitter_pred_bboxes
else:
return class_logits.view(N, nr_boxes, -1), pred_bboxes.view(N, nr_boxes, -1), obj_features, None
else:
act_feature = self.combine_action_feat(fc_feature, temp_helper, N, nr_boxes).clone()
for act_layer in self.act_module:
act_feature = act_layer(act_feature)
action_logits = self.action_logits(act_feature)
return action_logits.view(N, nr_boxes, -1), obj_features
def combine_action_feat(self, spatio_feat, tempo_feat, N=None, nr_boxes=None):
if self.combine_mode == 'sum':
return spatio_feat + tempo_feat
elif self.combine_mode == 'concat':
return torch.cat([spatio_feat, tempo_feat], dim=-1)
elif self.combine_mode == 'MHA': # MultiHeadAttention
tempo_feat = tempo_feat.view(N, nr_boxes, self.d_model).permute(1, 0, 2) #(nr_boxex, N, self.d_model)
spatio_feat = spatio_feat.view(N, nr_boxes, self.d_model).permute(1, 0, 2)
st_feature = self.st_attn(tempo_feat, spatio_feat, value=spatio_feat)[0]
return st_feature.permute(1, 0, 2).reshape(N*nr_boxes, self.d_model) # view will cause error because contiguous
elif self.combine_mode == 'none':
return spatio_feat
else:
raise NotImplementedError("Check combine type {}".format(self.combine_mode))
def apply_deltas(self, deltas, boxes, with_jitter=False, images_whwh=None, N=None, nr_boxes=None):
"""
Apply transformation `deltas` (dx, dy, dw, dh) to `boxes`.
Args:
deltas (Tensor): transformation deltas of shape (N, k*4), where k >= 1.
deltas[i] represents k potentially different class-specific
box transformations for the single box boxes[i].
boxes (Tensor): boxes to transform, of shape (N, 4)
"""
boxes = boxes.to(deltas.dtype)
widths = boxes[:, 2] - boxes[:, 0]
heights = boxes[:, 3] - boxes[:, 1]
ctr_x = boxes[:, 0] + 0.5 * widths
ctr_y = boxes[:, 1] + 0.5 * heights
wx, wy, ww, wh = self.bbox_weights
dx = deltas[:, 0::4] / wx
dy = deltas[:, 1::4] / wy
dw = deltas[:, 2::4] / ww
dh = deltas[:, 3::4] / wh
# Prevent sending too large values into torch.exp()
dw = torch.clamp(dw, max=self.scale_clamp)
dh = torch.clamp(dh, max=self.scale_clamp)
pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
pred_w = torch.exp(dw) * widths[:, None]
pred_h = torch.exp(dh) * heights[:, None]
pred_boxes = torch.zeros_like(deltas)
pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w # x1
pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h # y1
pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w # x2
pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h # y2
if not with_jitter:
return pred_boxes, None
assert images_whwh is not None
jitter_pred_box = torch.zeros_like(deltas)
if not self.training:
# https://github.com/MVIG-SJTU/AlphAction/blob/master/alphaction/structures/bounding_box.py#L197
x_scale = 0.1
y_scale = 0.05
jitter_pred_box[:, 0::4] = pred_ctr_x - 0.5 * pred_w * (1 + x_scale)
jitter_pred_box[:, 1::4] = pred_ctr_y - 0.5 * pred_h * (1 + y_scale)
jitter_pred_box[:, 2::4] = pred_ctr_x + 0.5 * pred_w * (1 + x_scale)
jitter_pred_box[:, 3::4] = pred_ctr_y + 0.5 * pred_h * (1 + y_scale)
jitter_pred_box = jitter_pred_box.view(N, nr_boxes, -1)
for idx, (boxes_per_image, curr_whwh) in enumerate(zip(jitter_pred_box, images_whwh)):
jitter_pred_box[idx] = clip_boxes_tensor(boxes_per_image, curr_whwh[1], curr_whwh[0])
return pred_boxes, jitter_pred_box
else:
# https://github.com/MVIG-SJTU/AlphAction/blob/master/alphaction/structures/bounding_box.py#L226
jitter_x_out, jitter_x_in, jitter_y_out, jitter_y_in = 0.2, 0.1, 0.1, 0.05
device = pred_boxes.device
def torch_uniform(rows, a=0.0, b=1.0):
return torch.rand(rows, 1, dtype=torch.float32, device=device) * (b - a) + a
num_boxes = N * nr_boxes
jitter_pred_box[:, 0::4] = pred_ctr_x - 0.5 * pred_w + pred_w * torch_uniform(num_boxes, -jitter_x_out, jitter_x_in)
jitter_pred_box[:, 1::4] = pred_ctr_y - 0.5 * pred_h + pred_h * torch_uniform(num_boxes, -jitter_y_out, jitter_y_in)
jitter_pred_box[:, 2::4] = pred_ctr_x + 0.5 * pred_w + pred_w * torch_uniform(num_boxes, -jitter_x_in, jitter_x_out)
jitter_pred_box[:, 3::4] = pred_ctr_y + 0.5 * pred_h + pred_h * torch_uniform(num_boxes, -jitter_y_in, jitter_y_out)
jitter_pred_box = jitter_pred_box.view(N, nr_boxes, -1)
for idx, (_, curr_whwh) in enumerate(zip(jitter_pred_box, images_whwh)):
jitter_pred_box[idx][0].clamp_(min=0, max=curr_whwh[0] - 1)
jitter_pred_box[idx][1].clamp_(min=0, max=curr_whwh[1] - 1)
jitter_pred_box[idx][2] = torch.max(torch.clamp(jitter_pred_box[idx][2], max=curr_whwh[0]-1), jitter_pred_box[idx][0] + 1)
jitter_pred_box[idx][3] = torch.max(torch.clamp(jitter_pred_box[idx][3], max=curr_whwh[1]-1), jitter_pred_box[idx][1] + 1)
jitter_pred_box[idx] = clip_boxes_tensor(jitter_pred_box[idx], curr_whwh[1], curr_whwh[0])
return pred_boxes, jitter_pred_box
class DynamicConv(nn.Module):
def __init__(self, cfg, origin=True):
super().__init__()
self.hidden_dim = cfg.MODEL.SparseRCNN.HIDDEN_DIM
self.dim_dynamic = cfg.MODEL.SparseRCNN.DIM_DYNAMIC
self.num_dynamic = cfg.MODEL.SparseRCNN.NUM_DYNAMIC
self.num_params = self.hidden_dim * self.dim_dynamic
self.dynamic_layer = nn.Linear(self.hidden_dim, self.num_dynamic * self.num_params)
self.norm1 = nn.LayerNorm(self.dim_dynamic)
self.norm2 = nn.LayerNorm(self.hidden_dim)
self.activation = nn.ReLU(inplace=True)
pooler_resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION
if origin:
num_output = self.hidden_dim * pooler_resolution ** 2
else:
num_output = self.hidden_dim * cfg.DATA.NUM_FRAMES
self.out_layer = nn.Linear(num_output, self.hidden_dim)
self.norm3 = nn.LayerNorm(self.hidden_dim)
def forward(self, pro_features, roi_features):
'''
pro_features: (1, N * nr_boxes, self.d_model)
roi_features: (49, N * nr_boxes, self.d_model)
'''
features = roi_features.permute(1, 0, 2) # (N*nr_boxes, 49, 256)
parameters = self.dynamic_layer(pro_features).permute(1, 0, 2) # (N*nr_boxes, 1, 32768)
param1 = parameters[:, :, :self.num_params].view(-1, self.hidden_dim, self.dim_dynamic) # (N*nr_boxes, 256, 64)
param2 = parameters[:, :, self.num_params:].view(-1, self.dim_dynamic, self.hidden_dim) # (N*nr_boxes, 64, 256)
features = torch.bmm(features, param1) # (N*nr_boxes, 49, 64)
features = self.norm1(features)
features = self.activation(features)
features = torch.bmm(features, param2) # (N*nr_boxes, 49, 256)
features = self.norm2(features)
features = self.activation(features)
features = features.flatten(start_dim=1) # (N*nr_boxes, 49*256)
features = self.out_layer(features)
features = self.norm3(features)
features = self.activation(features)
return features
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
class RCNNHead3D(nn.Module):
def __init__(self, cfg, scale_clamp: float = _DEFAULT_SCALE_CLAMP,
bbox_weights=(2.0, 2.0, 1.0, 1.0), origin=True):
super().__init__()
d_model = cfg.MODEL.SparseRCNN.HIDDEN_DIM
num_classes = cfg.MODEL.SparseRCNN.NUM_CLASSES
num_actions = cfg.MODEL.NUM_CLASSES
dim_feedforward = cfg.MODEL.SparseRCNN.DIM_FEEDFORWARD
nhead = cfg.MODEL.SparseRCNN.NHEADS
dropout = cfg.MODEL.SparseRCNN.DROPOUT
activation = cfg.MODEL.SparseRCNN.ACTIVATION
self.d_model = d_model
self.jitter_box = cfg.MODEL.SparseRCNN.JITTER_BOX
# dynamic.
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.inst_interact = DynamicConv(cfg, origin=False)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.is_origin = origin # origin = False for action_proposal_feature head
def forward(self, features, bboxes, pro_features, pooler):
"""
:param bboxes: (N, nr_boxes, 4)
:param pro_features: (N, nr_boxes, d_model)
"""
N, nr_boxes = bboxes.shape[:2]
nr_frames = features.shape[2]
# roi_feature.
proposal_boxes = list()
for b in range(N):
proposal_boxes.append(Boxes(bboxes[b]))
roi_feats = []
# only consider slow path way
for k in range(nr_frames):
frame_roi_features = pooler([features[:, :, k]], proposal_boxes)
frame_roi_features = F.adaptive_avg_pool2d(frame_roi_features, output_size=(1, 1))
roi_feats.append(frame_roi_features)
roi_features = torch.stack(roi_feats, dim=2) # (N*nr_boxes, d_model, nr_frames, 1, 1)
# roi_features = pooler(features, proposal_boxes)
roi_features = roi_features.view(N * nr_boxes, self.d_model, -1).permute(2, 0, 1)
# self_att.
pro_features = pro_features.view(N, nr_boxes, self.d_model).permute(1, 0, 2)
pro_features2 = self.self_attn(pro_features, pro_features, value=pro_features)[0]
pro_features = pro_features + self.dropout1(pro_features2)
pro_features = self.norm1(pro_features)
# inst_interact.
pro_features = pro_features.view(nr_boxes, N, self.d_model).permute(1, 0, 2).reshape(1, N * nr_boxes, self.d_model)
pro_features2 = self.inst_interact(pro_features, roi_features)
pro_features = pro_features + self.dropout2(pro_features2)
obj_features = self.norm2(pro_features)
# obj_feature.
obj_features2 = self.linear2(self.dropout(self.activation(self.linear1(obj_features))))
obj_features = obj_features + self.dropout3(obj_features2)
obj_features = self.norm3(obj_features)
fc_feature = obj_features.transpose(0, 1).reshape(N * nr_boxes, -1)
return fc_feature, obj_features
act_feature = fc_feature.clone()
for act_layer in self.act_module:
act_feature = act_layer(act_feature)
action_logits = self.act(self.action_logits(act_feature))
return action_logits.view(N, nr_boxes, -1), obj_features
class X3DHead(nn.Module):
""" X3D head before the global average pooling
copy-paste from slowfast/head_helper.py
Only keep layers before the global average pooling layer
"""
def __init__(
self,
dim_in,
dim_inner,
inplace_relu=True,
eps=1e-5,
bn_mmt=0.1,
norm_module=nn.BatchNorm3d,
):
super(X3DHead, self).__init__()
self.eps = eps
self.bn_mmt = bn_mmt
self.inplace_relu = inplace_relu
self._construct_head(dim_in, dim_inner, norm_module)
def _construct_head(self, dim_in, dim_inner, norm_module):
self.conv_5 = nn.Conv3d(
dim_in,
dim_inner,
kernel_size=(1, 1, 1),
stride=(1, 1, 1),
padding=(0, 0, 0),
bias=False,
)
self.conv_5_bn = norm_module(
num_features=dim_inner, eps=self.eps, momentum=self.bn_mmt
)
self.conv_5_relu = nn.ReLU(self.inplace_relu)
def forward(self, inputs):
# In its current design the X3D head is only useable for a single
# pathway input.
assert len(inputs) == 1, "Input tensor does not contain 1 pathway"
x = self.conv_5(inputs[0])
x = self.conv_5_bn(x)
x = self.conv_5_relu(x)
return x
@JonathanFlores2503
Copy link

Will you have the complete code, or could you pass me the complete code please?

@thanhmcisai
Copy link

How can I use this code for my custom dataset? Can you provide the tutorial?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment