Skip to content

Instantly share code, notes, and snippets.

@samehkamaleldin
Last active August 27, 2020 15:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save samehkamaleldin/6eafc237ad2fce563199c10cc6b8d46f to your computer and use it in GitHub Desktop.
Save samehkamaleldin/6eafc237ad2fce563199c10cc6b8d46f to your computer and use it in GitHub Desktop.
Generate pse dataset
# -*- coding: utf-8 -*-
import os
import numpy as np
import gzip
import math
def export_entry(entry, out_fd):
"""
:param entry:
:param out_fd:
:return:
"""
for s, p, o in entry:
out_fd.write("%s\t%s\t%s\n" % (s, p, o))
def generate_random_splits(data, nb_splits=10):
"""
split dataset into random equal size pieces
:param data: np.array
dataset np array
:param nb_splits: int
number of splits
:return:
"""
data_size = len(data)
data_indices = np.arange(data_size)
np.random.shuffle(data_indices)
split_size = int(math.ceil(data_size/nb_splits))
for idx in range(0, nb_splits):
yield data[data_indices][idx * split_size:min(data_size, (idx + 1) * split_size), :]
def main():
seed = 1234
np.random.seed(seed)
decagon_dp = "./decagon_data"
kg_dp = "./kg"
poly_kg_fd = open(os.path.join(kg_dp, "ploypharmacy_facts.txt"), "w")
poly_tr_kg_fd = open(os.path.join(kg_dp, "ploypharmacy_facts_train.txt"), "w")
poly_vl_kg_fd = open(os.path.join(kg_dp, "ploypharmacy_facts_valid.txt"), "w")
poly_ts_kg_fd = open(os.path.join(kg_dp, "ploypharmacy_facts_test.txt"), "w")
rest_kg_fd = open(os.path.join(kg_dp, "drug_se_facts.txt"), "w")
se_double_fp = os.path.join(decagon_dp, "bio-decagon-combo.csv")
se_categories_fp = os.path.join(decagon_dp, "bio-decagon-effectcategories.csv")
se_single_fp = os.path.join(decagon_dp, "bio-decagon-mono.csv")
ppi_fp = os.path.join(decagon_dp, "bio-decagon-ppi.csv")
drug_t_fp = os.path.join(decagon_dp, "bio-decagon-targets.csv")
drug_t_all_fp = os.path.join(decagon_dp, "bio-decagon-targets-all.csv")
ppi_triples = [["GENE:%s" % g1, "INTERACT_WITH", "GENE:%s" % g2] for g1, g2 in
[l.strip().split(",") for l in open(ppi_fp).readlines()[1:]]]
dt_triples = [["DRUG:%s" % d, "DRUG_TARGET", "GENE:%s" % g] for d, g in
[l.strip().split(",") for l in open(drug_t_all_fp).readlines()[1:]]]
se_cats = [["SE:%s" % se, "SE_CATEGORY", "CAT:%s" % cat.replace(" ", "_")] for se, _, cat in
[l.strip().split(",") for l in open(se_categories_fp).readlines()[1:]]]
se_ploy = [["DRUG:%s" % d1, "SE:%s" % se, "DRUG:%s" % d2] for d1, d2, se, _ in
[l.strip().split(",") for l in open(se_double_fp).readlines()[1:]]]
mono_raw = [l.strip().split(",") for l in open(se_single_fp).readlines()[1:]]
mono_triples = [["DRUG:%s" % v[0], "DRUG_SIDE_EFFECT", "SE:%s" % v[1]] for v in mono_raw]
ploy_se_uniq = set(list([se for _, se, _ in se_ploy]))
mono_triples = np.array([[s, p, o] for s, p, o in mono_triples if o not in ploy_se_uniq])
# ------------------------------------------------------------------------------------
# generate splits
benchmark_se = list(set([v[1] for v in se_ploy]))
se_ploy_train = []
se_ploy_valid = []
se_ploy_test = []
rel_dict = dict()
for s, p, o in se_ploy:
if p not in rel_dict:
rel_dict[p] = 1
else:
rel_dict[p] += 1
ignored_rels = set([r for r in rel_dict if rel_dict[r] < 500])
se_facts_full_dict = {se: [] for se in benchmark_se if se not in ignored_rels}
print("Ignored %d side effects" % len(ignored_rels))
se_ploy = [[s, p, o] for s, p, o in se_ploy if p not in ignored_rels]
# populate se groups
for s, p, o in se_ploy:
se_facts_full_dict[p].append([s, p, o])
for k in se_facts_full_dict:
se_facts_full_dict[k] = np.array(se_facts_full_dict[k])
# shuffle se groups
for k in se_facts_full_dict:
np.random.shuffle(se_facts_full_dict[k])
for k in se_facts_full_dict:
data_size = len(se_facts_full_dict[k])
test_split_idx = int(data_size / 10)
valid_split_idx = test_split_idx * 2
se_ploy_valid.extend(se_facts_full_dict[k][0: test_split_idx].tolist())
se_ploy_test.extend(se_facts_full_dict[k][test_split_idx: valid_split_idx].tolist())
se_ploy_train.extend(se_facts_full_dict[k][valid_split_idx:].tolist())
# ------------------------------------------------------------------------------------
# export facts
export_entry(se_ploy, poly_kg_fd)
export_entry(se_ploy_train, poly_tr_kg_fd)
export_entry(se_ploy_valid, poly_vl_kg_fd)
export_entry(se_ploy_test, poly_ts_kg_fd)
poly_kg_fd.close()
poly_tr_kg_fd.close()
poly_vl_kg_fd.close()
poly_ts_kg_fd.close()
export_entry(ppi_triples, rest_kg_fd)
export_entry(dt_triples, rest_kg_fd)
export_entry(se_cats, rest_kg_fd)
export_entry(mono_triples, rest_kg_fd)
rest_kg_fd.close()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment