Skip to content

Instantly share code, notes, and snippets.

@johndpope
Created June 29, 2024 07:50
Show Gist options
  • Save johndpope/53e836be7b0546f6528af75ebbc1b290 to your computer and use it in GitHub Desktop.
Save johndpope/53e836be7b0546f6528af75ebbc1b290 to your computer and use it in GitHub Desktop.
class MRLR:
def __init__(self, tensor, partitions, ranks):
"""
Initialize the MRLR decomposition.
Args:
tensor (torch.Tensor): The input tensor to be decomposed.
partitions (list of list of list): The partitions for multi-resolution decomposition.
ranks (list of int): The ranks for each partition.
"""
self.tensor = tensor
self.partitions = partitions
self.ranks = ranks
self.factors = self._initialize_factors()
def _initialize_factors(self):
"""Initialize factors for each partition."""
factors = []
for partition, rank in zip(self.partitions, self.ranks):
partition_factors = []
for mode_group in partition:
size = 1
for mode in mode_group:
size *= self.tensor.shape[mode]
partition_factors.append(torch.randn(size, rank))
factors.append(partition_factors)
return factors
def _unfold(self, tensor, partition):
"""Unfold the tensor according to the given partition."""
shape = []
for mode_group in partition:
size = 1
for mode in mode_group:
size *= tensor.shape[mode]
shape.append(size)
return tensor.reshape(shape)
def _fold(self, unfolded, partition, original_shape):
"""Fold the unfolded tensor back to its original shape."""
intermediate_shape = []
for mode_group in partition:
for mode in mode_group:
intermediate_shape.append(original_shape[mode])
return unfolded.reshape(intermediate_shape)
def _parafac(self, tensor, rank, max_iter=100, tol=1e-4):
"""Perform PARAFAC decomposition."""
unfolded = self._unfold(tensor, self.partitions[0])
factors = [torch.randn(s, rank) for s in unfolded.shape]
for _ in range(max_iter):
old_factors = [f.clone() for f in factors]
for mode in range(len(factors)):
V = torch.ones(rank, rank)
for i, factor in enumerate(factors):
if i != mode:
V *= factor.t() @ factor
unfold_mode = unfolded.transpose(0, mode)
unfold_mode = unfold_mode.reshape(unfold_mode.shape[0], -1)
factor_update = unfold_mode @ torch.prod([f for i, f in enumerate(factors) if i != mode], dim=0)
factors[mode] = factor_update @ torch.pinverse(V)
if all(torch.norm(f - old_f) < tol for f, old_f in zip(factors, old_factors)):
break
return factors
def decompose(self, max_iter=100, tol=1e-4):
"""Perform MRLR decomposition."""
residual = self.tensor.clone()
approximations = []
for partition, rank, partition_factors in zip(self.partitions, self.ranks, self.factors):
unfolded = self._unfold(residual, partition)
factors = self._parafac(unfolded, rank, max_iter, tol)
approximation = torch.zeros_like(unfolded)
for r in range(rank):
term = torch.ones(1)
for factor in factors:
term = torch.outer(term, factor[:, r])
approximation += term.reshape(unfolded.shape)
folded_approximation = self._fold(approximation, partition, residual.shape)
approximations.append(folded_approximation)
residual -= folded_approximation
return approximations
def reconstruct(self):
"""Reconstruct the tensor from its decomposition."""
return sum(self.decompose())
def normalized_frobenius_error(original, approximation):
"""Compute the Normalized Frobenius Error."""
return torch.norm(original - approximation) / torch.norm(original)
# Example usage
if __name__ == "__main__":
# Create a sample 5x201x61 tensor
tensor = torch.randn(5, 201, 61)
# Define partitions for multi-resolution decomposition
partitions = [
[[0], [1, 2]], # Matrix unfolding
[[0], [1], [2]] # Full tensor
]
# Define ranks for each partition
ranks = [10, 5]
# Create MRLR object and perform decomposition
mrlr = MRLR(tensor, partitions, ranks)
approximations = mrlr.decompose()
# Reconstruct the tensor
reconstructed = mrlr.reconstruct()
# Compute error
error = normalized_frobenius_error(tensor, reconstructed)
print(f"Normalized Frobenius Error: {error.item()}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment