Skip to content

Instantly share code, notes, and snippets.

@TrentBrick
Created June 30, 2022 17:23
Show Gist options
  • Save TrentBrick/9b733da9b1c2d8cfa0bb67921085a335 to your computer and use it in GitHub Desktop.
Save TrentBrick/9b733da9b1c2d8cfa0bb67921085a335 to your computer and use it in GitHub Desktop.
Potts Model Closed Form Expectation
# for use with a batch of sequences
def energy_torch(self, inp):
"""
Calculates in pytorch the hamiltonian energy.
Takes in the softmax over the sequences generated from the neural network.
Then computes the expected energy over this softmax in a vectorized way.
Parameters
----------
sequences : np.array
Flattened protein sequences output from the neural network that have already been softmaxed batch_size x (protein_length x 20)
batch_size: int
Size of the batch to be able to perform reshaping
Returns
-------
torch.Tensor
torch.float32 matrix of size batch_size x 1
"""
if not self.is_discrete:
batch_size = inp.shape[0]
# assumes that input is of the shape [batch x (L * properties)]
assert len(inp.shape) ==2, 'wrong shape!'
inp = inp.view( (batch_size, self.L, -1)) # decoder assumes 3D tensor.
# need to convert to a prob dist over the AAs
# then plug it into the score.
inp = self.decode(inp).view((batch_size, -1)) # this will return [batch_size x log pdf of AAs.]
#print('make sure no change!!! this is the h', self.h_torch)
# applying the vectorized EVH loss:
h_val = torch.matmul(inp, self.h_torch )
j_val = torch.unsqueeze( torch.sum(inp * torch.matmul(inp, self.J_torch), dim=-1) /2, 1)
evh = j_val + h_val
return evh
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment