Skip to content

Instantly share code, notes, and snippets.

@talesa
Created May 10, 2019 16:52
Show Gist options
  • Save talesa/57535997916cd28c8eb9583d018882a7 to your computer and use it in GitHub Desktop.
Save talesa/57535997916cd28c8eb9583d018882a7 to your computer and use it in GitHub Desktop.
class OrthogonalLinear(nn.Module):
""" Implements a non-square linear with orthogonal colums """
def __init__(self, input_size, output_size, lr_factor=0.1):
super(OrthogonalLinear, self).__init__()
self.input_size = input_size
self.output_size = output_size
self.max_size = max(self.input_size, self.output_size)
self.log_orthogonal_kernel = nn.Parameter(torch.Tensor(self.max_size, self.max_size))
self.log_orthogonal_kernel.register_hook(lambda: print("This should not be executed"))
self.register_buffer('orthogonal_kernel', torch.empty(self.max_size, self.max_size, requires_grad=True))
self.orthogonal_kernel.register_hook(self.orthogonal_kernel_grad_hook)
self.log_orthogonal_kernel.data = \
torch.as_tensor(self.skew_initializer(self.max_size),
dtype=self.log_orthogonal_kernel.dtype,
device=self.log_orthogonal_kernel.device)
self.orthogonal_kernel.data = self._B
self.lr_factor = lr_factor
@property
def _A(self):
A = self.log_orthogonal_kernel.data
A = A.triu(diagonal=1)
return A - A.t()
@property
def _B(self):
return expm(self._A)
def orthogonal_kernel_grad_hook(self, orthogonal_kernel_grad):
A = self._A
B = self.orthogonal_kernel.data
G = orthogonal_kernel_grad
BtG = B.t().mm(G)
grad = 0.5*(BtG - BtG.t())
frechet_deriv = B.mm(expm_frechet(-A, grad))
self.log_orthogonal_kernel.grad = self.lr_factor * (frechet_deriv - frechet_deriv.t()).triu(diagonal=1)
return None
def forward(self, input):
self.orthogonal_kernel.data = self._B
if self.orthogonal_kernel.grad is not None:
self.orthogonal_kernel.grad.data.zero_()
return input.matmul(self.orthogonal_kernel[:self.input_size, :self.output_size])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment