Skip to content

Instantly share code, notes, and snippets.

@matthew-brett
Last active August 29, 2015 14:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save matthew-brett/0596950032608d32f247 to your computer and use it in GitHub Desktop.
Save matthew-brett/0596950032608d32f247 to your computer and use it in GitHub Desktop.
Partial port of mpmath hyp1f1 function
""" mpmath hyp1f1 algorithm crudely ported """
import math
import cmath
import operator
import numpy as np
import scipy.special as sps
eps = np.finfo(float).eps
H_FACTOR = np.ldexp(1.0, -int(53 * 0.3))
class NoConvergence(Exception):
pass
def hypsum(p, q, coeffs, z, maxterms=6000):
coeffs = list(coeffs)
num = range(p)
den = range(p,p+q)
tol = eps
s = t = 1.0
k = 0
while True:
for i in num: t *= (coeffs[i]+k)
for i in den: t /= (coeffs[i]+k)
k += 1; t /= k; t *= z; s += t
if abs(t) < tol:
return s
if k > maxterms:
raise NoConvergence
def convert(x):
try:
return float(x)
except:
return complex(x)
def mag(z):
if z:
return np.frexp(abs(z))[1]
return -np.inf
def isint(z):
if z.imag:
return False
z = z.real
try:
return z == int(z)
except:
return False
def expjpi(x):
return exp(1j * np.pi * x)
def exp(x):
if type(x) is float:
return math.exp(x)
if type(x) is complex:
return cmath.exp(x)
try:
x = float(x)
return math.exp(x)
except (TypeError, ValueError):
x = complex(x)
return cmath.exp(x)
def power(*args):
try:
return operator.pow(*(float(x) for x in args))
except (TypeError, ValueError):
return operator.pow(*(complex(x) for x in args))
def fneg(x):
return -convert(x)
def isnpint(x):
if type(x) is complex:
if x.imag:
return False
x = x.real
return x <= 0.0 and round(x) == x
def nint_distance(z):
n = round(z.real)
if n == z:
return n, -np.inf
return n, mag(abs(z-n))
def _check_need_perturb(terms, discard_known_zeros):
perturb = False
discard = []
for term_index, term in enumerate(terms):
w_s, c_s, alpha_s, beta_s, a_s, b_s, z = term
have_singular_nongamma_weight = False
# Avoid division by zero in leading factors (TODO:
# also check for near division by zero?)
for k, w in enumerate(w_s):
if not w:
if np.real(c_s[k]) <= 0 and c_s[k]:
perturb = True
have_singular_nongamma_weight = True
pole_count = [0, 0, 0]
# Check for gamma and series poles and near-poles
for data_index, data in enumerate([alpha_s, beta_s, b_s]):
for i, x in enumerate(data):
n, d = nint_distance(x)
# Poles
if n > 0:
continue
if d == -np.inf:
# OK if we have a polynomial
# ------------------------------
if data_index == 2:
for u in a_s:
if isnpint(u) and u >= int(n):
break
else:
pole_count[data_index] += 1
if (discard_known_zeros and
pole_count[1] > pole_count[0] + pole_count[2] and
not have_singular_nongamma_weight):
discard.append(term_index)
elif sum(pole_count):
perturb = True
return perturb, discard
def hypercomb(function, params=[], discard_known_zeros=True):
params = params[:]
terms = function(*params)
perturb, discard = _check_need_perturb(terms, discard_known_zeros)
if perturb:
h = H_FACTOR
for k in range(len(params)):
params[k] += h
# Heuristically ensure that the perturbations
# are "independent" so that two perturbations
# don't accidentally cancel each other out
# in a subtraction.
h += h/(k+1)
terms = function(*params)
if discard_known_zeros:
terms = [term for (i, term) in enumerate(terms) if i not in discard]
if not terms:
return 0.
evaluated_terms = []
for term_index, term_data in enumerate(terms):
w_s, c_s, alpha_s, beta_s, a_s, b_s, z = term_data
# Always hyp2f0
assert len(a_s) == 2
assert len(b_s) == 0
v = np.prod([hypsum(2, 0, a_s, z)] + \
[sps.gamma(a) for a in alpha_s] + \
[sps.rgamma(b) for b in beta_s] + \
[power(w, c) for (w,c) in zip(w_s,c_s)])
evaluated_terms.append(v)
if len(terms) == 1 and (not perturb):
return evaluated_terms[0]
sumvalue = sum(evaluated_terms)
return sumvalue
def hyp1f1(a, b, z):
z = convert(z)
if not z:
return 1.0 + z
magz = mag(z)
if magz >= 7 and not (isint(a) and np.real(a) <= 0):
if np.isinf(z):
if (np.sign(a) == np.sign(b) == np.sign(z) == 1):
return np.inf
return np.nan * z
try:
sector = np.imag(z) < 0
def h(a,b):
if sector:
E = expjpi(fneg(a))
else:
E = expjpi(a)
rz = 1/z
T1 = ([E,z], [1,-a], [b], [b-a], [a, 1+a-b], [], -rz)
T2 = ([exp(z),z], [1,a-b], [b], [a], [b-a, 1-a], [], rz)
return T1, T2
v = hypercomb(h, [a,b])
if np.isrealobj(a) and np.isrealobj(b) and np.isrealobj(z):
v = np.real(v)
return v
except NoConvergence:
pass
v = hypsum(1, 1, [a, b], z)
return v
@samuelstjean
Copy link

The code still has issues (as pointed in the mpmath blog http://fredrikj.net/blog/2009/09/python-floats-and-other-unusual-things-spotted-in-mpmath/), this is inherently due to floating point precision sadly.

mpmath:
mp.hyp1f1(2.5, 1.2, -30.5)
6.62762709628679e-5
fp.hyp1f1(2.5, 1.2, -30.5)
-0.012819333651375751

This version :
In [9]: hyp1f1(2.5,1.2,-30.5)
Out[9]: -0.01281933365137575

scipy :
In [2]: hyp1f1(2.5,1.2,-30.5)
Out[2]: 6.6276270962867717e-05

So... it might no happen in our restricted usecase (first argument is always 0.5 or -1.5 I think, something close to that, but who knows), but giving that to scipy might be harder to sell, since both versions exhibit complimentary issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment