Created
April 8, 2021 15:05
-
-
Save ricardoV94/72dd91d91e25620bfe980c0d98373678 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import aesara | |
import aesara.tensor as at | |
from aesara.tensor import basic_opt | |
from aesara.tensor.nnet import sigmoid, softplus | |
from aesara import config | |
from aesara.graph.opt import PatternSub | |
from aesara.tensor.math import add, exp, mul, log, log1p, neg, sub, true_div | |
from aesara.graph.basic import graph_inputs as at_inputs | |
from aesara.graph.fg import FunctionGraph | |
from aesara.graph.optdb import Query | |
from aesara.compile import optdb | |
# We don't need to waste time compiling graphs to C | |
config.cxx = "" | |
# Manual version of the sigmoid function instead of using tt.nnet.sigmoid directly | |
def sigmoid2(x): | |
# Also works with the commented out version, except for one edge case with | |
# sigmoid2(logit(x)) which nevertheless still returns correct results | |
# return at.exp(x) / (1 + at.exp(x)) | |
return 1 / (1 + at.exp(-x)) | |
# Different versions of the logit function that are reasonable to appear | |
# Either because the user wrote them as such or due to automatic rewrites | |
def logit(x): | |
return at.log(x / (1-x)) | |
def logit2(x): | |
return at.log(x) - at.log(1-x) | |
def logit3(x): | |
return at.log(x) - at.log1p(-x) | |
def optimize_graphs(*graphs, verbose=False): | |
inputs = list(at_inputs(graphs)) | |
graphs = list(graphs) | |
fgraph = FunctionGraph(inputs, graphs, clone=False) | |
canonicalize_opt = optdb.query(Query(include=['fast_run'])) | |
with config.change_flags(optimizer_verbose=verbose): | |
canonicalize_opt(fgraph) | |
print(aesara.pp(fgraph.outputs[0])) | |
def register_pattern(pattern): | |
basic_opt.register_canonicalize(pattern) | |
basic_opt.register_stabilize(pattern) | |
basic_opt.register_specialize(pattern) | |
# log1p(-sigmoid(x)) -> -softplus(x) | |
# This is duplicated from sigm.py, because the original | |
# was not registered in enough places it seems | |
# for instance it would fail with -log1p(-sigmoid(x)) | |
log1p_neg_sigmoid = PatternSub( | |
(log1p, (neg, (sigmoid, "x"))), | |
(neg, (softplus, "x")), | |
allow_multiple_clients=True, | |
name="log1p_neg_sigmoid", | |
) | |
register_pattern(log1p_neg_sigmoid) | |
# 1 - sigmoid(x) -> sigmoid(-x) | |
# This is present in sigm.py but is disabled | |
# It doesn't provide a sizeable speedup but it could make some of the | |
# downstream rewrites more straightforward | |
# local_one_minus_sigm = PatternSub( | |
# (sub, 1, (sigmoid, "x")), | |
# (sigmoid, (neg, "x")), | |
# allow_multiple_clients=True, | |
# name="local_one_minus_sigm", | |
# ) | |
# register_pattern(local_one_minus_sigm) | |
# sigmoid(x) / sigmoid(-x) -> exp(x) | |
# This is a nice accident of trying to get the equivalences. However, it is not a | |
# fully fledged offer as it fails to do the reverse sigmoid(-x) / sigmoid(x) -> exp(-x) | |
# sigm_over_sigm_negx = PatternSub( | |
# (true_div, (sigmoid, "x"), (sigmoid, (neg, "x"))), | |
# (exp, "x"), | |
# allow_multiple_clients=True, | |
# name="sigm_over_sigm_negx", | |
# ) | |
# register_pattern(sigm_over_sigm_negx) | |
# sigmoid * (1 + exp(x)) -> exp(x) | |
# 1 / (sigmoid(-x)) can be converted to 1 + exp(x) when using an implicit sigmoid function | |
# resulting in the expression sigmoid * (1 + exp(x)), Should some other strategy be used? | |
# sigm_mul_1pexp = PatternSub( | |
# (mul, (sigmoid, "x"), (add, 1, (exp, "x"))), | |
# (exp, "x"), | |
# allow_multiple_clients=True, | |
# name="sigm_mul_1pexp", | |
# ) | |
# register_pattern(sigm_mul_1pexp) | |
# log(sigmoid(x) / (1 - sigmoid(x))) -> x | |
# only useful if not using the rewrite (1 - sigmoid(x)) -> sigmoid(-x) | |
useless_logit_sigmoid = PatternSub( | |
(log, (true_div, (sigmoid, "x"), (sub, 1, (sigmoid, "x")))), | |
"x", | |
allow_multiple_clients=True, | |
name="useless_logit_sigmoid" | |
) | |
register_pattern(useless_logit_sigmoid) | |
# softplus(x) - softplus(-x) -> x | |
# often arises when inverting an implicit sigmoid function | |
# ideally would write the pattern above, but with it aesara does not recognize | |
# the more common reverse version (-softplus(-x)) - (-softplus(x)) | |
useless_sub_softplus = PatternSub( | |
# (sub, (softplus, "x"), (softplus, (neg, "x"))), | |
(sub, (neg, (softplus, (neg, "x"))), (neg, (softplus, "x"))), | |
"x", | |
allow_multiple_clients=True, | |
name="useless_sub_softplus" | |
) | |
register_pattern(useless_sub_softplus) | |
# sigmoid(logit(x)) -> x | |
useless_sigmoid_logit = PatternSub( | |
(sigmoid, (log, (true_div, "x", (sub, 1, "x")))), | |
"x", | |
allow_multiple_clients=True, | |
name="useless_sigmoid_logit" | |
) | |
register_pattern(useless_sigmoid_logit) | |
# sigmoid(logit(x)) -> x, when log1p is present in the denominator | |
useless_sigmoid_logit2 = PatternSub( | |
(sigmoid, (sub, (log, "x"), (log1p, (neg, "x")))), | |
"x", | |
allow_multiple_clients=True, | |
name="useless_sigmoid_logit2" | |
) | |
register_pattern(useless_sigmoid_logit2) | |
x = at.vector('x') | |
print('Testing: log(sigmoid(+-x) -> -softflus(-+x)') | |
y = at.log(sigmoid2(x)) | |
optimize_graphs(y) | |
y = at.log(1 - sigmoid2(x)) | |
optimize_graphs(y) | |
y = at.log1p(-sigmoid2(x)) | |
optimize_graphs(y) | |
y = -at.log1p(-sigmoid2(x)) | |
optimize_graphs(y) | |
# print('\nTesting: 1 - sigmoid(x) -> sigmoid(-x)') | |
# y = 1 - sigmoid(x) | |
# optimize_graphs(y) | |
# y = 1 - sigmoid2(x) | |
# optimize_graphs(y) | |
# print('\nTesting: sigmoid(x) / sigmoid(-x) -> exp(x)') | |
# y = sigmoid(x) / (1 - sigmoid(x)) | |
# optimize_graphs(y) | |
# y = sigmoid(x) / (sigmoid(-x)) | |
# optimize_graphs(y) | |
# | |
# y = sigmoid2(x) / (1 - sigmoid2(x)) | |
# optimize_graphs(y) | |
# y = sigmoid2(x) / (sigmoid2(-x)) | |
# optimize_graphs(y) | |
# print('\nTesting: sigmoid(-x) / sigmoid(x) -> exp(-x)') | |
# y = (1 - sigmoid(x)) / (sigmoid(x)) | |
# optimize_graphs(y) | |
# y = sigmoid(-x) / sigmoid(x) | |
# optimize_graphs(y) | |
# | |
# y = (1 - sigmoid2(x)) / (sigmoid2(x)) | |
# optimize_graphs(y) | |
# y = sigmoid2(-x) / sigmoid2(x) | |
# optimize_graphs(y) | |
# Requires optimization log(exp(x)) -> x if going through | |
# (logit(x) / (1 - logit(x)) -> exp(x)` | |
# (see #https://github.com/pymc-devs/aesara/pull/364) | |
print('\nTesting: invlogit(sigmoid(x)) -> x') | |
y = logit(sigmoid(x)) | |
optimize_graphs(y) | |
y = logit2(sigmoid(x)) | |
optimize_graphs(y) | |
y = logit3(sigmoid(x)) | |
optimize_graphs(y) | |
y = logit(sigmoid2(x)) | |
optimize_graphs(y) | |
y = logit2(sigmoid2(x)) | |
optimize_graphs(y) | |
y = logit3(sigmoid2(x)) | |
optimize_graphs(y) | |
print('\nTesting: sigmoid(invlogit(x)) -> x') | |
y = sigmoid(logit(x)) | |
optimize_graphs(y) | |
y = sigmoid(logit2(x)) | |
optimize_graphs(y) | |
y = sigmoid(logit3(x)) | |
optimize_graphs(y) | |
y = sigmoid2(logit(x)) # This one gets funky wit the alternative sigmoid2 version | |
optimize_graphs(y) | |
y = sigmoid2(logit2(x)) | |
optimize_graphs(y) | |
y = sigmoid2(logit3(x)) | |
optimize_graphs(y) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment