#!/usr/bin/env python
#------------------------------------------------------------------------------
#$Id: ColourJpegger.py 10969 2016-01-21 15:38:06Z EckhardSutorius $
"""
   Swarp FITS files and create colour jpegs from them.
"""
#------------------------------------------------------------------------------
import os
from PIL import Image
#import Image
import math
import numpy
import astropy.io.fits as pyfits
import shutil

from wsatools.CLI                 import CLI
from wsatools.DbConnect.DbSession import DbSession
from wsatools.Logger              import Logger
import wsatools.Utilities                  as utils
#------------------------------------------------------------------------------

class EdSwarp(object):
    """
    """
    #--------------------------------------------------------------------------
    # Define class constants (access as EdSwarp.varName)

    surveyName = "LAS"
    filterSet = {}
    frameSetID = 0
    # Edited fits files
    EdFits = {'R': "r_ed.fits", 'G': "g_ed.fits", 'B': "b_ed.fits"}

    # WFCAM ZPs AB ZPs
    _ZpConvAB = {'Z': 0.528, 'Y': 0.634, 'J': 0.938, 'H': 1.379, 'K': 1.900}

    _headFile = "temp.head"

    # Association between filter and RGB colour
    _filterAssoc = {"GCS": dict(zip("RGB", ["K_1", "H", "Z"])),
                    "LAS": dict(zip("RGB", "KHY")),
                     "GPS": dict(zip("RGB", ["K_1", "H", "J"]))}

    #--------------------------------------------------------------------------

    def run(self):

        # get the framesets
        self.getFrameSets()

        for frameSetID in self.frameSetDict:
            # clean up first
            self.cleanUp()
            # reset dicts
            rgbFits = {}
            rgbZP = {}
            rgbExp = {}
            rgbMu = {}
            rgbJitter = {}

            # get the fits files
            fitsFiles, _obsData, mfIDs = self.frameSetDict[frameSetID]
            rgbFits['R'] = ("%s[%s]" % fitsFiles[:2]).replace("djoser:", '')
            rgbFits['G'] = ("%s[%s]" % fitsFiles[2:4]).replace("djoser:", '')
            rgbFits['B'] = ("%s[%s]" % fitsFiles[4:]).replace("djoser:", '')
            print "the Files: ", rgbFits['R'], rgbFits['G'], rgbFits['B']
            rgbMfIDDict = dict(zip("RGB", mfIDs))

            # read fits headers
            for colour in "RGB":
                imcopyHeadCmd = ' '.join(["imcopyAndEdHead", rgbFits[colour],
                                self.EdFits[colour]])
                results = self.runSysCmd(imcopyHeadCmd)
                print results
                # open a FITS file
                hdulist = pyfits.open(self.EdFits[colour])
                # get the primary HDU header
                prihdr = hdulist[0].header
                rgbZP[colour] = prihdr['MAGZPT'] \
                                + self._ZpConvAB[self.makeTranslat(colour)[:1]]
                rgbExp[colour] = prihdr['EXP_TIME']
                rgbMu[colour] = prihdr['NUSTEP']
                rgbJitter[colour] = prihdr['NJITTER']
                if colour == 'B':
                    header = prihdr.cards
                    self.writeTempHeader(header)
                hdulist.close()

            # re-bin if necessary
            if not (rgbMu['R'] == rgbMu['G'] == rgbMu['B']):
                print "rgbMu:", rgbMu
                for colour in "RGB":
                    if rgbMu[colour] > 1:
                        cmd = ' '.join([
                            "fimgbin",
                            "clobber=yes",
                            "average=yes",
                            "xbin=%s" % math.sqrt(rgbMu[colour]),
                            self.EdFits[colour],
                            '%s_ed_bin.fits' % colour.lower()])
                        print cmd
                        os.system(cmd)
                        self.EdFits[colour] = '%s_ed_bin.fits' % colour.lower()
                        rgbZP[colour] -= math.log10(rgbMu[colour]) * 2.5

            # calculate the scaling factor
            rgbCount = {}
            rgbFactor = {}
            for colour in "RGB":
                rgbCount[colour] = pow(10.0, rgbZP[colour] / 2.5) \
                                   * rgbExp[colour]
                print rgbCount[colour],
            print
            maxCount = max(rgbCount.values())
            print maxCount
            for colour in "RGB":
                rgbFactor[colour] = maxCount / rgbCount[colour]

            rgbFactor['R'] /= rgbFactor['G']
            rgbFactor['B'] /= rgbFactor['G']
            rgbFactor['G'] = 1.0

            print "factor: ", rgbFactor

            # swarp files
            Logger.addMessage("Starting swarp...")
            for colour in "RGB":
                swarpCmd = ' '.join(["swarp",
                                     "-VERBOSE_TYPE QUIET",
                                     "-BACK_SIZE 512",
                                     "-IMAGEOUT_NAME temp.fits",
                                     "-FSCALE_DEFAULT %s" % rgbFactor[colour],
                                     self.EdFits[colour]])
                print swarpCmd
                os.system(swarpCmd)
                Logger.addMessage("Finished %s swarp..." % colour)
                shutil.move("temp.fits", "%s_s.fits" % colour.lower())

            # make colour jpeg
            makeColour = ColourJpeg()
            makeColour.rgbFits = dict(
                zip("RGB", ["r_s.fits", "g_s.fits", "b_s.fits"]))
            makeColour.outFileName = createJpgName(rgbMfIDDict,
                                                   rgbFits)
            makeColour.run()
    #--------------------------------------------------------------------------

    def cleanUp(self):
        """Remove old files.
        """
        for fileName in self.EdFits.values():
            if os.path.exists(fileName):
                os.remove(fileName)
        if os.path.exists(self._headFile):
            os.remove(self._headFile)

    #--------------------------------------------------------------------------

    def runSysCmd(self, command, verbose=True):
        """
        Run a system command with popen3.
        @param command: The command to run.
        @type  command: str
        @return: List of lines returned.
        @rtype: list(str)
        """
        fi, fo, fe = os.popen3(command)
        fi.close()
        result = [x.rstrip() for x in fo.readlines()]
        fo.close()
        if verbose:
            errorlines = fe.readlines()
            if errorlines:
                for line in errorlines:
                    print line
        fe.close()
        return result

    #--------------------------------------------------------------------------

    def writeTempHeader(self, header):
        # write the temp header file
        f = open(self._headFile, 'w')
        print >> f, header
        f.close()

    #--------------------------------------------------------------------------

    def makeTranslat(self, name):
        invFilterSet = utils.invertDict(self.filterSet)
        if name in "RGB":
            return self.filterSet[name]
        else:
            return invFilterSet[name]

    #--------------------------------------------------------------------------

    def getFrameSets(self):
        """Get framesets from DB.
        """
        self.frameSetDict = {}
        self.archive = DbSession(self.database)
        # restricted to 1 for testing
        selectStr = "top 1 framesetid, "
        selectStr += ', '.join([
            "mf_.filename, _enum - 1".replace('_', self.filterSet[col])
            for col in "RGB"])
        selectStr += ', ' + ', '.join([
            "mf_.exptime, mf_.nustep, mfd_.photzpcat".replace(
            '_', self.filterSet[col])
             for col in "RGB"])
        selectStr += ', ' + ', '.join([
            "mf_.multiframeid".replace('_', self.filterSet[col])
            for col in "RGB"])
        fromStr = "___mergelog as ml, ".replace('___', self.surveyName)
        fromStr += ', '.join(["multiframe as mf_".replace('_', f)
                                for f in self.filterSet.values()])
        fromStr += ', ' + ', '.join([
            "multiframedetector as mfd_".replace('_', f)
            for f in self.filterSet.values()])
        whereStr = " and ".join(
            ["_mfid > 0".replace('_', f) for f in self.filterSet.values()] + \
            ["ml._mfid=mf_.multiframeid".replace('_', f)
             for f in self.filterSet.values()] + \
            ["ml._mfid=mfd_.multiframeid".replace('_', f)
             for f in self.filterSet.values()] + \
            ["ml._enum=mfd_.extnum".replace('_', f)
             for f in self.filterSet.values()])
        if self.frameSetID:
            whereStr += " and framesetID=%s" % self.frameSetID
        if not self.isTestRun:
            print "SELECT ", selectStr
            frameSets = self.archive.query(selectStr, fromStr, whereStr)
            if frameSets:
                for entry in frameSets:
                    self.frameSetDict[entry[0]] = [entry[1:7], entry[7:-3],
                                                   entry[-3:]]
# fitsFiles, obsData, mfIDs = self.frameSetDict[frameSetID]
            else:
                Logger.addMessage("No framesets found!")
        else:
            print "SELECT ", selectStr
            print "FROM ", fromStr
            print "WHERE ", whereStr
        del self.archive

#------------------------------------------------------------------------------

class ColourJpeg(object):
    """
    """
    #--------------------------------------------------------------------------
    # Define class constants (access as ColourJpeg.varName)
    rgbFits = {}
    outFileName = 'rgb.jpg'
    minVal = 12.0
    maxVal = 50000.0
    clip = False
    # test point

    valuey = 1
    valuex = 1

    # scaling values
    _rgbMin = {'R': 10., 'G': 20., 'B': 10.}
    sigma = 5.0

    #--------------------------------------------------------------------------
    def run(self):
        """
        """
        rgbFitsData = {}
        rgbScaledData = {}
        self._slope = 255.0 / numpy.arcsinh((self.maxVal - self.minVal) / self.sigma);
        print self.maxVal, self.minVal, self.sigma
        for colour in "RGB":
            fits = pyfits.open(self.rgbFits[colour])
            if colour == 'B':
                ysize, xsize = fits[0].data.shape
                self.valuex = int(xsize / 2)
                self.valuey = int(ysize / 2)
#                self.valuex=239
#                self.valuey=249
                print xsize, ysize, self.minVal, self.maxVal, self._slope
            rgbFitsData[colour] = fits[0].data
            fits.close()

        for colour in "RGB":
            print colour, ":", rgbFitsData[colour][self.valuey, self.valuex]

        # minmax cuts
        print "clip ", self.clip
#        if self.clip :
        for colour in "RGB":

#            rgbFitsData[colour] = numpy.where(
#                rgbFitsData[colour] > self.minVal,
#                rgbFitsData[colour], self.minVal)
#
#                rgbFitsData[colour] = numpy.where(
#                    rgbFitsData[colour] < self.maxVal,
#                    rgbFitsData[colour], self.maxVal)
            print "";

        # calculate mean and scaling factors
        mean = (rgbFitsData['R'] + rgbFitsData['G'] + rgbFitsData['B']) / 3.0
        print "mean:", mean[self.valuey, self.valuex]

        scale = (self._slope * numpy.arcsinh((mean - self.minVal) \
                                            / self.sigma)) / mean
        print "scale:", scale[self.valuey, self.valuex]

        # scale the fits images
        for colour in "RGB":
            rgbScaledData[colour] = scale * rgbFitsData[colour]
            del  rgbFitsData[colour]
        del scale
        for colour in "RGB":
            rgbScaledData[colour] = numpy.where(
                mean > self.minVal, rgbScaledData[colour], 0.)

            print colour, ":", rgbScaledData[colour][self.valuey, self.valuex]
        del mean, rgbFitsData
        if self.clip :
            for colour in "RGB":
                rgbScaledData[colour] = numpy.minimum(255, rgbScaledData[colour])
        else :
            maxRGB = numpy.maximum(
                numpy.maximum(rgbScaledData['R'], rgbScaledData['G']),
                rgbScaledData['B']) / 255.0
            maxRGB = numpy.maximum(maxRGB, 1.0)
            print "maxRGB:", maxRGB[self.valuey, self.valuex]

            for colour in "RGB":
                rgbScaledData[colour] = rgbScaledData[colour] / maxRGB

            del maxRGB

#            rgbScaledData[colour] = numpy.where(
#                maxRGB < 1.0, rgbScaledData[colour],
#                rgbScaledData[colour] / maxRGB)

        # create the default RGB array
        rgbArray = numpy.zeros((ysize, xsize, 3), numpy.uint8)

        # fill the RGB array with the scaled data
        rgbArray[:, :, 0] = numpy.maximum(0, rgbScaledData['R'])
        rgbArray[:, :, 1] = numpy.maximum(0, rgbScaledData['G'])
        rgbArray[:, :, 2] = numpy.maximum(0, rgbScaledData['B'])

        #print "Fits data:"
        #print rgbFitsData['R'][self.valuey, self.valuex],
        #print rgbFitsData['G'][self.valuey, self.valuex],
        #print rgbFitsData['B'][self.valuey, self.valuex]
        #print "Scaled data:"
        #print "red:", rgbScaledData['R'][self.valuey, self.valuex],
        #print " green:", rgbScaledData['G'][self.valuey, self.valuex],
        #print " blue:", rgbScaledData['B'][self.valuey, self.valuex]
        print "RGB array:"
        print rgbArray[self.valuey, self.valuex, 0],
        print rgbArray[self.valuey, self.valuex, 1],
        print rgbArray[self.valuey, self.valuex, 2]

        # create the jpg
        image = Image.fromarray(rgbArray, "RGB").transpose(
            Image.FLIP_TOP_BOTTOM)
        #image.show(command='display')
        image.save(self.outFileName , quality=95)
        Logger.addMessage("Created colour JPEG: %s" % self.outFileName)

#------------------------------------------------------------------------------

def createJpgName(mfIDDict, fitsFileDict):
    return "%s_%s_%s_%s_%s_%s.jpg" % (
        mfIDDict['R'], getExtNum(fitsFileDict['R']),
        mfIDDict['G'], getExtNum(fitsFileDict['G']),
        mfIDDict['B'], getExtNum(fitsFileDict['B']))

def getExtNum(fileName):
    return fileName[:-1].rpartition('[')[2]

#------------------------------------------------------------------------------
#
# Entry point for ColourJpegger

if __name__ == '__main__':
    CLI.progOpts += [CLI.Option('f', "filterset",
                                "Commasep. list of filters used for R, G, B, "
                                "resp.; can be survey name for preset "
                                "values: %r" % EdSwarp._filterAssoc,
                                "LIST", EdSwarp.surveyName),
                     CLI.Option('j', "jpeg",
                                "create a JPEG from the given 3 commasep. "
                                "fits file names (in order RGB)",
                                "LIST", ''),
                     CLI.Option('L', "low",
                                "low",
                                "FLOAT", ColourJpeg.minVal),
                     CLI.Option('H', "high",
                                "high",
                                "FLOAT", ColourJpeg.maxVal),
                     CLI.Option('C', "clipping",
                                "clip"),
                     CLI.Option('z', "sigma",
                                "sigma",
                                "FLOAT", ColourJpeg.sigma),
                     CLI.Option('o', "outfile",
                                "name for the created JPEG, "
                                "(default: multiframeIDSet)",
                                "NAME", ColourJpeg.outFileName),
                     CLI.Option('s', "survey",
                                "Survey name if filterset not defined via "
                                "survey name", "NAME", EdSwarp.surveyName),
                     CLI.Option('i', "framesetid",
                                "frameset ID",
                                "INT", EdSwarp.frameSetID)]

    cli = CLI(EdSwarp.__name__, "$Revision: 10969 $", EdSwarp.__doc__,
               checkSVN=False)
    Logger.addMessage(cli.getProgDetails())

    if cli.getOpt("jpeg"):
        makeColour = ColourJpeg()
        inFits = cli.getOpt("jpeg").split(',')
        makeColour.rgbFits = dict(zip("RGB", inFits))
        makeColour.outFileName = cli.getOpt("outfile")
        makeColour.minVal = float(cli.getOpt("low"))
        makeColour.maxVal = float(cli.getOpt("high"))
        makeColour.sigma = float(cli.getOpt("sigma"))
        makeColour.clip = cli.getOpt("clipping")
        makeColour.run()
    else:
        makeSwarp = EdSwarp()
        if cli.getOpt("filterset") in EdSwarp._filterAssoc.keys():
            makeSwarp.filterSet = EdSwarp._filterAssoc[cli.getOpt("filterset")]
            makeSwarp.surveyName = cli.getOpt("filterset")

        else:
            makeSwarp.filterSet = cli.getOpt("filterset").split(',')
            makeSwarp.surveyName = cli.getOpt("survey")
        makeSwarp.isTestRun = cli.getOpt("test")
        makeSwarp.frameSetID = cli.getOpt("framesetid")
        makeSwarp.database = cli.getOpt("database")
        makeSwarp.userName = cli.getOpt("user")
        print "filterset ", makeSwarp.filterSet['B']
        print "filterset ", makeSwarp.filterSet.values()
        print "filterset ", makeSwarp.filterSet.keys()
        for col in "RGB":
            print "filter ", makeSwarp.filterSet[col]
        makeSwarp.run()
