Skip to content

Instantly share code, notes, and snippets.

@aamedina
Created September 19, 2023 22:17
Show Gist options
  • Save aamedina/20d42bda351175ea088fad0dc328e2a5 to your computer and use it in GitHub Desktop.
Save aamedina/20d42bda351175ea088fad0dc328e2a5 to your computer and use it in GitHub Desktop.
An extremely basic generic RDF loader for PyG HeteroData
# extended from https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/datasets/entities.html#Entities
import os
import os.path as osp
import requests
from collections import defaultdict
from sklearn.preprocessing import MultiLabelBinarizer
import torch
import rdflib as rdf
from torch_geometric.data import HeteroData, InMemoryDataset
class RDFGraph(InMemoryDataset):
"""An RDF graph dataset."""
def __init__(self, root, url=None, transform=None, pre_transform=None):
self.url = url
self.name = osp.basename(url).split('.')[0]
super().__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_dir(self):
return osp.join(self.root, self.name, 'raw')
@property
def processed_dir(self):
return osp.join(self.root, self.name, 'processed')
@property
def raw_file_names(self):
return [osp.basename(self.url)]
@property
def processed_file_names(self):
return ['data.pt']
def download(self):
response = requests.get(self.url, allow_redirects=True)
if response.status_code != 200:
raise RuntimeError("Failed to download dataset.")
with open(self.raw_paths[0], 'wb') as f:
f.write(response.content)
def process(self):
graph = rdf.Graph().parse(self.raw_paths[0])
relations = list(set(graph.predicates()))
nodes = list(set(graph.subjects()).union(set(graph.objects())))
relations_dict = {rel: i for i, rel in enumerate(relations)}
nodes_dict = {node: i for i, node in enumerate(nodes)}
edge_data = defaultdict(list)
for s, p, o in graph.triples((None, None, None)):
src, dst, rel = nodes_dict[s], nodes_dict[o], relations_dict[p]
edge_data['edge_index'].append([src, dst])
edge_data['edge_type'].append(rel)
mlb = MultiLabelBinarizer()
node_types = mlb.fit_transform([[node.split('/')[-1]] for node in nodes])
data = HeteroData(
edge_index=torch.tensor(edge_data['edge_index'], dtype=torch.long).t().contiguous(),
edge_type=torch.tensor(edge_data['edge_type'], dtype=torch.long),
node_type=torch.tensor(node_types, dtype=torch.float)
)
torch.save(self.collate([data]), self.processed_paths[0])
@aamedina
Copy link
Author

schema = RDFGraph('/tmp', url="https://schema.org/version/latest/schemaorg-current-http.ttl")
d3fend = RDFGraph('/tmp', url="https://d3fend.mitre.org/ontologies/d3fend.owl")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment