Created
June 29, 2024 07:50
-
-
Save johndpope/53e836be7b0546f6528af75ebbc1b290 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MRLR: | |
def __init__(self, tensor, partitions, ranks): | |
""" | |
Initialize the MRLR decomposition. | |
Args: | |
tensor (torch.Tensor): The input tensor to be decomposed. | |
partitions (list of list of list): The partitions for multi-resolution decomposition. | |
ranks (list of int): The ranks for each partition. | |
""" | |
self.tensor = tensor | |
self.partitions = partitions | |
self.ranks = ranks | |
self.factors = self._initialize_factors() | |
def _initialize_factors(self): | |
"""Initialize factors for each partition.""" | |
factors = [] | |
for partition, rank in zip(self.partitions, self.ranks): | |
partition_factors = [] | |
for mode_group in partition: | |
size = 1 | |
for mode in mode_group: | |
size *= self.tensor.shape[mode] | |
partition_factors.append(torch.randn(size, rank)) | |
factors.append(partition_factors) | |
return factors | |
def _unfold(self, tensor, partition): | |
"""Unfold the tensor according to the given partition.""" | |
shape = [] | |
for mode_group in partition: | |
size = 1 | |
for mode in mode_group: | |
size *= tensor.shape[mode] | |
shape.append(size) | |
return tensor.reshape(shape) | |
def _fold(self, unfolded, partition, original_shape): | |
"""Fold the unfolded tensor back to its original shape.""" | |
intermediate_shape = [] | |
for mode_group in partition: | |
for mode in mode_group: | |
intermediate_shape.append(original_shape[mode]) | |
return unfolded.reshape(intermediate_shape) | |
def _parafac(self, tensor, rank, max_iter=100, tol=1e-4): | |
"""Perform PARAFAC decomposition.""" | |
unfolded = self._unfold(tensor, self.partitions[0]) | |
factors = [torch.randn(s, rank) for s in unfolded.shape] | |
for _ in range(max_iter): | |
old_factors = [f.clone() for f in factors] | |
for mode in range(len(factors)): | |
V = torch.ones(rank, rank) | |
for i, factor in enumerate(factors): | |
if i != mode: | |
V *= factor.t() @ factor | |
unfold_mode = unfolded.transpose(0, mode) | |
unfold_mode = unfold_mode.reshape(unfold_mode.shape[0], -1) | |
factor_update = unfold_mode @ torch.prod([f for i, f in enumerate(factors) if i != mode], dim=0) | |
factors[mode] = factor_update @ torch.pinverse(V) | |
if all(torch.norm(f - old_f) < tol for f, old_f in zip(factors, old_factors)): | |
break | |
return factors | |
def decompose(self, max_iter=100, tol=1e-4): | |
"""Perform MRLR decomposition.""" | |
residual = self.tensor.clone() | |
approximations = [] | |
for partition, rank, partition_factors in zip(self.partitions, self.ranks, self.factors): | |
unfolded = self._unfold(residual, partition) | |
factors = self._parafac(unfolded, rank, max_iter, tol) | |
approximation = torch.zeros_like(unfolded) | |
for r in range(rank): | |
term = torch.ones(1) | |
for factor in factors: | |
term = torch.outer(term, factor[:, r]) | |
approximation += term.reshape(unfolded.shape) | |
folded_approximation = self._fold(approximation, partition, residual.shape) | |
approximations.append(folded_approximation) | |
residual -= folded_approximation | |
return approximations | |
def reconstruct(self): | |
"""Reconstruct the tensor from its decomposition.""" | |
return sum(self.decompose()) | |
def normalized_frobenius_error(original, approximation): | |
"""Compute the Normalized Frobenius Error.""" | |
return torch.norm(original - approximation) / torch.norm(original) | |
# Example usage | |
if __name__ == "__main__": | |
# Create a sample 5x201x61 tensor | |
tensor = torch.randn(5, 201, 61) | |
# Define partitions for multi-resolution decomposition | |
partitions = [ | |
[[0], [1, 2]], # Matrix unfolding | |
[[0], [1], [2]] # Full tensor | |
] | |
# Define ranks for each partition | |
ranks = [10, 5] | |
# Create MRLR object and perform decomposition | |
mrlr = MRLR(tensor, partitions, ranks) | |
approximations = mrlr.decompose() | |
# Reconstruct the tensor | |
reconstructed = mrlr.reconstruct() | |
# Compute error | |
error = normalized_frobenius_error(tensor, reconstructed) | |
print(f"Normalized Frobenius Error: {error.item()}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment