"""This module provides the class VariationalProblem in Python. This
needs special handling and we cannot use the SWIG wrapper directly
since we need to call the JIT compiler."""

__author__ = "Anders Logg (logg@simula.no)"
__date__ = "2007-08-15 -- 2011-01-22"
__copyright__ = "Copyright (C) 2007-2008 Anders Logg"
__license__  = "GNU LGPL Version 2.1"

# Modified by Johan Hake, 2009
# Modified by Marie E. Rognes, 2011

__all__ = ["VariationalProblem"]

import types

from ufl.algorithms import extract_arguments

# Import SWIG-generated extension module (DOLFIN C++)
import dolfin
import dolfin.cpp as cpp

# Local imports
from dolfin.fem.form import *
from dolfin.fem.assemble import assemble
from dolfin.function.function import *
from dolfin.function.functionspace import *

class VariationalProblem(cpp.VariationalProblem):

    # Reuse doc-string from cpp.VariationalProblem
    __doc__ = cpp.VariationalProblem.__doc__

    def __init__(self, a,
                 L=None,
                 bcs=None,
                 cell_domains=None,
                 exterior_facet_domains=None,
                 interior_facet_domains=None,
                 form_compiler_parameters={}):
        "Define a variational problem."

        # To be implemented ...
        if L is None:
            raise dolfin.error("Not supporting only giving linear form, yet.")

        # Store input (ufl forms)
        self.a_ufl = a
        self.L_ufl = L

        # Wrap forms
        self.a = Form(a, form_compiler_parameters=form_compiler_parameters)
        self.L = Form(L, form_compiler_parameters=form_compiler_parameters)
        self.is_linear = self.a.rank() == 2

        # Check bcs
        if not isinstance(bcs, (types.NoneType, list, cpp.BoundaryCondition)):
            raise TypeError, \
                "expected a 'list' or a 'BoundaryCondition' as bcs argument"
        if bcs is None:
            self.bcs = []
        elif isinstance(bcs, cpp.BoundaryCondition):
            self.bcs = [bcs]
        else:
            self.bcs = bcs

        # Initialize base class
        cpp.VariationalProblem.__init__(self, self.a, self.L, self.bcs,
                                        cell_domains, exterior_facet_domains,
                                        interior_facet_domains)


    def solve(self, u=None, tolerance=None, goal=None):
        "Solve variational problem and return solution."

        # Extract trial space and create function for solution (If no
        # u is given, the problem must be linear.)
        if u is None:
            if not len(self.a.function_spaces) == 2:
                dolfin.error("Unable to extract trial space for solution of variational problem, is 'a' bilinear?")
            V = self.a.function_spaces[1]
            u = Function(V)

        # Solve non-adaptively and return solution
        if tolerance is None:
            cpp.VariationalProblem.solve(self, u)
            return u

        # Solve adaptively and return solution
        self.adaptive_solve(u, tolerance, goal)
        return u

    def adaptive_solve(self, u, tolerance, goal):

        # Check input
        if tolerance is None or goal is None:
            dolfin.error("Unable to perform adaptive solve without goal and tolerance. Did you forget to specify either?")

        # Generate forms used for error control
        ufl_forms = self.generate_error_control_forms(u, goal)

        # Compile generated forms
        forms = [Form(form) for form in ufl_forms]

        # Compile goal functional separately
        M = Form(goal)

        # Initialize cpp error control object
        forms += [self.is_linear]
        ec = cpp.ErrorControl(*forms)

        # Call cpp.VariationalProblem.solve
        cpp.VariationalProblem.solve(self, u, tolerance, M, ec)

    def generate_error_control_forms(self, u, goal):

        cpp.info("Generating forms required for error control.")

        # Extract primal variational forms (NB: Need ufl form)
        if self.is_linear:
            bilinear, linear = self.a_ufl, self.L_ufl
        else:
            bilinear, linear = self.L_ufl, self.a_ufl

        # Check forms
        forms = (bilinear, linear, goal)

        # Get FFC to generate error control forms
        from ffc.errorcontrol.formmanipulations import \
             (generate_dual_forms, generate_weak_residual,
              generate_cell_residual, generate_facet_residual,
              generate_error_indicator)

        # Use DOLFIN module
        module = __import__("dolfin.fem.formmanipulations")

        (a_star, L_star) = generate_dual_forms(forms, u, module)
        if self.is_linear:
            weak_residual = generate_weak_residual((bilinear, linear), u)
        else:
            weak_residual = generate_weak_residual(linear)

        V = u.function_space()
        mesh = V.mesh()

        # Generate extrapolation space by increasing order of trial space
        E = module.increase_order(V)
        Ez_h = Function(E)

        # Generate error estimate (residual) (# FIXME: Add option here)
        eta_h = dolfin.action(weak_residual, Ez_h)

        # Define approximation space for cell and facet residuals
        V_h = module.tear(V)

        # Generate cell residual stuff
        B = FunctionSpace(mesh, "B", mesh.topology().dim()+1)
        b_T = Function(B)
        b_T.vector()[:] = 1.0
        a_R_T, L_R_T = generate_cell_residual(weak_residual, V_h, b_T, module)

        # Generate facet residual stuff
        R_T = Function(V_h)
        C = FunctionSpace(mesh, "DG", mesh.topology().dim())
        b_e = Function(C)
        a_R_dT, L_R_dT = generate_facet_residual(weak_residual, V_h, b_e, R_T,
                                                 module)

        # Generate error indicators (# FIXME: Add option here)
        R_dT = Function(V_h)
        z_h = Function(extract_arguments(a_star)[1].function_space())
        q = TestFunction(FunctionSpace(mesh, "DG", 0))
        eta_T = generate_error_indicator(weak_residual, R_T, R_dT, Ez_h, z_h, q)

        forms = (a_star, L_star, eta_h, a_R_T, L_R_T, a_R_dT, L_R_dT, eta_T)

        return forms
