Created
November 5, 2023 04:59
-
-
Save lzqlzzq/4cd81048119dfe6baf2acec85d1fe790 to your computer and use it in GitHub Desktop.
DCNv2 (Deep CrossNet v2) implementation for pytorch.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List | |
from itertools import chain | |
import torch | |
from torch import nn | |
class Transpose(nn.Module): | |
def __init__(self, *args): | |
super().__init__() | |
self.args = args | |
def forward(self, x): | |
return x.transpose(*self.args) | |
class MLP(nn.Module): | |
def __init__(self, | |
hidden_sizes: List[int], | |
act_fn: nn.Module = nn.ReLU, | |
batch_norm: bool = True, | |
bias: bool = False, | |
dropout=.5): | |
super().__init__() | |
self.layers = nn.ModuleList([ | |
nn.Sequential( | |
nn.Linear(in_features=in_size, out_features=out_size, bias=bias), | |
act_fn(), | |
Transpose(-1, -2), | |
nn.BatchNorm1d(out_size) if batch_norm else nn.Identity(), | |
Transpose(-1, -2), | |
nn.Dropout(dropout)) \ | |
for in_size, out_size in zip(hidden_sizes[:-1], hidden_sizes[1:])]) | |
def forward(self, x): | |
for layer in self.layers: | |
x = layer(x) | |
return x | |
class CrossNet(nn.Module): | |
def __init__(self, | |
embedding_dim: int, | |
layers: int, | |
batch_norm: bool = True, | |
bias: bool = False): | |
super().__init__() | |
self.layers = nn.ModuleList([ | |
nn.Sequential( | |
nn.Linear(embedding_dim, | |
embedding_dim, | |
bias=bias), | |
Transpose(-1, -2), | |
nn.BatchNorm1d(embedding_dim) if batch_norm else nn.Identity(), | |
Transpose(-1, -2))] * layers) | |
def forward(self, features): | |
# (batch_size, seq_len, embedding_dim) | |
x0 = features | |
x = x0 | |
for layer in self.layers: | |
x = x0 * layer(x) + x | |
return x | |
class DCNv2(nn.Module): | |
def __init__(self, | |
embedding_dim: int, | |
out_dim: int, | |
cross_layers: int, | |
mlp_sizes: List[int], | |
structure: str = 'parallel', # or stacked | |
act_fn: nn.Module = nn.ReLU, | |
batch_norm: bool = True, | |
bias: bool = False): | |
super().__init__() | |
self.cross_net = CrossNet(embedding_dim=embedding_dim, | |
layers=cross_layers, | |
batch_norm=batch_norm, | |
bias=bias) | |
self.structure = structure | |
if(structure == 'parallel'): | |
self.mlp = MLP(hidden_sizes=[embedding_dim] + mlp_sizes, | |
act_fn=act_fn, | |
batch_norm=batch_norm, | |
bias=bias, | |
dropout=0) | |
self.projection = nn.Linear(mlp_sizes[-1] + embedding_dim, out_dim) | |
elif(structure == 'stacked'): | |
self.mlp = nn.Sequential( | |
act_fn(), | |
MLP(hidden_sizes=[embedding_dim] + mlp_sizes + [out_dim], | |
act_fn=act_fn, | |
batch_norm=batch_norm, | |
bias=bias, | |
dropout=0)) | |
else: | |
raise KeyError(f'No such DCNv2 structure named {structure}!') | |
def forward(self, features): | |
if(self.structure == 'parallel'): | |
return self.projection(torch.cat([self.cross_net(features), self.mlp(features)], dim=-1)) | |
elif(self.structure == 'stacked'): | |
return self.mlp(self.cross_net(features)) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment