#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
General utilities for code generation.
"""



__author__ = "Martin Sandve Alnes"
__date__ = "2008-02-27 -- 2008-10-13"
__copyright__ = "(C) 2008-2009 Martin Sandve Alnes and Simula Resarch Laboratory"
__license__  = "GNU GPL Version 2, or (at your option) any later version"


import re
import swiginac

from pprint import *


# global variable for consistent indentation 
indent_size = 4

def indent(text, n=1):
    """Indent a text n times. Text can also be a list of 
    strings, or recursively a list of lists of strings."""
    # must have something to indent
    if not(text and n):
        return text
    
    # indent components of a list without merging
    if isinstance(text, list):
        return [indent(c, n) for c in text]
    
    # fix to avoid extra spaces at end
    append_lineend = text.endswith("\n")

    # apply spaces to all lines
    lines = text.split("\n")
    if append_lineend:
        lines.pop()
    space = " "*(indent_size*n)
    text = "\n".join((space + s) for s in lines)

    # fix to avoid extra spaces at end
    if append_lineend:
        text += "\n"

    return text


class CodeFormatter:
    """Utility class for assembling code strings into a multiline string.
    
       Supports checking for matching parenteses and applying indentation
       to generate code that is more robust with respect to correctness
       of program flow and readability of code.

       Support for the following constructs:
       {} (basic block), if, else if, else, switch, case, while, do, class
       
       Typical usage:
       >>> c = CodeFormatter()
       >>> c.begin_switch("i")
       >>> c.begin_case(0)
       >>> c += "foo();"
       >>> c.end_case()
       >>> c.begin_case(1)
       >>> c.begin_if("a > b")
       >>> c += "bar();"
       >>> c.begin_else_if("c > b")
       >>> c += "bar2();"
       >>> c.end_if()
       >>> c.end_case()
       >>> c.end_switch()
       >>> print str(c)
       switch(i)
       {
       case 0:
           foo();
           break;
       case 1:
           if( a > b )
           {
               bar();
           }
           else if( c > b )
           {
               bar2();
           }
           break;
       }
    """
    def __init__(self, name=""):
        self.name        = name
        #self.text        = ""
        self.text        = []
        self.indentation = 0
        self.context     = ["ROOT"]

    def get_context(self):
        return self.context[-1]

    def add_context(self, context):
        self.context.append(context)
    
    def remove_context(self, context):
        self.assert_context(context)
        c = self.context.pop(-1)
        if not c == context:
            raise RuntimeError, "Expected context '%s' to be '%s'." % (c, context)
    
    def assert_context(self, context):
        if not (self.context[-1] == context):
            raise RuntimeError, "Expected to be in context '%s', state is %s." % (context, str(self.context))
    
    def assert_contexts(self, contexts):
        for i in range(len(contexts)):
            if not self.context[-(i+1)] == contexts[i]:
                raise RuntimeError, "Expected to be in contexts %s, state is %s." % (str(contexts), str(self.context))
    
    def assert_closed_code(self):
        self.assert_context("ROOT")
        if self.indentation != 0:
            raise RuntimeError("Code is not closed, indentation state is %d != 0." % self.indentation)
    
    def __str__(self):
        #try:
        self.assert_closed_code()
        #except:
        #    raise RuntimeError("Converting a CodeFormatter to str, but context is not closed: %s" % str(self.context))
        #return self.text
        return "".join(self.text)
    
    def new_text(self, text):
        "Add a block of text directly with no modifications."
        #self.text += text
        self.text.append(text)
    
    def new_line(self, text=""):
        "Add a line with auto-indentation."
        if self.text:
            #self.text += "\n"
            self.text.append("\n")
        #self.text += indent(text, self.indentation)
        self.text.append(indent(text, self.indentation))
    
    def __iadd__(self, text):
        if self.text:
            #self.text += "\n"
            self.text.append("\n")
        #self.text += indent(text, self.indentation)
        self.text.append(indent(text, self.indentation))
        return self
    
    def indent(self):
        self.indentation += 1
    
    def dedent(self):
        if self.indentation <= 0:
            print "WARNING: Dedented non-indented code, something"\
                " is wrong in code generation program flow."
            self.indentation = 0
        else:
            self.indentation -= 1
    
    def comment(self, text):
        if "\n" in text:
            self.new_line( "/*" )
            self.new_line( indent(text, self.indentation) ) # TODO: indent text to proper level!
            self.new_line( "*/" )
        else:
            self.new_line( "// " + text )
    
    def begin_debug(self):
        self.add_context("debug")
        self.new_line("#ifdef SFCDEBUG")
    
    def end_debug(self):
        self.new_line("#endif // SFCDEBUG")
        self.remove_context("debug")
    
    def begin_block(self):
        self.add_context("block")
        self.new_line("{")
        self.indent()
    
    def end_block(self):
        self.dedent()
        self.new_line("}")
        self.remove_context("block")
    
    def begin_switch(self, arg):
        self.new_line("switch(%s)" % str(arg))
        self.new_line("{")
        self.add_context("switch")
    
    def begin_case(self, arg, braces=False):
        self.assert_context("switch")
        self.add_context("case")
        self.new_line("case %s:" % str(arg))
        self.indent()
        if braces: self.begin_block()
    
    def end_case(self):
        if self.get_context() == "block": # braces = True in begin_case
            self.end_block()
        self.new_line("break;")
        self.dedent()
        self.remove_context("case")
    
    def end_switch(self):
        self.new_line("}")
        self.remove_context("switch")

    def begin_while(self, arg):
        self.add_context("while")
        self.new_line("while( %s )" % str(arg))
        self.begin_block()
    
    def end_while(self):
        self.end_block()
        self.remove_context("while")

    def begin_do(self):
        self.add_context("do")
        self.new_line("do")
        self.new_line("{")
        self.indent()

    def end_do(self, arg):
        self.dedent()
        self.new_line( "} while(%s);" % str(arg) )
        self.remove_context("do")
    
    def begin_if(self, arg):
        self.add_context("if")
        self.new_line("if( %s )" % str(arg))
        self.begin_block()
    
    def begin_else_if(self, arg):
        self.end_block()
        self.assert_context("if")
        self.new_line("else if( %s )" % str(arg))
        self.begin_block()
    
    def begin_else(self):
        self.end_block()
        self.assert_context("if")
        self.new_line("else")
        self.begin_block()
    
    def end_if(self):
        self.end_block()
        self.remove_context("if")
    
    def begin_class(self, classname, bases=[]):
        self.add_context("class")
        code = "class %s" % str(classname)
        if bases:
            code += ": " + ", ".join(bases)
        self.new_line( code )
        self.begin_block()
    
    def end_class(self):
        self.end_block()
        self.new_text(";")
        self.remove_context("class")
    
    def declare_function(self, name, return_type="void", args=[], const=False, virtual=False, inline=False, classname=None):
        code = "%s;" % function_signature(name, return_type, args, virtual, inline, const, classname)
        self.new_line(code)
    
    def define_function(self, name, return_type="void", args=[], const=False, virtual=False, inline=False, classname=None, body="// Empty body"):
        signature = function_signature(name, return_type, args, virtual, inline, const, classname)
        self.new_line(signature)
        self.begin_block()
        self.new_line(body)
        self.end_block()
    
    def call_function(self, name, args=[]):
        args = ", ".join(args)
        code = "".join((name, "(", args, ");"))
        self.new_line(code)


# ... Utility functions for code generation

def row_major_index_string(i, shape):
    """Constructs an index string for a row major (C type)
    indexing of a flattened tensor of rank 0, 1, or 2."""
    if len(shape) == 0:
        return "0"
    if len(shape) == 1:
        return "%d" % i[0]
    if len(shape) == 2:
        return "%d*%d + %d" % (shape[1], i[0], i[1])
    raise "Rank 3 or higher not supported in row_major_index_string()"


def optimize_floats(code):
    """Optimize storage size of floating point numbers by removing unneeded trailing zeros."""
    regexp = re.compile('0+e')
    return regexp.sub('e', code)

def function_signature(name, return_type="void", args=[], virtual=False, inline=False, const=False, classname=None):
    "Render function signature from arguments."
    code = []
    if virtual:
        code.append("virtual ")
    if inline:
        code.append("inline ")

    code.extend((return_type, " "))

    if classname:
        code.extend((classname, "::"))
    
    def joinarg(arg):
        if isinstance(arg, str):
            return arg
        if len(arg) == 2:
            return "%s %s" % arg
        elif len(arg) == 3:
            return "%s %s%s" % arg
        raise RuntimeError("Invalid arg: %s" % repr(arg))
    args = ", ".join(joinarg(arg) for arg in args)

    code.extend((name, "(", args, ")"))

    if const:
        code.append(" const")
    return "".join(code)

# TODO: These token code generation functions can be optimized by iterating differently
# TODO: A function that can apply two rules to the same stream and return two joined codes

default_code_rule = r"double %(symbol)s = %(value)s;"
def gen_token_code(token, rule=default_code_rule):
    """token[0] is a symbol or string or matrix with symbols, token[1] is a scalar, expression or matrix of the same shape with the corresponding values.
       Generates code based on rule, default %s, for all elements of token[0] and token[1].""" % default_code_rule

    def code_rule(s, v):
        return rule % { "symbol": s, "value": v }

    # make symbol and value swiginac objects if they're not
    sym, val = token

    if isinstance(val, (int, float)):
        val = swiginac.numeric(val)

    # check if we have one or more tokens here
    if isinstance(sym, (swiginac.matrix, swiginac.lst)):
        if len(sym) != len(val):
            raise RuntimeError("sym and val must have same size.")
        code = "\n".join( code_rule(sym[i].printc(), val[i].printc()) for i in xrange(len(sym)) )
    else:
        code = code_rule(str(sym), val.printc())

    return code

#===============================================================================
# def gen_symbol_declaration(symbol, prefix="double ", postfix=";\n"):
#    """symbol is a swiginac.symbol or matrix or lst with symbols.
#       Generates code for declaration of all symbols in symbol."""
# 
#    if not isinstance(symbol, (swiginac.matrix, swiginac.lst) ):
#        symbol = swiginac.matrix(1,1, [symbol])
# 
#    # cat all symbols in comma separated string
#    symbol_list = ", ".join( str(symbol[i]) for i in xrange(len(symbol)) )
# 
#    code = prefix + symbol_list + postfix
#    return code
#===============================================================================


def gen_token_prints(tokens, indent="    "):
    return "\n".join(gen_token_code(token, rule='std::cout << "'+indent+'%(symbol)s = " << %(symbol)s << std::endl;') for token in tokens)


def gen_symbol_declarations(symbols):
    return "\n".join("double %s;" % str(s) for s in symbols)


def gen_token_declarations(tokens):
    return "\n".join(gen_token_code(token, rule="double %(symbol)s;") for token in tokens)


def gen_token_definitions(tokens):
    return "\n".join(gen_token_code(token, rule="double %(symbol)s = %(value)s;") for token in tokens)


def gen_const_token_definitions(tokens):
    return "\n".join(gen_token_code(token, rule="const double %(symbol)s = %(value)s;") for token in tokens)


def gen_token_assignments(tokens):
    return "\n".join(gen_token_code(token, rule="%(symbol)s = %(value)s;") for token in tokens)


def gen_token_additions(tokens):
    return "\n".join(gen_token_code(token, rule="%(symbol)s += %(value)s;") for token in tokens)


def gen_switch(argument, cases, default_case=None, braces=False):
    if braces:
        start_brace = "{\n"
        end_brace   = "\n}"
    else:
        start_brace = ""
        end_brace   = ""
    case_code = "\n".join([ "case %s:\n%s%s\n  break;%s" % (str(c[0]), start_brace, indent(c[1]), end_brace) for c in cases])
    if default_case:
        case_code += "\ndefault:\n%s%s  %s" % (start_brace, indent(default_case), end_brace)
    return "switch(%s)\n{\n%s\n}" % (argument, indent(case_code))


#class Switch:
#    def __init__(self, argument, cases=[], default_case=None, braces=False):
#        self.argument = argument
#        self.cases = cases
#        self.default_case = default_case
#        self.braces = braces
#
#    def __str__(self):
#        return gen_switch(self.argument, self.cases, self.default_case, self.braces)
#
#
#class IfElse:
#    def __init__(self, cases=[]):
#        self.cases = cases
#
#    def __str__(self):
#        c = self.cases[0]
#        code = "if(%s)\n{\n%s\n}" % (c[0], indent(c[1]))
#        for c in self.cases[1:]:
#            code += "\nelse if(%s)\n{\n%s\n}" % (c[0], indent(c[1]))
#        return code
#
#
#class Struct:
#    def __init__(self, name, variables=[]):
#        self.name = name
#        self.variables = variables
#
#    def __str__(self):
#        inner_code = ""
#        for v in self.variables:
#            if isinstance(v, tuple):
#                inner_code += "%s %s;\n" % (v[0], v[1])
#            else:
#                inner_code += "double %s;\n" % v
#        return "struct %s\n{\n%s}" % (self.name, indent( inner_code ))



def outline(name, tokens, targets, deps):
    
    inputargs = deps

    # figure out which variables to output and
    # which to declare on the local stack
    localtokens = []
    outputargs = []
    for t in tokens:
        s = t[0]
        if s in targets:
            outputargs.append(s)
        else:
            localtokens.append(s)
    
    # generate code for argument list in the function call
    allargs = inputargs + outputargs
    callargumentlist = ", ".join(str(a) for a in allargs)
    
    # generate code for the argument list in the function definition,
    # input args and output args separately
    argumentlist1    = ", ".join("double %s" % str(a) for a in inputargs)
    argumentlist2    = ", ".join("double & %s" % str(a) for a in outputargs)

    # join input and output arguments to a single argument list
    argumentlist     = argumentlist1
    if argumentlist1 and argumentlist2:
        argumentlist += ", "
    argumentlist += argumentlist2
    
    # generate function body to compute targets
    body = ""
    body += "\n".join("  double %s;" % s for s in localtokens)
    body += "\n"
    body += "\n".join("  %s = %s;" % (t[0], t[1]) for t in tokens)

    # stich together code pieces to return
    fundef = "void %s(%s)\n{\n%s\n}" % (name, argumentlist, body)
    funcall = "%s(%s);" % (name, callargumentlist)

    return fundef, funcall


def test_outliner():
    tokens = [
                ("a", "u * v"),
                ("b", "u + w"),
             ]
    deps = ["u", "v", "w"]
    targets = ["a"]

    fundef, funcall = outline("foo", tokens, targets, deps)
    print fundef
    print funcall



#year, month, day, hour, minute = time.localtime()[:5]
#date_string = "%d:%d, %d/%d %d" % (hour, minute, day, month, year)


if __name__ == "__main__":
    c = CodeFormatter()

    c.begin_class( "Foo" )
    c += "dings"
    c.end_class()
    
    c += ""
    
    c.begin_class( "Bar", ("fee", "foe") )
    c.comment( "something funny" )
    c.declare_function("blatti", "int", const=False, virtual=False)
    c.declare_function("foo", "double", const=True)
    c.declare_function("bar", virtual=True)
    c.end_class()

    c.define_function("blatti", "int", const=False, virtual=False, classname="Bar", body='cout << "Hello world!" << endl;')
    c.call_function("blatti")
    
    c += ""
    
    # a basic if
    c.begin_if( "a < b" )
    c += "foo.bar();"
    c.end_if()
    
    c += ""
    
    # a compound if
    c.begin_if( "a < b" )
    c += "foo.bar();"
    c.begin_else_if( "c < b" )
    c += "foo.bar();"
    c.begin_else()
    c += "foo.bar();"
    c.end_if()
    
    c += ""
    
    # a simple do loop
    c.begin_do()
    c += "foo();"
    c.end_do("a > 0")
    
    c += ""
    
    # a simple while loop
    c.begin_while("a > 0")
    c += "foo();"
    c.end_while()
    
    c += ""
    
    # an empty switch
    c.begin_switch("i")
    c.end_switch()
    
    c += ""
    
    # a compound switch
    c.begin_switch("k")
    c.begin_case(0)
    c += "foo.bar();"
    c.end_case()
    c.begin_case("1")
    c += "bar.foo();"
    c += "bar.foo();"
    c.end_case()
    c.end_switch()

    c += ""
    
    # verify that the code is closed
    c.assert_closed_code()
    
    # print the result
    print c



if __name__ == '__main__':    
    from swiginac import symbol

    x = symbol("x")
    y = symbol("y")
    z = symbol("z")
    pi = swiginac.Pi
    cos = swiginac.cos
    tokens = [ (x, 1), (y, x**2-1), (z, cos(2*pi*x*y)) ]
    print gen_token_declarations(tokens)
    print gen_token_definitions(tokens)
    print gen_const_token_definitions(tokens)
    print gen_token_assignments(tokens)
    print gen_token_additions(tokens)

#    print Switch("i", [(1, "foo();"), (2, "bar();")], "foe();", False)
#    print Switch("i", [(1, "foo();"), (2, "bar();")], None, True)
#    print Switch("i", [(1, "foo();"), (2, "bar();")], "foe();", True)
#    print Switch("i", [(1, "foo();"), (2, "bar();")], None, False)
#
#    print IfElse([("i==0", "foo();"), ("i!=1", "bar();")])
#
#    print Struct("foostruct", ["a", "b", "c"])
#    print Struct("barstruct", [("int", "a"), ("bool", "b"), "c"])


def _test():
    import doctest
    return doctest.testmod()

if __name__ == "__main__":
    _test()
