Skip to content

Instantly share code, notes, and snippets.

@ricardoV94
Created April 8, 2021 15:05
Show Gist options
  • Save ricardoV94/72dd91d91e25620bfe980c0d98373678 to your computer and use it in GitHub Desktop.
Save ricardoV94/72dd91d91e25620bfe980c0d98373678 to your computer and use it in GitHub Desktop.
import aesara
import aesara.tensor as at
from aesara.tensor import basic_opt
from aesara.tensor.nnet import sigmoid, softplus
from aesara import config
from aesara.graph.opt import PatternSub
from aesara.tensor.math import add, exp, mul, log, log1p, neg, sub, true_div
from aesara.graph.basic import graph_inputs as at_inputs
from aesara.graph.fg import FunctionGraph
from aesara.graph.optdb import Query
from aesara.compile import optdb
# We don't need to waste time compiling graphs to C
config.cxx = ""
# Manual version of the sigmoid function instead of using tt.nnet.sigmoid directly
def sigmoid2(x):
# Also works with the commented out version, except for one edge case with
# sigmoid2(logit(x)) which nevertheless still returns correct results
# return at.exp(x) / (1 + at.exp(x))
return 1 / (1 + at.exp(-x))
# Different versions of the logit function that are reasonable to appear
# Either because the user wrote them as such or due to automatic rewrites
def logit(x):
return at.log(x / (1-x))
def logit2(x):
return at.log(x) - at.log(1-x)
def logit3(x):
return at.log(x) - at.log1p(-x)
def optimize_graphs(*graphs, verbose=False):
inputs = list(at_inputs(graphs))
graphs = list(graphs)
fgraph = FunctionGraph(inputs, graphs, clone=False)
canonicalize_opt = optdb.query(Query(include=['fast_run']))
with config.change_flags(optimizer_verbose=verbose):
canonicalize_opt(fgraph)
print(aesara.pp(fgraph.outputs[0]))
def register_pattern(pattern):
basic_opt.register_canonicalize(pattern)
basic_opt.register_stabilize(pattern)
basic_opt.register_specialize(pattern)
# log1p(-sigmoid(x)) -> -softplus(x)
# This is duplicated from sigm.py, because the original
# was not registered in enough places it seems
# for instance it would fail with -log1p(-sigmoid(x))
log1p_neg_sigmoid = PatternSub(
(log1p, (neg, (sigmoid, "x"))),
(neg, (softplus, "x")),
allow_multiple_clients=True,
name="log1p_neg_sigmoid",
)
register_pattern(log1p_neg_sigmoid)
# 1 - sigmoid(x) -> sigmoid(-x)
# This is present in sigm.py but is disabled
# It doesn't provide a sizeable speedup but it could make some of the
# downstream rewrites more straightforward
# local_one_minus_sigm = PatternSub(
# (sub, 1, (sigmoid, "x")),
# (sigmoid, (neg, "x")),
# allow_multiple_clients=True,
# name="local_one_minus_sigm",
# )
# register_pattern(local_one_minus_sigm)
# sigmoid(x) / sigmoid(-x) -> exp(x)
# This is a nice accident of trying to get the equivalences. However, it is not a
# fully fledged offer as it fails to do the reverse sigmoid(-x) / sigmoid(x) -> exp(-x)
# sigm_over_sigm_negx = PatternSub(
# (true_div, (sigmoid, "x"), (sigmoid, (neg, "x"))),
# (exp, "x"),
# allow_multiple_clients=True,
# name="sigm_over_sigm_negx",
# )
# register_pattern(sigm_over_sigm_negx)
# sigmoid * (1 + exp(x)) -> exp(x)
# 1 / (sigmoid(-x)) can be converted to 1 + exp(x) when using an implicit sigmoid function
# resulting in the expression sigmoid * (1 + exp(x)), Should some other strategy be used?
# sigm_mul_1pexp = PatternSub(
# (mul, (sigmoid, "x"), (add, 1, (exp, "x"))),
# (exp, "x"),
# allow_multiple_clients=True,
# name="sigm_mul_1pexp",
# )
# register_pattern(sigm_mul_1pexp)
# log(sigmoid(x) / (1 - sigmoid(x))) -> x
# only useful if not using the rewrite (1 - sigmoid(x)) -> sigmoid(-x)
useless_logit_sigmoid = PatternSub(
(log, (true_div, (sigmoid, "x"), (sub, 1, (sigmoid, "x")))),
"x",
allow_multiple_clients=True,
name="useless_logit_sigmoid"
)
register_pattern(useless_logit_sigmoid)
# softplus(x) - softplus(-x) -> x
# often arises when inverting an implicit sigmoid function
# ideally would write the pattern above, but with it aesara does not recognize
# the more common reverse version (-softplus(-x)) - (-softplus(x))
useless_sub_softplus = PatternSub(
# (sub, (softplus, "x"), (softplus, (neg, "x"))),
(sub, (neg, (softplus, (neg, "x"))), (neg, (softplus, "x"))),
"x",
allow_multiple_clients=True,
name="useless_sub_softplus"
)
register_pattern(useless_sub_softplus)
# sigmoid(logit(x)) -> x
useless_sigmoid_logit = PatternSub(
(sigmoid, (log, (true_div, "x", (sub, 1, "x")))),
"x",
allow_multiple_clients=True,
name="useless_sigmoid_logit"
)
register_pattern(useless_sigmoid_logit)
# sigmoid(logit(x)) -> x, when log1p is present in the denominator
useless_sigmoid_logit2 = PatternSub(
(sigmoid, (sub, (log, "x"), (log1p, (neg, "x")))),
"x",
allow_multiple_clients=True,
name="useless_sigmoid_logit2"
)
register_pattern(useless_sigmoid_logit2)
x = at.vector('x')
print('Testing: log(sigmoid(+-x) -> -softflus(-+x)')
y = at.log(sigmoid2(x))
optimize_graphs(y)
y = at.log(1 - sigmoid2(x))
optimize_graphs(y)
y = at.log1p(-sigmoid2(x))
optimize_graphs(y)
y = -at.log1p(-sigmoid2(x))
optimize_graphs(y)
# print('\nTesting: 1 - sigmoid(x) -> sigmoid(-x)')
# y = 1 - sigmoid(x)
# optimize_graphs(y)
# y = 1 - sigmoid2(x)
# optimize_graphs(y)
# print('\nTesting: sigmoid(x) / sigmoid(-x) -> exp(x)')
# y = sigmoid(x) / (1 - sigmoid(x))
# optimize_graphs(y)
# y = sigmoid(x) / (sigmoid(-x))
# optimize_graphs(y)
#
# y = sigmoid2(x) / (1 - sigmoid2(x))
# optimize_graphs(y)
# y = sigmoid2(x) / (sigmoid2(-x))
# optimize_graphs(y)
# print('\nTesting: sigmoid(-x) / sigmoid(x) -> exp(-x)')
# y = (1 - sigmoid(x)) / (sigmoid(x))
# optimize_graphs(y)
# y = sigmoid(-x) / sigmoid(x)
# optimize_graphs(y)
#
# y = (1 - sigmoid2(x)) / (sigmoid2(x))
# optimize_graphs(y)
# y = sigmoid2(-x) / sigmoid2(x)
# optimize_graphs(y)
# Requires optimization log(exp(x)) -> x if going through
# (logit(x) / (1 - logit(x)) -> exp(x)`
# (see #https://github.com/pymc-devs/aesara/pull/364)
print('\nTesting: invlogit(sigmoid(x)) -> x')
y = logit(sigmoid(x))
optimize_graphs(y)
y = logit2(sigmoid(x))
optimize_graphs(y)
y = logit3(sigmoid(x))
optimize_graphs(y)
y = logit(sigmoid2(x))
optimize_graphs(y)
y = logit2(sigmoid2(x))
optimize_graphs(y)
y = logit3(sigmoid2(x))
optimize_graphs(y)
print('\nTesting: sigmoid(invlogit(x)) -> x')
y = sigmoid(logit(x))
optimize_graphs(y)
y = sigmoid(logit2(x))
optimize_graphs(y)
y = sigmoid(logit3(x))
optimize_graphs(y)
y = sigmoid2(logit(x)) # This one gets funky wit the alternative sigmoid2 version
optimize_graphs(y)
y = sigmoid2(logit2(x))
optimize_graphs(y)
y = sigmoid2(logit3(x))
optimize_graphs(y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment