/*****************************************************************
* Unipro UGENE - Integrated Bioinformatics Suite
* Copyright (C) 2008 Unipro, Russia (http://ugene.unipro.ru)
* All Rights Reserved
* 
*     This source code is distributed under the terms of the
*     GNU General Public License. See the files COPYING and LICENSE
*     for details.
*****************************************************************/

#include "HMMSearchTask.h"
#include "TaskLocalStorage.h"
#include "HMMIO.h"

#include <core_api/AppContext.h>
#include <core_api/DNAAlphabet.h>
#include <core_api/DNATranslation.h>

#include <hmmer2/funcs.h>
#include <util_tasks/SequenceWalkerTask.h>

namespace GB2 {

HMMSearchTask::HMMSearchTask(plan7_s* _hmm, const DNASequence& _seq, const UHMMSearchSettings& s)
: Task("", TaskFlags_NR_DWF | TaskFlags_FAIL_OSCOE), 
  hmm(_hmm), seq(_seq), settings(s), complTrans(NULL), aminoTrans(NULL) 
{
    setTaskName(tr("HMM search with '%1'").arg(hmm->name));
}

void HMMSearchTask::prepare() {
    if (!checkAlphabets(hmm->atype, seq.alphabet, complTrans, aminoTrans)) {
        return;
    }

    SequenceWalkerConfig config;
    config.seq = seq.seq.data();
    config.seqSize = seq.seq.size();
    config.complTrans = complTrans;
    config.aminoTrans = aminoTrans;
    config.chunkSize = qMin(settings.searchChunkSize, seq.seq.size());
    config.overlapSize = 2 * hmm->M;
    settings.overlapSize = 2 * hmm->M;
    config.exOverlapSize = (settings.searchChunkSize + (int)((config.overlapSize/2.0) + 0.5))/2;
    config.parallel = settings.nThreads > 1;
    config.maxThreads = settings.nThreads;

    addSubTask(new SequenceWalkerTask(config, this, tr("parallel_hmm_search_task")));
}


void HMMSearchTask::onRegion(SequenceWalkerSubtask* t, TaskStateInfo& si) 
{
    const char* localSeq = t->getRegionSequence();
    int localSeqSize = t->getRegionSequenceLen();
    bool wasCompl = t->isDNAComplemented();
    bool wasAmino = t->isAminoTranslated();
    LRegion globalReg = t->getGlobalRegion();

    //set TLS data
    TaskLocalData::initializeHMMContext(t->getTaskId());

    QList<UHMMSearchResult> sresults;
    try {
        sresults = UHMMSearch::search(hmm, localSeq, localSeqSize, settings, si);
    } catch (HMMException e) {
        stateInfo.error = e.error;
    }
    if (si.hasErrors()) {
        stateInfo.error = si.error;
    }
    if (sresults.isEmpty()  || stateInfo.cancelFlag || stateInfo.hasErrors()) {
        TaskLocalData::freeHMMContext();
        return;
    }
    //convert all UHMMSearchResults into HMMSearchTaskResult
    foreach(const UHMMSearchResult& sr, sresults) {
        HMMSearchTaskCachedResult r;
        r.evalue = sr.evalue;
        r.score = sr.score;
        r.onCompl = wasCompl;
        r.onAmino = wasAmino;
		int resStart = wasAmino ? sr.r.startPos * 3 : sr.r.startPos;
        int resLen   = wasAmino ? sr.r.len * 3 : sr.r.len;
        r.r.startPos = globalReg.startPos + resStart;
        r.r.len = resLen;
		int overlap = settings.overlapSize;
		if(resStart < overlap || resStart + resLen > globalReg.len - overlap) {
			r.globalReg = globalReg;
            if (wasCompl)  {
				cacheCompl.append(r);
            } else  {
				cacheDirect.append(r);
            }
		}
		else { 
            if(wasCompl) {
				resultsCompl.append(r);    
            }
            else {
                results.append(r);
            }
		}
    }
    TaskLocalData::freeHMMContext();
}

bool HMMSearchResultEValLessThan(const HMMSearchTaskResult& r1, const HMMSearchTaskResult& r2)
{
	return r1.evalue < r2.evalue;
}

Task::ReportResult HMMSearchTask::report() {
    if (hasErrors()) {
        return ReportResult_Finished;
    }

    //postprocess cached results

	//direct task cache
	if(!cacheDirect.isEmpty()) {
		for(int i=0; i < cacheDirect.count(); i++){
			const HMMSearchTaskCachedResult& r1 = cacheDirect[i];
			HMMSearchTaskResult res = r1;
			for(int j=i+1; j < cacheDirect.count(); j++){
				const HMMSearchTaskCachedResult& r2 = cacheDirect[j];
				if(r1.r.intersects(r2.r)){
                    if((r1.r == r2.r && r1.evalue==r2.evalue && r1.score==r2.score && r1.onAmino == r2.onAmino)
                        ||  r2.r.contains(r1.r)
                        || 	r1.globalReg.startPos == r1.r.startPos)
                    {
							res = r2;
                    }

					cacheDirect.removeAt(j);
					if(j<=i) i--;
					break;
				}
			}
			results.append(res);
		}
	}

	//complement task cache
	if(!cacheCompl.isEmpty()) {
		for(int i=0; i < cacheCompl.count(); i++){
			const HMMSearchTaskCachedResult& r1 = cacheCompl[i];
			HMMSearchTaskResult res = r1;
			for(int j=i+1; j < cacheCompl.count(); j++){
				const HMMSearchTaskCachedResult& r2 = cacheCompl[j];
				if(r1.r.intersects(r2.r)){
                    if ((r1.r == r2.r && r1.evalue==r2.evalue && r1.score==r2.score && r1.onAmino == r2.onAmino)
                        || r2.r.contains(r1.r)
                        || r1.globalReg.startPos == r1.r.startPos)
                    {
                        res = r2;
                    }
					cacheCompl.removeAt(j);
					if(j<=i) i--;
					break;
				}
			}
			resultsCompl.append(res);
		}
	}
	// sort results by E-value
    if(results.count() > 1) {
		qSort(results.begin(), results.end(), HMMSearchResultEValLessThan);
    }
    if(resultsCompl.count() > 1) {
		qSort(resultsCompl.begin(), resultsCompl.end(), HMMSearchResultEValLessThan);
    }
	results << resultsCompl;
	cacheCompl.clear();
	cacheDirect.clear();
	resultsCompl.clear();

	return ReportResult_Finished;
}

QList<SharedAnnotationData> HMMSearchTask::getResultsAsAnnotations(const QString& name) const {
    QList<SharedAnnotationData>  annotations;
    foreach(const HMMSearchTaskResult& hmmRes, results) {
        AnnotationData* a = new AnnotationData();
        a->name = name;
        a->complement = hmmRes.onCompl;
        a->aminoStrand = hmmRes.onAmino ? TriState_Yes :TriState_No;
        a->location.append(hmmRes.r);

        QString str; /*add zeros at begin of evalue exponent part, so exponent part must contains 3 numbers*/
        str.sprintf("%.2g", ((double) hmmRes.evalue));
        QRegExp rx("\\+|\\-.+");
        int pos = rx.indexIn(str,0);
        if(pos!=-1){
            str.insert(pos+1,"0");
        }
        QString info = hmm->name;
        if (hmm->flags & PLAN7_ACC) {
            info += QString().sprintf("\nAccession number in PFAM : %s", hmm->acc);
        }
        if (hmm->flags & PLAN7_DESC) {
            info += QString().sprintf("\n%s", hmm->desc);
        }
        if (!info.isEmpty()) {
            a->qualifiers.append(Qualifier("HMM-model", info));
        }
        //a->qualifiers.append(Qualifier("E-value", QString().sprintf("%.2lg", ((double) hmmRes.evalue))));
        a->qualifiers.append(Qualifier("E-value", str));
        a->qualifiers.append(Qualifier("Score", QString().sprintf("%.1f", hmmRes.score)));
        annotations.append(SharedAnnotationData(a));
    }
    return annotations;
}

bool HMMSearchTask::checkAlphabets(int hmmAlType, DNAAlphabet* seqAl, DNATranslation*& complTrans, DNATranslation*& aminoTrans) 
{
    assert(stateInfo.error.isEmpty());
    DNAAlphabetType hmmAl = HMMIO::convertHMMAlphabet(hmmAlType);
    if (hmmAl == DNAAlphabet_RAW) {
        stateInfo.error = tr("invalid_hmm_alphabet_type");
        return false;
    }
    if (seqAl->isRaw()) {
        stateInfo.error = tr("invalid_sequence_alphabet_type");
        return false;
    }

    complTrans = NULL;
    aminoTrans = NULL;
    if (seqAl->isNucleic()) {
        DNATranslationRegistry* tr = AppContext::getDNATranslationRegistry();
        QList<DNATranslation*> complTs = tr->lookupTranslation(seqAl, DNATranslationType_NUCL_2_COMPLNUCL);
        if (!complTs.empty()) {
            complTrans = complTs.first();
        }
        if (hmmAl == DNAAlphabet_AMINO) {
            QList<DNATranslation*> aminoTs = tr->lookupTranslation(seqAl, DNATranslationType_NUCL_2_AMINO);
            if (!aminoTs.empty()) {
                aminoTrans = aminoTs.first();
            }
        }
    } else {
        assert(seqAl->isAmino());
    }

    // check the result;
    if (hmmAl == DNAAlphabet_AMINO) {
        if (seqAl->isAmino()) {
            assert(complTrans == NULL && aminoTrans == NULL);
        } else {
            if (aminoTrans == NULL) {
                stateInfo.error = tr("can_t_find_amino");
                return false;
            }
        }
    }

    return true;
}

}//endif
