Skip to content

Instantly share code, notes, and snippets.

@jparkhill
Created April 30, 2021 18:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jparkhill/34851d8c3cfe63eff241891f9eadbb70 to your computer and use it in GitHub Desktop.
Save jparkhill/34851d8c3cfe63eff241891f9eadbb70 to your computer and use it in GitHub Desktop.
Two Ansatze modules for Pytorch PDEs.
class Neural_Density(torch.nn.Module):
"""
A neural model of a time dependent probability density on
a vector valued state-space.
ie: rho(t,{x_0, x_1, ... x_{state_dim}})
for now, I'm not even enforcing normalization.
could with a gaussian mixture or whatever.
"""
def __init__(self, state_dim, hidden_dim = 64):
super(Neural_Density, self).__init__()
self.input_dim = state_dim+1 # self.register_buffer('input_dim',state_dim+1)
self.state_dim = state_dim
self.net = tch.nn.Sequential(
tch.nn.Linear(self.input_dim, hidden_dim),
tch.nn.Softplus(),
tch.nn.Linear(hidden_dim, 1),
tch.nn.Softplus(), # density is positive.
)
def forward(self,t,x):
# Just evaluate the probability at the argument.
return self.net(tch.cat([t.unsqueeze(-1),x],-1)).squeeze()
class Reshape(tch.nn.Module):
def __init__(self, shp):
super(Reshape, self).__init__()
self.shape = shp
def forward(self, x):
return x.view(self.shape)
class Gaussian_Mixture_Density(tch.nn.Module):
def __init__(self, state_dim,
m_dim=1,
hidden_dim = 16,
):
"""
A network which parameterically
produces gaussian output with feed-forwards
that parameterize the mixture.
"""
super(Gaussian_Mixture_Density, self).__init__()
# Rho(x,y) is the density parameterized by t
input_dim=1
output_dim=state_dim
self.output_dim = output_dim
self.m_dim = m_dim
mixture_dim = output_dim*m_dim
self.n_corr = int((self.output_dim*(self.output_dim-1)/2.))
self.sftpls = tch.nn.Softplus()
self.sftmx = tch.nn.Softmax(dim=-1)
self.corr_net = tch.nn.Sequential(
# tch.nn.Dropout(0.1),
tch.nn.Linear(input_dim, hidden_dim),
tch.nn.Tanh(),
tch.nn.Linear(hidden_dim, self.n_corr*m_dim),
Reshape((-1, m_dim, self.n_corr))
)
self.std_net = tch.nn.Sequential(
# tch.nn.Dropout(0.1),
tch.nn.Linear(input_dim, hidden_dim),
tch.nn.SELU(),
tch.nn.Linear(hidden_dim, mixture_dim),
tch.nn.Softplus(10.),
Reshape((-1, m_dim, self.output_dim))
)
self.mu_net = tch.nn.Sequential(
# tch.nn.Dropout(0.1),
tch.nn.Linear(input_dim, hidden_dim),
tch.nn.Tanh(),
tch.nn.Linear(hidden_dim, mixture_dim),
Reshape((-1, m_dim, self.output_dim))
)
self.pi_net = tch.nn.Sequential(
# tch.nn.Dropout(0.1),
tch.nn.Linear(input_dim, hidden_dim),
tch.nn.SELU(),
tch.nn.Linear(hidden_dim, m_dim),
tch.nn.Tanh(),
tch.nn.Softmax(dim=-1)
)
super(Gaussian_Mixture_Density, self).add_module("corr_net",self.corr_net)
super(Gaussian_Mixture_Density, self).add_module("std_net",self.std_net)
super(Gaussian_Mixture_Density, self).add_module("mu_net",self.mu_net)
super(Gaussian_Mixture_Density, self).add_module("pi_net",self.pi_net)
def pi(self, x):
return self.pi_net(x)
def mu(self, x):
return self.mu_net(x)
def L(self, x):
"""
Constructs the lower diag cholesky decomposed sigma matrix.
"""
batch_size = x.shape[0]
L = tch.zeros(batch_size, self.m_dim, self.output_dim, self.output_dim)
b_inds = tch.arange(batch_size).unsqueeze(1).unsqueeze(1).repeat(1, self.m_dim, self.output_dim).flatten()
m_inds = tch.arange(self.m_dim).unsqueeze(1).unsqueeze(0).repeat(batch_size, 1, self.output_dim).flatten()
s_inds = tch.arange(self.output_dim).unsqueeze(0).unsqueeze(0).repeat(batch_size, self.m_dim,1).flatten()
L[b_inds, m_inds, s_inds, s_inds] = self.std_net(x).flatten()
if self.output_dim>1:
t_inds = tch.tril_indices(self.output_dim,self.output_dim,-1)
txs = t_inds[0].flatten()
tys = t_inds[1].flatten()
bb_inds = tch.arange(batch_size).unsqueeze(1).unsqueeze(1).repeat(1, self.m_dim, txs.shape[0]).flatten()
mt_inds = tch.arange(self.m_dim).unsqueeze(1).unsqueeze(0).repeat(batch_size, 1, txs.shape[0]).flatten()
xt_inds = txs.unsqueeze(0).unsqueeze(0).repeat(batch_size, self.m_dim, 1).flatten()
yt_inds = tys.unsqueeze(0).unsqueeze(0).repeat(batch_size, self.m_dim, 1).flatten()
L[bb_inds, mt_inds, xt_inds, yt_inds] = self.corr_net(x).flatten()
return L
def get_distribution(self, x):
pi_distribution = tch.distributions.Categorical(self.pi(x))
GMM = tch.distributions.mixture_same_family.MixtureSameFamily(pi_distribution,
tch.distributions.MultivariateNormal(self.mu(x),
scale_tril=self.L(x)))
return GMM
def forward(self, t, x):
return self.get_distribution(t.unsqueeze(-1)).log_prob(x).exp()
def rsample(self, t, sample_shape = 128):
"""
returns samples from the gaussian mixture (samples are added last dimension)
ie: batch X dim X samp
"""
samps_ = self.get_distribution(t).sample(sample_shape=[sample_shape])
samps = samps_.permute(1,2,0)
return samps
def mean(self,t):
return self.get_distribution(t.unsqueeze(-1)).mean
def std(self,t):
return tch.sqrt(self.get_distribution(t.unsqueeze(-1)).variance)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment