"""
Collection of functions for doing stuff with forms
"""

__author__ = "Marie E. Rognes (meg@simula.no)"
__copyright__ = "Copyright (C) 2009 - Marie E. Rognes"
__license__  = "GNU GPL version 3 or any later version"

# Modified by Anders Logg, 2009-2010
# Last changed: 2010-06-03

from ufl.algorithms import preprocess

from dolfin import VariationalProblem, DirichletBC
from dolfin import TestFunction, TrialFunction, MixedFunctionSpace, FunctionSpace
from dolfin import adjoint, derivative, lhs, rhs, replace
import dolfin.cpp as cpp


__all__ = ["tear", "glue", "construct_dual_form", "higher_order_problem",
           "increase_order", "increase_bc_order", "replace_arguments"]

def change_regularity(V, family):
    """
    For a given function space, return the corresponding space with
    the finite elements specified by 'family'. Possible families
    are the families supported by the form compiler
    """

    n = V.num_sub_spaces()
    if n > 0:
        return MixedFunctionSpace([change_regularity(V.sub(i), family)
                                   for i in range(n)])

    element = V.ufl_element()
    shape = element.value_shape()
    if not shape:
        return FunctionSpace(V.mesh(), family, element.degree())

    return MixedFunctionSpace([FunctionSpace(V.mesh(), family, element.degree())
                               for i in range(shape[0])])

def construct_dual_form(F, goal, u=None):
    """
    Assume that F and M are linear if u is None.

    Linear case:     a*(v, w) = M(v)
    Nonlinear case:  a'*(v, w; u) = M'(v; u)

    Although the linear case is a special case of the non-linear
    case, the code differs: For the linear case, u needs not be
    specified.
    """

    if u is None:
        # Linear case
        a_star = adjoint(lhs(F))

        # Extract form data
        processed_goal = preprocess(goal)
        processed_a_star = preprocess(a_star)

        # Replace basis function in goal by (original) test function
        # in a_star:
        u = processed_goal.form_data().original_arguments[0]
        v = processed_a_star.form_data().original_arguments[0]

        l_star = replace(goal, {u: v})

    else:
        # Nonlinear case
        du = TrialFunction(u.function_space())
        a_star = adjoint(derivative(F, u, du))

        preprocessed = preprocess(a_star)

        v = preprocessed.form_data().original_arguments[0]
        l_star = derivative(goal, u, v)

    return a_star - l_star

def glue(V_h):
    """
    For a given function space, return the corresponding continuous
    space
    """
    return change_regularity(V_h, "CG")

def higher_order_problem(problem, V_h, exterior_facets):
    """
    Increase order of the approximation in the form
    """

    A = increase_argument_order(problem.F)
    bcs = increase_bc_order(problem.bcs, V_h)

    # Return higher order pde
    pde = VariationalProblem(lhs(A), rhs(A), bcs,
                             exterior_facet_domains=exterior_facets)
    return pde

def replace_arguments(form, W_hs):
    """
    Replace basis functions in form by the corresponding basis
    functions on W_h
    """

    # Get ufl form and its basis functions
    preprocessed = preprocess(form)
    arguments = preprocessed.form_data().original_arguments

    if not isinstance(W_hs, list):
        W_hs = [W_hs for u in arguments]

    # Construct dictionary for replacing old basis functions by new
    # basis functions
    Basis = [TestFunction, TrialFunction]
    replace_dict = {}

    for i in range(len(arguments)):
        replace_dict[arguments[i]] = Basis[i](W_hs[i])

    # Return replaced form
    return replace(form, replace_dict)

def increase_argument_order(form):
    """
    For a given form of maximal arity 2, increase the polynomial
    degree of the *basis functions* involved.
    """

    # Get ufl form and its basis functions
    preprocessed = preprocess(form)
    arguments = preprocessed.form_data().original_arguments
    V_hs = [v.function_space() for v in arguments]

    # Create elements of increased polynomial degree
    W_hs = [increase_order(V) for V in V_hs]

    return replace_arguments(form, W_hs)

def increase_bc_order(bcs, V_h):
    """

    """
    if bcs is None: return None

    new_bcs = []
    for bc in bcs:

        if len(bc.domain_args) > 1:
            cpp.error("Unable to increase bc order for DirichletBC defined with MeshFunction")

        V_i = bc.function_space()
        component = V_i.component()
        if component.size() > 0:
            # FIXME: Component/Array issues to be resolved
            W = increase_order(V_h).sub(component[0])
        else:
            W = increase_order(V_i)
        new_bcs += [DirichletBC(W, bc.value(), bc.domain_args[0])]

    return new_bcs

def increase_order(V):
    """
    For a given function space, return the same space, but with a
    higher polynomial degree
    """

    n = V.num_sub_spaces()
    if n > 0:
        return MixedFunctionSpace([increase_order(V.sub(i)) for i in range(n)])

    if V.ufl_element().family() == "Real":
        return FunctionSpace(V.mesh(), "Real", 0)

    return FunctionSpace(V.mesh(), V.ufl_element().family(),
                         V.ufl_element().degree() + 1)

def tear(V_h):
    """
    For a given function space, return the corresponding discontinuous
    space
    """
    W = change_regularity(V_h, "DG")
    return W


def extract_mesh(form):

    # Get ufl form and its basis functions
    preprocessed = preprocess(form)
    arguments = preprocessed.form_data().original_arguments
    mesh = arguments[0].function_space().mesh()
    return mesh
