We are kicking off our very first community sprint!
The community sprint resolves around adding missing type hints and TorchScript support for various functions across PyG, aiming to improve and clean-up our core codebase. Each individual contribution is designed to only take around 30 minutes to two hours to complete.
The sprint begins Wednesday October 12th with a kick off meeting at 8am PST.
The community sprint will last 2 weeks and we will have another live hangouts when the sprint has completed.
If you are interested in helping out, please also join our PyG slack channel #community-sprint-type-hints
for more information.
Type hints are currently used inconsistently in the torch-geometric
repository, and it would be nice to make them a complete, consistent thing across all datasets, models and utilities.
Adding type hint support in models also helps us to improve our TorchScript coverage across layers and models provided in nn.*
.
Take a look at the current implementation of contains_isolated_nodes
:
def contains_isolated_nodes(edge_index, num_nodes=None):
num_nodes = maybe_num_nodes(edge_index, num_nodes)
edge_index, _ = remove_self_loops(edge_index)
return torch.unique(edge_index.view(-1)).numel() < num_nodes
Adding type hints support to the function signature helps us to better understand its input and output, improving code readability:
def contains_isolated_nodes(
edge_index: Tensor,
num_nodes: Optional[int] = None,
) -> bool:
...
Importantly, it also lets us use it as part of a TorchScript model.
Without it, all arguments that miss type hints are expected to be PyTorch tensors (which is clearly not the case for the num_nodes
argument).
Without it, torch.jit.script
compilation will fail:
import torch
from torch_geometric.utils import contains_isolated_nodes
contains_isolated_nodes = torch.jit.script(contains_isolated_nodes)
contains_isolated_nodes(torch.tensor([[0, 1, 0], [1, 0, 0]])
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
File ".../pytorch_geometric/torch_geometric/utils/isolated.py", line 29, in contains_isolated_nodes
num_nodes = maybe_num_nodes(edge_index, num_nodes)
~~~~~~~~~~~~~~~ <--- HERE
edge_index, _ = remove_self_loops(edge_index)
return torch.unique(edge_index.view(-1)).numel() < num_nodes
RuntimeError: Cannot input a tensor of dimension other than 0 as a scalar argument
See here for a basic example to follow.
- Ensure you have read our contributing guidelines.
- Claim the functionality/model you want to improve. More information on this will follow soon.
- Implement the changes as in pyg-team/pytorch_geometric#5603. At best, ensure in a test that the model/function is convertable to a TorchScript program and that it results in the same output. It is okay to add type hint support for multiple functions/models within a single PR as long as you have assigned yourself to each of them, and the number of file changes stays at a reasonable number to ease reviewing (e.g., not more than 10 touched files).
- Open a PR to the PyG repository and name it: "[Type Hints]
{model_name/function_name}
". In addition, ensure that documentation is rendered properly (CI will build the documentation automatically for you). Afterwards, add your PR number to the "Improved type hint support" line inCHANGELOG.md
.
- If you are unfamiliar with how type hints work, you can read the Python library documentation on them, but it is probably even easier to just look at another PR that added them.
- The types will usually be obvious, e.g.,
Tensor
,bool
,int
,float
. Wrap them withinOptional[*]
whenever the argument can beNone
. - Specialized PyG type hints (e.g.,
Adj
) are defined intyping.py
. - In some rare cases, type hints are challenging to add, e.g., whenever a model/function supports a
Union
of different types (which may be the case for functions/models that also supportSparseTensor
, e.g.,edge_index: Union[Tensor, SparseTensor]
). In that case, TorchScript support can be achieved via thetorch.jit._overload
decorator. See here for an example. - The corresponding tests of PyG models and functions can be found in the
test/
directory. For example, tests fortorch_geometric/utils/isolated.py
can be found intest/utils/test_isolated.py
. You can run individual test files viapytest test/utils/test_isolated.py
. It is only necessary to test TorchScript support for thenn.*
andutils.*
packages. No changes necessary fordatasets.*
andtransforms.*
packages. - Ensure that the TorchScript variant compiles and achieves the same output:
Note that we generally gate TorchScript tests behind
from torch_geometric.testing import is_full_test ... edge_index = torch.tensor([[0, 1, 2, 0], [1, 0, 2, 0]]) assert contains_isolated_nodes(edge_index) if is_full_test(): jit = torch.jit.script(contains_isolated_nodes) assert jit(edge_index)
is_full_test()
which guarantees that TorchScript tests are only run nightly. You can enable full tests locally via theFULL_TEST=1
environment variable, e.g.,FULL_TEST=1 pytest test/utils/test_isolated.py
. It is only necessary to ensure TorchScript support for thenn.*
andutils.*
packages. No changes necessary fordatasets.*
andtransforms.*
packages.
This list may be incomplete. If you still find a function without missing type hints/TorchScript tests, please let us know or add them on your own.
-
nn.MetaLayer
-
nn.InstanceNorm
-
nn.GraphSizeNorm
-
nn.MessageNorm
-
nn.DiffGroupNorm
-
nn.TopKPooling
-
nn.SAGPooling
-
nn.EdgePooling
-
nn.PANPooling
-
nn.max_pool
(TorchScript support not possible) -
nn.avg_pool
(TorchScript support not possible) -
nn.max_pool_x
-
nn.max_pool_neighbor_x
(TorchScript support not possible) -
nn.avg_pool_x
-
nn.avg_pool_neighbor_x
(TorchScript support not possible) -
nn.Node2Vec
-
nn.DeepGraphInfomax
-
nn.InnerProductEncoder
-
nn.GAE
-
nn.VGAE
-
nn.ARGA
-
nn.ARGVA
-
nn.SignedGCN
-
nn.RENet
-
nn.GraphUNet
-
nn.SchNet
-
nn.DimeNet
-
nn.GNNExplainer
-
nn.DeepGCNLayer
-
nn.AttentiveFP
-
nn.DenseGCNConv
-
nn.DenseGINConv
-
nn.DenseGraphConv
-
nn.DenseSAGEConv
-
nn.dense_diff_pool
-
nn.dense_mincut_pool
-
datasets.NELL
-
datasets.PPI
-
datasets.Reddit
-
datasets.Reddit2
-
datasets.Yelp
-
datasets.QM7b
-
datasets.ZINC
-
datasets.MoleculeNet
-
datasets.MNISTSuperpixels
-
datasets.ShapeNet
-
datasets.ModelNet
-
datasets.SHREC2016
-
datasets.TOSCA
-
datasets.PCPNetDataset
-
datasets.S3DIS
-
datasets.ICEWS18
-
datasets.WILLOWObjectClass
-
datasets.PascalVOCKeypoints
-
datasets.PascalPF
-
datasets.SNAPDataset
-
datasets.SuiteSparseMatrixCollection
-
datasets.WordNet18
-
datasets.WebKB
-
datasets.JODIEDataset
-
datasets.MixHopSyntheticDataset
-
datasets.UPFD
-
transforms.Distance
-
transforms.Cartesian
-
transforms.LocalCartesian
-
transforms.Polar
-
transforms.Spherical
-
transforms.PointPairFeatures
-
transforms.OneHotDegree
-
transforms.TargetIndegree
-
transforms.RandomJitter
-
transforms.RandomFlip
-
transforms.RandomScale
-
transforms.RandomRotate
-
transforms.RandomShear
-
transforms.KNNGraph
-
transforms.FaceToEdge
-
transforms.SamplePoints
-
transforms.FixedPoints
-
transforms.ToDense
-
transforms.LaplacianLambdaMax
-
transforms.ToSLIC
-
transforms.GDC
-
transforms.SIGN
-
transforms.SVDFeatureReduction
-
utils.remove_isolated_nodes
-
utils.get_laplacian
-
utils.to_dense_adj
-
utils.dense_to_sparse
-
utils.normalized_cut
-
utils.grid
-
utils.geodesic_distance
(TorchScript support not possible) -
utils.tree_decomposition
(TorchScript support not possible) -
utils.to_scipy_sparse_matrix
(TorchScript support not possible) -
utils.from_scipy_sparse_matrix
(TorchScript support not possible) -
utils.to_networkx
(TorchScript support not possible) -
utils.from_networkx
(TorchScript support not possible) -
utils.erdos_renyi_graph
-
utils.stochastic_blockmodel_graph
-
utils.barabasi_albert_graph
-
utils.train_test_split_edges