Last active
August 9, 2021 14:11
-
-
Save msetzu/48e6bcdc5f1eb3b48e062d8ffad47af1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
from numpy import finfo, inf | |
from models import Rule | |
def lore_to_glocalx(json_file, info_file, class_name='class'): | |
"""Load file `json_file` and `info_file` and return the loaded JSON | |
rules in a list of `Rule` objects preserving the loading order. | |
Args: | |
json_file (str): Path to the JSON file. | |
info_file (str): Path to the info file containing the rules' metadata. | |
class_name (str): Name of the class feature found in the rules. Defaults to 'class' | |
Returns: | |
(list): List of `Rule` objects. | |
""" | |
with open(json_file, 'r') as rules_log, open(info_file, 'r') as info_log: | |
rules = json.load(rules_log) | |
infos = json.load(info_log) | |
class_values = infos['class_values'] | |
feature_names = infos['feature_names'] | |
rules = [r for r in rules if len(r) > 0] | |
output_rules = [] | |
for rule in rules: | |
# rules are [consequence, premises, infos] | |
if len(rule) != 3 or class_name not in rule[0]: | |
continue | |
try: | |
consequence = class_values.index(rule[0][class_name]) | |
except KeyError as e: | |
raise e | |
actual_premises = dict() | |
for f, val in rule[1].items(): | |
# lower and upper bound, add the check on feature for categorical values with '<' and '<=' | |
if '< ' in val and ' <=' in val and f in feature_names: | |
f_idx = feature_names.index(f) | |
actual_premises[f_idx] = (float(val.split('< ')[0]), float(val.split(' <=')[1])) | |
# lower but no upper bound | |
elif val.startswith('>') and val[1] != '=': | |
f_idx = feature_names.index(f) | |
actual_premises[f_idx] = (float(val.split('>')[1]), inf) | |
# no lower but upper bound | |
elif val.startswith('<='): | |
f_idx = feature_names.index(f) | |
actual_premises[f_idx] = (-inf, float(val.split('<=')[1])) | |
# one-hot | |
elif ('<' not in val and '<=' not in val) or ('< ' in val and ' <=' in val and f not in feature_names): | |
if f + '=' + val in feature_names: | |
f_idx = feature_names.index(f + '=' + val) | |
actual_premises[f_idx] = (0.5, 1 + finfo(float).eps) | |
else: | |
f_idx = feature_names.index(f) | |
actual_premises[f_idx] = (float(val) - finfo(float).eps, float(val)) | |
transformed_rule = Rule(premises=actual_premises, consequence=consequence, names=feature_names) | |
output_rules.append(transformed_rule) | |
output_rules = set(output_rules) | |
return output_rules |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment