Skip to content

Instantly share code, notes, and snippets.

Created November 3, 2015 16:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save anonymous/7c6d7a5e80e1e5f6b39e to your computer and use it in GitHub Desktop.
Save anonymous/7c6d7a5e80e1e5f6b39e to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
# $ conda create -n testenv python=2.7
# $ source activate testenv
# $ conda install -c bjodah symengine-python
# $ conda install pytest
# $ py.test
from symengine import Lambdify
def lambdify(args, exprs):
try:
len(args)
except TypeError:
args = [args]
try:
len(exprs)
except TypeError:
exprs = [exprs]
lmb = Lambdify(args, exprs)
def f(*inner_args):
if len(inner_args) != len(args):
raise TypeError("Incorrect number of arguments")
return lmb(inner_args)
return f
from lambdify import lambdify
from sympy.utilities.pytest import XFAIL, raises
from sympy import (
symbols, sqrt, sin, cos, tan, pi, acos, acosh, Rational,
Float, Matrix, Lambda, Piecewise, exp, Integral, oo, I, Abs, Function,
true, false, And, Or, Not, ITE)
from sympy.printing.lambdarepr import LambdaPrinter
import mpmath
#from sympy.utilities.lambdify import implemented_function
from sympy.utilities.pytest import skip
from sympy.utilities.decorator import conserve_mpmath_dps
from sympy.external import import_module
import math
import sympy
MutableDenseMatrix = Matrix
numpy = import_module('numpy')
w, x, y, z = symbols('w,x,y,z')
#================== Test different arguments =======================
# def test_no_args():
# f = lambdify([], 1)
# raises(TypeError, lambda: f(-1))
# assert f() == 1
def test_single_arg():
f = lambdify(x, 2*x)
assert f(1) == 2
def test_list_args():
f = lambdify([x, y], x + y)
assert f(1, 2) == 3
def test_sin():
f = lambdify(x, sin(x))
assert f(0) == 0.0
#================== Test some functions ============================
def test_exponentiation():
f = lambdify(x, x**2)
assert f(-1) == 1
assert f(0) == 0
assert f(1) == 1
assert f(-2) == 4
assert f(2) == 4
assert f(2.5) == 6.25
def test_sqrt():
f = lambdify(x, sqrt(x))
assert f(0) == 0.0
assert f(1) == 1.0
assert f(4) == 2.0
assert abs(f(2) - 1.414) < 0.001
assert f(6.25) == 2.5
def test_trig():
f = lambdify([x], [cos(x), sin(x)])
d = f(pi)
prec = 1e-11
assert -prec < d[0] + 1 < prec
assert -prec < d[1] < prec
d = f(3.14159)
prec = 1e-5
assert -prec < d[0] + 1 < prec
assert -prec < d[1] < prec
#================== Test vectors ===================================
def test_vector_simple():
f = lambdify((x, y, z), (z, y, x))
assert numpy.allclose(f(3, 2, 1), (1, 2, 3))
assert numpy.allclose(f(1.0, 2.0, 3.0), (3.0, 2.0, 1.0))
# make sure correct number of args required
raises(TypeError, lambda: f(0))
# def test_vector_discontinuous():
# f = lambdify(x, (-1/x, 1/x))
# raises(ZeroDivisionError, lambda: f(0))
# assert f(1) == (-1.0, 1.0)
# assert f(2) == (-0.5, 0.5)
# assert f(-2) == (0.5, -0.5)
def test_trig_symbolic():
f = lambdify([x], [cos(x), sin(x)])
d = f(pi)
assert abs(d[0] + 1) < 0.0001
assert abs(d[1] - 0) < 0.0001
def test_trig_float():
f = lambdify([x], [cos(x), sin(x)])
d = f(3.14159)
assert abs(d[0] + 1) < 0.0001
assert abs(d[1] - 0) < 0.0001
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment