Skip to content

Instantly share code, notes, and snippets.

@sbalnojan
Created June 20, 2019 19:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sbalnojan/bd98a77740141211f46135751e200322 to your computer and use it in GitHub Desktop.
Save sbalnojan/bd98a77740141211f46135751e200322 to your computer and use it in GitHub Desktop.
print("Now let's do some simple things within torch:")
from torchbiggraph.model import DotComparator
src_entity_offset = dictionary["entities"]["user_id"].index("0") # France
dest_1_entity_offset = dictionary["entities"]["user_id"].index("7") # Paris
dest_2_entity_offset = dictionary["entities"]["user_id"].index("1") # Paris
rel_type_index = dictionary["relations"].index("follow") # note we only have one...
with h5py.File("model/example_2/embeddings_user_id_0.v10.h5", "r") as hf:
src_embedding = hf["embeddings"][src_entity_offset, :]
dest_1_embedding = hf["embeddings"][dest_1_entity_offset, :]
dest_2_embedding = hf["embeddings"][dest_2_entity_offset, :]
dest_embeddings = hf["embeddings"][...]
import torch
comparator = DotComparator()
scores_1, _, _ = comparator(
comparator.prepare(torch.tensor(src_embedding.reshape([1,1,10]))),
comparator.prepare(torch.tensor(dest_1_embedding.reshape([1,1,10]))),
torch.empty(1, 0, 10), # Left-hand side negatives, not needed
torch.empty(1, 0, 10), # Right-hand side negatives, not needed
)
scores_2, _, _ = comparator(
comparator.prepare(torch.tensor(src_embedding.reshape([1,1,10]))),
comparator.prepare(torch.tensor(dest_2_embedding.reshape([1,1,10]))),
torch.empty(1, 0, 10), # Left-hand side negatives, not needed
torch.empty(1, 0, 10), # Right-hand side negatives, not needed
)
print(scores_1)
print(scores_2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment