Skip to content

Instantly share code, notes, and snippets.

@matthew-brett
Last active August 29, 2015 14:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save matthew-brett/0596950032608d32f247 to your computer and use it in GitHub Desktop.
Save matthew-brett/0596950032608d32f247 to your computer and use it in GitHub Desktop.
Partial port of mpmath hyp1f1 function
""" mpmath hyp1f1 algorithm crudely ported """
import math
import cmath
import operator
import numpy as np
import scipy.special as sps
eps = np.finfo(float).eps
H_FACTOR = np.ldexp(1.0, -int(53 * 0.3))
class NoConvergence(Exception):
pass
def hypsum(p, q, coeffs, z, maxterms=6000):
coeffs = list(coeffs)
num = range(p)
den = range(p,p+q)
tol = eps
s = t = 1.0
k = 0
while True:
for i in num: t *= (coeffs[i]+k)
for i in den: t /= (coeffs[i]+k)
k += 1; t /= k; t *= z; s += t
if abs(t) < tol:
return s
if k > maxterms:
raise NoConvergence
def convert(x):
try:
return float(x)
except:
return complex(x)
def mag(z):
if z:
return np.frexp(abs(z))[1]
return -np.inf
def isint(z):
if z.imag:
return False
z = z.real
try:
return z == int(z)
except:
return False
def expjpi(x):
return exp(1j * np.pi * x)
def exp(x):
if type(x) is float:
return math.exp(x)
if type(x) is complex:
return cmath.exp(x)
try:
x = float(x)
return math.exp(x)
except (TypeError, ValueError):
x = complex(x)
return cmath.exp(x)
def power(*args):
try:
return operator.pow(*(float(x) for x in args))
except (TypeError, ValueError):
return operator.pow(*(complex(x) for x in args))
def fneg(x):
return -convert(x)
def isnpint(x):
if type(x) is complex:
if x.imag:
return False
x = x.real
return x <= 0.0 and round(x) == x
def nint_distance(z):
n = round(z.real)
if n == z:
return n, -np.inf
return n, mag(abs(z-n))
def _check_need_perturb(terms, discard_known_zeros):
perturb = False
discard = []
for term_index, term in enumerate(terms):
w_s, c_s, alpha_s, beta_s, a_s, b_s, z = term
have_singular_nongamma_weight = False
# Avoid division by zero in leading factors (TODO:
# also check for near division by zero?)
for k, w in enumerate(w_s):
if not w:
if np.real(c_s[k]) <= 0 and c_s[k]:
perturb = True
have_singular_nongamma_weight = True
pole_count = [0, 0, 0]
# Check for gamma and series poles and near-poles
for data_index, data in enumerate([alpha_s, beta_s, b_s]):
for i, x in enumerate(data):
n, d = nint_distance(x)
# Poles
if n > 0:
continue
if d == -np.inf:
# OK if we have a polynomial
# ------------------------------
if data_index == 2:
for u in a_s:
if isnpint(u) and u >= int(n):
break
else:
pole_count[data_index] += 1
if (discard_known_zeros and
pole_count[1] > pole_count[0] + pole_count[2] and
not have_singular_nongamma_weight):
discard.append(term_index)
elif sum(pole_count):
perturb = True
return perturb, discard
def hypercomb(function, params=[], discard_known_zeros=True):
params = params[:]
terms = function(*params)
perturb, discard = _check_need_perturb(terms, discard_known_zeros)
if perturb:
h = H_FACTOR
for k in range(len(params)):
params[k] += h
# Heuristically ensure that the perturbations
# are "independent" so that two perturbations
# don't accidentally cancel each other out
# in a subtraction.
h += h/(k+1)
terms = function(*params)
if discard_known_zeros:
terms = [term for (i, term) in enumerate(terms) if i not in discard]
if not terms:
return 0.
evaluated_terms = []
for term_index, term_data in enumerate(terms):
w_s, c_s, alpha_s, beta_s, a_s, b_s, z = term_data
# Always hyp2f0
assert len(a_s) == 2
assert len(b_s) == 0
v = np.prod([hypsum(2, 0, a_s, z)] + \
[sps.gamma(a) for a in alpha_s] + \
[sps.rgamma(b) for b in beta_s] + \
[power(w, c) for (w,c) in zip(w_s,c_s)])
evaluated_terms.append(v)
if len(terms) == 1 and (not perturb):
return evaluated_terms[0]
sumvalue = sum(evaluated_terms)
return sumvalue
def hyp1f1(a, b, z):
z = convert(z)
if not z:
return 1.0 + z
magz = mag(z)
if magz >= 7 and not (isint(a) and np.real(a) <= 0):
if np.isinf(z):
if (np.sign(a) == np.sign(b) == np.sign(z) == 1):
return np.inf
return np.nan * z
try:
sector = np.imag(z) < 0
def h(a,b):
if sector:
E = expjpi(fneg(a))
else:
E = expjpi(a)
rz = 1/z
T1 = ([E,z], [1,-a], [b], [b-a], [a, 1+a-b], [], -rz)
T2 = ([exp(z),z], [1,a-b], [b], [a], [b-a, 1-a], [], rz)
return T1, T2
v = hypercomb(h, [a,b])
if np.isrealobj(a) and np.isrealobj(b) and np.isrealobj(z):
v = np.real(v)
return v
except NoConvergence:
pass
v = hypsum(1, 1, [a, b], z)
return v
@samuelstjean
Copy link

Equivalently crude cython port, so much python stuff that we can't change though.

""" mpmath hyp1f1 algorithm crudely ported """

from __future__ import division

import math
import cmath
import operator

import numpy as np
cimport numpy as cnp
cimport cython

import scipy.special as sps

DEF eps = 1e-8
DEF H_FACTOR = 3.0517578125e-05
DEF pi = 3.141592653589793

class NoConvergence(Exception):
    raise ValueError("Hypergeometric serie did not converge :(")


cdef double hypsum(int p, int q, coeffs, double z, int maxterms=6000):

    cdef:
        int k = 0, i
        double s = 1., t = 1.

    coeffs = list(coeffs)
    num = range(p)
    den = range(p,p+q)

    while True:

        for i in num: t *= (coeffs[i]+k)
        for i in den: t /= (coeffs[i]+k)

        k += 1; t /= k; t *= z; s += t

        if abs(t) < eps:
            return s

        if k > maxterms:
            raise NoConvergence


cdef double mag(z):
        return np.frexp(abs(z))[1]


cdef expjpi(x):
    return exp(1j * pi * x)


cdef exp(x):
    if type(x) is float:
        return math.exp(x)
    if type(x) is complex:
        return cmath.exp(x)


def power(*args):
    try:
        return operator.pow(*(float(x) for x in args))
    except (TypeError, ValueError):
        return operator.pow(*(complex(x) for x in args))


cdef isnpint(x):
    if type(x) is complex:
        if x.imag:
            return False
        x = x.real
    return x <= 0.0 and round(x) == x


cdef nint_distance(z):
    cdef int n
    n = round(z.real)
    if n == z:
        return n, -np.inf
    return n, mag(abs(z-n))


cdef _check_need_perturb(terms, int discard_known_zeros):

    cdef:
        int perturb, have_singular_nongamma_weight, n
        double d

    perturb = False
    discard = []
    for term_index, term in enumerate(terms):
        w_s, c_s, alpha_s, beta_s, a_s, b_s, z = term
        have_singular_nongamma_weight = False
        # Avoid division by zero in leading factors (TODO:
        # also check for near division by zero?)
        for k, w in enumerate(w_s):
            if not w:
                if np.real(c_s[k]) <= 0 and c_s[k]:
                    perturb = True
                    have_singular_nongamma_weight = True
        pole_count = [0, 0, 0]
        # Check for gamma and series poles and near-poles
        for data_index, data in enumerate([alpha_s, beta_s, b_s]):
            for i, x in enumerate(data):
                n, d = nint_distance(x)
                # Poles
                if n > 0:
                    continue
                if d == -np.inf:
                    # OK if we have a polynomial
                    # ------------------------------
                    if data_index == 2:
                        for u in a_s:
                            if isnpint(u) and u >= int(n):
                                break
                    else:
                        pole_count[data_index] += 1
        if (discard_known_zeros and
            pole_count[1] > pole_count[0] + pole_count[2] and
            not have_singular_nongamma_weight):
            discard.append(term_index)
        elif sum(pole_count):
            perturb = True
    return perturb, discard


cdef  double hypercomb(function, params=[], int discard_known_zeros=True):
    cdef:
        int discard, perturb
        double h, sumvalue

    params = params[:]
    terms = function(*params)
    perturb, discard =  _check_need_perturb(terms, discard_known_zeros)
    if perturb:
        h = H_FACTOR
        for k in range(len(params)):
            params[k] += h
            # Heuristically ensure that the perturbations
            # are "independent" so that two perturbations
            # don't accidentally cancel each other out
            # in a subtraction.
            h += h/(k+1)
        terms = function(*params)
    if discard_known_zeros:
        terms = [term for (i, term) in enumerate(terms) if i not in discard]
    if not terms:
        return 0.
    evaluated_terms = []
    for term_index, term_data in enumerate(terms):
        w_s, c_s, alpha_s, beta_s, a_s, b_s, z = term_data
        # Always hyp2f0
        assert len(a_s) == 2
        assert len(b_s) == 0
        v = np.prod([hypsum(2, 0, a_s, z)] + \
            [sps.gamma(a) for a in alpha_s] + \
            [sps.rgamma(b) for b in beta_s] + \
            [power(w, c) for (w,c) in zip(w_s,c_s)])
        evaluated_terms.append(v)

    if len(terms) == 1 and (not perturb):
        return evaluated_terms[0]

    sumvalue = sum(evaluated_terms)
    return sumvalue


cdef double _hyp1f1(double a, int b, double z):

    cdef:
        double magz, rz, v

    magz = mag(z)

    if magz >= 7:
        try:

            def h(a,b):
                E = expjpi(a)
                rz = 1./z
                T1 = ([E,z], [1,-a], [b], [b-a], [a, 1+a-b], [], -rz)
                T2 = ([exp(z),z], [1,a-b], [b], [a], [b-a, 1-a], [], rz)
                return T1, T2

            v = hypercomb(h, [a,b])
            return np.real(v)

        except NoConvergence:
            pass

    return hypsum(1, 1, [a, b], z)


def  hyp1f1(a, b, z):
    return _hyp1f1(a, b, z)

@samuelstjean
Copy link

The code still has issues (as pointed in the mpmath blog http://fredrikj.net/blog/2009/09/python-floats-and-other-unusual-things-spotted-in-mpmath/), this is inherently due to floating point precision sadly.

mpmath:
mp.hyp1f1(2.5, 1.2, -30.5)
6.62762709628679e-5
fp.hyp1f1(2.5, 1.2, -30.5)
-0.012819333651375751

This version :
In [9]: hyp1f1(2.5,1.2,-30.5)
Out[9]: -0.01281933365137575

scipy :
In [2]: hyp1f1(2.5,1.2,-30.5)
Out[2]: 6.6276270962867717e-05

So... it might no happen in our restricted usecase (first argument is always 0.5 or -1.5 I think, something close to that, but who knows), but giving that to scipy might be harder to sell, since both versions exhibit complimentary issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment