Skip to content

Instantly share code, notes, and snippets.

View sbalnojan's full-sized avatar

Sven Balnojan sbalnojan

View GitHub Profile
edge_paths = [os.path.join(DATA_DIR, name) for name in FILENAMES.values()]
from torchbiggraph.converters.import_from_tsv import convert_input_data
convert_input_data(
CONFIG_PATH,
edge_paths,
lhs_col=0,
rhs_col=1,
rel_col=None,
from torchbiggraph.config import parse_config
import attr
train_config = parse_config(CONFIG_PATH)
train_path = [convert_path(os.path.join(DATA_DIR, FILENAMES['train']))]
train_config = attr.evolve(train_config, edge_paths=train_path)
from torchbiggraph.train import train
train(train_config)
from torchbiggraph.eval import do_eval
eval_path = [convert_path(os.path.join(DATA_DIR, FILENAMES['test']))]
eval_config = attr.evolve(train_config, edge_paths=eval_path)
do_eval(eval_config)
import json
import h5py
with open(os.path.join(DATA_DIR,"dictionary.json"), "rt") as tf:
dictionary = json.load(tf)
user_id = "0"
offset = dictionary["entities"]["user_id"].index(user_id)
print("our offset for user_id " , user_id, " is: ", offset)
import os
import random
"""
adapted from https://github.com/facebookresearch/PyTorch-BigGraph/blob/master/torchbiggraph/examples/livejournal.py
"""
FILENAMES = {
'train': 'train.txt',
'test': 'test.txt',
}
print("Now let's do some simple things within torch:")
from torchbiggraph.model import DotComparator
src_entity_offset = dictionary["entities"]["user_id"].index("0") # France
dest_1_entity_offset = dictionary["entities"]["user_id"].index("7") # Paris
dest_2_entity_offset = dictionary["entities"]["user_id"].index("1") # Paris
rel_type_index = dictionary["relations"].index("follow") # note we only have one...
with h5py.File("model/example_2/embeddings_user_id_0.v10.h5", "r") as hf:
src_embedding = hf["embeddings"][src_entity_offset, :]
print("finally, let's do some ranking...")
entity_count = 8
scores, _, _ = comparator(
comparator.prepare(torch.tensor(src_embedding.reshape([1,1,10]))).expand(1, entity_count, 10),
comparator.prepare(torch.tensor(dest_embeddings.reshape([1,8,10]))),
torch.empty(1, 0, 10), # Left-hand side negatives, not needed
torch.empty(1, 0, 10), # Right-hand side negatives, not needed
)
permutation = scores.flatten().argsort(descending=True)
containers:
build-env:
image: python:3.7
working_directory: /src
environment:
PYTHONPATH: "/src"
run_as_current_user:
enabled: true
home_directory: /home/container-user
containers:
build-env:
image: python:3.7
volumes:
- local: .
container: /src
options: cached
- local: .pip-cache
container: /src/.pip-cache
options: cached
containers:
build-env:
image: python:3.7
volumes:
- local: .
container: /src
options: cached
- local: .pip-cache
container: /src/.pip-cache
options: cached