This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def removeNoisyData(data): | |
output = {} | |
for k, v in data.items(): | |
if not k in SNAPSHOT and not hasattr(v, "__call__") and k != "SNAPSHOT": | |
output[k] = v | |
return output | |
def expr(s, data=None): | |
if data is None: | |
_data = removeNoisyData(globals()) | |
else: | |
_data = removeNoisyData(data) | |
result = s | |
for name, value in _data.items(): | |
target = '`{}'.format(name) | |
if target in result: | |
result = result.replace(target, value) | |
return result | |
def sum(expression, bottom=None, top=None, data=None): | |
_expression = expr(expression, data) | |
if bottom is None and top is None: | |
return "\sum{{ {} }}".format(_expression) | |
if bottom is None: | |
return "\sum^{{ {} }} {{ {} }}".format(top, _expression) | |
if top is None: | |
return "\sum_{{ {} }} {{ {} }}".format(bottom, _expression) | |
SNAPSHOT = set(globals().keys()) | |
def pp(s): | |
print("\n{}\n".format(expr(s))) | |
################################################################################################################ | |
Mq = "\mathcal{M}_{q}(t)" | |
Md = "\mathcal{M}_{d}(t)" | |
t_in_V = 't \in V' | |
s = lambda x: sum(x, bottom=t_in_V) | |
A = s(r'`Mq \cdot log \frac{`Mq}{`Md}') | |
B = s(r"`Mq \cdot log `Mq") | |
C = s(r"`Mq \cdot log `Md") | |
pp("`A = `B - `C") | |
def KL_distance(p, q, bottom=None): | |
_s = lambda x, d: sum(x, bottom = bottom, data=d) | |
A = _s(r'`p \cdot log \frac{`p}{`q}', locals()) | |
B = _s(r"`p \cdot log `p", locals()) | |
C = _s(r"`p \cdot log `q", locals()) | |
return expr("`A = `B - `C", data=locals()) | |
Mq = "\mathcal{M}_{q}(t)" | |
Md = "\mathcal{M}_{d}(t)" | |
print(KL_distance(p = Mq, q = Md, bottom="t \in V")) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment