Skip to content

Instantly share code, notes, and snippets.

@javipus
Created August 22, 2018 20:30
Show Gist options
  • Save javipus/ba149aa7f296a16900bf7edc7a4f1e09 to your computer and use it in GitHub Desktop.
Save javipus/ba149aa7f296a16900bf7edc7a4f1e09 to your computer and use it in GitHub Desktop.
from sympy import simplify, Pow
from sympy.polys import Poly, domains, polyroots
from sympy.assumptions import assuming, ask, Q
import signal
from functools import wraps
# Exceptions #
class GaloisOverflow(Exception):
pass
class TimeOutError(Exception):
pass
# Helpers #
def timeout(seconds = 30, message = 'Function call took too long!'):
def decorator(f):
def _timeout_handler(signum, frame):
raise TimeOutError(message)
def wrapper(*args, **kwds):
signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(seconds)
try:
result = f(*args, **kwds)
finally:
signal.alarm(0)
return result
return wrapper
return decorator
def _solver(deg):
"""
Helper function to determine what routine to use to find polynomial roots as a function of degree.
"""
if deg == 1:
return polyroots.roots_linear
elif deg == 2:
return polyroots.roots_quadratic
elif deg == 3:
return polyroots.roots_cubic
elif deg == 4:
return polyroots.roots_quartic
else:
raise GaloisOverflow('Degree must be at most 4!')
# Main function #
# TODO prevent expr.subs(x, root) from doing floating point evaluation - I want to keep a sympy expression
@timeout()
def intersect(*args, x = None, interval = None, removeComplex = True, doAssume = [], solver = None, **kwds):
"""
Calculate pairwise intersection of a family of curves described by polynomials of degree <= 4.
@param args: Sympy expressions. Must be polynomials of degree <= 4.
@param x: Sympy symbol. Independent variable of all polynomials. If None, polynomial expressions must be of type Poly.
@param interval: Return solutions only in given interval.
- If tuple (a, b), consider the open interval a < x < b.
- If list [a, b], consider the closed interval a <= x <= b.
- If None, return solutions in all R.
- Half-open intervals like [a, b) not supported for obvious reasons.
@param removeComplex: Only return real solutions. If False, it kinda defeats the purpose of a function called intersect, no?
@param doAssume: List of assumptions about the polynomial coefficients using the class AssumptionKeys, e.g. Q.positive(a), Q.real(b).
@param solver: Preferred solver to find roots. If None, the ones in polyroots are used, depending on the degree of the polynomial.
@param kwds: Keyword arguments to be passed to the solver.
@return List of tuples of the form (p, q, points) where points is the list of points (x, y) where p and q intersect, e.g. (x**2, x+1, [(1/2 + sqrt(5)/2, 3/2 + sqrt(5)/2), (1/2 - sqrt(5)/2, 3/2 - sqrt(5)/2)]). If no intersections are found, points is the empty list.
"""
N = len(args)
if N < 2:
raise TypeError('Need at least two curves to intersect!')
if not hasattr(doAssume, '__len__'):
doAssume = [doAssume]
ps = []
print('Pre-processing...')
for p in args:
try:
n = p.degree()
except AttributeError:
if x:
try:
coeffs = list(p.free_symbols - {x})
coeffs += [Pow(coeff, -1) for coeff in coeffs] # just in case a symbol is dividing
except AttributeError: # not a sympy expression - could be float, int, etc.
coeffs = None
p = Poly(p, x, domain = domains.RR[coeffs] if coeffs else domains.RR)
n = p.degree()
else:
raise TypeError('Need to pass a value for x if expressions are not of type Poly!')
if n > 4 and not solver:
raise GaloisOverflow('Degrees must be at most 4!')
ps.append(p)
print('Done!')
sols = []
print('\nIntersection begins:\n')
for i in range(N):
for j in range(i+1, N):
print('Intersecting p{} with p{}...'.format(i, j))
p, q = ps[i], ps[j]
if not solver:
deg = max(p.degree(), q.degree())
solver = _solver(deg)
x_star = solver(p-q, **kwds)
with assuming(*doAssume): # This check takes like forever :(
if removeComplex:
print('Filtering out complex roots...')
x_star = list(filter(lambda _x: ask(Q.real(_x)) in (True, None), x_star))
if interval:
print('Filtering out solutions outside of {}...'.format(interval))
if type(interval) == list:
cond = Q.nonnegative
elif type(interval) == tuple:
cond = Q.positive
_lower = lambda _x: ask(cond(_x - interval[0]))
_upper = lambda _x: ask(cond(interval[1] - _x))
x_star = list(filter(lambda _x: (_lower(_x) and _upper(_x)) in (True, None), x_star))
print('Evaluating solutions (if any)...')
xy_star = list(map(lambda _x: (simplify(_x), simplify(p.as_expr().subs(x, _x))), x_star))
sols.append((p.as_expr(), q.as_expr(), list(xy_star)))
print('Done!\n')
return sols
if __name__ == '__main__':
from sympy import symbols, init_printing
init_printing()
x, p, theta, u = symbols('x p theta u', real = True)
s, L, k = symbols('s L k', real = True) #, positive = True)
parabola = x - u + (p - theta) * s + .5 * k * s**2 / L
line = u + theta * s
res = intersect(parabola, line, x = s, interval = [0, L])
for intersection in res:
print('{} intersects {} at\n{}'.format(*intersection))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment