module Optimise where

import Language
import Inliner
import Propagate

import Debug.Trace

-- An optimisation is a transformation from an expression to an 
-- optimised expression
data Optimisation n = Opt String (Expr n -> Expr n)


-- Naturally then, optimisations can be composed.
optComp (Opt fn f) (Opt gn g) = Opt (gn++ " -> "++fn) (f.g)

-- And we have an identity optimisation
optID = Opt "" id

-- Top level optimisation loop
-- Optimisation levels are currently:
--    0: Just parse time constant folding
--    1: Inlining, propagation, for loop collapsing, append collapsing,
--       remaining constant folds
--       Do some of them twice to catch anything which previous steps may 
--       expose.

dontimes 0 f x = x
dontimes n f x = f (dontimes (n-1) f x)

runOpts :: Int -> Program -> Program
runOpts 0 p = p
runOpts 1 p 
    = dontimes 2 (copyPropagate.(optAll basicOpts).inliner.copyPropagate.(optAll basicOpts)) p
runOpts n p = runOpts (n-1) p

basicOpts :: Optimisation Name
basicOpts = forOpt `optComp`
            appendOpt `optComp`
            constantFold `optComp`
	    optID

allOpts = [forOpt,appendOpt,constantFold]

allOptNames = ["forloop","append","constant"]

useOpts :: Monad m => [String] -> m (Optimisation Name)
useOpts [] = return optID
useOpts (x:xs) = do rest <- useOpts xs
		    thisopt <- findOpt x allOpts
		    return $ thisopt `optComp` rest
   where findOpt x [] = fail $ "Unknown optimisation " ++ x
	 findOpt x ((opt@(Opt n _)):os) | n == x = return opt
					| otherwise = findOpt x os

{-testOpt = let src = Infix Plus (GConst (Num 4)) (Metavar "foo" 0 0)
	      dest = Infix Plus (Metavar "foo" 0 0) (Metavar "foo" 0 0) in
	      genericOpt "test" src dest
-}

-- Although we do this at parse time, we might find a few more arising
-- from other optimisations, especially inlining.
constantFold = Opt "constant" cfold'
  where cfold' e@(Infix op (GConst (Num x)) (GConst (Num y)))
            = case foldint op x y of
                 Nothing -> e
                 Just v -> GConst v
        cfold' e@(RealInfix op (GConst (Re x)) (GConst (Re y)))
            = case foldreal op x y of
                 Nothing -> e
                 Just v -> GConst v
        cfold' e@(Unary op (GConst (Num x)))
            = case foldunint op x of
                 Nothing -> e
                 Just v -> GConst v
        cfold' e@(RealUnary op (GConst (Re x)))
            = case foldunreal op x of
                 Nothing -> e
                 Just v -> GConst v
        cfold' e@(If (GConst (Bo True)) t f) = t
        cfold' e@(If (GConst (Bo False)) t f) = f
        cfold' (Coerce (Prim Number) (Prim RealNum) (GConst (Num x)))
            = (GConst (Re (fromIntegral x)))
        cfold' c@(Coerce t1 t2 e) 
            = let e' = cfold' e in
                  if (e/=e')
                     then (cfold' (Coerce t1 t2 e'))
                     else c
        cfold' x = x

appendOpt = Opt "append" appOpt'
  where appOpt' (Assign lval (AppendChain [apped, v]))
                   | same lval apped = AssignApp lval v
        appOpt' app@(Append _ _) = AppendChain $ getApps app
        appOpt' x = x

        same (AName i) (Loc i') = i == i'
        same (AGlob i) (GVar i') = i == i'
        same (AIndex l exp) (Index l' exp') = same l l' && exp == exp'
        same (AField l nm a t) (Field l' nm' a' t') 
            = same l l' && nm == nm' && a == a' && t ==t'
        same _ _ = False

        getApps (Append x y) = (getApps x)++[y]
        getApps (AppendChain xs) = xs
        getApps x = [x]

forOpt = Opt "forloop" forOpt'
  where
     forOpt' (For i nm j (AName x) 
	      (Annotation a (Apply (Global (NS (UN "Array") (UN "range")) _ _)
                             [start,end,step@(GConst (Num stepval))])) body)
        -- x = start; while (x<=end) { body; x=x+step; }
	 = Declare "optimiser" 0 ((MN ("foropt",i)),True) (Prim Number) $
	   Declare "optimiser" 0 ((MN ("foropt",j)),True) (Prim Number) $
	   Seq (Assign (AName x) start) $
           Seq (Assign (AName i) end) $ -- Save recomputation?
           While (Infix (endtest stepval) (Loc x) (Loc i)) $
	    Seq body (Assign (AName x) (Infix Plus (Loc x) step))
     forOpt' (For i nm j (AName x) 
	      (Annotation a (Apply (Global (NS (UN "Array") (UN "range")) _ _) 
                             [start,end,step])) body)
        -- x = start; while (x<=end) { body; x=x+step; }
	 = Declare "optimiser" 0 (MN (("foropt",i)),True) (Prim Number) $
	   Declare "optimiser" 0 (MN (("foropt",j)),True) (Prim Number) $
	   Seq (Assign (AName x) start) $
           Seq (Assign (AName i) end) $ -- Save recomputation?
           Seq (Assign (AName j) step) $
           (While (If (Infix OpGT (Loc j) (GConst (Num 0)))
                   (Infix OpLE (Loc x) (Loc i))
		   (Infix OpGE (Loc x) (Loc i))) $
	   Seq body (Assign (AName x) (Infix Plus (Loc x) (Loc j))))
     forOpt' x = x
     rangefn = Global (NS (UN "Array") (UN "range")) "_iii" 3
     endtest x | x>0 = OpLE
	       | otherwise = OpGE

-- Apply an optimisation across a program
optAll :: Optimisation Name -> Program -> Program
optAll (Opt n opt) = opt' where
    opt' [] = []
    opt' ((FunBind (f,l,nm,ty,opts,Defined exp) com ority):xs) =
	     (FunBind (f,l,nm,ty,opts,Defined (oapply opt exp)) com ority):(opt' xs)
    opt' (x:xs) = x:(opt' xs)

-- Build a generic optimisation, name <n>
-- transforming a term of the form e1 to e2
genericOpt :: (Show n, Eq n) => String -> Expr n -> Expr n -> Optimisation n
genericOpt n e1 e2 = Opt n (\t ->
			      case match e1 t of
			         Nothing -> t
			         Just ms -> trace (show ms) $ replace ms e2)

-- Match first expression against second. Any metavariables in the first
-- are mapped to corresponding expressions in the second
-- Returns a list of mappings if there is a match, otherwise failure.
match e1 e2 = do ms <- getmatches e1 e2
		 checkMatches [] ms

checkMatches acc [] = return acc
checkMatches acc (x:xs) =
   case (lookup (fst x) xs) of
       Nothing -> checkMatches (x:acc) xs
       (Just y) -> if (snd x)==y then checkMatches acc xs
		      else fail "No match"

getmatches :: (Monad m, Eq n) => Expr n -> Expr n -> m [(Int,Expr n)]
getmatches (Metavar f l x) e = return [(x,e)]
getmatches (Global x xm _) (Global y ym _) | x == y && xm == ym = return []
getmatches (Loc x) (Loc y) | x == y = return []
getmatches (Lambda _ _ sx) (Lambda _ _ sy) = getmatches sx sy
getmatches (Closure _ _ sx) (Closure _ _ sy) = getmatches sx sy
getmatches (Bind _ _ vx sx) (Bind _ _ vy sy) = do mv <- getmatches vx vy
						  ms <- getmatches sx sy
						  return $ mv++ms
getmatches (Declare _ _ _ _ sx) (Declare _ _ _ _ sy) = getmatches sx sy
getmatches (Return x) (Return y) = getmatches x y
getmatches (Assign ax ex) (Assign ay ey) = do ma <- agetmatches ax ay
					      me <- getmatches ex ey
					      return $ ma ++ me
getmatches (AssignOp op1 ax ex) (AssignOp op2 ay ey) | op1 == op2
					= do ma <- agetmatches ax ay
					     me <- getmatches ex ey
					     return $ ma ++ me
getmatches (Seq x1 x2) (Seq y1 y2) = do m1 <- getmatches x1 y1
					m2 <- getmatches x2 y2
					return $ m1 ++ m2
getmatches (Apply x xs) (Apply y ys) = do m <- getmatches x y
					  ms <- getmatcheses xs ys
					  return $ m ++ ms
getmatches (Partial x xs i) (Partial y ys j) | i == j =
					    do m <- getmatches x y
					       ms <- getmatcheses xs ys
					       return $ m ++ ms
getmatches (Foreign _ _ xs) (Foreign _ _ ys) = getmatcheses (map fst xs) (map fst ys)
getmatches (While a x) (While b y) = do ma <- getmatches a b
					mb <- getmatches x y
					return $ ma ++ mb
getmatches (DoWhile a x) (DoWhile b y) = do ma <- getmatches a b
					    mb <- getmatches x y
					    return $ ma ++ mb
getmatches (For _ _ _ a1 s1 b1) (For _ _ _ a2 s2 b2) = do
    ma <- agetmatches a1 a2
    ms <- getmatches s1 s2
    mb <- getmatches b1 b2
    return $ ma ++ ms ++ mb
getmatches (TryCatch a1 s1 b1 f1) (TryCatch a2 s2 b2 f2) = do
    ma <- getmatches a1 a2
    ms <- getmatches s1 s2
    mb <- getmatches b1 b2
    mf <- getmatches f1 f2
    return $ ma ++ ms ++ mb ++ mf
getmatches (Throw x) (Throw y) = getmatches x y
getmatches (Except a x) (Except b y) = do ma <- getmatches a b
					  mb <- getmatches x y
					  return $ ma ++ mb
getmatches (Infix op1 a x) (Infix op2 b y) | op1 == op2 
				= do ma <- getmatches a b
				     mb <- getmatches x y
				     return $ ma ++ mb
getmatches (RealInfix op1 a x) (RealInfix op2 b y) | op1 == op2 
				= do ma <- getmatches a b
				     mb <- getmatches x y
				     return $ ma ++ mb
getmatches (Append a x) (Append b y) = do ma <- getmatches a b
					  mb <- getmatches x y
					  return $ ma ++ mb
getmatches (Unary opx x) (Unary opy y) | opx == opy = getmatches x y
getmatches (Coerce t1 t2 x) (Coerce v1 v2 y) | t1==v1 && t2==v2 = getmatches x y
getmatches (Case e1 a1) (Case e2 a2) = do me <- getmatches e1 e2
					  ma <- getmatchesalts a1 a2
					  return $ me ++ ma
getmatches (If a1 s1 b1) (If a2 s2 b2) = do
    ma <- getmatches a1 a2
    ms <- getmatches s1 s2
    mb <- getmatches b1 b2
    return $ ma ++ ms ++ mb
getmatches (Index a x) (Index b y) = do ma <- getmatches a b
					mb <- getmatches x y
					return $ ma ++ mb
getmatches (Field e1 n1 _ _) (Field e2 n2 _ _) | n1 == n2 = getmatches e1 e2
getmatches (ArrayInit es1) (ArrayInit es2) = getmatcheses es1 es2
getmatches (Annotation a e) (Annotation a2 e2) = getmatches e e2
getmatches x y | x == y = return []
               | otherwise = fail "No getmatches"

agetmatches (AName x) (AName y) | x == y= return []
agetmatches (AGlob x) (AGlob y) | x == y= return []
agetmatches (AIndex ax ex) (AIndex ay ey) = do ma <- agetmatches ax ay
					       me <- getmatches ex ey
					       return $ ma ++ me
agetmatches (AField ax n _ _) (AField ay m _ _) | n == m = agetmatches ax ay
agetmatches _ _ = fail "No getmatches"

getmatcheses [] [] = return []
getmatcheses (x:xs) (y:ys) = do m <- getmatches x y
				ms <- getmatcheses xs ys
				return $ m ++ ms

getmatchesalts [] [] = return []
getmatchesalts ((Default e1):as1) ((Default e2):as2)
    = do m <- getmatches e1 e2
	 ma <- getmatchesalts as1 as2
	 return $ m ++ ma
getmatchesalts ((Alt _ _ es1 e1):as1) ((Alt _ _ es2 e2):as2)
    = do ms <- getmatcheses es1 es2
	 m <- getmatches e1 e2
	 ma <- getmatchesalts as1 as2
	 return $ ms ++ m ++ ma
getmatchesalts (_:as1) (_:as2) = getmatchesalts as1 as2
       
-- Replace each metavariable in a term with an expression
replace :: [(Int,Expr n)] -> Expr n -> Expr n
replace ms = mapsubexpr id repmv
   where repmv fl l x = case lookup x ms of
			   (Just e) -> e
			   Nothing -> (Metavar fl l x)

-- Apply an optimisation across an expression
oapply opt expr = opt (mapsubexpr (oapply opt) Metavar expr)

{-
replace ms = rep' where
    rep' m@(Metavar _ _ x) = case lookup x ms of
			        (Just e) -> e
				Nothing -> m
    rep' (Lambda args e) = Lambda args (rep' e)
    rep' (Closure args t e) = Closure args t (rep' e)
    rep' (Bind n ty e1 e2) = Bind n ty (rep' e1) (rep' e2)
    rep' (Declare f l n t e) = Declare f l n t (rep' e)
    rep' (Return e) = Return (rep' e)
    rep' (Assign a e) = Assign (arep a) (rep' e)
    rep' (AssignOp op a e) = AssignOp op (arep a) (rep' e)
    rep' (Seq x y) = Seq (rep' x) (rep' y)
    rep' (Apply f as) = Apply (rep' f) (reps as)
    rep' (Partial f as i) = Partial (rep' f) (reps as) i
    rep' (Foreign ty n es) = Foreign ty n 
			        (zip (reps (map fst es)) (map snd es))
    rep' (While e b) = While (rep' e) (rep' b)
    rep' (DoWhile e b) = DoWhile (rep' e) (rep' b)
    rep' (For i j a e1 e2) = For i j (arep a) (rep' e1) (rep' e2)
    rep' (TryCatch t e f) = TryCatch (rep' t) (rep' e) (rep' f)
    rep' (Throw e) = Throw (rep' e)
    rep' (Except e i) = Except (rep' e) (rep' i)
    rep' (Infix op x y) = Infix op (rep' x) (rep' y)
    rep' (RealInfix op x y) = RealInfix op (rep' x) (rep' y)
    rep' (Append x y) = Append (rep' x) (rep' y)
    rep' (Unary op x) = Unary op (rep' x)
    rep' (RealUnary op x) = RealUnary op (rep' x)
    rep' (Coerce t1 t2 x) = Coerce t1 t2 (rep' x)
    rep' (Case e as) = Case (rep' e) (altrep as)
    rep' (If a t e) = If (rep' a) (rep' t) (rep' e)
    rep' (Index a b) = Index (rep' a) (rep' b)
    rep' (Field e n i j) = Field (rep' e) n i j
    rep' (ArrayInit as) = ArrayInit (reps as)
    rep' x = x

    arep (AIndex a e) = AIndex (arep a) (rep' e)
    arep (AField a n i j) = AField (arep a) n i j
    arep x = x

    reps [] = []
    reps (x:xs) = (rep' x) : (reps xs)

    altrep [] = []
    altrep ((Alt i j es e):as) = (Alt i j (reps es) (rep' e)):(altrep as)
-}

{-
oapply opt expr = opt (oa' expr)
  where oapply' = oapply opt
	oa' (Lambda args e) = Lambda args (oapply' e)
	oa' (Closure args t e) = Closure args t (oapply' e)
	oa' (Bind n ty e1 e2) = Bind n ty (oapply' e1) (oapply' e2)
	oa' (Declare f l n t e) = Declare f l n t (oapply' e)
	oa' (Assign a e) = Assign (aapply a) (oapply' e)
	oa' (AssignOp op a e) = AssignOp op (aapply a) (oapply' e)
	oa' (Seq a b) = Seq (oapply' a) (oapply' b)
	oa' (Apply f as) = Apply (oapply' f) (applys as)
	oa' (Partial f as i) = Partial (oapply' f) (applys as) i
	oa' (Foreign ty n es) = Foreign ty n 
			        (zip (applys (map fst es)) (map snd es))
	oa' (While e b) = While (oapply' e) (oapply' b)
	oa' (DoWhile e b) = DoWhile (oapply' e) (oapply' b)
	oa' (For i j a e1 e2) = For i j (aapply a) (oapply' e1) (oapply' e2)
	oa' (TryCatch t e f) = TryCatch (oapply' t) (oapply' e) (oapply' f)
	oa' (Throw e) = Throw (oapply' e)
	oa' (Except e i) = Except (oapply' e) (oapply' i)
	oa' (Infix op x y) = Infix op (oapply' x) (oapply' y)
	oa' (RealInfix op x y) = RealInfix op (oapply' x) (oapply' y)
	oa' (Append x y) = Append (oapply' x) (oapply' y)
	oa' (Unary op x) = Unary op (oapply' x)
	oa' (RealUnary op x) = RealUnary op (oapply' x)
	oa' (Coerce t1 t2 x) = Coerce t1 t2 (oapply' x)
	oa' (Case e as) = Case (oapply' e) (altapp as)
	oa' (If a t e) = If (oapply' a) (oapply' t) (oapply' e)
	oa' (Index a b) = Index (oapply' a) (oapply' b)
	oa' (Field e n i j) = Field (oapply' e) n i j
	oa' (ArrayInit as) = ArrayInit (applys as)
	oa' x = x

        aapply (AIndex a e) = AIndex (aapply a) (oapply' e)
	aapply (AField a n i j) = AField (aapply a) n i j
	aapply x = x

        applys [] = []
	applys (x:xs) = (oapply' x) : (applys xs)

        altapp [] = []
	altapp ((Alt i j es e):as) 
	    = (Alt i j (applys es) (oapply' e)):(altapp as)
-}