Skip to content

Instantly share code, notes, and snippets.

@cycyyy
Created March 6, 2021 03:52
Show Gist options
  • Save cycyyy/24cc2229c5437d8ad572b46dcefa40f6 to your computer and use it in GitHub Desktop.
Save cycyyy/24cc2229c5437d8ad572b46dcefa40f6 to your computer and use it in GitHub Desktop.
from collections import OrderedDict
from tqdm import tqdm
import torch.nn.functional as F
import torch.utils.data as td
from torch import nn
import torch
import numpy as np
from deepctr_torch.inputs import SparseFeat, DenseFeat
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
dpath = os.getenv('BBPATH', '..')
import prepare_data
# torch.set_deterministic(True)
torch.manual_seed(0)
np.random.seed(0)
gpu = torch.device("cuda:0")
cpu = torch.device("cpu")
device = gpu
small_dataset = False
print(device, small_dataset)
use_dram = True
class SparseOnlyModel(torch.nn.Module):
def __init__(self, feature_columns, hidden_size, batch_size, binary=False, dim=128, cache_ratio=0.4):
super(SparseOnlyModel, self).__init__()
self.binary = binary
# Real embeddings define
input_size = 0
total_embed_size = 0
for feature_column in feature_columns:
input_size += dim
total_embed_size += feature_column.vocabulary_size
if use_dram:
self.embedding_table = torch.Tensor(total_embed_size, dim)
nn.init.normal_(self.embedding_table)
else:
self.embedding_table = nn.Embedding(
total_embed_size, dim, sparse=True)
# Cache
self.cache_size = int(total_embed_size * cache_ratio)
# self.cache_size = 24
self.cache_idx_to_real_idx = torch.LongTensor(
self.cache_size).fill_(-1) # cache idx to real idx
# self.LRU_cache_idx = OrderedDict() # real idx to cache idx
# for i in range(0, self.cache_size):
# self.LRU_cache_idx[-i] = i
# Model define
self.cache_table = nn.Embedding(self.cache_size, dim, sparse=True)
self.fc1 = nn.Linear(input_size, hidden_size[0])
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(hidden_size[0], hidden_size[1])
self.relu2 = nn.ReLU()
self.fc3 = nn.Linear(hidden_size[1], hidden_size[2])
self.relu3 = nn.ReLU()
self.fc4 = nn.Linear(hidden_size[2], 1)
if binary == True:
self.sigmoid = nn.Sigmoid()
def fetch(self, real_idx, cache_idx):
if len(real_idx) == 0:
return
assert(len(real_idx) == len(cache_idx))
self.cache_table.weight.data[cache_idx] = self.embedding_table[real_idx].to(
gpu)
def evict(self, real_idx, cache_idx):
if len(real_idx) == 0:
return
assert(len(real_idx) == len(cache_idx))
self.embedding_table[real_idx] = self.cache_table.weight.data[cache_idx].to(
cpu)
def update_cache(self, unique_idx):
# unique_idx = unique_idx.to(gpu)
# combined = torch.cat((unique_idx, self.cache_idx_to_real_idx))
# uniques, counts = combined.unique(return_counts=True)
# uncached_idx = uniques[counts == 1]
# cached_idx = uniques[counts > 1]
cached_idx = np.intersect1d(unique_idx, self.cache_idx_to_real_idx)
uncached_idx = np.setdiff1d(unique_idx, cached_idx)
assert(len(uncached_idx) + len(cached_idx) < self.cache_size)
for idx in cached_idx:
self.LRU_cache_idx.move_to_end(idx)
cache_idx = torch.LongTensor(len(uncached_idx))
evict_cache_idx = []
evict_real_idx = []
for i in range(0, len(uncached_idx)):
last_item = self.LRU_cache_idx.popitem(last=False)
cache_idx[i] = last_item[1]
if last_item[0] >= 0:
evict_cache_idx.append(last_item[1])
evict_real_idx.append(last_item[0])
self.LRU_cache_idx[uncached_idx[i]] = last_item[1]
self.cache_idx_to_real_idx[last_item[1]] = uncached_idx[i]
# print("After Dict")
# print(self.LRU_cache_idx)
# print("After Mapping")
# print(self.cache_idx_to_real_idx)
# print("*"*10)
self.evict(torch.LongTensor(evict_real_idx),
torch.LongTensor(evict_cache_idx))
self.fetch(uncached_idx, cache_idx)
def load_hybrid_embeds(self, x):
unique, idx_in_unique = torch.unique(
x, sorted=True, return_inverse=True)
unique = unique.to(cpu).type(torch.LongTensor)
# self.update_cache(unique)
# for i in range(0, len(unique)):
# unique[i] = self.LRU_cache_idx[unique[i].item()]
# x = self.cache_table(unique[idx_in_unique].to(gpu))
return x
def forward(self, x):
if use_dram:
x = self.load_hybrid_embeds(x)
else:
x = self.embedding_table(x)
return x
x = torch.reshape(x, (x.shape[0], x.shape[1] * x.shape[2]))
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
x = self.relu2(x)
x = self.fc3(x)
x = self.relu3(x)
x = self.fc4(x)
if self.binary == True:
return self.sigmoid(x)
return x
def get_moivelen():
return prepare_data.build_movielens1m(path=dpath+"/movielens/ml-1m", cache_folder=dpath+"/.cache")
def get_criteo():
return prepare_data.build_criteo(path=dpath+"/criteo/train.txt", cache_folder=dpath+"/.cache")
# return prepare_data.build_avazu(path=dpath+"/avazu/train", cache_folder=dpath+"/.cache")
def generate_input():
if small_dataset:
feature_columns, _, raw_data, input_data, target = get_moivelen()
else:
feature_columns, _, raw_data, input_data, target = get_criteo()
y = raw_data[target].to_numpy()
del raw_data
feature_list = []
x = []
for feature_column in feature_columns:
if isinstance(feature_column, SparseFeat):
feature_list.append(feature_column)
x.append(input_data[feature_column.embedding_name].to_numpy())
x = np.array(x).T[:]
y = y[:]
accum_idx = 0
for i in range(1, len(feature_list)):
accum_idx += feature_list[i-1].vocabulary_size
x[:, i] += accum_idx
train_tensor_data = td.TensorDataset(
torch.from_numpy(x), torch.from_numpy(y))
return train_tensor_data, feature_list
def train(batch_size, epoch):
train_tensor_data, feature_list = generate_input()
train_loader = td.DataLoader(
dataset=train_tensor_data, batch_size=batch_size)
if small_dataset:
binary = False
else:
binary = True
model = SparseOnlyModel(
feature_list, [512, 256, 64], batch_size, binary).to(gpu)
# print(model)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# if small_dataset:
# loss_fuc = F.mse_loss
# else:
# loss_fuc = F.binary_cross_entropy
for e in range(epoch):
total_loss = 0.0
with tqdm(enumerate(train_loader), total=len(train_loader)) as t:
for index, (x, y) in t:
continue
optimizer.zero_grad()
x = x.to(gpu)
pred_y = model(x)
y = y.to(gpu).float()
loss = loss_fuc(pred_y, y)
total_loss += loss
loss.backward()
optimizer.step()
print(e, ":", total_loss / len(train_loader))
if small_dataset:
train(2048, 10)
else:
train(8192, 1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment