/*
 * Copyright (c) 2004, 2005 The University of Wroclaw.
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *    1. Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *    2. Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *    3. The name of the University may not be used to endorse or promote
 *       products derived from this software without specific prior
 *       written permission.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY ``AS IS'' AND ANY EXPRESS OR
 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN
 * NO EVENT SHALL THE UNIVERSITY BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
 * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

using Nemerle;
using Nemerle.Collections;
using Nemerle.Compiler.SolverMacros;
using Nemerle.Compiler.Typedtree;

/*

There are two kinds of type variables:

  * free type variables with an associated upper and lower bounds on
    types that can be substituted for it
    
  * fixed type variables, that are already substituted some type
    constructor
    
The constraint solver maintains a graph of type variables (vertices)
and subtyping relations (edges).  The graph follows several invariants:

  1. There are only free type variables in it.

  2. There are no cycles in it. If a cycle emerge, all type variables
     involved in it are merged into one type variable. (The graph is
     therefore a DAG).

  3. The graph is transitively closed, that is if A :> B and B :> C, then
     A :> C, where X :> Y stands for an edge in the graph from X to Y.
     
  4. The upper and lower bounds are also transitively closed, that is
     if t :> A, A :> B, B :> t' then t :> t', where :> stands for a
     subtyping relation.

  5. If t :> A and A :> t', then t :> t' (that is upper bound has to be
     bigger than lower bound). If t = t', then the type t is substituted
     for the variable A (that is A gets fixed), since it is the only
     type to fulfill both upper and lower limits. To maintain 1., it's
     then removed from the graph.
     
It is sometimes crucial to save the graph in a certain state and then go
back to it. This is done with the PushState and PopState methods -- they
maintain stack of maps from type variable identifiers to type variables
themselves. Type variables in given state are looked up with the Find
method, while if there is a need to update a type variable, Copy should
be called.  Copy is a noop if there is already a copy in the current
state, while both Copy and Find are noops if the stack of states is empty.

The Find() methods also takes into account the equality constraints
on type variables (they result from subtyping tests or cycle merging).
It returns the selected representative of given merged variable class.
New equality constraints are added using the AddLink method.

*/

namespace Nemerle.Compiler 
{
  /** A constraint solver. */
  public class Solver
  {
    /** Store current constraint state.
        
        Called before some speculative type checking fragment, like
        overload resolution. */
    public PushState () : void
    {
      ++serial_stack_top;
      ++last_serial;
      top_serial = last_serial;
      serial_stack [serial_stack_top] = top_serial;

      dt_stack.Push (dt_store);

      messenger.PushState ();
    }


    /** Go one constraint state back.
        
        This function rollbacks all constraints made in current state,
        and also clear the error state if it was in */
    public PopState () : void
    {
      --serial_stack_top;
      top_serial = serial_stack [serial_stack_top];
      
      dt_store = dt_stack.Pop ();
      messenger.PopState ();
    }


    public Unwind () : void
    {
      while (!IsTopLevel)
        PopState ()
    }


    internal CanEnterPossiblyLooping () : bool
    {
      possibly_looping++;

      if (possibly_looping < 100) true
      else {
        CyclicTypeCount++;
        SaveError (CurrentMessenger, "cyclic type found");
        false
      }
    }


    internal LeavePossiblyLooping () : void
    {
      possibly_looping--;
    }


    /** Make sure all TyVars created so far won't go anywhere.
        
        Called when the type checking process for a method is finished (or
        for a class, if we decide to go with inferred private types).  */
    [Nemerle.NotImplemented] // XXX it is not clear if it will be needed.
    public static FixateAll () : void
    {
    }


    /** Generate a new type variable. */
    public static FreshTyVar () : TyVar
    {
      TyVar ()
    }


    public static MonoTypes (m : list [MType]) : list [TyVar]
    {
      if (m.IsEmpty) []
      else List.Map (m, fun (x) { x })
    }


    public static FixedValues (m : list [TyVar]) : list [MType]
    {
      if (m.IsEmpty) []
      else List.Map (m, fun (x : TyVar) { x.FixedValue })
    }
    

    public static Fix (m : list [TyVar]) : list [MType]
    {
      if (m.IsEmpty) []
      else List.Map (m, fun (x : TyVar) { x.Fix () })
    }


    /** Increment current type variable rank.

        Called before typing of a local function. */
    public PushRank () : void
    {
      ++current_rank;
    }

    
    /** Restore previous type variable rank.

        Called after typing of a local function. */
    public PopRank () : void
    {
      --current_rank;
    }


    #region Anti dead lock queue
    public Enqueue (action : void -> void) : void
    {
      if (locked) {
        comp_queue.Push (action);
      } else {
        try {
          locked = true;
          action ();
          while (!comp_queue.IsEmpty) {
            def action = comp_queue.Pop ();
            action ()
          }
        } finally {
          locked = false;
          comp_queue.Clear ();
        }
      }
    }

    comp_queue : Queue [void -> void] = Queue ();
    mutable locked : bool;
    #endregion


    /** Return the biggest type [t] such that [t <: t1] and [t <: t2]. 
        It doesn't work for separated types. */
    public Intersect (t1 : MType, t2 : MType) : MType
    {
      def intersect_classes (t1 : MType, t2 : MType) {
        def tc1 = (t1 :> MType.Class).tycon;
        def tc2 = (t2 :> MType.Class).tycon;
        match (tc1.SuperType (tc2)) {
          | None =>
            match (tc2.SuperType (tc1)) {
              | None =>
                match ((tc1.GetTydecl (), tc2.GetTydecl ())) {
                  | (TypeDeclaration.Interface, TypeDeclaration.Interface)
                  | (TypeDeclaration.Class, TypeDeclaration.Interface)
                  | (TypeDeclaration.Interface, TypeDeclaration.Class) =>
                    [t1, t2]
                  | _ =>
                    SaveError (messenger, 
                                 $ "types $t1 and $t2 are not compatible "
                                   "[during intersection]");
                    [t1]
                }
              | Some =>
                intersect_classes (t2, t1)
            }
          | Some =>
            // tc1 : tc2(args)
            _ = t1.Require (t2);
            [t1]
        }
      }

      if (t1 == null) t2
      else if (t2 == null) t1
      else {
        t1.Validate ();
        t2.Validate ();
        def result =
          match ((t1, t2)) {
            | _ when t1.Equals (t2) => [t1]

            | (MType.Class (tc, []), t) when tc.Equals (InternalType.Object_tc)
            | (t, MType.Class (tc, [])) when tc.Equals (InternalType.Object_tc) =>
              [t]

            | ((MType.TyVarRef (tv)) as tvr, (MType.Class (tc, _)) as t)
            | ((MType.Class (tc, _)) as t, (MType.TyVarRef (tv)) as tvr) =>
              mutable res = true;
              mutable seen = false;
              foreach (MType.Class (tc', _) as t' in tv.Constraints)
                when (!seen && tc'.SuperType (tc).IsSome) {
                  res = t'.Require (t);
                  // Message.Debug ($ "$(t') vs $t : $res");
                  seen = true;
                }
              unless (seen && res)
                SaveError (messenger, 
                           $ "types $t1 and $t2 are not compatible "
                             "[during intersection, tyvar]");
              [tvr : MType]

            | (MType.Class (tc, args), MType.Intersection (lst))
            | (MType.Intersection (lst), MType.Class (tc, args)) =>
              def loop (res, tc, args, lst) {
                match (lst) {
                  | (MType.Class (tc', args') as t) :: tl =>
                    if (tc'.SuperType (tc).IsSome ||
                        tc.SuperType (tc').IsSome)
                      match (intersect_classes (MType.Class (tc, args), 
                                                MType.Class (tc', args'))) {
                        | [MType.Class (tc, args)] =>
                          loop (res, tc, args, tl)
                        | _ => assert (false)
                      }
                    else loop ((t : MType) :: res, tc, args, tl)
                    
                  | _ :: _ => assert (false)

                  | [] => MType.Class (tc, args) :: res
                }
              }
              loop ([], tc, args, lst)
              
            | (MType.Class (tc1, args1), MType.Class (tc2, args2)) =>
              intersect_classes (MType.Class (tc1, args1), 
                                 MType.Class (tc2, args2))

            | _ => 
              SaveError (messenger, 
                         $ "types $t1 and $t2 are not compatible "
                           "[during intersection]");
              [t1]
          }
        match (result) {
          | [x] => x
          | lst =>
            def res = MType.Intersection (lst);
            res.Validate ();
            res
        }
      }
    }
    

    public Sum (t1 : MType, t2 : MType) : MType
    {
      def sum_list (lst : list [MType])
      {
        mutable supertypes = null;
        mutable seen_object = false;
        mutable seen_value_type = false;

        foreach (t in lst) {
          match (t) {
            | MType.Class (tc, _) =>
              when (tc.Equals (InternalType.Object_tc))
                seen_object = true;
              when (tc.Equals (InternalType.ValueType_tc))
                seen_value_type = true;
              def s =
                List.FoldLeft (InternalType.Object :: tc.GetSuperTypes (), 
                  Set.Singleton (tc),
                  fun (e, s : Set [TypeInfo]) {
                    match (e) {
                      | MType.Class (tc, _) =>
                        s.Replace (tc)
                      | _ => assert (false)
                    }
                  });
              if (supertypes == null)
                supertypes = s
              else
                supertypes = supertypes.Intersect (s);
            | _ => assert (false, $ "wrong type in Sum: $t")
          }
        }

        assert (!supertypes.IsEmpty);

        def maximal =
          supertypes.Fold ([], fun (t, lst) {
            mutable seen_better = false;
            def lst = List.RevFilter (lst, fun (l : TypeInfo) {
              if (l.SuperType (t).IsSome) {
                seen_better = true;
                true
              } else if (t.SuperType (l).IsSome) {
                assert (!seen_better);
                false
              } else true
            });

            if (seen_better) lst else t :: lst
          });

        def maximal =
          List.RevMap (maximal, fun (tc : TypeInfo) {
            // reuse existing type if possible
            mutable res = null;
            foreach (x in lst)
              match (x) {
                | MType.Class (tc', _) when tc'.Equals (tc) =>
                  res = x
                | _ => {}
              }
            // and create fresh substitution if not
            when (res == null)
              res = tc.GetFreshType ();
              
            foreach (x : MType in lst)
              _ = x.Require (res);

            res
          });


        // if there are interfaces and class in the bag choose the class
        def aint_interface (t) { ! t.IsInterface }
        def maximal =
          if (maximal.Exists (aint_interface))
            maximal.Filter (aint_interface)
          else
            maximal;

        match (maximal) {
          | [t] when t.Equals (InternalType.ValueType) && !seen_value_type
          | [t] when t.Equals (InternalType.Object) && !seen_object =>
            SaveError (messenger, 
                         $ "common super type of types $lst is just "
                           "`$t', please upcast one of the types to "
                           "`$t' if this is desired")
          | _ => ()
        }

        match (maximal) {
          | [x] => x
          | lst' =>
            SaveError (messenger, 
                         $ "common super type of types $lst is a set of "
                           "interfaces $(lst'). This is not supported");
            def res = MType.Intersection (lst');
            res.Validate ();
            res
        }
      }

      if (t1 == null) t2
      else if (t2 == null) t1
      else {
        t1.Validate ();
        t2.Validate ();
        match ((t1, t2)) {
          | (MType.Intersection (l1), MType.Intersection (l2)) =>
            sum_list (l1 + l2)
          | (t, MType.Intersection (l))
          | (MType.Intersection (l), t) =>
            sum_list (t :: l)
          // a little special case, for better speed
          | (MType.Class (tc1, []), MType.Class (tc2, [])) =>
            if (tc1.Equals (tc2))
              t1
            else if (tc1.SuperType (tc2).IsSome)
              t2
            else if (tc2.SuperType (tc1).IsSome)
              t1
            else
              sum_list ([t1, t2])
          | (t1, t2) =>
            sum_list ([t1, t2])
        }
      }
    }
    

    public CurrentMessenger : Messenger
    {
      get { messenger }
    }

    
    public IsTopLevel : bool
    {
      get { serial_stack_top == 0 }
    }
    
    
    public this ()
    {
      dt_stack = Stack ();
      
      messenger = Messenger ();
      messenger.NeedMessage = true;
      messenger.InErrorMode = true;

      serial_stack = array (5);
      last_serial = 1;
      top_serial = 1;
      serial_stack [0] = 1;
      serial_stack_top = 0;
      possibly_looping = 0;
    }


    #region Interface for DelayedTyping
    internal mutable dt_store : NemerleMap [Typer.DelayedTyping, Typer.DelayedTyping.Kind];
    dt_stack : Stack [NemerleMap [Typer.DelayedTyping, Typer.DelayedTyping.Kind]];
    #endregion


    internal serial_stack : array [int];
    internal mutable serial_stack_top : int;
    internal mutable top_serial : int;
    private mutable last_serial : int;
    private mutable possibly_looping : int;
    internal mutable CyclicTypeCount : int;

    mutable current_rank : int;
    messenger : Messenger;
  }
}
