Skip to content

Instantly share code, notes, and snippets.

@lzqlzzq
Created November 6, 2023 15:28
Show Gist options
  • Save lzqlzzq/79dbb0afbcbb280c25bfab7c7cc6d3a4 to your computer and use it in GitHub Desktop.
Save lzqlzzq/79dbb0afbcbb280c25bfab7c7cc6d3a4 to your computer and use it in GitHub Desktop.
Modified AutoInt (Automatic Feature Interaction) for embedding aggregation (pytroch implementation)
# https://arxiv.org/pdf/1810.11921.pdf
from typing import Dict
from itertools import product
import torch
from torch import nn
import torch.nn.functional as F
class AutoInt(nn.Module):
def __init__(self,
features: Dict[str, int],
attention_dim: int,
output_dim: int,
head_num: int):
super().__init__()
self.feature_dims = sum(features.values())
self.feature_names = list(sorted(features.keys()))
self.feature_num = len(features)
self.attention_dim = attention_dim
self.head_num = head_num
self.output_dim = output_dim
self.attention = nn.ModuleDict({
f'{feature_name}_attn': nn.ModuleDict({
'attn_q': nn.Linear(feature_dim, attention_dim * head_num, bias=False),
'attn_k': nn.Linear(feature_dim, attention_dim * head_num, bias=False),
'attn_v': nn.Linear(feature_dim, attention_dim * head_num, bias=False)
}) for feature_name, feature_dim in features.items()
})
self.attn_proj = nn.Linear(self.feature_num * attention_dim * head_num, output_dim)
self.emb_proj = nn.Linear(self.feature_dims, output_dim)
def forward(self, x):
batch_size, seq_len, _ = list(x.values())[0].shape
# (attention_heads, batch_size * seq_len, feature_type_num, attention_dim)
query = torch.stack([self.attention[f'{k}_attn']['attn_q'](x[k]) for k in self.feature_names],
dim=-2).reshape(-1, self.feature_num, self.head_num, self.attention_dim).permute(2, 0, 1, 3)
key = torch.stack([self.attention[f'{k}_attn']['attn_k'](x[k]) for k in self.feature_names],
dim=-2).reshape(-1, self.feature_num, self.head_num, self.attention_dim).permute(2, 0, 1, 3)
value = torch.stack([self.attention[f'{k}_attn']['attn_v'](x[k]) for k in self.feature_names],
dim=-2).reshape(-1, self.feature_num, self.head_num, self.attention_dim).permute(2, 0, 1, 3)
attn_weight = F.softmax(torch.matmul(query, key.transpose(-1, -2)), dim=-1)
attn_values = torch.matmul(attn_weight, value)
# (head_num, batch_size * seq_len, feature_type_num, attention_dim) -> (batch_size, seq_len, feature_type_num * head_num * attention_dim)
attn_values = attn_values.permute(1, 2, 0, 3).reshape(batch_size, seq_len, -1)
projected_attn = self.attn_proj(attn_values)
res_emb = self.emb_proj(torch.cat(list(x.values()), dim=-1))
return F.relu(res_emb + projected_attn)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment