Skip to content

Instantly share code, notes, and snippets.

@hengck23
Created May 6, 2020 14:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hengck23/c21b8b6f2f34634687ebd8a4e963f560 to your computer and use it in GitHub Desktop.
Save hengck23/c21b8b6f2f34634687ebd8a4e963f560 to your computer and use it in GitHub Desktop.
adanet
from common import *
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from my_tabnet.sparsemax import Sparsemax
class Identity(torch.nn.Module):
def forward(self, x):
return x
# class IdentityEmbedding(torch.nn.Module):
# def forward(self, x):
# batch_size = len(x)
# x = x.view(batch_size,1).float()
# return x
# 'From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification' - Andre F. T. Martins
# https://arxiv.org/pdf/1602.02068.pdf
def initialize_non_glu(module, in_dim, out_dim):
gain_value = np.sqrt((in_dim+out_dim)/np.sqrt(4*in_dim))
torch.nn.init.xavier_normal_(module.weight, gain=gain_value)
# torch.nn.init.zeros_(module.bias)
return
def initialize_glu(module, in_dim, out_dim):
gain_value = np.sqrt((in_dim+out_dim)/np.sqrt(in_dim))
torch.nn.init.xavier_normal_(module.weight, gain=gain_value)
# torch.nn.init.zeros_(module.bias)
return
#---
# 'Train longer, generalize better: closing the generalization gap in large batch training of neural networks' - Elad Hoffer, arvix 2018
# https://arxiv.org/abs/1705.08741
# 'Four Things Everyone Should Know to Improve Batch Normalization' - Cecilia Summers, arvix 2020
# https://arxiv.org/pdf/1906.03548.pdf
class GhostBatchNorm1d(torch.nn.Module):
def __init__(self, in_dim, ghost_size=128, momentum=0.01):
super(GhostBatchNorm1d, self).__init__()
self.ghost_size = ghost_size
self.bn = nn.BatchNorm1d(in_dim, momentum=momentum)
#self.gn = nn.GroupNorm(in_dim, ghost_size)
def forward(self, x):
batch_size = len(x)
chunk = x.chunk( int(np.ceil(batch_size/self.ghost_size)),0)
x = [self.bn(c) for c in chunk]
x = torch.cat(x, dim=0)
#x = self.bn(x)
return x
# Gated Linear Unit
class GLU(torch.nn.Module):
def __init__(self, in_dim, out_dim, ghost_size=128, momentum=0.02):
super(GLU, self).__init__()
self.fc = nn.Linear(in_dim, 2*out_dim, bias=False)
self.bn = GhostBatchNorm1d(2*out_dim, ghost_size=ghost_size, momentum=momentum)
initialize_glu(self.fc, in_dim, 2*out_dim)
def forward(self, x):
batch_size = len(x)
x = self.fc(x)
x = self.bn(x)
z, a = x.chunk(2,dim=1)
x = z*torch.sigmoid(a)
return x
class FeatureTransformer(torch.nn.Module):
def __init__(self, in_dim, out_dim, num_glu=2,
ghost_size=128, momentum=0.02):
super(FeatureTransformer, self).__init__()
self.num_glu = num_glu
param = {
'ghost_size': ghost_size,
'momentum': momentum
}
self.glu = torch.nn.ModuleList(
[ GLU( in_dim, out_dim, **param) ]
+ [ GLU(out_dim, out_dim, **param) for i in range(1, num_glu) ]
)
def forward(self, x):
scale = np.sqrt(0.5)
x = self.glu[0](x)
for i in range(1, self.num_glu):
x = scale*(x+self.glu[i](x))
return x
class AttentiveTransformer(torch.nn.Module):
def __init__(self, in_dim, out_dim, ghost_size=128, momentum=0.02):
super(AttentiveTransformer, self).__init__()
self.fc = nn.Linear(in_dim, out_dim, bias=False)
self.bn = GhostBatchNorm1d(out_dim, ghost_size=ghost_size, momentum=momentum)
self.sparsemax = Sparsemax(dim=-1) # Sparsemax
initialize_non_glu(self.fc, in_dim, out_dim)
def forward(self, prior, x):
x = self.fc(x)
x = self.bn(x)
x = x*prior
x = self.sparsemax(x)
return x
############################################################
def do_embed(embedding, z):
z_t = z.T.long()
z = []
for i in range(len(z_t)):
z.append(embedding[i](z_t[i]))
z = torch.cat(z,1)
return z
class TabNet(torch.nn.Module):
def __init__(self,
numeric_dim = 3,
category_dim = [
(4, 2),
(4, 2),
(1, 1),
],
out_dim = 1,
decision_dim = 8,
attention_dim = 8,
num_step = 3,
num_glu = 4,
num_share = 2,
gamma = 1.3,
ghost_size = 128,
momentum = 0.02
):
super(TabNet, self).__init__()
dim = decision_dim + attention_dim
self.decision_dim = decision_dim
self.attention_dim = attention_dim
self.num_step = num_step
self.gamma = gamma
#----
#self.embedding = Identity()
f_dim = sum( embed_dim for (in_dim, embed_dim) in category_dim) + numeric_dim
self.embedding = torch.nn.ModuleList([
nn.Embedding(in_dim, embed_dim) for (in_dim, embed_dim) in category_dim
])
self.bn = nn.BatchNorm1d(f_dim, momentum=0.01) #0.10
#self.bn = GhostBatchNorm1d(f_dim, ghost_size=ghost_size, momentum=momentum)
#----
self.first_transformer = FeatureTransformer(
f_dim, dim,
num_glu=num_glu,
ghost_size=ghost_size,
momentum=momentum
)
self.feature_transformer = torch.nn.ModuleList()
self.attentive_transformer = torch.nn.ModuleList()
for i in range(num_step):
t = FeatureTransformer(
f_dim, dim,
num_glu = num_glu,
ghost_size = ghost_size,
momentum = momentum
)
a = AttentiveTransformer(
attention_dim, f_dim,
ghost_size = ghost_size,
momentum = momentum
)
self.feature_transformer.append(t)
self.attentive_transformer.append(a)
# ----
if num_share > 0:
for i in range(num_step):
for n in range(num_share):
del self.feature_transformer[i].glu[n].fc.weight
self.feature_transformer[i].glu[n].fc.weight = \
self.first_transformer.glu[n].fc.weight
#----
self.final = nn.Linear(decision_dim, out_dim, bias=False)
initialize_non_glu(self.final, decision_dim, out_dim)
def forward(self, numeric, category):
splitter = lambda x : (x[:, :self.decision_dim], x[:, self.decision_dim:])
if category is not None:
z = do_embed(self.embedding, category)
x = torch.cat([numeric,z],1)
else:
x= numeric
f = self.bn(x)
#-----
prior = torch.ones_like(f)
t = self.first_transformer(f)
_, attention = splitter(t)
mask = {}
residual = 0
for i in range(self.num_step):
m = self.attentive_transformer[i](prior, attention)
mask[i] = m
prior = (self.gamma-m) * prior
t = self.feature_transformer[i](m * f)
decision, attention = splitter(t)
residual = residual + F.relu(decision,inplace=True)
x = self.final(residual)
return x, mask
def criterion_sparsity_regularization_entropy(mask):
epsilon = 1e-15
num_mask = len(mask)
loss = 0
for i in range(num_mask):
m = mask[i]
loss -= (m * torch.log(m + epsilon)).sum(dim=1).mean() / num_mask
return loss
def criterion_cross_entropy(logit,truth):
batch_size,dim = logit.shape
truth = truth.view(-1)
loss = F.cross_entropy(logit,truth)
return loss
def metric_accurcy(logit,truth):
predict = torch.argmax(logit,1)
accuracy = (truth==predict).float().mean().item()
return accuracy
#######################################################################################
def print_state_dict(state_dict):
print('*** print key *** ')
keys = list(state_dict.keys())
#keys = sorted(keys)
for k in keys:
if any(s in k for s in [
'num_batches_tracked'
# '.kernel',
# '.gamma',
# '.beta',
# '.running_mean',
# '.running_var',
]):
continue
p = state_dict[k].data.cpu().numpy()
print(' \'%s\',\t%s,'%(k,tuple(p.shape)))
print('')
def run_check_train():
num_class = 10
batch_size = 10
category_dim = [
(4, 2),
(4, 2),
(1, 1),
]
numeric_dim = 3
out_dim = num_class
decision_dim = 8
attention_dim = 8
truth = np.random.choice(num_class, batch_size)
numeric = np.random.uniform(-1,1,(batch_size,numeric_dim))
category = np.zeros((batch_size, len(category_dim)))
for i, (in_dim, embed_dim) in enumerate(category_dim):
category[:,i] = np.random.choice(in_dim, batch_size)
#---
numeric = torch.from_numpy(numeric).float().cuda()
category = torch.from_numpy(category).long().cuda()
truth = torch.from_numpy(truth).long().cuda()
net = TabNet(
numeric_dim,
category_dim,
out_dim,
decision_dim,
attention_dim,
).cuda()
#print(net)
#print_state_dict(net.state_dict())
net = net.eval()
with torch.no_grad():
logit, mask = net(numeric, category)
print('logit:', logit.shape)
print('mask:', len(mask), mask[0].shape)
loss = criterion_cross_entropy(logit, truth)
loss_mask = criterion_sparsity_regularization_entropy(mask)
print('loss:',loss.item())
print('loss_mask:',loss_mask.item())
print('')
# optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()),lr=0.001)
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()),lr=0.1, momentum=0.9, weight_decay=0.0001)
#---
lambda_sparse = 1e-3
clip_value = 1
print('batch_size =',batch_size)
print('----------------------------------------------------')
print('[iter ] loss mask | acc | ')
print('----------------------------------------------------')
#[00075] 0.00939, 0.20384 | 1.00000 | 0 hr 00 min
start_timer = timer()
i=0
while i<= 125:
#with torch.autograd.set_detect_anomaly(True):
net.train()
optimizer.zero_grad()
logit, mask = net(numeric, category)
loss = criterion_cross_entropy(logit, truth)
loss_mask = criterion_sparsity_regularization_entropy(mask)
(loss + lambda_sparse*loss_mask).backward()
#(loss).backward()
torch.nn.utils.clip_grad_norm_(net.parameters(), clip_value)
optimizer.step()
#---
accurcy = metric_accurcy(logit, truth)
if i%25==0:
print(
'[%05d] %8.5f, %8.5f | '%(i, loss.item(),loss_mask.item(),) +\
'%0.5f | '%(accurcy) +\
'%s' % (time_to_str((timer() - start_timer),'min'))
)
i = i+1
print('')
# if 1:
# for i in range(2):
# for n in range(2):
# print(id(net.feature_transformer[i].glu[n].fc.weight),
# id(net.first_transformer.glu[n].fc.weight))
# print((net.feature_transformer[i].glu[n].fc.weight),
# (net.first_transformer.glu[n].fc.weight))
if 1:
probability = F.softmax(logit,1)
probability = probability.data.cpu().numpy()
predict = np.argsort(-probability,1)
truth = truth.data.cpu().numpy()
for i,m in mask.items():
mask[i] = m.data.cpu().numpy()
for b in range(batch_size):
print('%d ------------- '%b)
print('truth', truth[b])
print('predict', predict[b][0])
print('top')
for i in range(3):
print('\t %2d %0.5f'%(predict[b][i], probability[b][predict[b][i]]))
print('')
# main #################################################################
if __name__ == '__main__':
print( '%s: calling main function ... ' % os.path.basename(__file__))
run_check_train()
print('\nsucess!')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment