-
-
Save RaulPPelaez/36b6a3a4bbdb0c373beaf3c1376e8f49 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tinygrad import Tensor, nn | |
from tinygrad.dtype import dtypes | |
import numpy as np | |
from typing import Optional | |
from utils import * | |
import pytest | |
def get_einsum_subscript(extra_dims: int) -> str: | |
result_subscripts = "i" | |
feature_subscripts = "j" | |
additional_letters = "klhmno"[: extra_dims + 1] | |
feature_subscripts += additional_letters | |
result_subscripts += additional_letters | |
einsum_subscript = "ij" + "," + feature_subscripts + "->" + result_subscripts | |
return einsum_subscript | |
def aggregate_edge_features(n_nodes: int, edge_features: Tensor, edge_index: Tensor): | |
r = edge_index == Tensor.arange(n_nodes).unsqueeze(-1) | |
extra_dims = edge_features.ndim - r.ndim | |
einsum_subscript = get_einsum_subscript(extra_dims) | |
node_features = Tensor.einsum(einsum_subscript, r, edge_features) | |
return node_features | |
def aggregate_edge_features_np(n_nodes: int, edge_features: Tensor, edge_index: Tensor): | |
r = edge_index == np.arange(n_nodes).reshape(-1, 1) | |
extra_dims = edge_features.ndim - r.ndim | |
einsum_subscript = get_einsum_subscript(extra_dims) | |
node_features = np.einsum(einsum_subscript, r, edge_features) | |
return node_features | |
@pytest.mark.parametrize("ndims", [1, 2, 3, 4, 5]) | |
@pytest.mark.parametrize("last_dim_size", [1, 10]) | |
@pytest.mark.parametrize("n_nodes", [1, 10]) | |
@pytest.mark.parametrize("n_edges_per_node", [1, 10]) | |
def test_aggregate_edge_features(ndims, last_dim_size, n_nodes, n_edges_per_node): | |
edge_index = Tensor.uniform( | |
n_edges_per_node * n_nodes, low=0, high=n_nodes, dtype=dtypes.int | |
) | |
extra_dims = [1] * ndims | |
extra_dims[-1] = last_dim_size | |
edge_features = Tensor.rand(edge_index.shape[0], *extra_dims) | |
node_features = aggregate_edge_features(n_nodes, edge_features, edge_index) | |
reference = aggregate_edge_features_np( | |
n_nodes, edge_features.numpy(), edge_index.numpy() | |
) | |
assert np.allclose( | |
node_features.numpy(), | |
reference, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment