Skip to content

Instantly share code, notes, and snippets.

@simonepri
Last active June 23, 2019 08:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save simonepri/f150224b918b6145e3e1e02574c5c3c7 to your computer and use it in GitHub Desktop.
Save simonepri/f150224b918b6145e3e1e02574c5c3c7 to your computer and use it in GitHub Desktop.
pbg-tools
from typing import Dict
import os
import sys
import random
import json
import attr
from torchbiggraph.config import ConfigSchema
from torchbiggraph.schema import DeepTypeError
from torchbiggraph.converters.import_from_tsv import convert_input_data
from torchbiggraph.converters.export_to_tsv import make_tsv
from torchbiggraph.eval import do_eval
from torchbiggraph.filtered_eval import FilteredRankingEvaluator
from torchbiggraph.train import train
def random_split_file(
filenames: Dict[str, str],
from_line: int = 0,
to_line: int = -1,
train_split: float = 0.75,
valid_split: float = 0.10,
overwrite: bool = False
):
edgelist_file = filenames['edgelist']
train_file = filenames['train']
valid_file = filenames['valid']
test_file = filenames['test']
output_paths = [train_file, valid_file, test_file]
if not overwrite and all(os.path.exists(path) for path in output_paths):
print('Found some files that indicate that the input data '
'has already been shuffled and split, not doing it again.')
print('These files are: %s' % ', '.join(output_paths))
return
if not os.path.exists(edgelist_file):
print('The edgelist file does not exists ')
print('The path provided was: %s' % edgelist_file)
sys.exit(1)
return
print('Shuffling and splitting train/test file. This may take a while.')
print('Reading data from file: %s' % filenames['edgelist'])
with open(filenames['edgelist'], 'rt') as in_tf:
lines = in_tf.readlines()
if from_line != 0 or to_line != -1:
lines = lines[from_line:to_line]
print('Shuffling data')
random.shuffle(lines)
train_split_len = int(len(lines) * train_split)
valid_split_len = int((len(lines) - train_split_len) * valid_split)
print('Splitting to train, validation and test files')
with open(train_file, 'wt') as out_tf_train:
for line in lines[:train_split_len]:
out_tf_train.write(line)
with open(valid_file, 'wt') as out_tf_valid:
for line in lines[train_split_len:train_split_len+valid_split_len]:
out_tf_valid.write(line)
with open(test_file, 'wt') as out_tf_test:
for line in lines[train_split_len+valid_split_len:]:
out_tf_test.write(line)
print('Total examples: %d' % len(lines))
print('Train examples: %d' % train_split_len)
print('Valid examples: %d' % valid_split_len)
print('Test examples: %d' % (len(lines) - train_split_len - valid_split_len))
def convert_path(
fname: str
) -> str:
basename, _ = os.path.splitext(fname)
out_dir = basename + '_partitioned'
return out_dir
def run_training(config: ConfigSchema, edges_paths: Dict[str, str], filtered: bool = False):
train_path = [convert_path(edges_paths['train'])]
train_config = attr.evolve(config, edge_paths=train_path)
train(train_config)
def run_evaluation(
config: ConfigSchema,
edges_paths: Dict[str, str],
filtered: bool = False,
all_negs: bool = True
):
eval_path = [convert_path(edges_paths['test'])]
if all_negs:
relations = [attr.evolve(relation, all_negs=all_negs) for relation in config.relations]
eval_config = attr.evolve(config, edge_paths=eval_path, relations=relations, num_uniform_negs=0)
else:
eval_config = attr.evolve(config, edge_paths=eval_path)
if filtered:
filter_paths = [
convert_path(edges_paths['test']),
convert_path(edges_paths['valid']),
convert_path(edges_paths['train']),
]
do_eval(eval_config, evaluator=FilteredRankingEvaluator(eval_config, filter_paths))
else:
do_eval(eval_config)
def parse_config(
config_dict: Dict
) -> ConfigSchema:
try:
config = ConfigSchema.from_dict(config_dict)
except DeepTypeError as err:
print("Error in the configuration file, aborting.", file=sys.stderr)
print(str(err), file=sys.stderr)
sys.exit(1)
return config
def input_from_tsv(
config: ConfigSchema,
edges_paths: Dict[str, str],
cols: Dict[str, int]
):
random_split_file(edges_paths, train_split=0.95, valid_split=0.45)
convert_input_data(
config.entities,
config.relations,
config.entity_path,
[edges_paths['test'], edges_paths['valid'], edges_paths['train']],
lhs_col=cols['lhs'],
rel_col=cols['rel'],
rhs_col=cols['rhs'],
)
def output_to_tsv(
config: ConfigSchema,
embs_paths: Dict[str, str]
):
dict_path = os.path.join(config.entity_path, 'dictionary.json')
with open(dict_path, "rt") as dict_file:
dump = json.load(dict_file)
with open(embs_paths['ent'], "wt+") as ent_emb_file, open(embs_paths['rel'], "wt+") as rel_emb_file:
make_tsv(config.checkpoint_path, dump["relations"], dump["entities"], ent_emb_file, rel_emb_file)
def run_pbg(
config_dict: Dict,
edges_paths: Dict[str, str],
embs_paths: Dict[str, str],
run_train: bool = True,
run_eval: bool = True,
filtered: bool = False,
all_negs: bool = True,
cols: Dict[str, int] = {'lhs': 0, 'rel': 1, 'rhs': 2}
):
config = parse_config(config_dict)
input_from_tsv(config, edges_paths, cols)
if run_train:
run_training(config, edges_paths)
if run_eval:
run_evaluation(config, edges_paths, filtered=filtered, all_negs=all_negs)
output_to_tsv(config, embs_paths)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment