Skip to content

Instantly share code, notes, and snippets.

@ricardoV94
Created April 8, 2021 12:51
Show Gist options
  • Save ricardoV94/87ed8447a639949132ffd688a606170a to your computer and use it in GitHub Desktop.
Save ricardoV94/87ed8447a639949132ffd688a606170a to your computer and use it in GitHub Desktop.
import aesara
import aesara.tensor as at
from aesara.tensor import basic_opt
from aesara.tensor.nnet import sigmoid, softplus
from aesara.tensor.nnet.sigm import _is_1, _skip_mul_1
from aesara import config
from aesara.graph.opt import PatternSub
from aesara.tensor.math import exp, log1p, neg, sub, true_div
from aesara.graph.basic import graph_inputs as at_inputs
from aesara.graph.fg import FunctionGraph
from aesara.graph.optdb import Query
from aesara.compile import optdb
from aesara.tensor.type import values_eq_approx_remove_inf
# We don't need to waste time compiling graphs to C
# config.cxx = ""
# Manual version of the sigmoid function
def sigmoid2(x):
return 1 / (1 + at.exp(-x))
# Different version of the logit function that are reasonable to appear
# Either because the user wrote them as such or due to automatic rewrites
def logit(x):
return at.log(x / (1-x))
def logit2(x):
return at.log(x) - at.log(1-x)
def logit3(x):
return at.log(x) - at.log1p(-x)
def optimize_graphs(*graphs, verbose=False):
inputs = list(at_inputs(graphs))
graphs = list(graphs)
fgraph = FunctionGraph(inputs, graphs, clone=False)
canonicalize_opt = optdb.query(Query(include=['fast_run']))
with config.change_flags(optimizer_verbose=verbose):
canonicalize_opt(fgraph)
print(aesara.pp(fgraph.outputs[0]))
# This is present in sigm.py but is disabled
# # 1 - sigmoid(x) -> sigmoid(-x)
_1msigm = PatternSub(
(sub, dict(pattern="y", constraint=_is_1), (sigmoid, "x")),
(sigmoid, (neg, "x")),
values_eq_approx=values_eq_approx_remove_inf,
allow_multiple_clients=True,
skip_identities_fn=_skip_mul_1,
)
# This shouldn't be needed but -log1p(-sigmoid) is not rewritten whereas log1p(-sigmoid) is
# Still fails
neg_log1p_neg_sigmoid = PatternSub(
(neg, (log1p, (neg, (sigmoid, "x")))),
(softplus, "x"),
values_eq_approx=values_eq_approx_remove_inf,
allow_multiple_clients=True,
)
# sigmoid(x) / (1 - sigmoid(x)) -> exp(x)
sigm_over_1msigm = PatternSub(
(true_div, (sigmoid, "x"), (sub, dict(pattern="y", constraint=_is_1), (sigmoid, "x"))),
(exp, "x"),
values_eq_approx=values_eq_approx_remove_inf,
allow_multiple_clients=True,
# skip_identities_fn=_skip_mul_1,
)
# sigmoid(x) / sigmoid(-x) -> exp(x)
# Fails to do anything
sigm_over_sigm_negx = PatternSub(
(true_div, (sigmoid, "x"), (sigmoid, (neg, "x"))),
(exp, "x"),
values_eq_approx=values_eq_approx_remove_inf,
allow_multiple_clients=True,
# skip_identities_fn=_skip_mul_1,
)
# softplus(x) - softplus(-x)
# Also not working
sub_uselusse_softplus = PatternSub(
(sub, (softplus, "x"), (softplus, (neg, "x"))),
"x",
values_eq_approx=values_eq_approx_remove_inf,
allow_multiple_clients=True,
# skip_identities_fn=_skip_mul_1,
)
# basic_opt.register_stabilize(logexp, name="logexp")
basic_opt.register_stabilize(neg_log1p_neg_sigmoid, name='neg_log1p_neg_sigmoid')
basic_opt.register_stabilize(_1msigm, name="1msigm")
basic_opt.register_stabilize(sigm_over_1msigm, name="sigm_over_1msigm")
basic_opt.register_stabilize(sigm_over_sigm_negx, name="sigm_over_sigm_negx")
basic_opt.register_stabilize(sub_uselusse_softplus, name="sub_uselusse_softplus")
x = at.vector('x')
# Testing conversion log(sigmoid(x)) -> softplus(x),
# All work with default sigmoid, but some fail with implicit sigmoid2
print('Testing: sigmoid -> softflus')
y = at.log(sigmoid2(x)) # works
optimize_graphs(y)
y = at.log(1 - sigmoid2(x)) # fails without default sigmoid
optimize_graphs(y)
y = at.log1p(-sigmoid2(x)) # fails without default sigmoid
optimize_graphs(y)
y = -at.log1p(-sigmoid(x)) # negative makes it fail, even with custom obt added above
optimize_graphs(y)
# Testing conversion 1 - sigmoid(x) -> sigmoid(-x)
# This was added in sigm.py but disabled due to conflicts with other operations
print('\nTesting: 1 - sigmoid(x) -> sigmoid(-x)')
y = 1 - sigmoid(x) # works
optimize_graphs(y)
y = 1 - sigmoid2(x) # works if skip_identities_fn=_skip_mul_1,
optimize_graphs(y)
print('\nTesting: sigmoid(x) / sigmoid(-x) -> exp(x)')
# Testing conversion sigmoid(x) / (1 - sigmoid(x)) -> exp(x)
# Or equivalently sigmoid(x) / (sigmoid(-x) -> exp(x)
y = sigmoid(x) / (1 - sigmoid(x)) # works, if _1msigm is enabled
optimize_graphs(y)
y = sigmoid(x) / (sigmoid(-x)) # fails
optimize_graphs(y)
y = sigmoid2(x) / (1 - sigmoid2(x)) # fails
optimize_graphs(y)
y = sigmoid2(x) / (sigmoid2(-x)) # fails even more somehow
optimize_graphs(y)
print('\nTesting: invlogit(sigmoid(x)) -> x')
y = logit(sigmoid(x)) # works
optimize_graphs(y)
y = logit2(sigmoid(x)) # fails,
optimize_graphs(y)
y = logit3(sigmoid(x)) # works
optimize_graphs(y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment