from __future__ import with_statement
from __future__ import print_function

import sys
import os
import types

try:
    import numpy as np
    import reflex
    import pipeline_display
    import pipeline_product
    import reflex_plot_widgets as reflex_widgets
    import_success = True

except ImportError:
    import_success = False
    print("Error importing reflex modules, numpy")


def formatCoordinates(data, x, y):
    formattedString = ""
    (imgData, wcs) = data

    xPix = int(x)
    yPix = int(y)    
    if (x > 0 and x <= imgData.image.shape[1]) and (y > 0 and y <= imgData.image.shape[0]):
        alpha, delta = wcs.imageToCelestial(x, y)
        formattedString = ("%s alpha=%.4f delta=%.4f" + \
                        "     Image X=%.3f  Y=%.3f" + \
                        "     Value=%.3f") % \
            (wcs.radecsys, alpha, delta, x + 1., y + 1.,
            imgData.image[yPix, xPix])

    return formattedString



class WorldCoordinateSystem():
    def __init__(self, fitsDataContainer, fitsDataSet=0, tiny=1.e-6):
        self.tiny = tiny
        self.ctype = None
        self.crval = None
        self.crpix = None
        self.cdelt = None
        self.pc    = None
        self.cd    = None

        try:
            _equinox = self._getInheritedKey(fitsDataContainer, fitsDataSet, "EQUINOX")
        except KeyError:
            _equinox = None

        try:
            _radecsys= self._getInheritedKey(fitsDataContainer, fitsDataSet, "RADESYS")
        except KeyError:
            try:
                _radecsys = fitsDataContainer.all_hdu[fitsDataSet].header["RADECSYS"]
            except KeyError:
                _radecsys = None

        if _equinox is not None:
            equinox = _equinox[0]
        else:
            equinox = None

        if _radecsys is not None:
            radecsys = _radecsys[0]
        else:
            radecsys = None

        if equinox is not None:
            if radecsys is None:
                if equinox < 1984.0:
                    self.radecsys = "FK4"
                else:
                    self.radecsys = "FK5"
            else:
                self.radecsys = radecsys.strip()
        else:
            if radecsys is not None:
                self.radecsys = radecsys.strip()
                if self.radecsys == "FK4" or self.radecsys == "FK4-NO-E":
                    self.equinox = 1950.
                elif self.radecsys == "FK5":
                    self.equinox = 2000.
            else:
                self.radecsys = "ICRS"

        try:
            ctype1 = fitsDataContainer.all_hdu[fitsDataSet].header["CTYPE1"]
            ctype2 = fitsDataContainer.all_hdu[fitsDataSet].header["CTYPE2"]
            crval1 = fitsDataContainer.all_hdu[fitsDataSet].header["CRVAL1"]
            crval2 = fitsDataContainer.all_hdu[fitsDataSet].header["CRVAL2"]
            crpix1 = fitsDataContainer.all_hdu[fitsDataSet].header["CRPIX1"]
            crpix2 = fitsDataContainer.all_hdu[fitsDataSet].header["CRPIX2"]

            self.ctype = np.array([ctype1, ctype2])
            self.crval = np.array([crval1, crval2])
            self.crpix = np.array([crpix1, crpix2])

            if "CD1_1" in fitsDataContainer.all_hdu[fitsDataSet].header :
                cd11 = fitsDataContainer.all_hdu[fitsDataSet].header["CD1_1"]
                cd12 = fitsDataContainer.all_hdu[fitsDataSet].header["CD1_2"]
                cd21 = fitsDataContainer.all_hdu[fitsDataSet].header["CD2_1"]
                cd22 = fitsDataContainer.all_hdu[fitsDataSet].header["CD2_2"]

                rhoA = 0.
                if cd21 > 0.:
                    rhoA = np.arctan2(cd21, cd11)
                elif cd21 < 0.:
                    rhoA = np.arctan2(-cd21, -cd11)
                rhoB = 0.
                if cd12 > 0.:
                    rhoB = np.arctan2(cd12, -cd22)
                elif cd12 < 0.:
                    rhoB = np.arctan2(-cd12, cd22)
                if np.fabs(rhoA - rhoB) < self.tiny:
                    rho = 0.5 * (rhoA + rhoB)
                    pc11 =  np.cos(rho)
                    pc12 = -np.sin(rho)
                    if cd11 != 0.:
                        cdelt1 = cd11 / pc11
                    else:
                        cdelt1 = -cd21 / pc12
                    if cd22 != 0.:
                        cdelt2 = cd22 / pc11
                    else:
                        cdelt2 = cd12 / pc12
                else:
                    raise ValueError

                self.cdelt = np.array([[cdelt1, 0.], [0., cdelt2]])
                self.pc = np.array([[pc11, pc12], [-pc12, pc11]])
                self.cd = np.array([[cd11, cd12], [cd21, cd22]])
            else:
                cdelt1 = fitsDataContainer.all_hdu[fitsDataSet].header["CDELT1"]
                cdelt2 = fitsDataContainer.all_hdu[fitsDataSet].header["CDELT2"]

                pc11 = fitsDataContainer.all_hdu[fitsDataSet].header["PC1_1"]
                pc12 = fitsDataContainer.all_hdu[fitsDataSet].header["PC1_2"]
                pc21 = fitsDataContainer.all_hdu[fitsDataSet].header["PC2_1"]
                pc22 = fitsDataContainer.all_hdu[fitsDataSet].header["PC2_2"]

                self.cdelt = np.array([[cdelt1, 0.], [0., cdelt2]])
                self.pc = np.array([[pc11, pc12], [pc21, pc22]])
                self.cd = np.dot(self.cdelt, self.pc)

        except (KeyError, ValueError):
            self.ctype = None
            self.crval = None
            self.crpix = None
            self.cdelt = None
            self.cd = None
            self.pc = None


    def _getInheritedKey(self, fitsDataContainer, fitsDataSet, key):
        value = None
        idx   = None
        try:
            idx = fitsDataSet
            value = fitsDataContainer.all_hdu[idx].header[key]
        except KeyError:
            # ESO DICB standard: Keyword inheritance is applied unless it is
            # explicitly disabled. 
            inherit = True
            try:
                inherit = fitsDataContainer.all_hdu[idx].header["INHERIT"]
                if not isinstance(inherit, bool):
                    inherit = True
            except KeyError:
                pass
            if inherit == True:
                try:
                    idx = 0
                    value = fitsDataContainer.all_hdu[idx].header[key]
                except KeyError:
                    raise

        return (value, idx)


    def imageToCelestial(self, xPix, yPix):
        # Linear transformation
        pos  = np.array([xPix - self.crpix[0], yPix - self.crpix[1]]).transpose()
        x, y = np.dot(self.cd, pos)

        # Gnomonic projection
        phi   = np.arctan2(x, -y)
        theta = np.arctan(1. / np.radians(np.sqrt(x**2. + y**2.)))

        alphaP = np.radians(self.crval[0])
        deltaP = np.radians(self.crval[1])

        stheta = np.sin(theta)
        ctheta = np.cos(theta)
        sphi   = np.sin(phi)
        cphi   = np.cos(phi)
        sdeltaP = np.sin(deltaP)
        cdeltaP = np.cos(deltaP)

        ra =  alphaP + np.arctan2(ctheta * sphi,
                                  stheta * cdeltaP + ctheta * sdeltaP * cphi)
        dec = np.arcsin(stheta * sdeltaP - ctheta * cdeltaP * cphi)
        
        return (np.degrees(ra), np.degrees(dec))


    def celestialToImage(self, raDeg, decDeg):
        ra  = np.radians(raDeg)
        dec = np.radians(decDeg)

        alphaP = np.radians(self.crval[0])
        deltaP = np.radians(self.crval[1])

        sra = np.sin(ra - alphaP)
        cra = np.cos(ra - alphaP)
        sdec = np.sin(dec)
        cdec = np.cos(dec)
        sdeltaP = np.sin(deltaP)
        cdeltaP = np.cos(deltaP)

        phi    = np.arctan2(-cdec * sra,
                            sdec * cdeltaP - cdec * sdeltaP * cra) + np.pi
        theta  = np.arcsin(sdec * sdeltaP + cdec * cdeltaP * cra)
        Rtheta = np.rad2deg(1.) / np.tan(theta)

        x =  Rtheta * np.sin(phi)
        y = -Rtheta * np.cos(phi)

        invDet = 1. / np.linalg.det(self.cd)

        xPix = (self.cd[1, 1] * x - self.cd[0, 1] * y) * invDet + self.crpix[0]
        yPix = (self.cd[0, 0] * y - self.cd[1, 0] * x) * invDet + self.crpix[1]

        return (xPix, yPix)



class DataPlotterManager(object):

    recipeName       = "muse_exp_align"
    tagInputData     = "IMAGE_FOV"
    tagAuxiliaryData = "SOURCE_LIST"

    _srcListColumnIds = ["Id", "X", "Y", "RA", "DEC", "RA_CORR", "DEC_CORR", "Flux"]

    _btnLabelAll     = "All detections (uncorrected)"
    _btnLabelAllCorr = "All detections (corrected)"

    _obsId = "DATE-OBS"

    _markerSymbols = ["o", "s", "D", "p", "*", "+", "x"]
    _markerColors  = ["#0000FF", "#00FF00", "#FF0000", "#00FFFF",
                      "#FF00FF", "#FFFF00", "#FFA500", "#800080"]

    def __init__(self):
        self._plt     = dict()
        self._widgets = dict()
        self._imageSelectorActive = 0

        super(DataPlotterManager, self).__init__()
        return


    def setWindowTitle(self):
        return self.recipeName + " GUI"


    def setWindowHelp(self):
        return "Help for " + self.setWindowTitle()


    def setInteractiveParameters(self):
         self.parameters = [
            reflex.RecipeParameter(recipe=self.recipeName,
                displayName="rsearch", group="Offset Calculation",
                description="Search radius (in arcsec)"),
            reflex.RecipeParameter(recipe=self.recipeName,
                displayName="nbins", group="Offset Calculation",
                description="Number of bins of the 2D histogram"),
            reflex.RecipeParameter(recipe=self.recipeName,
                displayName="weight", group="Offset Calculation",
                description="Use weighting"),
            reflex.RecipeParameter(recipe=self.recipeName,
                displayName="fwhm", group="Source Detection",
                description="FWHM of the convolution filter"),
            reflex.RecipeParameter(recipe=self.recipeName,
                displayName="bpixdistance", group="Source Detection",
                description="Minimum distance from image border [pixels]"),
            reflex.RecipeParameter(recipe=self.recipeName,
                displayName="threshold", group="Source Detection",
                description="Initial threshold for detecting point sources"),
            reflex.RecipeParameter(recipe=self.recipeName,
                displayName="bkgignore", group="Source Detection",
                description="Fraction of the image to be ignored"),
            reflex.RecipeParameter(recipe=self.recipeName,
                displayName="bkgfraction", group="Source Detection",
                description="Fraction of the image (without the ignored part) considered as background"),
            reflex.RecipeParameter(recipe=self.recipeName,
                displayName="step", group="Source Detection",
                description="Increment/decrement of the threshold value"),
            reflex.RecipeParameter(recipe=self.recipeName,
                displayName="iterations", group="Source Detection",
                description="Maximum number of iterations"),
            reflex.RecipeParameter(recipe=self.recipeName,
                displayName="srcmin", group="Source Detection",
                description="Minimum number of sources"),
            reflex.RecipeParameter(recipe=self.recipeName,
                displayName="srcmax", group="Source Detection",
                description="Maximum number of sources")
            ]
         return self.parameters


    def setCurrentParameterHelper(self, parameterValueGetter):
        self._getParameterValue = parameterValueGetter
        return


    def readFitsData(self, fitsFiles):
        self.frameset = dict()
        for fitsFile in fitsFiles:
            if fitsFile.category not in self.frameset.keys():
                self.frameset[fitsFile.category] = []

            datasets = pipeline_product.PipelineProduct(fitsFile)

            frameId = os.path.basename(datasets.fits_file.name)
            if fitsFile.category == self.tagInputData:
                frameId += "\n(" + DataPlotterManager._obsId + ": "
                frameId += datasets.readKeyword(DataPlotterManager._obsId) + ")"

            frame = dict()
            frame["label"] =  frameId
            frame["datasets"] = datasets
            self.frameset[fitsFile.category].append(frame)

        if self.tagInputData in self.frameset.keys():
            self._pltCreator = self._dataPlotCreate
            self._pltPlotter = self._dataPlotDraw
        else:
            self._pltCreator = self._dummyPlotCreate
            self._pltPlotter = self._dummyPlotDraw
        return


    def addSubplots(self, figure):
        self._figure = figure
        self._figure.clear()
        self._pltCreator()
        return


    def plotWidgets(self):

        if "dummy" not in self._plt.keys():
            if "buttons" not in self._widgets.keys():
                buttonLabels = []
                for frame in self.frameset[self.tagInputData]:
                    buttonLabels.append(frame["label"])
                buttonLabels.append(DataPlotterManager._btnLabelAll)
                buttonLabels.append(DataPlotterManager._btnLabelAllCorr)

                self._widgets["buttons"] = \
                    reflex_widgets.InteractiveRadioButtons(self._plt["buttons"],
                    self.onImageSelected, buttonLabels, self._imageSelectorActive,
                    title="Image Selector")
                self._resizeCID = self._figure.canvas.mpl_connect("resize_event",
                    self.onResize)
    
                for btnLabel in self._widgets["buttons"].rbuttons.labels:
                    btnLabel.set_clip_on(True)

        widgets = list()
        for key in self._widgets.keys():
            widgets.append(self._widgets[key])

        return widgets


    def plotProductsGraphics(self):
        self._pltPlotter()
        return


    def readTable(self, frame, columns, dataset=1):
        table = dict()
        for name in columns:
            try:
                table[name] = frame.all_hdu[dataset].data.field(name)
            except KeyError:
                table = None
                print("Warning: column '" + name + "' not found in " +
                      frame.fits_file.name)
            if not table:
                break
        return table


    # Event handler

    def onImageSelected(self, buttonLabel):
        selectionChanged = False
        if buttonLabel == DataPlotterManager._btnLabelAll:
            if self._imageSelectorActive != -1:
                self._imageSelectorActive = -1
                selectionChanged = True
        elif buttonLabel == DataPlotterManager._btnLabelAllCorr:
            if self._imageSelectorActive != -2:
                self._imageSelectorActive = -2
                selectionChanged = True
        else:
            for idx, frame in enumerate(self.frameset[self.tagInputData]):
                if frame["label"] == buttonLabel and idx != self._imageSelectorActive:
                    self._imageSelectorActive = idx
                    selectionChanged = True

        # Update image view only if the selection changed
        if selectionChanged:
            self._dataPlotDraw()

        return


    def onResize(self, event):
        if hasattr(self, '_widgets') and "buttons" in self._widgets.keys():
            bbox = self._plt["buttons"].get_position().get_points()
            btns = self._widgets["buttons"].rbuttons

            w = bbox[:, 0].ptp()
            h = bbox[:, 1].ptp()
            fh = self._figure.get_figheight()
            fw = self._figure.get_figwidth()
            vscale = (w * fw) / (h * fh)

            width = btns.labels[0].get_fontsize()
            width /= (self._figure.get_dpi() * w * fw)

            for circle in btns.circles:
                circle.width  = width
                circle.height = width * vscale
        return


    # Utility functions
    def _cycleMarkerProperties(self, index):
        ncolors  = len(DataPlotterManager._markerColors)
        nsymbols = len(DataPlotterManager._markerSymbols)

        color  = DataPlotterManager._markerColors[index % ncolors]
        symbol = DataPlotterManager._markerSymbols[(index // ncolors) % nsymbols]
        return (symbol, color)


    # Implementation of the plot creator and plotter delegates to be used
    # if the required data is actually available.

    def _dataPlotCreate(self):
        self._plt["buttons"]   = self._figure.add_axes([0.05, 0.05, 0.35, 0.90])
        self._plt["imageview"] = self._figure.add_axes([0.55, 0.05, 0.40, 0.90])

        # Turn off coordinate display in the status bar for the radio
        # buttons widget
        self._plt["buttons"].format_coord = lambda x, y: ""

        return


    def _dataPlotDraw(self):
        if self._imageSelectorActive >= 0:
            imgFrame = self.frameset[self.tagInputData][self._imageSelectorActive]
            imgData = imgFrame["datasets"]
            imgData.readImage(1)

            wcs = WorldCoordinateSystem(imgData, 1)
            if wcs.ctype[0].strip() != "RA---TAN" or wcs.ctype[1].strip() != "DEC--TAN":
                raise ValueError("Unsupported world coordinate system found in '%s'" %
                                 (imgData.fits_file.name,))

            tooltip = imgFrame["label"]

            # Find matching source list for the current FOV image imgData
            tblData = None
            numSources = 0
            for frame in self.frameset[self.tagAuxiliaryData]:
                timestamp = frame["datasets"].readKeyword(DataPlotterManager._obsId)
                if timestamp == imgData.readKeyword(DataPlotterManager._obsId):
                    print("Reading Source list '" + frame["label"] +
                        "' for FOV image '" + imgData.fits_file.name + "'")
                    tblData = self.readTable(frame["datasets"],
                                            self._srcListColumnIds)
                    if tblData is not None:
                        numSources = len(tblData["Id"])        

            tooltip += ": " + str(numSources) + " sources detected"

            xlabel = "X [pixel]"
            ylabel = "Y [pixel]"
            xpos = tblData["X"]        
            ypos = tblData["Y"]
            markerSize = np.zeros(len(tblData["Id"]))
            markerSize[:] = 8

            self._plt["imageview"].clear()
            
            imageview = pipeline_display.ImageDisplay()
            imageview.setAspect("equal")
            imageview.setCmap("gray")
            imageview.setZAutoLimits(imgData.image, None)
            imageview.setLabels(xlabel, ylabel)
            imageview.setXLinearWCSAxis(0., 1., 1.)
            imageview.setYLinearWCSAxis(0., 1., 1.)

            imageview.display(self._plt["imageview"], "Exposure FOV", tooltip, imgData.image)

            # This must be called after imageview display, to override the internal
            # formatting by imageview.display. This is a hack!
            
            self._plt["imageview"].format_coord = types.MethodType(formatCoordinates,
                                                                   (imgData, wcs))

            marker, color = self._cycleMarkerProperties(self._imageSelectorActive)

            self._plt["imageview"].autoscale(enable=False, tight=True)
            self._plt["imageview"].scatter(xpos, ypos,
                                           facecolors="none", edgecolors=color,
                                           marker=marker, s=markerSize**2)
            
            for idx, srcId in enumerate(tblData["Id"]):
                txtOffset = 0.5 * np.sqrt(2.) * markerSize[idx]
                xydata = (xpos[idx], ypos[idx])
                xytext = (txtOffset, -txtOffset)
                self._plt["imageview"].annotate(str(srcId),
                                                xy=xydata, xycoords="data",
                                                xytext=xytext, textcoords="offset points",
                                                ha="left", va="top", fontsize="x-small",
                                                color=color, clip_on=True)
        else:
            # Use the first image to setup the image display and get the WCS. 
            imgFrame = self.frameset[self.tagInputData][0]
            imgData = imgFrame["datasets"]
            imgData.readImage(1)
            
            wcs = WorldCoordinateSystem(imgData, 1)
            if wcs.ctype[0].strip() != "RA---TAN" or wcs.ctype[1].strip() != "DEC--TAN":
               raise ValueError("Unsupported world coordinate system found in '%s'" %
                                (imgData.fits_file.name,))

            # Set the pixel data of the image to 0 to clear it so that no
            # background image is displayed.
            bkgImage = np.array(imgData.image, copy=True)
            bkgImage[:] = 0

            srcLists = list()
            for imgFrame in self.frameset[self.tagInputData]:
                timestamp = imgFrame["datasets"].readKeyword(DataPlotterManager._obsId)
                imgName = imgFrame["datasets"].fits_file.name
                for tblFrame in self.frameset[self.tagAuxiliaryData]:
                    if timestamp == tblFrame["datasets"].readKeyword(DataPlotterManager._obsId):
                        print("Reading Source list '" + tblFrame["label"] +
                              "' for FOV image '" + imgName + "'")
                        tblData = self.readTable(tblFrame["datasets"],
                                                 self._srcListColumnIds)
                        srcLists.append(tblData)

            xlabel = "X [pixel]"
            ylabel = "Y [pixel]"
            colNameX = None
            colNameY = None
            if self._imageSelectorActive == -1:
                colNameX = "RA"
                colNameY = "DEC"
                tooltip = "Uncorrected positions"
            else:
                colNameX = "RA_CORR"
                colNameY = "DEC_CORR"
                tooltip = "Corrected positions"

            self._plt["imageview"].clear()
            
            imageview = pipeline_display.ImageDisplay()
            imageview.setAspect("equal")
            imageview.setCmap("gray")
            imageview.setZLimits([0., 1.])
            imageview.setLabels(xlabel, ylabel)
            imageview.setXLinearWCSAxis(0., 1., 1.)
            imageview.setYLinearWCSAxis(0., 1., 1.)
            imageview.display(self._plt["imageview"], "Source Positions", tooltip, bkgImage)

            # This must be called after imageview display, to override the internal
            # formatting by imageview.display. This is a hack!
            
            self._plt["imageview"].format_coord = types.MethodType(formatCoordinates,
                                                                   (imgData, wcs))

            self._plt["imageview"].autoscale(enable=False)

            for idx, srcList in enumerate(srcLists):
                imgCoordinates = np.empty((len(srcList[colNameX]), 2))
                skyCoordinates = zip(srcList[colNameX], srcList[colNameY])
                for i, (raDeg, decDeg) in enumerate(skyCoordinates):
                    imgCoordinates[i] = wcs.celestialToImage(raDeg, decDeg)
                markerSize = np.zeros(len(tblData["Id"]))
                markerSize[:] = 8
                marker, color = self._cycleMarkerProperties(idx)
                self._plt["imageview"].scatter(imgCoordinates[:, 0], imgCoordinates[:, 1],
                                               facecolors="none", edgecolors=color,
                                               marker=marker, s=markerSize**2)


        return


    # Implementation of the dummy plot creator and plotter delegates
    # follows here.

    def _dummyPlotCreate(self):
        self._plt["dummy"] = self._figure.add_subplot(1, 1, 1)
        return


    def _dummyPlotDraw(self):
        label = "Data not found! Input files should contain these types:\n%s" \
            % self.tagInputData

        self._plt["dummy"].set_axis_off()
        self._plt["dummy"].text(0.1, 0.6, label, color="#11557c", fontsize=18,
            horizontalalignment="left", verticalalignment="center", alpha=0.25)
        self._plt["dummy"].tooltip = "No data found"
        return



if __name__ == "__main__":

    from reflex_interactive_app import PipelineInteractiveApp


    interactive_app = PipelineInteractiveApp(enable_init_sop=True)
    interactive_app.parse_args()

    if not import_success:
        interactive_app.setEnableGUI(False)

    if interactive_app.isGUIEnabled():
        dataPlotManager = DataPlotterManager()

        interactive_app.setPlotManager(dataPlotManager)
        interactive_app.showGUI()
    else:
        interactive_app.set_continue_mode()

    # NOTE: Do not remove this line! This prints the output, which is parsed
    #       by the Reflex PythonActor to get the results!
    interactive_app.print_outputs()

    sys.exit()
