#!/usr/bin/python
# vim: set fileencoding=utf-8 :
# -*- coding: utf-8 -*-
########################################################################
# Copyright (C) 2011 by Carlos Veiga Rodrigues. All rights reserved.
# author: Carlos Veiga Rodrigues <cvrodrigues@gmail.com>
#
# This program can be redistribuited and modified
# under the terms of the GNU Lesser General Public License
# as published by the Free Software Foundation,
# either version 3 of the License or any later version.
# This program is distributed WITHOUT ANY WARRANTY,
# without even the implied warranty of MERCHANTABILITY
# or FITNESS FOR A PARTICULAR PURPOSE.
# For more details consult the GNU Lesser General Public License
# at http://www.gnu.org/licenses/lgpl.html.
#
# ChangeLog (date - version - author):
# * December 2011 - 1.0 - Carlos Veiga Rodrigues
#
# Summary:
# Recursive bisection method for 2d convex quadrilaterals, which allows
# to capture more than one zero of the intersection of two functions,
# u(x,y) and v(x,y). Instead of the traditional bisection method,
# which refines the search of an zero and allows only to find one,
# this creates a recursive tree from the division of the original 2d element
# into 4, progressively eliminating form the search sub-elements that do not
# meet the requirements for housing a zero.
# This is based on the work of A. Globus, C. Levit, and T. Lasinski
# on finding critical points in flow topology, although it looks like
# they use a traditional bisection method for the 3D case only.
# (see http://alglobus.net/NASAwork/topology)
#
# Advantages:
# * More than one zero can be found with this method.
# Disadvantages:
# * There is no way to discriminate from a case with infinite solutions.
# * If a zero lies on the boundary between 2 or more elements, both will
#   return the same zero.
########################################################################

import sys
sys.setrecursionlimit (10000)

import numpy as np

########################################################################
## FUNCTIONS u (x, y) AND v (x, y)
########################################################################
def ufun (x, y):
        return x

def vfun (x, y):
        return y

########################################################################
## BILINEAR INTERPOLATION OF SPATIAL COORDINATES AND ELEMENT DIVISION
########################################################################
def bilin_interp (xcv, c):
        xp = (1-c[1])*( xcv[0,0]*(1-c[0]) + xcv[1,0]*c[0] )\
                +c[1]*( xcv[0,1]*(1-c[0]) + xcv[1,1]*c[0] )
        return xp

def divide_cv_in_4 (cv):
        ## INTERPOLATE NEW VERTEXES
        #fs = bilin_interp (cv, [.5, 0.])
        #fn = bilin_interp (cv, [.5, 1.])
        #fw = bilin_interp (cv, [0., .5])
        #fe = bilin_interp (cv, [1., .5])
        #fp = bilin_interp (cv, [.5, .5])
        fs = np.mean(cv[:,0])
        fn = np.mean(cv[:,1])
        fw = np.mean(cv[0,:])
        fe = np.mean(cv[1,:])
        fp = np.mean(cv)
        ## DEFINE NEW CV'S
        cvsw = np.array([[cv[0,0], fw], [fs, fp]], dtype=np.float64)
        cvse = np.array([[fs, fp], [cv[1,0], fe]], dtype=np.float64)
        cvnw = np.array([[fw, cv[0,1]], [fp, fn]], dtype=np.float64)
        cvne = np.array([[fp, fn], [fe, cv[1,1]]], dtype=np.float64)
        return cvsw, cvse, cvnw, cvne

########################################################################
## RECURSIVE TREE BISECTION 2D ALGORITHM
########################################################################
def rectree_bisection2d (xcv, ycv, ufun, vfun,\
        tol=1.E-15, maxiter=None):
        if (2,2)!=np.shape(xcv) or (2,2)!=np.shape(ycv):
                raise RuntimeError ("element shape != (2,2). Aborting")
        ## TEST 1
        def test_point_out (ucv, vcv):
                isout = (0>ucv).all() or (0<ucv).all()\
                        or (0>vcv).all() or (0<vcv).all()
                return isout
        ## TEST 2
        def test_point_vertex (ucv, vcv):
                return ((0==ucv) & (0==vcv)).any()
        ## TEST 3
        #def test_inf_sol (ucv, vcv):
        #       return (ucv/np.max(np.abs(ucv))==vcv/np.max(np.abs(vcv))).all()
        ## GET U, V
        ucv, vcv = ufun (xcv, ycv), vfun (xcv, ycv)
        ## TEST IF POINT IS INSIDE INITIAL CV
        if test_point_out (ucv, vcv):
                raise RuntimeError ("Bad initial condition. Point outside CV.")
        ## DEFINE RECURSIVE FUNCTION
        def recursive_testing (xcv, ycv, ucv, vcv, qsi, eta,\
                tol=1.E-15, iter=0, maxiter=None):
                iter+= 1
                found, ii, jj = [], [], []
                ## FUNCTION FOR LIST UPDATE
                def update_list (ii, i):
                        if type(i)==list:
                                ii.extend(i)
                        else:
                                ii.append(i)
                        return ii
                ## GET U, V
                ucv, vcv = ufun (xcv, ycv), vfun (xcv, ycv)
                ## TEST IF 0 IS IN VERTEXES
                if test_point_vertex (ucv, vcv):
                        if 0==ucv[0,0] and 0==vcv[0,0]:
                                ii = update_list (ii, qsi[0])
                                jj = update_list (jj, eta[0])
                        if 0==ucv[1,0] and 0==vcv[1,0]:
                                ii = update_list (ii, qsi[1])
                                jj = update_list (jj, eta[0])
                        if 0==ucv[0,1] and 0==vcv[0,1]:
                                ii = update_list (ii, qsi[0])
                                jj = update_list (jj, eta[1])
                        if 0==ucv[1,1] and 0==vcv[1,1]:
                                ii = update_list (ii, qsi[1])
                                jj = update_list (jj, eta[1])
                        return ii, jj, [3], iter
                ## TEST IF I REACHED MY LIMIT
                if None!=tol \
                and abs(np.diff(qsi))<=tol \
                and abs(np.diff(eta))<=tol:
                        ii = update_list (ii, np.mean(qsi))
                        jj = update_list (jj, np.mean(eta))
                        return ii, jj, [2], iter
                ## TEST IF MAXIMUM RECURSION LEVELS ARE REACHED
                if None!=maxiter and iter>=maxiter:
                        ii = update_list (ii, np.mean(qsi))
                        jj = update_list (jj, np.mean(eta))
                        return ii, jj, [1], iter
                ## ARRAYS WHERE TO STORE STUFF
                iiter = np.zeros([4], dtype=np.int)
                ## DIVIDE INTO 4 ELEMENTS
                xsw, xse, xnw, xne = divide_cv_in_4 (xcv)
                ysw, yse, ynw, yne = divide_cv_in_4 (ycv)
                ## GET U V IN EACH OF THE 4 ELEMENTS
                usw, vsw = ufun (xsw, ysw), vfun (xsw, ysw)
                use, vse = ufun (xse, yse), vfun (xse, yse)
                unw, vnw = ufun (xnw, ynw), vfun (xnw, ynw)
                une, vne = ufun (xne, yne), vfun (xne, yne)
                ## LAUNCH EACH OF MY ELEMENTS
                if not test_point_out (usw, vsw):
                        qsw = np.array([qsi[0], np.mean(qsi)])
                        esw = np.array([eta[0], np.mean(eta)])
                        i, j, ifound, iiter[0] = recursive_testing (\
                                xsw, ysw, usw, vsw, qsw, esw,\
                                tol, iter, maxiter)
                        if []!=ifound and np.max(ifound)>0:
                                ii = update_list (ii, i)
                                jj = update_list (jj, j)
                                found = update_list (found, ifound)
                if not test_point_out (use, vse):
                        qse = np.array([np.mean(qsi), qsi[1]])
                        ese = np.array([eta[0], np.mean(eta)])
                        i, j, ifound, iiter[1] = recursive_testing (\
                                xse, yse, use, vse, qse, ese,\
                                tol, iter, maxiter)
                        if []!=ifound and np.max(ifound)>0:
                                ii = update_list (ii, i)
                                jj = update_list (jj, j)
                                found = update_list (found, ifound)
                if not test_point_out (unw, vnw):
                        qnw = np.array([qsi[0], np.mean(qsi)])
                        enw = np.array([np.mean(eta), eta[1]])
                        i, j, ifound, iiter[2] = recursive_testing (\
                                xnw, ynw, unw, vnw, qnw, enw,\
                                tol, iter, maxiter)
                        if []!=ifound and np.max(ifound)>0:
                                ii = update_list (ii, i)
                                jj = update_list (jj, j)
                                found = update_list (found, ifound)
                if not test_point_out (une, vne):
                        qne = np.array([np.mean(qsi), qsi[1]])
                        ene = np.array([np.mean(eta), eta[1]])
                        i, j, ifound, iiter[3] = recursive_testing (\
                                xne, yne, une, vne, qne, ene,\
                                tol, iter, maxiter)
                        if []!=ifound and np.max(ifound)>0:
                                ii = update_list (ii, i)
                                jj = update_list (jj, j)
                                found = update_list (found, ifound)
                ## PROCESS ITERATIONS
                iter = np.maximum(iter, np.max(iiter))
                if []!=found and np.max(found)>0:
                        return ii, jj, found, iter
                else:
                        return [], [], [], iter
        ## RUN MY RECURSIVE FUNCTION
        qsi = np.array([0., 1.])
        eta = np.array([0., 1.])
        ii, jj, found, iter = recursive_testing (\
                xcv, ycv, ucv, vcv, qsi, eta, tol=tol, maxiter=maxiter)
        print "Iterations = %d" % iter
        if []!=found:
                found = np.array(found)
                print "zeros exact = %d" % np.size(np.flatnonzero(found==3))
                print "zeros tol   = %d" % np.size(np.flatnonzero(found==2))
                print "zeros iter  = %d" % np.size(np.flatnonzero(found==1))
                for (i, j, k) in zip(ii, jj, found):
                        xx = bilin_interp (xcv, [i, j])
                        yy = bilin_interp (ycv, [i, j])
                        print "%d : (%.4f , %.4f) -> (%.4f, %.4f)" % \
                                (k, i, j, xx, yy)
        return

########################################################################
## CODE TO TEST THE METHOD
########################################################################
def test_method (xcv, ycv, ufun, vfun, c, msg=None, maxiter=None):
        print "\n" + 50*"="
        if None!=msg:
                print msg
        if None!=c:
                xp = bilin_interp (xcv, c)
                yp = bilin_interp (ycv, c)
        else:
                xp, yp = 0, 0
        rectree_bisection2d (xcv-xp, ycv-yp,\
                ufun=ufun, vfun=vfun, maxiter=maxiter)
        if None!=c:
                print 33*"-"
                print "True=(%.4f , %.4f)" % tuple(c)
        return

########################################################################
## EXECUTE
########################################################################
if __name__ == '__main__':
        xcv = np.array([[0., 1.], [0., 1.]])
        ycv = np.array([[0., 0.], [1., 1.]])
        c = [0.11, 0.62]
        test_method (xcv, ycv, ufun, vfun, c, msg="POINT IN POLYGON TEST")
        #
        xcv = np.array([[0., 1.], [0.1, 1.2]])
        ycv = np.array([[0., 0.], [1., 1.]])
        c = [0.11, 0.62]
        test_method (xcv, ycv, ufun, vfun, c, msg="POINT IN POLYGON TEST")
        #
        xcv = np.array([[0., 1.], [0.1, 1.2]])
        ycv = np.array([[0., 0.], [1., 0.9]])
        c = [0.11, 0.62]
        test_method (xcv, ycv, ufun, vfun, c, msg="POINT IN POLYGON TEST")
        #
        xcv = np.array([[0., 1.], [0., 1.]])
        ycv = np.array([[0., 0.], [1., 1.]])
        c = [0.5, 0.5]
        test_method (xcv, ycv, ufun, vfun, c, msg="POINT IN POLYGON TEST")
        #
        ucv = np.array([[10., 10.], [-2., 10.]])
        vcv = np.array([[10., 10.], [-1., 10.]])
        test_method (ucv, vcv, ufun, vfun, None,\
                msg="NO ROOTS SHOULD EVER BE FOUND")
        #
        ucv = np.array([[6., -2.], [ 6.,  6.]])
        vcv = np.array([[8., -2.], [ 8., -2.]])
        test_method (ucv, vcv, ufun, vfun, None, msg="ONE SOLUTION")
        #
        ucv = np.array([[-2.,  2.], [ 2.,  2.]])
        vcv = np.array([[-1., -5.], [-5., 50.]])
        test_method (ucv, vcv, ufun, vfun, None, msg="TWO VALID SOLUTIONS")
        #
        ucv = np.array([[-2.,  2.], [ 2.,  2.]])
        vcv = np.array([[-1., -5.], [-5., 10.]])
        test_method (ucv, vcv, ufun, vfun, None, msg="TWO VALID SOLUTIONS")
        #
        ucv = np.array([[10., 10.], [-2., 10.]])
        vcv = np.array([[10., 10.], [-2., 10.]])
        test_method (ucv, vcv, ufun, vfun, None, msg="INFINITE SOLUTIONS",\
                maxiter=10)
        #
        ucv = np.array([[2., -2.], [ 6., 2.]])
        vcv = np.array([[4., -4.], [12., 4.]])
        test_method (ucv, vcv, ufun, vfun, None, msg="INFINITE SOLUTIONS",\
                maxiter=10)