/*
 *  Copyright (c) 2008 Cyrille Berger <cberger@cberger.net>
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation;
 * version 2 of the License.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this library; see the file COPYING.  If not, write to
 * the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
 * Boston, MA 02110-1301, USA.
 */

#include "Program.h"

// GTLCore
#include "GTLCore/Buffer.h"
#include "GTLCore/CodeGenerator_p.h"
#include "GTLCore/GenerationContext_p.h"
#include "GTLCore/PixelDescription.h"
#include "GTLCore/ExpressionResult_p.h"
#include "GTLCore/Function.h"
#include "GTLCore/Function_p.h"
#include "GTLCore/Parameter.h"
#include "GTLCore/String.h"
#include "GTLCore/Type.h"
#include "GTLCore/Type_p.h"
#include "GTLCore/Value.h"
#include "GTLCore/VariableNG_p.h"

// OpenCTL
#include "Module.h"
#include "GTLCore/ModuleData_p.h"
#include "Debug.h"
#include "GTLCore/VirtualMachine_p.h"

// LLVM
#include <llvm/DerivedTypes.h>
#include <llvm/Instructions.h>
#include <llvm/Module.h>
#include <llvm/ModuleProvider.h>
#include <llvm/Transforms/Utils/Cloning.h>

// Passes
#include <llvm/PassManager.h>
#include <llvm/Analysis/LoopPass.h>
#include <llvm/Analysis/LoadValueNumbering.h>
#include <llvm/Analysis/Verifier.h>
#include <llvm/Target/TargetData.h>
#include <llvm/ExecutionEngine/ExecutionEngine.h>
#include <llvm/GlobalVariable.h>

using namespace OpenCTL;

struct Program::Private {
  Private(const GTLCore::PixelDescription& _srcPixelDescription, const GTLCore::PixelDescription& _dstPixelDescription) : srcPixelDescription(_srcPixelDescription), dstPixelDescription(_dstPixelDescription)
  {
  }
  llvm::Module* module;
  GTLCore::ModuleData* moduleData;
  void (*func)( const char*, char*, int);
  llvm::ModuleProvider* moduleProvider;
  GTLCore::PixelDescription srcPixelDescription;
  GTLCore::PixelDescription dstPixelDescription;
  static int s_id;
  std::map< GTLCore::String, llvm::GlobalVariable*> varyings;
  std::map< GTLCore::String, void*> varyingsPtr;
  std::list<GTLCore::String> varyingsName;
};

int Program::Private::s_id = 0;

void configureBuffer(GTLCore::CodeGenerator& cg, const std::vector< const GTLCore::Type*>& channelsTypes, llvm::Value**& indexes, bool*& needBuffer, int countChannels)
{
  int currentPos = 0;
  for( int i = 0; i < countChannels; ++i)
  {
    indexes[i] = cg.integerToConstant( currentPos );
    currentPos += channelsTypes[i]->bitsSize() / 8;
    OCTL_ASSERT( channelsTypes[i]->bitsSize() % 8 == 0);
    if( channelsTypes[i] == GTLCore::Type::Integer8
        or channelsTypes[i] == GTLCore::Type::UnsignedInteger8
        or channelsTypes[i] == GTLCore::Type::Integer16
        or channelsTypes[i] == GTLCore::Type::UnsignedInteger16
        or channelsTypes[i] == GTLCore::Type::Half )
    {
      needBuffer[i] = true;
    } else {
      needBuffer[i] = false;
    }
  }
}

Program::Program(const GTLCore::String& functionName, const Module* module, const GTLCore::PixelDescription& pixelDescription) : d(new Private(pixelDescription, pixelDescription))
{
  init(functionName, module, pixelDescription, pixelDescription);
}


Program::Program(const GTLCore::String& functionName, const Module* module, const GTLCore::PixelDescription& srcPixelDescription, const GTLCore::PixelDescription& dstPixelDescription) : d(new Private(srcPixelDescription, dstPixelDescription))
{
  init(functionName, module, srcPixelDescription, dstPixelDescription);
}

void Program::init(const GTLCore::String& functionName, const Module* module, const GTLCore::PixelDescription& srcPixelDescription, const GTLCore::PixelDescription& dstPixelDescription)
{
  d->func = 0;
  d->moduleProvider = 0;
  d->moduleData = 0;
  // Clone the module
  if( module->data())
  {
    d->module = llvm::CloneModule(module->data()->llvmModule());
    d->moduleData = new GTLCore::ModuleData( d->module );
    const GTLCore::Function* functionDef = module->function( functionName );
    if( not functionDef ) return;
    llvm::Function* function = d->module->getFunction( GTLCore::Function::Data::symbolName( GTLCore::ScopedName( "", functionName), functionDef->parameters() ) );
    if(function)
    {
      GTLCore::CodeGenerator cg( d->moduleData );
      // Initialise pixel information 
      int srcPixelSize = srcPixelDescription.bitsSize() / 8;
      OCTL_ASSERT(srcPixelDescription.bitsSize() % 8== 0);
      int dstPixelSize = dstPixelDescription.bitsSize() / 8;
      OCTL_ASSERT(dstPixelDescription.bitsSize() % 8== 0);
      
      // Configure input buffer
      OCTL_DEBUG("Configure input buffer");
      const std::vector< const GTLCore::Type*>& srcChannelsTypes = srcPixelDescription.channelTypes();
      llvm::Value** srcIndexes = new llvm::Value*[ srcChannelsTypes.size() ];
      bool* srcNeedBuffer = new bool[ srcChannelsTypes.size() ]; // CTL doesn't support Int8 or Int16 (and OpenCTL's VM doesn't know about half), so when some pixel channel is in Int8 or Int16 (or half) it first need to be converted to Int32 (or float)
      int srcCountChannels = srcChannelsTypes.size();
      configureBuffer( cg, srcChannelsTypes, srcIndexes, srcNeedBuffer, srcCountChannels);
      // Configure output buffer
      const std::vector< const GTLCore::Type*>& dstChannelsTypes = dstPixelDescription.channelTypes();
      llvm::Value** dstIndexes = new llvm::Value*[ dstChannelsTypes.size() ];
      bool* dstNeedBuffer = new bool[ dstChannelsTypes.size() ]; // CTL doesn't support Int8 or Int16 (and OpenCTL's VM doesn't know about half), so when some pixel channel is in Int8 or Int16 (or half) it first need to be converted to Int32 (or float)
      int dstCountChannels = dstChannelsTypes.size();
      configureBuffer( cg, dstChannelsTypes, dstIndexes, dstNeedBuffer, dstCountChannels);
      
      //---------- Program-Pseudo-code ----------//
      // This just a rough explanation of the function that get generated here.
      //
      // template<int pixelSize, typename _TYPE_>
      // void program(const char* in8, char* out8, int size)
      // {
      //   _BUFFER_TYPE_ buffer[pixelSize]; // buffer for Int8, Int16 and half channels
      //   for(int i = 0; i < size; i += pixelSize)
      //   {
      //     const _TYPE_* in = (const _TYPE_*)(in8 + i);
      //     _TYPE_* out = (_TYPE_*)(out8 + 1);
      //     if(needBuffer)
      //     {
      //       
      //     }
      //     function(in[i], in[i+indexes[1]] ..., out + i, out + i +indexes[1], ...);
      //   }
      //   return;
      // }
      OCTL_DEBUG("Initialize the function type");
      llvm::Value* srcPixelSizeValue = cg.integerToConstant( srcPixelSize );
      llvm::Value* dstPixelSizeValue = cg.integerToConstant( dstPixelSize );
      std::vector<const llvm::Type*> params;
      params.push_back( llvm::PointerType::get( llvm::Type::Int8Ty, 0 ));
      params.push_back( llvm::PointerType::get( llvm::Type::Int8Ty, 0 ));
      params.push_back( llvm::Type::Int32Ty );
      // void program(const char* in8, char* out8, int size)
      llvm::FunctionType* definitionType = llvm::FunctionType::get( llvm::Type::VoidTy, params, false );
      int programNb = ++Program::Private::s_id;
      llvm::Function* func = cg.createFunction( definitionType, "CTLProgram" + GTLCore::String::number(programNb));
      // Initialise a generation context
      GTLCore::GenerationContext gc( &cg, func, 0, d->moduleData );
      // {
      OCTL_DEBUG("Initial block");
      llvm::BasicBlock* initialBlock = llvm::BasicBlock::Create();
      func->getBasicBlockList().push_back( initialBlock );
      // Initialise the buffer, as needed
      OCTL_DEBUG("Initialise buffer");
      GTLCore::VariableNG** buffer = new GTLCore::VariableNG*[dstCountChannels];
      for(int i = 0; i < dstCountChannels; ++i)
      {
        if(dstNeedBuffer[i])
        {
          GTLCore::VariableNG* vng = 0;
          if(dstChannelsTypes[i] == GTLCore::Type::Half)
          {
            vng = new GTLCore::VariableNG(GTLCore::Type::Float, false );
          } else {
            vng = new GTLCore::VariableNG(GTLCore::Type::Integer32, false );
          }
          vng->initialise( gc, initialBlock, GTLCore::ExpressionResult(), std::list<llvm::Value*>());
          buffer[i] = vng;
        } else {
          buffer[i] = 0;
        }
      }
      OCTL_DEBUG("Get the arguments");
      // Get the args.
      llvm::Function::arg_iterator arg_it = func->arg_begin();
      //   const char* in8 = first arg;
      llvm::Value* in = arg_it;
      //   char* out8 = second arg;
      ++arg_it;
      llvm::Value* out = arg_it;
      // size
      ++arg_it;
      llvm::Value* size = arg_it;
    // Construct the "conditions" of the loop
      //   int i = 0;
      OCTL_DEBUG("int i = 0");
      GTLCore::VariableNG* posSrc = new GTLCore::VariableNG( GTLCore::Type::Integer32, false);
      posSrc->initialise( gc, initialBlock, GTLCore::ExpressionResult(cg.integerToConstant(0), GTLCore::Type::Integer32), std::list<llvm::Value*>());
      GTLCore::VariableNG* posDst = new GTLCore::VariableNG( GTLCore::Type::Integer32, false);
      posDst->initialise( gc, initialBlock, GTLCore::ExpressionResult(cg.integerToConstant(0), GTLCore::Type::Integer32), std::list<llvm::Value*>());
      // i < size
      OCTL_DEBUG("i < size");
      llvm::BasicBlock* forTestBlock = llvm::BasicBlock::Create("forTestBlock");
      func->getBasicBlockList().push_back( forTestBlock);
      llvm::Value* forTest = cg.createStrictInferiorExpression(forTestBlock, posSrc->get( gc, forTestBlock ), posSrc->type(), size, GTLCore::Type::Integer32 );
      // i += pixelSize
      OCTL_DEBUG("i += pixelSize");
      llvm::BasicBlock* updateBlock = llvm::BasicBlock::Create("updateBlock");
      func->getBasicBlockList().push_back( updateBlock);
      posSrc->set( gc, updateBlock, cg.createAdditionExpression( updateBlock, posSrc->get( gc, updateBlock), posSrc->type(), srcPixelSizeValue, GTLCore::Type::Integer32 ), GTLCore::Type::Integer32 );
      posDst->set( gc, updateBlock, cg.createAdditionExpression( updateBlock, posDst->get( gc, updateBlock), posSrc->type(), dstPixelSizeValue, GTLCore::Type::Integer32 ), GTLCore::Type::Integer32 );
      // Construct the body of the for loop
      OCTL_DEBUG("bodyBlock");
      llvm::BasicBlock* bodyBlock = llvm::BasicBlock::Create("bodyBlock");
      func->getBasicBlockList().push_back( bodyBlock);
      
      // function(in[i], in[i+indexes[1]] ..., out + i, out + i +indexes[1], ...);
      std::vector<llvm::Value*> arguments;
      
      // Generate in[i], in[i+indexes[1]] ...
      OCTL_DEBUG("Generate in[i], in[i+indexes[1]] ...");
      llvm::Value* srcIndex = posSrc->get( gc, bodyBlock);
      for(int i = 0; i < srcCountChannels; ++i)
      {
        // Load the value from the input buffer
        OCTL_DEBUG("Load the value from the input buffer");
        llvm::Value* convertedIn = new llvm::LoadInst(
                        cg.convertPointerTo( bodyBlock,
                                           llvm::GetElementPtrInst::Create( in, cg.createAdditionExpression( bodyBlock, srcIndex, posSrc->type(), srcIndexes[i], GTLCore::Type::Integer32), "", bodyBlock),
                                           srcChannelsTypes[i]->d->type()),
                        "", bodyBlock);
        if( srcNeedBuffer[i])
        { // if a buffer is needed that means that the value must be converted
          OCTL_DEBUG("if a buffer is needed that means that the value must be converted");
          if(srcChannelsTypes[i] == GTLCore::Type::Half)
          {
            convertedIn = GTLCore::CodeGenerator::convertFromHalf( gc, bodyBlock, convertedIn);
          } else {
            convertedIn = GTLCore::CodeGenerator::convertValueTo( bodyBlock, convertedIn, srcChannelsTypes[i], GTLCore::Type::Integer32);
          }
        }
        arguments.push_back( convertedIn );
      }
      
      // Generate ut + i, out + i +indexes[1], ...
      OCTL_DEBUG("Generate ut + i, out + i +indexes[1], ...");
      llvm::Value* dstIndex = posDst->get( gc, bodyBlock);
      for(int i = 0; i < dstCountChannels; ++i)
      {
        llvm::Value* pointer = cg.convertPointerTo( bodyBlock, llvm::GetElementPtrInst::Create( out, cg.createAdditionExpression( bodyBlock, dstIndex, GTLCore::Type::Integer32, dstIndexes[i], GTLCore::Type::Integer32), "", bodyBlock), dstPixelDescription.channelTypes()[i]->d->type());
        if( dstNeedBuffer[i])
        {
          OCTL_ASSERT(buffer[i]);
          buffer[i]->set( gc, bodyBlock, new llvm::LoadInst(pointer,"", bodyBlock), dstPixelDescription.channelTypes()[i]);
          arguments.push_back( buffer[i]->pointer());
        } else {
          arguments.push_back(pointer);
        }
      }
      // Check if there are more parameters to call
      OCTL_DEBUG("Check if there are more parameters to call");
      const std::vector< GTLCore::Parameter >& parameters = functionDef->parameters();
      if( arguments.size() < parameters.size() )
      {
        OCTL_DEBUG("Filling with constant parameters");
        for( unsigned int i = arguments.size(); i < parameters.size(); ++i)
        {
          llvm::GlobalVariable* globalVar =
              new llvm::GlobalVariable(
                    parameters[i].type()->d->type(),
                    false, llvm::GlobalValue::ExternalLinkage,
                    cg.valueToConstant( parameters[i].defaultValue() ),
                    "", d->module );
          arguments.push_back( new llvm::LoadInst( globalVar, "", bodyBlock ) );
          d->varyings[ parameters[i].name() ] = globalVar;
          d->varyingsName.push_back( parameters[i].name() );
        }
      }
      // Some debug
#ifndef NDEBUG
      for( unsigned int i = 0; i < arguments.size(); i++)
      {
        OCTL_DEBUG("arguments[" << i << "] == " << *arguments[i]);
        OCTL_DEBUG("type[" << i << "] == " << function->getFunctionType()->getParamType(i) << " " << arguments[i]->getType());
      }
      OCTL_DEBUG( *function->getFunctionType());
#endif
      llvm::CallInst *CallFunc = llvm::CallInst::Create(function, arguments.begin(), arguments.end(), "", bodyBlock);
      CallFunc->setTailCall();
      // If there was buffering, save to output
      OCTL_DEBUG("If there was buffering, save to output");
      for(int i = 0; i < dstCountChannels; ++i)
      {
        if( dstNeedBuffer[i])
        {
          const llvm::Type* channelType = dstChannelsTypes[i]->d->type();
          llvm::Value* pointer = cg.convertPointerTo( bodyBlock, llvm::GetElementPtrInst::Create( out, cg.createAdditionExpression( bodyBlock, dstIndex, GTLCore::Type::Integer32, dstIndexes[i], GTLCore::Type::Integer32), "", bodyBlock), channelType);
          llvm::Value* result = 0;
          if( dstChannelsTypes[i] == GTLCore::Type::Half )
          {
            result = GTLCore::CodeGenerator::convertToHalf( gc, bodyBlock, buffer[i]->get( gc, bodyBlock ), GTLCore::Type::Float );
          } else {
            result = GTLCore::CodeGenerator::convertValueTo( bodyBlock, buffer[i]->get( gc, bodyBlock ), GTLCore::Type::Integer32, dstChannelsTypes[i] );
          }
          OCTL_DEBUG( *result->getType() << *pointer->getType() );
          new llvm::StoreInst( result, pointer, bodyBlock);
        }
      }
    // Put the for loop together
      OCTL_DEBUG("Put the for loop together");
      // for(int i = 0; i < size; ++i)
      OCTL_DEBUG("for(int i = 0; i < size; ++i)");
      llvm::BasicBlock* finBlock = llvm::BasicBlock::Create("finBlock");
      func->getBasicBlockList().push_back( finBlock);
      cg.createForStatement(initialBlock, forTestBlock, forTest, GTLCore::Type::Boolean, updateBlock, bodyBlock, bodyBlock, finBlock);
      // return;
      OCTL_DEBUG("return;");
      llvm::ReturnInst::Create(finBlock);
      OCTL_DEBUG(*d->module);
      // Optimize, FIXME: use GTLCore's optimizer
      OCTL_DEBUG("Optimize");
      llvm::PassManager Passes;
      // Add in the passes we want to execute
      Passes.add(new llvm::TargetData(d->module));
      // Verify we start with valid
      Passes.add(llvm::createVerifierPass());
      // Run
      Passes.run(*d->module);
      // Register module in the VM
      d->moduleProvider = new llvm::ExistingModuleProvider( d->module );
      GTLCore::VirtualMachine::instance()->registerModule( d->moduleProvider );
      GTLCore::VirtualMachine::instance()->executionEngine()->clearAllGlobalMappings();
      d->func = ( void(*)(const char*, char*,int)) GTLCore::VirtualMachine::instance()->getPointerToFunction( func );
      for( std::map< GTLCore::String, llvm::GlobalVariable*>::iterator it = d->varyings.begin(); it != d->varyings.end(); ++it)
      {
          d->varyingsPtr[ it->first ] = GTLCore::VirtualMachine::instance()->getGlobalVariablePointer( it->second );
      }
//       OCTL_DEBUG( ((void*)d->func) << " ==== " << func << "  " << func->isDeclaration() << *func);
      // Cleanup
      for(int i = 0; i < dstCountChannels; ++i)
      {
          delete buffer[i];
      }
      delete[] srcIndexes;
      delete[] srcNeedBuffer;
      delete[] dstIndexes;
      delete[] dstNeedBuffer;
      delete[] buffer;
      delete posSrc;
      delete posDst;
    } else {
      OCTL_DEBUG("Function: " << functionName << " not found in module");
      delete d->module;
      d->module = 0;
    }
  } else {
    OCTL_DEBUG("No module was supplied");
    d->module = 0;
  }
}

Program::~Program()
{
  if(d->moduleProvider)
  {
    GTLCore::VirtualMachine::instance()->unregisterModule( d->moduleProvider);
    delete d->moduleProvider;
  }
  delete d->moduleData;
  delete d;
}

bool Program::initialised() const
{
  return d->func;
}

void Program::apply(const GTLCore::Buffer& input, GTLCore::Buffer& output) const
{
  OCTL_ASSERT( (input.size() / (d->srcPixelDescription.bitsSize() / 8) ) == (output.size() / (d->dstPixelDescription.bitsSize() / 8) ) );
  OCTL_ASSERT( d->func );
  OCTL_ASSERT( (input.size() % (d->srcPixelDescription.bitsSize() / 8)) == 0 );
  OCTL_ASSERT( (output.size() % (d->dstPixelDescription.bitsSize() / 8)) == 0 );
  OCTL_DEBUG("Apply program on a buffer of size " << input.size() << " for a source pixel of size " << (d->srcPixelDescription.bitsSize() / 8) << " for a destination pixel of size " << (d->dstPixelDescription.bitsSize() / 8));
  d->func( input.rawData(), output.rawData(), input.size() );
  OCTL_DEBUG("Applied");
}

void Program::setVarying( const GTLCore::String& _name, const GTLCore::Value& _value )
{
  std::map< GTLCore::String, void*>::iterator it = d->varyingsPtr.find( _name );
  std::map< GTLCore::String, llvm::GlobalVariable*>::iterator it2 = d->varyings.find( _name );
  if( it != d->varyingsPtr.end() )
  {
    void* ptr = it->second;
    if( it2->second->getType()->getElementType() == llvm::Type::Int32Ty )
    {
        *(int*)ptr =_value.asInt32();
    } else if( it2->second->getType()->getElementType() == llvm::Type::Int1Ty )
    {
        *(bool*)ptr = _value.asBoolean();
    } else if( it2->second->getType()->getElementType() == llvm::Type::FloatTy )
    {
        GTL_DEBUG("Set " << _value.asFloat() << " on ptr " << ptr << " from value = " << *(float*)ptr);
        *(float*)ptr = _value.asFloat();
        GTL_DEBUG( *(float*)ptr );
    }
  } else {
    OCTL_DEBUG(" No varying named: " << _name);
  }
}

GTLCore::Value Program::varying( const GTLCore::String& _name ) const
{
  std::map< GTLCore::String, void*>::iterator it = d->varyingsPtr.find( _name );
  std::map< GTLCore::String, llvm::GlobalVariable*>::iterator it2 = d->varyings.find( _name );
  if( it != d->varyingsPtr.end() )
  {
    void* ptr = it->second;
    if( it2->second->getType()->getElementType() == llvm::Type::Int32Ty )
    {
        return GTLCore::Value( *(int*)ptr);
    }
    if( it2->second->getType()->getElementType() == llvm::Type::Int1Ty )
    {
        return GTLCore::Value( *(bool*)ptr);
    }
    if( it2->second->getType()->getElementType() == llvm::Type::FloatTy )
    {
        return GTLCore::Value( *(float*)ptr);
    }
    OCTL_DEBUG("Invalid type");
    return GTLCore::Value();
  }
  OCTL_DEBUG(" No varying named: '" << _name << "'");
  return GTLCore::Value();
}

#if 0

void Program::setVarying( const GTLCore::String& _name, const GTLCore::Value& _value )
{
  std::map< GTLCore::String, llvm::GlobalVariable*>::iterator it = d->varyings.find( _name );
  if( it != d->varyings.end() )
  {
    GTLCore::VirtualMachine::instance()->setGlobalVariable( it->second, _value);
  } else {
    OCTL_DEBUG(" No varying named: " << _name);
  }
}

GTLCore::Value Program::varying( const GTLCore::String& _name ) const
{
  std::map< GTLCore::String, llvm::GlobalVariable*>::iterator it = d->varyings.find( _name );
  if( it != d->varyings.end() )
  {
    return GTLCore::VirtualMachine::instance()->getGlobalVariable( it->second);
  }
  OCTL_DEBUG(" No varying named: '" << _name << "'");
  return GTLCore::Value();
}

#endif

const std::list<GTLCore::String>& Program::varyings() const
{
    return d->varyingsName;
}
