Skip to content

Instantly share code, notes, and snippets.

import os
from sqlalchemy import create_engine, func
from sqlalchemy.orm.session import sessionmaker
import psycopg2
import logging
def get_db_connection():
PGHOST = os.getenv("DB_HOST","localhost")
PGDATABASE = "postgres"
containers:
build-env:
image: python:3.7
volumes:
- local: .
container: /src
options: cached
- local: .pip-cache
container: /src/.pip-cache
options: cached
containers:
build-env:
image: python:3.7
volumes:
- local: .
container: /src
options: cached
- local: .pip-cache
container: /src/.pip-cache
options: cached
containers:
build-env:
image: python:3.7
working_directory: /src
environment:
PYTHONPATH: "/src"
run_as_current_user:
enabled: true
home_directory: /home/container-user
print("finally, let's do some ranking...")
entity_count = 8
scores, _, _ = comparator(
comparator.prepare(torch.tensor(src_embedding.reshape([1,1,10]))).expand(1, entity_count, 10),
comparator.prepare(torch.tensor(dest_embeddings.reshape([1,8,10]))),
torch.empty(1, 0, 10), # Left-hand side negatives, not needed
torch.empty(1, 0, 10), # Right-hand side negatives, not needed
)
permutation = scores.flatten().argsort(descending=True)
print("Now let's do some simple things within torch:")
from torchbiggraph.model import DotComparator
src_entity_offset = dictionary["entities"]["user_id"].index("0") # France
dest_1_entity_offset = dictionary["entities"]["user_id"].index("7") # Paris
dest_2_entity_offset = dictionary["entities"]["user_id"].index("1") # Paris
rel_type_index = dictionary["relations"].index("follow") # note we only have one...
with h5py.File("model/example_2/embeddings_user_id_0.v10.h5", "r") as hf:
src_embedding = hf["embeddings"][src_entity_offset, :]
import os
import random
"""
adapted from https://github.com/facebookresearch/PyTorch-BigGraph/blob/master/torchbiggraph/examples/livejournal.py
"""
FILENAMES = {
'train': 'train.txt',
'test': 'test.txt',
}
import json
import h5py
with open(os.path.join(DATA_DIR,"dictionary.json"), "rt") as tf:
dictionary = json.load(tf)
user_id = "0"
offset = dictionary["entities"]["user_id"].index(user_id)
print("our offset for user_id " , user_id, " is: ", offset)
from torchbiggraph.eval import do_eval
eval_path = [convert_path(os.path.join(DATA_DIR, FILENAMES['test']))]
eval_config = attr.evolve(train_config, edge_paths=eval_path)
do_eval(eval_config)
from torchbiggraph.config import parse_config
import attr
train_config = parse_config(CONFIG_PATH)
train_path = [convert_path(os.path.join(DATA_DIR, FILENAMES['train']))]
train_config = attr.evolve(train_config, edge_paths=train_path)
from torchbiggraph.train import train
train(train_config)