Skip to content

Instantly share code, notes, and snippets.

@tansey
Created August 9, 2019 18:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tansey/406dec76295c15dbeca008fac8f50beb to your computer and use it in GitHub Desktop.
Save tansey/406dec76295c15dbeca008fac8f50beb to your computer and use it in GitHub Desktop.
Heterogeneous (AKA multi-view) factor modeling in pytorch.
'''
Heterogeneous factor modeling.
This model fits a heterogeneous factor model where columns may be:
1) Binary
2) Categorical
3) Gaussian
Everything is fit via alternating minimization and stochastic gradient descent.
The code relies on pytorch for SGD and a demo is included.
Author: Wesley Tansey
Date: 8/9/2019
'''
import numpy as np
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
from scipy.stats import norm
from utils import batches
class HomogeneousFactorModel(nn.Module):
def __init__(self, X, k, row_embeddings):
super(HomogeneousFactorModel, self).__init__()
# Handle missing data
if np.ma.is_masked(X):
self.present = torch.BoolTensor((~X.mask))
else:
self.present = torch.BoolTensor(np.ones(X.shape))
def forward(self, tidx):
raise NotImplementedError
def row_mode(self):
raise NotImplementedError
def col_mode(self):
raise NotImplementedError
def probs(self, i, j, vals):
raise NotImplementedError
class GaussianFactorModel(HomogeneousFactorModel):
def __init__(self, X, k, row_embeddings):
super(GaussianFactorModel, self).__init__(X, k, row_embeddings)
self.row_embeddings = row_embeddings
self.mean_embeddings = nn.Embedding(X.shape[1], k)
self.std_embeddings = nn.Embedding(X.shape[1], k)
# self.std_embeddings = nn.Embedding.from_pretrained(torch.FloatTensor(np.random.normal(0,1/np.sqrt(k),size=(X.shape[1], k))))
self.softplus = nn.Softplus()
self.means = torch.FloatTensor(X.mean(axis=0))
self.stds = torch.FloatTensor(X.std(axis=0))
self.labels = (torch.FloatTensor(X) - self.means[None]) / self.stds[None]
def forward(self, tidx):
'''Return the mean and standard deviation of the tidx entries.'''
return (((self.row_embeddings(tidx)[:,None] * self.mean_embeddings.weight[None]).sum(dim=2) + self.means,
(self.softplus(self.row_embeddings(tidx)[:,None]) * self.softplus(self.std_embeddings.weight[None])).sum(dim=2)),# + self.stds[None])),
self.present[tidx])
def row_mode(self):
self.mean_embeddings.requires_grad = False
self.std_embeddings.requires_grad = False
def col_mode(self):
self.mean_embeddings.requires_grad = True
self.std_embeddings.requires_grad = True
def prob(self, i, j, vals):
# Standardize
vals = (vals - self.means.data.numpy()[j]) / self.stds.data.numpy()[j]
# Get the mean embedding
r = self.row_embeddings.weight.data.numpy()[i]
mu = r.dot(self.mean_embeddings.weight.data.numpy()[j])
# std_c = self.std_embeddings.weight.data.numpy()[j]
# std_offset = self.stds.data.numpy()[j]
# std = np.log1p(np.exp(r.dot(std_c + std_offset)))
std = 1 # Assume standard normal after standardizing
return norm.pdf(vals, mu, scale=std)
class BinaryFactorModel(HomogeneousFactorModel):
def __init__(self, X, k, row_embeddings):
super(BinaryFactorModel, self).__init__(X, k, row_embeddings)
self.row_embeddings = row_embeddings
self.col_embeddings = nn.Embedding(X.shape[1], k)
self.labels = torch.FloatTensor(X)
def forward(self, tidx):
'''Return the logits for the tidx entries.'''
return ((self.row_embeddings(tidx)[:,None] * self.col_embeddings.weight[None]).sum(dim=2),
self.present[tidx])
def row_mode(self):
self.col_embeddings.requires_grad = False
def col_mode(self):
self.col_embeddings.requires_grad = True
def prob(self, i, j, vals):
p = ilogit(self.row_embeddings.weight.data.numpy()[i].dot(
self.col_embeddings.weight.data.numpy()[j]))
return p*vals + (1-p) * (1-vals)
class CategoricalFactorModel(HomogeneousFactorModel):
def __init__(self, X, k, row_embeddings):
super(CategoricalFactorModel, self).__init__(X, k, row_embeddings)
self.row_embeddings = row_embeddings
self.k = k
self.d = max([len(np.ma.unique(X[~X.mask[:,i],i])) for i in range(X.shape[1])])
self.col_embeddings = nn.Parameter(torch.FloatTensor(np.random.normal(size=(X.shape[1], self.d, self.k))))
self.labels = torch.LongTensor(X)
def forward(self, tidx):
'''Return the softmax logits for the tidx entries.'''
# return ((self.row_embeddings(tidx)[:,None,None] * self.col_embeddings.weight.view(-1, self.d, self.k)[None]).sum(dim=3),
# self.present[tidx])
return ((self.row_embeddings(tidx)[:,None,None] * self.col_embeddings[None]).sum(dim=3),
self.present[tidx])
def row_mode(self):
self.col_embeddings.requires_grad = False
def col_mode(self):
self.col_embeddings.requires_grad = True
def prob(self, i, j, vals):
logits = self.col_embeddings.data.numpy()[j].dot(self.row_embeddings.weight.data.numpy()[i])
p = np.exp(logits) / np.sum(np.exp(logits))
return p[vals]
class HeterogeneousFactorModel(nn.Module):
def __init__(self, X, k, min_continuous=10):
super(HeterogeneousFactorModel, self).__init__()
X = np.ma.array(X)
self.row_embeddings = nn.Embedding(X.shape[0], k)
# Count the unique values in each column
self.vals = [np.ma.sort(np.ma.unique(X[~X.mask[:,i],i])) for i in range(X.shape[1])]
self.nvals = np.array([len(v) for v in self.vals])
for j in range(X.shape[1]):
if np.any(np.isnan(self.vals[j])):
raise Exception()
# Find the binary columns (2 values)
self.bin_mask = self.nvals <= 2
self.bin_cols = np.arange(X.shape[1])[self.bin_mask]
if len(self.bin_cols) > 0:
self.X_bin = np.ma.array([X[:,i] == X[:,i].max() for i in range(X.shape[1]) if self.bin_mask[i]],
mask=[X.mask[:,i] for i in range(X.shape[1]) if self.bin_mask[i]]).T
self.factor_bin = BinaryFactorModel(self.X_bin, k, self.row_embeddings)
else:
print('No binary columns found.')
# Find the categorical columns (2 < d <= min_continuous values)
self.cat_mask = (self.nvals > 2) & (self.nvals < min_continuous)
self.cat_cols = np.arange(X.shape[1])[self.cat_mask]
if len(self.cat_cols) > 0:
self.X_cat = np.ma.array([(X[:,i:i+1] > self.vals[i][None]).sum(axis=1) for i in range(X.shape[1]) if self.cat_mask[i]],
mask=[X.mask[:,i] for i in range(X.shape[1]) if self.cat_mask[i]]).T.astype(int)
self.factor_cat = CategoricalFactorModel(self.X_cat, k, self.row_embeddings)
else:
print('No categorical columns found.')
# Find the continuous (Gaussian) columns (>= min_continuous values)
self.con_mask = self.nvals >= min_continuous
self.con_cols = np.arange(X.shape[1])[self.con_mask]
if len(self.con_cols) > 0:
self.X_con = np.ma.array(X[:,self.con_mask], mask=X.mask[:,self.con_mask])
self.factor_con = GaussianFactorModel(self.X_con, k, self.row_embeddings)
else:
print('No gaussian columns found.')
def row_mode(self):
self.row_embeddings.requires_grad = True
self.factor_bin.row_mode()
self.factor_cat.row_mode()
self.factor_con.row_mode()
def col_mode(self):
self.row_embeddings.requires_grad = False
self.factor_bin.col_mode()
self.factor_cat.col_mode()
self.factor_con.col_mode()
def forward(self, tidx):
bin_logits = self.factor_bin(tidx) if len(self.bin_cols) > 0 else None
cat_logits = self.factor_cat(tidx) if len(self.cat_cols) > 0 else None
con_logits = self.factor_con(tidx) if len(self.cat_cols) > 0 else None
return (bin_logits, cat_logits, con_logits)
def prob(self, i, j, vals):
if self.bin_mask[j]:
c = self.bin_mask[:j].sum()
v = np.argmax(self.vals[j][None] == vals[:,None], axis=1)
return self.factor_bin.prob(i, c, v)
if self.cat_mask[j]:
c = self.cat_mask[:j].sum()
v = np.argmax(self.vals[j][None] == vals[:,None], axis=1)
return self.factor_cat.prob(i, c, v)
if self.con_mask[j]:
c = self.con_mask[:j].sum()
return self.factor_con.prob(i, c, vals)
raise Exception('Why did this not qualify as any valid type??')
def fit_factor_model(X, k, mf_epochs=5000, lr=1e-1, weight_decay=0,
lr_decay=0.96, lr_step=50, batchsize=10,
verbose=True,
min_continuous=10, con_weight=0.01,
**kwargs):
import sys
# Create the model
model = HeterogeneousFactorModel(X, k, min_continuous=min_continuous)
# Setup the different losses
bin_loss_raw = nn.BCEWithLogitsLoss(reduction='none')
cat_loss_raw = nn.CrossEntropyLoss(reduction='none')
bin_loss = lambda predicted, target, present: bin_loss_raw(predicted[present], target[present]).mean() # (bin_loss_raw(predicted, target)*present).sum() / present.sum()
cat_loss = lambda predicted, target, present: cat_loss_raw(predicted[present], target[present]).mean() #(cat_loss_raw(predicted, target)*present).sum() / present.sum()
con_loss = lambda loc, scale, target, present:-(torch.distributions.Normal(loc[present], scale).log_prob(target[present])).mean() * con_weight # -(torch.distributions.Normal(loc, scale).log_prob(target) * present).sum() / present.sum()
# Sample stochastically over rows
train_indices = np.arange(X.shape[0])
# Track progress
losses = np.zeros(mf_epochs)
# Train the model
for epoch in range(mf_epochs*2):
if verbose and (epoch % 2) == 1:
print('\t\tEpoch {}'.format(epoch//2+1))
sys.stdout.flush()
# Train the row embeddings on even epochs and the columns on odd epochs
if epoch % 2 == 0:
model.row_mode()
else:
model.col_mode()
# Setup the SGD method
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
train_loss = torch.Tensor([0])
for batch_idx, batch in enumerate(batches(train_indices, batchsize, shuffle=True)):
if verbose and (batch_idx % 100 == 0):
print('\t\t\tBatch {}'.format(batch_idx))
tidx = autograd.Variable(torch.LongTensor(batch), requires_grad=False)
# Set the model to training mode
model.train()
# Reset the gradient
model.zero_grad()
# Get the model predictions for the rows in this batch
(bin_logits, bin_present), (cat_logits, cat_present), (con_out, con_present) = model(tidx)
loss = bin_loss(bin_logits, model.factor_bin.labels[tidx], bin_present)
loss += cat_loss(cat_logits.view(-1, model.factor_cat.d), model.factor_cat.labels[tidx].view(-1), cat_present.view(-1))
loss += con_loss(con_out[0], 1, model.factor_con.labels[tidx], con_present)
# Calculate gradients
loss.backward()
# Apply the update
optimizer.step()
# Track the loss
train_loss += loss.data
# Track the total loss
losses[epoch // 2] += train_loss.numpy()
scheduler.step()
if verbose and (epoch % 2) == 1:
print('Loss: {}'.format(train_loss))
if (epoch % 2) == 1 and ((epoch // 2) % lr_step) == 0:
lr *= lr_decay
return model
if __name__ == '__main__':
import matplotlib.pyplot as plt
from utils import ilogit
# Generate some fake data from something similar to the model
nbin = 10
ncat = 11
ncon = 12
N = 100
M = 4
P = nbin+ncat+ncon
K = 6
# Create the embeddings
print('Creating embeddings')
bin_embeds = np.random.normal(0,1, size=(nbin, K))
cat_embeds = np.random.normal(0,1/np.sqrt(M), size=(ncat, M, K))
con_embeds = np.random.normal(0,0.5, size=(ncon, 2, K))
row_embeds = np.random.normal(0,1, size=(N,K))
########## Create the data ##########
print('Creating data')
X = np.zeros((N,P))
# Binary samples
print('\tCreating binary samples')
logits = np.einsum('nk,mk->nm', row_embeds, bin_embeds)
bin_probs = ilogit(logits)
X[:,:nbin] = np.random.random(size=(N,nbin)) <= bin_probs
# Categorical samples
print('\tCreating categorical samples')
logits = np.einsum('nk,cmk->ncm', row_embeds, cat_embeds)
cat_probs = np.exp(logits) / np.exp(logits).sum(axis=-1, keepdims=True)
for i in range(N):
for j in range(nbin, nbin+ncat):
X[i,j] = np.random.choice(M, p=cat_probs[i,j-nbin])
# Continuous samples
print('\tCreating Gaussian samples')
congits = np.einsum('nk,cmk->ncm', row_embeds, con_embeds)
con_probs = norm.pdf(np.linspace(-5,5,100)[None,None], congits[:,:,0:1], scale=np.log1p(np.exp(congits[:,:,1:2])) + 0.1)
for j in range(nbin+ncat, nbin+ncat+ncon):
X[:,j] = np.random.normal(congits[:,j-ncat-nbin,0], np.log1p(np.exp(congits[:,j-ncat-nbin,1])) + 0.1)
print('Masking some random bits of the data')
X_mask = np.random.choice(X.shape[0], size=3), np.random.choice(X.shape[1], size=3)
mask = np.zeros(X.shape, dtype='bool')
X[X_mask[0], X_mask[1]] = np.nan
X[0,0] = np.nan
X = np.ma.array(X, mask=np.isnan(X))
########### Fit a factor model ###########
print('Fitting factor model')
factor_model = fit_factor_model(X, K, min_continuous=M+1, mf_epochs=5000)
########### Plot some example results ###########
print('Plotting results')
fig, axarr = plt.subplots(8,12,figsize=(60,40), sharex=False, sharey=False)
for i in range(axarr.shape[0]):
for j in range(axarr.shape[1]):
ax = axarr[i,j]
if j < 4:
# Binary
ax.bar(np.arange(2)+0.3, [(1-bin_probs[i,j]), bin_probs[i,j]], width=0.3, color='black')
ax.bar(np.arange(2)+0.65, factor_model.prob(i, j, np.arange(2)), width=0.3, color='orange')
ax.axvline(X[i,j]+0.5, color='red', ls='--')
ax.set_xlim([0,2])
ax.set_ylim([0,1])
elif j < 8:
# Categorical
ax.bar(np.arange(M)+1-0.7, cat_probs[i,j-4], width=0.3, color='black')
ax.bar(np.arange(M)+1-0.35, factor_model.prob(i, j-4+nbin, np.arange(M)), width=0.3, color='orange')
ax.axvline(X[i,j+nbin-4]+0.5, color='red', ls='--')
ax.set_xlim([0, M])
ax.set_ylim([0,1])
else:
# Continuous
x_min, x_max = X[:,nbin+ncat:nbin+ncat+4].min()*1.1, X[:,nbin+ncat:nbin+ncat+4].max()*1.1
ax.plot(np.linspace(x_min,x_max,100), con_probs[i,j-8], color='black')
ax.plot(np.linspace(x_min,x_max,100), factor_model.prob(i,j-8+nbin+ncat, np.linspace(x_min,x_max,100)), color='orange')
ax.axvline(X[i,j+nbin+ncat-8], color='red', ls='--')
ax.set_xlim([x_min, x_max])
plt.savefig('plots/factor-demo.pdf', bbox_inches='tight')
plt.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment