///
/// This file is part of Rheolef.
///
/// Copyright (C) 2000-2009 Pierre Saramito <Pierre.Saramito@imag.fr>
///
/// Rheolef is free software; you can redistribute it and/or modify
/// it under the terms of the GNU General Public License as published by
/// the Free Software Foundation; either version 2 of the License, or
/// (at your option) any later version.
///
/// Rheolef is distributed in the hope that it will be useful,
/// but WITHOUT ANY WARRANTY; without even the implied warranty of
/// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
/// GNU General Public License for more details.
///
/// You should have received a copy of the GNU General Public License
/// along with Rheolef; if not, write to the Free Software
/// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
///
/// =========================================================================
//
// form(X,Y,"A") : bilinear form A: X x Y --> R
// and it's associated operator  A : X --> Y'
//
// HOW TO ADD A NEW FORM:
//  a) create elementary matrix (see form_connectivity.c)
//  b) declare enum type in "form.h"
//  c) add string/enum association here in enum_and_string array
//  d) declare extern assembly algo here
//
// authors: Pierre.Saramito@imag.fr
//
// date: 7 july 1997
//
#include "rheolef/form.h"
#include "rheolef/field.h"
#include "rheolef/blas2.h"
#include "rheolef/csr-algo-amulx.h"

#include "form_assembly.h"
using namespace std;
namespace rheolef { 

string
get_form_approx_name (const space&  X) 
{
    form::size_type dim = X.dimension();
    string x_approx     = X.get_approx();
    string valued       = X.get_valued();
    if (valued == "vector") {
        x_approx = "v" + x_approx;
    } else if (valued == "tensor") {
        x_approx = "t" + x_approx;
    } else if (valued == "unsymmetric_tensor") {
        x_approx = "u" + x_approx;
    } else if (valued != "scalar") {
	error_macro ("unexpected valued space `" << X.get_valued() << "'");
    }
    return x_approx;
}
static
form_element
get_function (const string& name, const space&  X, const space&  Y)
{
    return form_element(name);
}
// constructor: element loop and elementary matrix calls inlined
form::form (const space& X, const space& Y, const string& name, bool locked_boundaries)
: X_(X), Y_(Y), _for_locked_boundaries(locked_boundaries), uu(), ub(), bu(), bb()
{
    X_.freeze();
    Y_.freeze();
    form_element form_e = get_function(name, X, Y);

    if (!X.is_on_boundary_domain() && !Y.is_on_boundary_domain()) {
       form_assembly (*this, form_e, locked_boundaries);
    } else {
       check_macro(!locked_boundaries, 
       		"Boundary forms not available with spaces with locked component");
       // at least one bdr domain
       if (X.is_on_boundary_domain()) {
           form_hybrid_assembly (*this, form_e, X.get_boundary_domain());
       } else {
           form_hybrid_assembly (*this, form_e, Y.get_boundary_domain());
       }
    }
}
// weighted constructor
form::form (const space& X, const space& Y, const string& name, const field& wh,
	bool use_coordinate_system_weight)
: X_(X), Y_(Y), _for_locked_boundaries(false), uu(), ub(), bu(), bb()
{
    X_.freeze();
    Y_.freeze();
    string fname = name;
    if (name == "convect_bdr")
      fname = "convect";
    form_element form_e = get_function(fname, X, Y);
    form_e.set_weight (wh);
    form_e.set_use_coordinate_system_weight(use_coordinate_system_weight);
    if (name == "convect_bdr") {
       form_assembly_discontinuous_galerkin_bdr (*this, form_e);
    } else if (name == "convect") {
       form_assembly_discontinuous_galerkin (*this, form_e);
    } else if (!X.is_on_boundary_domain() && !Y.is_on_boundary_domain()) {
       form_assembly (*this, form_e);
    } else {
       // at least one bdr domain
       if (X.is_on_boundary_domain()) {
           form_hybrid_assembly (*this, form_e, X.get_boundary_domain());
       } else {
           form_hybrid_assembly (*this, form_e, Y.get_boundary_domain());
       }
    }
}
// constructor on domain
form::form (const space& X, const space& Y, const string& name, const domain& d)
: X_(X), Y_(Y), _for_locked_boundaries(false), uu(), ub(), bu(), bb()
{
    X_.freeze();
    Y_.freeze();
    form_element form_e = get_function(name, X, Y);
    form_bdr_assembly (*this, form_e, d);
}
// weighted constructor on domain
form::form (const space& X, const space& Y, const string& name, const domain& d,
	const field& wh, bool use_coordinate_system_weight)
: X_(X), Y_(Y), _for_locked_boundaries(false), uu(), ub(), bu(), bb()
{
    X_.freeze();
    Y_.freeze();
    form_element form_e = get_function(name, X, Y);
    form_e.set_weight (wh);
    form_e.set_use_coordinate_system_weight(use_coordinate_system_weight);
    form_bdr_assembly (*this, form_e, d);
}
form::form (const form_diag& d) 
: X_(d.get_space()), Y_(d.get_space()),
  _for_locked_boundaries(false),
  uu(), ub(), bu(), bb()
{
    uu = csr<Float>(d.uu);
    bb = csr<Float>(d.bb);
    size_t nu = d.uu.nrow();
    size_t nb = d.bb.nrow();
    ub.resize (nu, nb, 0);
    bu.resize (nb, nu, 0);
}
//
// linear algebra
//
Float
form::operator () (const field& u, const field& v) const
{
    return dot((*this)*u, v);
}
form 
trans(const form& f)  
{
  form ft(f.get_second_space(),f.get_first_space()) ;
  ft.uu = trans(f.uu) ; ft.bb = trans(f.bb) ;
  ft.bu = trans(f.ub) ; ft.ub = trans(f.bu) ;
  return ft ;
}
form 
operator + (const form& f_1, const form& f_2)
{
  form f_3(f_1.get_first_space(),f_1.get_second_space()) ; // verifier que f_1 et f_2 ont les m^emes espaces

  f_3.uu = f_1.uu + f_2.uu ; f_3.ub = f_1.ub + f_2.ub ;
  f_3.bu = f_1.bu + f_2.bu ; f_3.bb = f_1.bb + f_2.bb ;

  return f_3 ;
}
form
operator - (const form& f)
{
  form g(f.get_first_space(),f.get_second_space());
  g.uu = -1.*f.uu; g.ub = -1.*f.ub;
  g.bu = -1.*f.bu; g.bb = -1.*f.bb;
  return g;
}
form 
operator - (const form& f_1, const form& f_2)
{
  form f_3(f_1.get_first_space(),f_1.get_second_space()) ; // verifier que f_1 et f_2 ont les m^emes espaces

  f_3.uu = f_1.uu - f_2.uu ; f_3.ub = f_1.ub - f_2.ub ;
  f_3.bu = f_1.bu - f_2.bu ; f_3.bb = f_1.bb - f_2.bb ;

  return f_3 ;
}
form 
operator * (const Float& lambda, const form& f)
{
  form f_0(f.get_first_space(),f.get_second_space());
  f_0.uu = lambda*f.uu ; f_0.ub = lambda*f.ub ;
  f_0.bu = lambda*f.bu ; f_0.bb = lambda*f.bb ;
  return f_0 ;
}
form 
operator * (const form& f_1, const form& f_2)
{
  form f_3(f_2.get_first_space(),f_1.get_second_space());

  f_3.uu = f_1.uu*f_2.uu + f_1.ub*f_2.bu ; 
  f_3.ub = f_1.uu*f_2.ub + f_1.ub*f_2.bb ;
  f_3.bu = f_1.bu*f_2.uu + f_1.bb*f_2.bu ;
  f_3.bb = f_1.bu*f_2.ub + f_1.bb*f_2.bb ;

  return f_3 ;
}
form 
operator * (const form& f, const form_diag& d)
{
  form f_3(d.get_space(),f.get_second_space());

  f_3.uu = f.uu*d.uu ; 
  f_3.ub = f.ub*d.bb ;
  f_3.bu = f.bu*d.uu ;
  f_3.bb = f.bb*d.bb ;

  return f_3 ;
}

form 
operator * (const form_diag& d, const form& f)
{
  form f_3(f.get_first_space(),d.get_space());

  f_3.uu = d.uu*f.uu ; 
  f_3.ub = d.uu*f.ub ;
  f_3.bu = d.bb*f.bu ;
  f_3.bb = d.bb*f.bb ;

  return f_3 ;
}

form  
operator / (const form& f, const Float& lambda)
{
  form f_0(f.get_first_space(),f.get_second_space());
  //  f_0.uu = f.uu/lambda ; f_0.ub = f.ub/lambda ;
  //  f_0.bu = f.bu/lambda ; f_0.bb = f.bb/lambda ;
  return f_0 ;
}

template <class A, class Size>
inline
void
check_amux (const A& a, Size n)
{
    check_macro (a.ncol() == n, "incompatible csr("
        << a.nrow() << "," << a.ncol() << ")*vec(" << n << ")");
}
field
operator * (const form& a, const field& x)
{
    // y.u := a.uu*x.u + a.ub*x.b
    // y.b := a.bu*x.u + a.bb*x.b

    field y(a.get_second_space()); 

    // 1)   y.u := a.uu*x.u
    check_amux (a.uu, x.u.size());
    csr_amux (
        a.uu.ia().begin(),
        a.uu.ia().end(),
        a.uu.ja().begin(),
        a.uu.a().begin(),
        x.u.begin(),
        y.u.begin());

    // 2)   y.u += a.ub*x.b
    check_amux (a.ub, x.b.size());
    csr_pamux (
        a.ub.ia().begin(),
        a.ub.ia().end(),
        a.ub.ja().begin(),
        a.ub.a().begin(),
        x.b.begin(),
        y.u.begin());

    // 3) y.b := a.bu*x.u
    check_amux (a.bu, x.u.size());
    csr_amux (
        a.bu.ia().begin(),
        a.bu.ia().end(),
        a.bu.ja().begin(),
        a.bu.a().begin(),
        x.u.begin(),
        y.b.begin());

    // y.b += a.bb*x.b
    check_amux (a.bb, x.b.size());
    csr_pamux (
        a.bb.ia().begin(),
        a.bb.ia().end(),
        a.bb.ja().begin(),
        a.bb.a().begin(),
        x.b.begin(),
        y.b.begin());

    return y;
}
field
operator * (const form& a, const field_component& x)
{ 
    // y.u := a.uu*x.u + a.ub*x.b
    // y.b := a.bu*x.u + a.bb*x.b

    field y(a.get_second_space());

    // 1)   y.u := a.uu*x.u
    check_amux (a.uu, x.u_size());
    csr_amux (
        a.uu.ia().begin(),
        a.uu.ia().end(),
        a.uu.ja().begin(),
        a.uu.a().begin(),
        x.u_begin(),
        y.u.begin());

    // 2)   y.u += a.ub*x.b
    check_amux (a.ub, x.b_size());
    csr_pamux (
        a.ub.ia().begin(),
        a.ub.ia().end(),
        a.ub.ja().begin(),
        a.ub.a().begin(),
        x.b_begin(),
        y.u.begin());

    // 3) y.b := a.bu*x.u
    check_amux (a.bu, x.u_size());
    csr_amux (
        a.bu.ia().begin(),
        a.bu.ia().end(),
        a.bu.ja().begin(),
        a.bu.a().begin(),
        x.u_begin(),
        y.b.begin());

    // y.b += a.bb*x.b
    check_amux (a.bb, x.b_size());
    csr_pamux (
        a.bb.ia().begin(),
        a.bb.ia().end(),
        a.bb.ja().begin(),
        a.bb.a().begin(),
        x.b_begin(),
        y.b.begin());

    return y;

}

field
form::trans_mult (const field& x) const
{
    field y(get_first_space(), Float(0)); 
    y.u = uu.trans_mult(x.u) + bu.trans_mult(x.b);
    y.b = ub.trans_mult(x.u) + bb.trans_mult(x.b);
    return y;
}
form
form_nul(const space& X, const space& Y)
{

  form f(X,Y) ;
  f.uu.resize(Y.n_unknown(),X.n_unknown(), 0) ; f.uu.clear() ;
  f.ub.resize(Y.n_unknown(),X.n_blocked(), 0) ; f.ub.clear() ;
  f.bu.resize(Y.n_blocked(),X.n_unknown(), 0) ; f.bu.clear() ;
  f.bb.resize(Y.n_blocked(),X.n_blocked(), 0) ; f.bb.clear() ;
  return f ;

}

ostream& 
operator << (ostream& s, const form& a)
{
    s << a.uu;
    s << a.ub;
    s << a.bu;
    s << a.bb;
    return s;
}
}// namespace rheolef
