/*****************************************************************
* Unipro UGENE - Integrated Bioinformatics Suite
* Copyright (C) 2008,2009 Unipro, Russia (http://ugene.unipro.ru)
* All Rights Reserved
* 
*     This source code is distributed under the terms of the
*     GNU General Public License. See the files COPYING and LICENSE
*     for details.
*****************************************************************/

#ifdef SW2_BUILD_WITH_CUDA
#include <cuda_runtime.h>
#endif

#include "SWAlgorithmTask.h"

#include "SmithWatermanAlgorithmATISTREAM.h"
#include "SmithWatermanAlgorithmCUDA.h"
#include "SmithWatermanAlgorithmSSE2.h"
#include "sw_cuda_cpp.h"

#include <core_api/AppContext.h>
#include <core_api/AppSettings.h>
#include <core_api/AppResources.h>
#include <core_api/Log.h>
#include <core_api/Counter.h>
#include <core_api/Timer.h>

#include <core_api/CudaGpuRegistry.h>
#include <core_api/AtiStreamGpuRegistry.h>
#include <util_smith_waterman/SmithWatermanResult.h>

#include <QtCore/QMutexLocker>

using namespace std;

namespace GB2 {

static LogCategory log(ULOG_CAT_SW);

SWAlgorithmTask::SWAlgorithmTask(const SmithWatermanSettings& s,
                                 const QString& taskName, AlgType _algType):Task (taskName, TaskFlag_NoRun), 
                                 sWatermanConfig(s)
{    
    GCOUNTER( cvar, tvar, "SWAlgorithmTask" );
    log.info("RUN constructor SWAlgorithmTask");    

    algType = _algType;
    if (algType == sse2) if(sWatermanConfig.ptrn.length() < 8) algType = classic;    

    //acquiring resourses for GPU computations
    if( cuda == algType ) {
        TaskResourceUsage tru( RESOURCE_CUDA_GPU, 1, true /*prepareStage*/);
        taskResources.append( tru );
    } else if( atistream == algType ) {
        TaskResourceUsage tru( RESOURCE_ATISTREAM_GPU, 1, true /*prepareStage*/);
        taskResources.append( tru );
    }

    int maxScore = calculateMaxScore(s.ptrn, s.pSm);    
    
    minScore = (maxScore * s.percentOfScore) / 100;        
    if ( (maxScore * (int)s.percentOfScore) % 100 != 0) minScore += 1;    

    setupTask(maxScore);

    log.info("FINISH constructor SWAlgorithmTask");
}

SWAlgorithmTask::~SWAlgorithmTask() {
    delete sWatermanConfig.resultListener;
    delete sWatermanConfig.resultCallback;
    // we do not delete resultFilter here, because filters are stored in special registry
}

void SWAlgorithmTask::setupTask(int maxScore) {    

    SequenceWalkerConfig c;
    c.seq = sWatermanConfig.sqnc.constData();
    c.seqSize = sWatermanConfig.sqnc.size();
    c.range = sWatermanConfig.globalRegion;
    c.complTrans = sWatermanConfig.complTT;
    c.aminoTrans = sWatermanConfig.aminoTT;
    c.strandToWalk = sWatermanConfig.strand;
    log.info(QString("Strand: %1 ").arg(c.strandToWalk));
    
    int overlapSize = calculateMatrixLength(sWatermanConfig.sqnc, 
        sWatermanConfig.ptrn, 
        sWatermanConfig.gapModel.scoreGapOpen, 
        sWatermanConfig.gapModel.scoreGapExtd, 
        maxScore, 
        minScore);

    // divide sequence by PARTS_NUMBER parts
    int idealThreadCount = AppContext::getAppSettings()->getAppResourcePool()->getIdealThreadCount();

    int PARTS_NUMBER = 0;     
    if (algType == sse2) {
        PARTS_NUMBER = idealThreadCount * 2.5;
    } else if (algType == classic){
        PARTS_NUMBER = idealThreadCount;
    } else if (algType == cuda) {
        PARTS_NUMBER = 1;
    } else if (algType == atistream) {
        PARTS_NUMBER = 1;
    }    
    if ((PARTS_NUMBER != 1) && (PARTS_NUMBER - 1) * overlapSize < sWatermanConfig.globalRegion.len) {
        c.chunkSize = (c.seqSize + overlapSize * (PARTS_NUMBER - 1)) / PARTS_NUMBER;
        if (c.chunkSize == overlapSize) c.chunkSize++;
        c.overlapSize = overlapSize;            
    }
    else {
        c.overlapSize = 0;    
        c.chunkSize = c.seqSize;
        PARTS_NUMBER = 1;
    }    
    
    log.info(tr("PARTS_NUMBER: %1").arg(PARTS_NUMBER));

    c.lastChunkExtraLen = PARTS_NUMBER - 1;    
    c.nThreads = PARTS_NUMBER;

    t = new SequenceWalkerTask(c, this, tr("Smith Waterman2 SequenceWalker"));    
    addSubTask(t);
}

void SWAlgorithmTask::prepare() {
    if( cuda == algType ) {
        CudaGpuModel *& gpu = gpuModel.cudaGpu;
        gpu = AppContext::getCudaGpuRegistry()->acquireAnyReadyGpu();
        assert( gpu );
#ifdef SW2_BUILD_WITH_CUDA
        quint64 needMemBytes = SmithWatermanAlgorithmCUDA::estimateNeededGpuMemory( 
                sWatermanConfig.pSm, sWatermanConfig.ptrn, sWatermanConfig.sqnc );
        quint64 gpuMemBytes = gpu->getGlobalMemorySizeBytes();
        if( gpuMemBytes < needMemBytes ) {
            stateInfo.setError( tr("Not enough memory on Cuda-enabled device. "
                                "Needed %1 bytes, device has %2 bytes. Device id: %3, device name: %4").
                                arg(QString::number(needMemBytes), QString::number(gpuMemBytes), QString::number(gpu->getId()), QString(gpu->getName()))
                              );
            return;
        } else {
            log.details( tr("Smith-Waterman search allocates ~%1 bytes (%2 Mb) on CUDA device").
                arg(QString::number(needMemBytes), QString::number(needMemBytes/1024/1024+1)) );
        }

        log.info(QString("GPU model: %1").arg(gpuModel.cudaGpu->getId()));

        cudaSetDevice( gpuModel.cudaGpu->getId() );
#else
        assert(false);  
#endif 
    } else if ( atistream == algType ) {
        static GCounter memCounter( "ATI SW2 Mem counter", "Mb", 1024*1024 );

        AtiStreamGpuModel *& gpu = gpuModel.atiGpu;
        gpu = AppContext::getAtiStreamGpuRegistry()->acquireAnyReadyGpu();
        assert( gpu);
#ifdef SW2_BUILD_WITH_ATISTREAM
        quint64 needMemBytes = SmithWatermanAlgorithmATISTREAM::estimateNeededGpuMemory( 
            sWatermanConfig.pSm, sWatermanConfig.ptrn, sWatermanConfig.sqnc );
        quint64 gpuMemBytes = gpu->getGlobalMemorySizeBytes();
        if( gpuMemBytes < needMemBytes ) {
            stateInfo.setError( 
                tr("Not enough memory on ATI device. "
                "Needed %1 bytes, device has %2 bytes. Device id: %3, device name: %4").
                arg(QString::number(needMemBytes), QString::number(gpuMemBytes), QString::number(gpu->getId()), QString(gpu->getName()))
                );
            return;
        } else {
            log.details( tr("Smith-Waterman search allocates ~%1 bytes (%2 Mb) on ATI device").
                         arg(QString::number(needMemBytes), QString::number(needMemBytes/1024/1024+1)) );
            memCounter.totalCount += needMemBytes;
        }
#else
        assert(false);  
#endif 

    }

}

QList<PairAlignSequences> &  SWAlgorithmTask::getResult() {
    
    removeResultFromOverlap(pairAlignSequences);
    SmithWatermanAlgorithm::sortByScore(pairAlignSequences);

    return pairAlignSequences;
}

void SWAlgorithmTask::onRegion(SequenceWalkerSubtask* t, TaskStateInfo& ti) {                
    Q_UNUSED(ti);

    log.info("RUN SWAlgorithmTask::onRegion(SequenceWalkerSubtask* t, TaskStateInfo& ti)");
    

    int regionLen = t->getRegionSequenceLen();
    QByteArray localSeq(t->getRegionSequence(), regionLen);

    SmithWatermanAlgorithm * sw = NULL;

    if (algType == sse2) {
#ifdef SW2_BUILD_WITH_SSE2
        sw = new SmithWatermanAlgorithmSSE2;
#else
        log.error( "SSE2 was not enabled in this build" );
        return;
#endif //SW2_BUILD_WITH_SSE2
    } else if (algType == cuda) {
#ifdef SW2_BUILD_WITH_CUDA
        sw = new SmithWatermanAlgorithmCUDA;
#else
        log.error( "CUDA was not enabled in this build" );
        return;
#endif //SW2_BUILD_WITH_CUDA
    } else if (algType == atistream) {
#ifdef SW2_BUILD_WITH_ATISTREAM
        sw = new SmithWatermanAlgorithmATISTREAM;
#else
        log.error( "ATI Stream was not enabled in this build" );
        return;
#endif //SW2_BUILD_WITH_ATISTREAM
    } else {
        assert(algType == classic);
        sw = new SmithWatermanAlgorithm;
    }
    
    
    quint64 t1 = GTimer::currentTimeMicros();
    sw->launch(sWatermanConfig.pSm, sWatermanConfig.ptrn, localSeq, 
        sWatermanConfig.gapModel.scoreGapOpen + sWatermanConfig.gapModel.scoreGapExtd, 
        sWatermanConfig.gapModel.scoreGapExtd, 
        minScore);            
    log.details("**************");
    QString algName;
    if (algType == cuda) {
        algName = "CUDA";
    } else {
        algName = "Classic";
    }
    QString testName; 
    if (getParentTask() != NULL) {
        testName = getParentTask()->getTaskName();
    } else {
        testName = "SW alg";
    }
    log.details(QString("\n%1 %2 run time is %3\n").arg(testName).arg(algName).arg(GTimer::secsBetween(t1, GTimer::currentTimeMicros())));
    log.details("**************");

    QList<PairAlignSequences> res = sw->getResults();

    for (int i = 0; i < res.size(); i++) {
        res[i].isDNAComplemented = t->isDNAComplemented();
        res[i].isAminoTranslated = t->isAminoTranslated();

        if (t->isAminoTranslated()) {
            res[i].intervalSeq1.startPos *= 3;
            res[i].intervalSeq1.len *= 3;
        }

        
        if (t->isDNAComplemented()) {
            res[i].intervalSeq1.startPos = t->getGlobalConfig().range.endPos() - res[i].intervalSeq1.startPos - res[i].intervalSeq1.len;
        }
        else {
            res[i].intervalSeq1.startPos += 
                (t->getGlobalRegion().startPos - sWatermanConfig.globalRegion.startPos);
        }        
    }
    
    addResult(res);

/////////////////////
    delete sw;
    log.info("FINISH SWAlgorithmTask::onRegion(SequenceWalkerSubtask* t, TaskStateInfo& ti)");
}

void SWAlgorithmTask::removeResultFromOverlap(QList<PairAlignSequences> & res) {     
    log.info("Removing results From Overlap");    

    for (int i = 0; i < res.size() - 1; i++) {
        for (int j = i + 1; j < res.size(); j++) {
            if (res.at(i).intervalSeq1 == res.at(j).intervalSeq1) {
				if (res.at(i).score > res.at(j).score) {
					res.removeAt(j);
					j--;
				} else {
					res.removeAt(i);					
					i--;
					j=res.size();
				}
                
            }
        }
    }
    
}


void SWAlgorithmTask::addResult(QList<PairAlignSequences> & res) {
    QMutexLocker ml(&lock);    
    pairAlignSequences += res;    
    pairAlignSequences += res;
}

int SWAlgorithmTask::calculateMatrixLength(const QByteArray & searchSeq, const QByteArray & patternSeq, int gapOpen, int gapExtension, int maxScore, int minScore) {

    int matrixLength = 0;

    int gap = gapOpen;
    if (gapOpen < gapExtension) gap = gapExtension;

    matrixLength = patternSeq.length() + (maxScore - minScore)/gap * (-1) + 1;    

    if (searchSeq.length() + 1 < matrixLength) matrixLength = searchSeq.length() + 1;

    matrixLength += 1;

    return matrixLength;    
}

int SWAlgorithmTask::calculateMaxScore(const QByteArray & seq, const SMatrix& substitutionMatrix) {
    int maxScore = 0;
    int max;    
    int substValue = 0;    

    QByteArray alphaChars = substitutionMatrix.getAlphabet()->getAlphabetChars();
    for (int i = 0; i < seq.length(); i++) {        
        max = 0;        
        for (int j = 0; j < alphaChars.size(); j++) {            
            //TODO: use raw pointers!
            char c1 = seq.at(i);
            char c2 = alphaChars.at(j);
            substValue = substitutionMatrix.getScore(c1, c2);
            if (max < substValue) max = substValue;                                
        }
        maxScore += max;
    }    
    return maxScore;
}

Task::ReportResult SWAlgorithmTask::report() {

    log.info("RUN SWAlgorithmTask::report()");
    if( cuda == algType ) {
        gpuModel.cudaGpu->setAcquired(false);
    } else if ( atistream == algType ) {
        gpuModel.atiGpu->setAcquired(false);
    }
     SmithWatermanResultListener* rl = sWatermanConfig.resultListener;
     QList<SmithWatermanResult> resultList = rl->getResults();
     
     int resultsNum = resultList.size();
     log.details(tr("%1 results found").arg(resultsNum));

    if (0 != sWatermanConfig.resultCallback) {
        SmithWatermanReportCallback* rcb = sWatermanConfig.resultCallback;
        QString res = rcb->report(resultList);
        if (!res.isEmpty()) {
            stateInfo.setError(res);
        }
    }        
    
    log.info("FINISH SWAlgorithmTask::report()");
    return ReportResult_Finished;
}

QList<Task*> SWAlgorithmTask::onSubTaskFinished( Task* subTask ){
	QList<Task*> res;
	if (hasErrors() || isCanceled()) {
		return res;
	}

	if (subTask == t){
		res.append(new SWResultsPostprocessingTask(sWatermanConfig, resultList, getResult()));
	}
	return res;
}


SWResultsPostprocessingTask::SWResultsPostprocessingTask( SmithWatermanSettings &_sWatermanConfig, 
														 QList<SmithWatermanResult> &_resultList, 
														 QList<PairAlignSequences> &_resPAS )
:Task ("SWResultsPostprocessing", TaskFlag_None), sWatermanConfig(_sWatermanConfig), resultList(_resultList), resPAS(_resPAS){
}

void SWResultsPostprocessingTask::prepare(){
	
}

void SWResultsPostprocessingTask::run(){
	SmithWatermanResult r;
	for (int i = 0; i < resPAS.size(); i++) {

		r.complement = resPAS.at(i).isDNAComplemented;
		r.trans = resPAS.at(i).isAminoTranslated;
		r.region = resPAS.at(i).intervalSeq1;
		r.region.startPos += sWatermanConfig.globalRegion.startPos;
		r.score = resPAS.at(i).score;

		resultList.append(r);
	}

	if (0 != sWatermanConfig.resultFilter) {
		SmithWatermanResultFilter* rf = sWatermanConfig.resultFilter;
		rf->applyFilter(&resultList);
	}    
    foreach( const SmithWatermanResult & r, resultList ) { /* push results after filters */
		sWatermanConfig.resultListener->pushResult( r );
	}
}

} //namespace
