// Copyright (C) 2002 Ronan Collobert (collober@iro.umontreal.ca)
//                
//
// This file is part of Torch. Release II.
// [The Ultimate Machine Learning Library]
//
// Torch is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
//
// Torch is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Torch; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

#include "StochasticGradient.h"
#include "Trainer.h"

namespace Torch {

StochasticGradient::StochasticGradient()
{
  addROption("end accuracy", &end_accuracy, 0.0001, "end accuracy", true);
  addROption("learning rate", &learning_rate, 0.01, "learning rate", true);
  addROption("learning rate decay", &learning_rate_decay, 0, "learning rate decay", true);
  addIOption("max iter", &max_iter, -1, "maximum number of iterations", true);
}

void StochasticGradient::train(List *measurers)
{
  int iter = 0;
  real err = 0;
  real prev_err = INF;
  real current_learning_rate = learning_rate;
  int n_train = data->n_examples;
  int *shuffle = (int *)xalloc(n_train*sizeof(int));
  
  DataSet **datas;
  Measurer ***mes;
  int *n_mes;
  int n_datas;

  message("StochasticGradient: training");

  List *measurers_ = measurers;
  while(measurers_)
  {
    ((Measurer *)measurers_->ptr)->reset();
    measurers_ = measurers_->next;
  }
  criterion->reset();

  extractMeasurers(measurers, data, &datas, &mes, &n_mes, &n_datas);

  while(1)
  {
    getShuffledIndices(shuffle, n_train);

    machine->iterInitialize();
    criterion->iterInitialize();
    err = 0;
    for(int t = 0; t < n_train; t++)
    {
      data->setExample(shuffle[t]);
      machine->forward(data->inputs);
      criterion->forward(machine->outputs);
      criterion->backward(machine->outputs, NULL);
      machine->backward(data->inputs, criterion->beta);
      
      for(int i = 0; i < n_mes[0]; i++)
        mes[0][i]->measureEx();
      
      List *params = machine->params;
      List *der_params = machine->der_params;
      while(params)
      {
        real *ptr_params = (real *)params->ptr;
        real *ptr_der_params = (real *)der_params->ptr;

        for(int i = 0; i < params->n; i++)
          *ptr_params++ -= current_learning_rate * *ptr_der_params++;

        params = params->next;
        der_params = der_params->next;
      }
      // Note que peut-etre faudrait foutre
      // un "accumul_erreur" dans la classe Criterion
      // des fois que ca soit pas une somme...
      // Mais bon, a priori ca vient d'une integrale,
      // donc me gonflez pas.
      // PREVENIR ICI L'UTILISATEUR DE L'UTILITE
      // DE L'OUTPUT DANS UN CRITERION
      err += ((real *)(criterion->outputs->ptr))[0];
    }

    for(int i = 0; i < n_mes[0]; i++)
      mes[0][i]->measureIter();

    // le data 0 est le train dans tous les cas...
    for(int julie = 1; julie < n_datas; julie++)
    {
      DataSet *dataset = datas[julie];

      for(int t = 0; t < dataset->n_examples; t++)
      {
        dataset->setExample(t);
        machine->forward(dataset->inputs);

        for(int i = 0; i < n_mes[julie]; i++)
          mes[julie][i]->measureEx();
      }

      for(int i = 0; i < n_mes[julie]; i++)
        mes[julie][i]->measureIter();
    }

    current_learning_rate = learning_rate/(1.+((real)(iter))*learning_rate_decay);

    err /= (real)(n_train);
    if(fabs(prev_err - err) < end_accuracy)
    {
      print("\n");
      break;
    }
    prev_err = err;

    print(".");
    iter++;      
    if( (iter >= max_iter) && (max_iter > 0) )
    {
      print("\n");
      warning("StochasticGradient: you have reached the maximum number of iterations");
      break;
    }
  }
  free(shuffle);

  for(int julie = 0; julie < n_datas; julie++)
  {
    for(int i = 0; i < n_mes[julie]; i++)
      mes[julie][i]->measureEnd();
  }

  deleteExtractedMeasurers(datas, mes, n_mes, n_datas);
}

StochasticGradient::~StochasticGradient()
{

}

}

