Created
April 8, 2021 12:51
-
-
Save ricardoV94/87ed8447a639949132ffd688a606170a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import aesara | |
import aesara.tensor as at | |
from aesara.tensor import basic_opt | |
from aesara.tensor.nnet import sigmoid, softplus | |
from aesara.tensor.nnet.sigm import _is_1, _skip_mul_1 | |
from aesara import config | |
from aesara.graph.opt import PatternSub | |
from aesara.tensor.math import exp, log1p, neg, sub, true_div | |
from aesara.graph.basic import graph_inputs as at_inputs | |
from aesara.graph.fg import FunctionGraph | |
from aesara.graph.optdb import Query | |
from aesara.compile import optdb | |
from aesara.tensor.type import values_eq_approx_remove_inf | |
# We don't need to waste time compiling graphs to C | |
# config.cxx = "" | |
# Manual version of the sigmoid function | |
def sigmoid2(x): | |
return 1 / (1 + at.exp(-x)) | |
# Different version of the logit function that are reasonable to appear | |
# Either because the user wrote them as such or due to automatic rewrites | |
def logit(x): | |
return at.log(x / (1-x)) | |
def logit2(x): | |
return at.log(x) - at.log(1-x) | |
def logit3(x): | |
return at.log(x) - at.log1p(-x) | |
def optimize_graphs(*graphs, verbose=False): | |
inputs = list(at_inputs(graphs)) | |
graphs = list(graphs) | |
fgraph = FunctionGraph(inputs, graphs, clone=False) | |
canonicalize_opt = optdb.query(Query(include=['fast_run'])) | |
with config.change_flags(optimizer_verbose=verbose): | |
canonicalize_opt(fgraph) | |
print(aesara.pp(fgraph.outputs[0])) | |
# This is present in sigm.py but is disabled | |
# # 1 - sigmoid(x) -> sigmoid(-x) | |
_1msigm = PatternSub( | |
(sub, dict(pattern="y", constraint=_is_1), (sigmoid, "x")), | |
(sigmoid, (neg, "x")), | |
values_eq_approx=values_eq_approx_remove_inf, | |
allow_multiple_clients=True, | |
skip_identities_fn=_skip_mul_1, | |
) | |
# This shouldn't be needed but -log1p(-sigmoid) is not rewritten whereas log1p(-sigmoid) is | |
# Still fails | |
neg_log1p_neg_sigmoid = PatternSub( | |
(neg, (log1p, (neg, (sigmoid, "x")))), | |
(softplus, "x"), | |
values_eq_approx=values_eq_approx_remove_inf, | |
allow_multiple_clients=True, | |
) | |
# sigmoid(x) / (1 - sigmoid(x)) -> exp(x) | |
sigm_over_1msigm = PatternSub( | |
(true_div, (sigmoid, "x"), (sub, dict(pattern="y", constraint=_is_1), (sigmoid, "x"))), | |
(exp, "x"), | |
values_eq_approx=values_eq_approx_remove_inf, | |
allow_multiple_clients=True, | |
# skip_identities_fn=_skip_mul_1, | |
) | |
# sigmoid(x) / sigmoid(-x) -> exp(x) | |
# Fails to do anything | |
sigm_over_sigm_negx = PatternSub( | |
(true_div, (sigmoid, "x"), (sigmoid, (neg, "x"))), | |
(exp, "x"), | |
values_eq_approx=values_eq_approx_remove_inf, | |
allow_multiple_clients=True, | |
# skip_identities_fn=_skip_mul_1, | |
) | |
# softplus(x) - softplus(-x) | |
# Also not working | |
sub_uselusse_softplus = PatternSub( | |
(sub, (softplus, "x"), (softplus, (neg, "x"))), | |
"x", | |
values_eq_approx=values_eq_approx_remove_inf, | |
allow_multiple_clients=True, | |
# skip_identities_fn=_skip_mul_1, | |
) | |
# basic_opt.register_stabilize(logexp, name="logexp") | |
basic_opt.register_stabilize(neg_log1p_neg_sigmoid, name='neg_log1p_neg_sigmoid') | |
basic_opt.register_stabilize(_1msigm, name="1msigm") | |
basic_opt.register_stabilize(sigm_over_1msigm, name="sigm_over_1msigm") | |
basic_opt.register_stabilize(sigm_over_sigm_negx, name="sigm_over_sigm_negx") | |
basic_opt.register_stabilize(sub_uselusse_softplus, name="sub_uselusse_softplus") | |
x = at.vector('x') | |
# Testing conversion log(sigmoid(x)) -> softplus(x), | |
# All work with default sigmoid, but some fail with implicit sigmoid2 | |
print('Testing: sigmoid -> softflus') | |
y = at.log(sigmoid2(x)) # works | |
optimize_graphs(y) | |
y = at.log(1 - sigmoid2(x)) # fails without default sigmoid | |
optimize_graphs(y) | |
y = at.log1p(-sigmoid2(x)) # fails without default sigmoid | |
optimize_graphs(y) | |
y = -at.log1p(-sigmoid(x)) # negative makes it fail, even with custom obt added above | |
optimize_graphs(y) | |
# Testing conversion 1 - sigmoid(x) -> sigmoid(-x) | |
# This was added in sigm.py but disabled due to conflicts with other operations | |
print('\nTesting: 1 - sigmoid(x) -> sigmoid(-x)') | |
y = 1 - sigmoid(x) # works | |
optimize_graphs(y) | |
y = 1 - sigmoid2(x) # works if skip_identities_fn=_skip_mul_1, | |
optimize_graphs(y) | |
print('\nTesting: sigmoid(x) / sigmoid(-x) -> exp(x)') | |
# Testing conversion sigmoid(x) / (1 - sigmoid(x)) -> exp(x) | |
# Or equivalently sigmoid(x) / (sigmoid(-x) -> exp(x) | |
y = sigmoid(x) / (1 - sigmoid(x)) # works, if _1msigm is enabled | |
optimize_graphs(y) | |
y = sigmoid(x) / (sigmoid(-x)) # fails | |
optimize_graphs(y) | |
y = sigmoid2(x) / (1 - sigmoid2(x)) # fails | |
optimize_graphs(y) | |
y = sigmoid2(x) / (sigmoid2(-x)) # fails even more somehow | |
optimize_graphs(y) | |
print('\nTesting: invlogit(sigmoid(x)) -> x') | |
y = logit(sigmoid(x)) # works | |
optimize_graphs(y) | |
y = logit2(sigmoid(x)) # fails, | |
optimize_graphs(y) | |
y = logit3(sigmoid(x)) # works | |
optimize_graphs(y) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment