Skip to content

Instantly share code, notes, and snippets.

@lzqlzzq
Last active November 1, 2023 10:18
Show Gist options
  • Save lzqlzzq/d8d2ba53c6bb1f2948c48c6c1f5f081d to your computer and use it in GitHub Desktop.
Save lzqlzzq/d8d2ba53c6bb1f2948c48c6c1f5f081d to your computer and use it in GitHub Desktop.
A xDeepFM module implementation for pytorch
from typing import List
from itertools import chain
import torch
from torch import nn
class Transpose(nn.Module):
def __init__(self, *args):
super().__init__()
self.args = args
def forward(self, x):
return x.transpose(*self.args)
class MLP(nn.Module):
def __init__(self,
hidden_sizes: List[int],
act_fn: nn.Module = nn.ReLU,
batch_norm: bool = False,
bias: bool = False,
dropout=.5):
super().__init__()
self.layers = nn.Sequential(*chain(*[ \
(nn.Linear(in_features=in_size, out_features=out_size, bias=bias),
act_fn(),
Transpose(-1, -2),
nn.BatchNorm1d(out_size) if batch_norm else nn.Identity(),
Transpose(-1, -2),
nn.Dropout(dropout)) \
for in_size, out_size in zip(hidden_sizes[:-1], hidden_sizes[1:])]))
def forward(self, x):
return self.layers(x)
class CIN(nn.Module):
def __init__(self,
input_dim: int,
output_dim: int,
num_layers: int,
act_fn: nn.Module = nn.ReLU,
batch_norm: bool = False,
bias: bool = False):
super().__init__()
self.layers = nn.ModuleList([
nn.Sequential(
Transpose(-1, -2),
nn.Conv1d(in_channels=input_dim * input_dim,
out_channels=input_dim,
kernel_size=1,
stride=1,
dilation=1,
bias=bias),
nn.ReLU(),
nn.BatchNorm1d(input_dim),
Transpose(-1, -2))
for _ in range(num_layers)])
self.projection = nn.Linear(input_dim, output_dim)
def forward(self, x):
features = [x.unsqueeze(-2)]
x0 = x.unsqueeze(-1)
for layer in self.layers:
h = x0 * features[-1]
h = h.reshape(h.shape[0], h.shape[1], h.shape[-1] * h.shape[-2])
h = layer(h)
features.append(h.unsqueeze(-2))
features.pop(0)
features = torch.cat(features, dim=-2)
pooled_feature = torch.sum(features, dim=-2)
return self.projection(pooled_feature)
class XDeepFM(nn.Module):
def __init__(self,
input_dim,
output_dim,
cin_layers: int,
mlp_layers: List[int],
act_fn: nn.Module = nn.ReLU,
batch_norm: bool = False,
bias: bool = False):
super().__init__()
self.cin = CIN(input_dim,
output_dim,
cin_layers,
act_fn,
batch_norm,
bias)
self.mlp = MLP([input_dim] + mlp_layers + [output_dim],
act_fn,
batch_norm,
bias,
0)
def forward(self, x):
return self.cin(x) + self.mlp(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment