Created
July 3, 2024 18:50
-
-
Save ricsi98/852d147463f5ee62d1b10d0aae0119a8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Literal | |
import numpy as np | |
import scipy | |
from scipy.sparse import csr_array | |
from torch_geometric.datasets import Planetoid | |
from torch_geometric.utils import to_scipy_sparse_matrix | |
NormType = Literal["sym", "row", None] | |
def _normalize_adj(adj: csr_array, normalization: NormType): | |
if normalization is None: | |
return adj | |
deg = np.array(adj.sum(1)).flatten() | |
N = len(deg) | |
if normalization == "sym": | |
d_inv_sqrt = np.power(deg, -0.5) | |
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0 | |
d_mat_inv_sqrt = csr_array((d_inv_sqrt, (range(N), range(N))), shape=(N, N)) | |
return d_mat_inv_sqrt.dot(adj).dot(d_mat_inv_sqrt) | |
elif normalization == "row": | |
d_inv = np.power(deg, -1.0) | |
d_inv[np.isinf(d_inv)] = 0.0 | |
d_mat_inv = csr_array((d_inv, (range(N), range(N))), shape=(N, N)) | |
return d_mat_inv.dot(adj) | |
def _adj2laplacian(adj: csr_array): | |
deg = np.array(adj.sum(1)).flatten() | |
N = len(deg) | |
d_mat = csr_array((deg, (range(N), range(N))), shape=(N, N)) | |
return d_mat - adj | |
def poly_features( | |
adj: csr_array, | |
features: np.ndarray, | |
identity: bool = True, | |
adj_powers: int = 2, | |
laplacian_powers: int = 2, | |
normalization: NormType = "sym", | |
): | |
adj = _normalize_adj(adj, normalization) | |
acc = [] | |
if identity: | |
acc.append(features) | |
if adj_powers > 0: | |
acc.append(adj.dot(features)) | |
for _ in range(2, adj_powers + 1): | |
acc.append(adj.dot(acc[-1])) | |
if laplacian_powers > 0: | |
laplacian = _adj2laplacian(adj) | |
acc.append(laplacian.dot(features)) | |
for _ in range(2, laplacian_powers + 1): | |
acc.append(laplacian.dot(acc[-1])) | |
if isinstance(features, csr_array): | |
return scipy.sparse.hstack(acc) | |
else: | |
return np.hstack(acc) | |
data = Planetoid(root="data/Planetoid", name="Cora")[0] | |
adj_matrix = to_scipy_sparse_matrix(data.edge_index).tocsr() | |
x = data.x.numpy() | |
x_poly = poly_features(adj_matrix, x, adj_powers=2, laplacian_powers=2) | |
print(x_poly.shape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment