Skip to content

Instantly share code, notes, and snippets.

@acceptable-security
Created June 26, 2021 02:49
Show Gist options
  • Save acceptable-security/2228fc1bbdc0fef7b5b40e69cc072def to your computer and use it in GitHub Desktop.
Save acceptable-security/2228fc1bbdc0fef7b5b40e69cc072def to your computer and use it in GitHub Desktop.
A really shoddy datalog engine I wrote in a few minutes
from dataclasses import dataclass, field
from typing import Dict, Generator, List, Tuple, Union
@dataclass(frozen=True)
class Constant():
value: Union[int, str, float]
def __str__(self) -> str:
return str(self.value)
@dataclass
class Term():
value: Union[Constant, str] # str -> variable
def __str__(self) -> str:
return str(self.value)
@dataclass
class Atom():
relation: str
term_list: List[Term]
def __str__(self) -> str:
return f"{self.relation}({', '.join(map(str, self.term_list))})"
@dataclass
class Rule():
implied: Atom
atom_list: List[Atom]
def __str__(self) -> str:
return f"{str(self.implied)} :- {', '.join(map(str, self.atom_list))}."
@property
def relation(self):
return self.implied.relation
@dataclass
class Fact():
relation: str
constants: List[Constant]
def __str__(self) -> str:
return f"{self.relation}({', '.join(map(str, self.constants))})."
@dataclass
class Program():
relationships: List[Tuple[str, List[type]]]
stmts: List[Union[Fact, Rule]] = field(default_factory=list)
def __str__(self) -> str:
return "\n\n".join(map(str, self.stmts))
def execute(self):
facts = [ fact for fact in self.stmts if isinstance(fact, Fact) ]
rules = [ rule for rule in self.stmts if isinstance(rule, Rule) ]
database: Dict[str, Set[Tuple[Constant, ...]]] = {
relation: set({
tuple( constant.value for constant in fact.constants )
for fact in facts
if fact.relation == relation
})
for (relation, _) in self.relationships
}
while True:
fixed_point = True
for rule in rules:
def eval_rule(bounds: Dict[str, Constant], a_rest: List[Atom]) -> List[Tuple[Constant, ...]]:
if len(a_rest) == 0:
return [
tuple(
bounds[term.value]
for term in rule.implied.term_list
)
]
atom, a_rest = a_rest[0], a_rest[1:]
def eval_terms(terms: List[Constant], binds: Dict[str, Constant], rest: List[Term]) -> List[Tuple[Constant, ...]]:
if len(rest) == 0:
return [ (terms, binds) ]
term, rest = rest[0], rest[1:]
results = []
if isinstance(term.value, str):
if term.value in bounds:
results += eval_terms(terms + [ bounds[term.value] ], binds, rest)
elif term.value in binds:
results += eval_terms(terms + [ binds[term.value] ], binds, rest)
else:
for value in database[atom.relation]:
if value[:len(terms)] == tuple(terms):
value_binds = binds.copy()
value_binds[term.value] = value[len(terms)]
results += eval_terms(terms + [ value[len(terms)] ], value_binds, rest)
else:
results += eval_terms(terms + [ term.value ], binds, rest)
return results
output = []
for (terms, binds) in eval_terms([], {}, atom.term_list):
next_bounds = bounds.copy()
for (bind, value) in binds.items():
next_bounds[bind] = value
output += eval_rule(next_bounds, a_rest)
return output
for terms in eval_rule({}, rule.atom_list):
if terms not in database[rule.relation]:
fixed_point = False
database[rule.relation].add(terms)
if fixed_point:
break
return database
if __name__ == "__main__":
x = Program(
[
('parent', [ str, str ]),
('sibling', [ str, str ]),
],
[
Fact('parent', [ Constant('a'), Constant('b') ]),
Fact('parent', [ Constant('a'), Constant('c') ]),
Rule(
Atom('sibling', [Term('X'), Term('Y')]),
[
Atom('parent', [Term('Z'), Term('X')]),
Atom('parent', [Term('Z'), Term('Y')]),
]
)
],
)
print(x.execute())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment