This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
edge_paths = [os.path.join(DATA_DIR, name) for name in FILENAMES.values()] | |
from torchbiggraph.converters.import_from_tsv import convert_input_data | |
convert_input_data( | |
CONFIG_PATH, | |
edge_paths, | |
lhs_col=0, | |
rhs_col=1, | |
rel_col=None, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torchbiggraph.config import parse_config | |
import attr | |
train_config = parse_config(CONFIG_PATH) | |
train_path = [convert_path(os.path.join(DATA_DIR, FILENAMES['train']))] | |
train_config = attr.evolve(train_config, edge_paths=train_path) | |
from torchbiggraph.train import train | |
train(train_config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torchbiggraph.eval import do_eval | |
eval_path = [convert_path(os.path.join(DATA_DIR, FILENAMES['test']))] | |
eval_config = attr.evolve(train_config, edge_paths=eval_path) | |
do_eval(eval_config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import h5py | |
with open(os.path.join(DATA_DIR,"dictionary.json"), "rt") as tf: | |
dictionary = json.load(tf) | |
user_id = "0" | |
offset = dictionary["entities"]["user_id"].index(user_id) | |
print("our offset for user_id " , user_id, " is: ", offset) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import random | |
""" | |
adapted from https://github.com/facebookresearch/PyTorch-BigGraph/blob/master/torchbiggraph/examples/livejournal.py | |
""" | |
FILENAMES = { | |
'train': 'train.txt', | |
'test': 'test.txt', | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
print("Now let's do some simple things within torch:") | |
from torchbiggraph.model import DotComparator | |
src_entity_offset = dictionary["entities"]["user_id"].index("0") # France | |
dest_1_entity_offset = dictionary["entities"]["user_id"].index("7") # Paris | |
dest_2_entity_offset = dictionary["entities"]["user_id"].index("1") # Paris | |
rel_type_index = dictionary["relations"].index("follow") # note we only have one... | |
with h5py.File("model/example_2/embeddings_user_id_0.v10.h5", "r") as hf: | |
src_embedding = hf["embeddings"][src_entity_offset, :] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
print("finally, let's do some ranking...") | |
entity_count = 8 | |
scores, _, _ = comparator( | |
comparator.prepare(torch.tensor(src_embedding.reshape([1,1,10]))).expand(1, entity_count, 10), | |
comparator.prepare(torch.tensor(dest_embeddings.reshape([1,8,10]))), | |
torch.empty(1, 0, 10), # Left-hand side negatives, not needed | |
torch.empty(1, 0, 10), # Right-hand side negatives, not needed | |
) | |
permutation = scores.flatten().argsort(descending=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
containers: | |
build-env: | |
image: python:3.7 | |
working_directory: /src | |
environment: | |
PYTHONPATH: "/src" | |
run_as_current_user: | |
enabled: true | |
home_directory: /home/container-user |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
containers: | |
build-env: | |
image: python:3.7 | |
volumes: | |
- local: . | |
container: /src | |
options: cached | |
- local: .pip-cache | |
container: /src/.pip-cache | |
options: cached |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
containers: | |
build-env: | |
image: python:3.7 | |
volumes: | |
- local: . | |
container: /src | |
options: cached | |
- local: .pip-cache | |
container: /src/.pip-cache | |
options: cached |