Skip to content

Instantly share code, notes, and snippets.

@sbalnojan
Created June 20, 2019 19:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sbalnojan/994c467a6e514f51a729c6753bc36430 to your computer and use it in GitHub Desktop.
Save sbalnojan/994c467a6e514f51a729c6753bc36430 to your computer and use it in GitHub Desktop.
print("finally, let's do some ranking...")
entity_count = 8
scores, _, _ = comparator(
comparator.prepare(torch.tensor(src_embedding.reshape([1,1,10]))).expand(1, entity_count, 10),
comparator.prepare(torch.tensor(dest_embeddings.reshape([1,8,10]))),
torch.empty(1, 0, 10), # Left-hand side negatives, not needed
torch.empty(1, 0, 10), # Right-hand side negatives, not needed
)
permutation = scores.flatten().argsort(descending=True)
top_entities = [dictionary["entities"]["user_id"][index] for index in permutation]
print(top_entities)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment