Skip to content

Instantly share code, notes, and snippets.

@lzqlzzq
Created November 5, 2023 04:59
Show Gist options
  • Save lzqlzzq/4cd81048119dfe6baf2acec85d1fe790 to your computer and use it in GitHub Desktop.
Save lzqlzzq/4cd81048119dfe6baf2acec85d1fe790 to your computer and use it in GitHub Desktop.
DCNv2 (Deep CrossNet v2) implementation for pytorch.
from typing import List
from itertools import chain
import torch
from torch import nn
class Transpose(nn.Module):
def __init__(self, *args):
super().__init__()
self.args = args
def forward(self, x):
return x.transpose(*self.args)
class MLP(nn.Module):
def __init__(self,
hidden_sizes: List[int],
act_fn: nn.Module = nn.ReLU,
batch_norm: bool = True,
bias: bool = False,
dropout=.5):
super().__init__()
self.layers = nn.ModuleList([
nn.Sequential(
nn.Linear(in_features=in_size, out_features=out_size, bias=bias),
act_fn(),
Transpose(-1, -2),
nn.BatchNorm1d(out_size) if batch_norm else nn.Identity(),
Transpose(-1, -2),
nn.Dropout(dropout)) \
for in_size, out_size in zip(hidden_sizes[:-1], hidden_sizes[1:])])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class CrossNet(nn.Module):
def __init__(self,
embedding_dim: int,
layers: int,
batch_norm: bool = True,
bias: bool = False):
super().__init__()
self.layers = nn.ModuleList([
nn.Sequential(
nn.Linear(embedding_dim,
embedding_dim,
bias=bias),
Transpose(-1, -2),
nn.BatchNorm1d(embedding_dim) if batch_norm else nn.Identity(),
Transpose(-1, -2))] * layers)
def forward(self, features):
# (batch_size, seq_len, embedding_dim)
x0 = features
x = x0
for layer in self.layers:
x = x0 * layer(x) + x
return x
class DCNv2(nn.Module):
def __init__(self,
embedding_dim: int,
out_dim: int,
cross_layers: int,
mlp_sizes: List[int],
structure: str = 'parallel', # or stacked
act_fn: nn.Module = nn.ReLU,
batch_norm: bool = True,
bias: bool = False):
super().__init__()
self.cross_net = CrossNet(embedding_dim=embedding_dim,
layers=cross_layers,
batch_norm=batch_norm,
bias=bias)
self.structure = structure
if(structure == 'parallel'):
self.mlp = MLP(hidden_sizes=[embedding_dim] + mlp_sizes,
act_fn=act_fn,
batch_norm=batch_norm,
bias=bias,
dropout=0)
self.projection = nn.Linear(mlp_sizes[-1] + embedding_dim, out_dim)
elif(structure == 'stacked'):
self.mlp = nn.Sequential(
act_fn(),
MLP(hidden_sizes=[embedding_dim] + mlp_sizes + [out_dim],
act_fn=act_fn,
batch_norm=batch_norm,
bias=bias,
dropout=0))
else:
raise KeyError(f'No such DCNv2 structure named {structure}!')
def forward(self, features):
if(self.structure == 'parallel'):
return self.projection(torch.cat([self.cross_net(features), self.mlp(features)], dim=-1))
elif(self.structure == 'stacked'):
return self.mlp(self.cross_net(features))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment