'''
Use image segmentation to photometry a SimpleCOSMOS image for optimum image parameters given their limited
dynamic range. Follows the procedures and code fragments at

https://photutils.readthedocs.io/en/stable/segmentation.html

Version using astrometry.net rough astrometric solution. Conda installation of astrometry.net
stand-alone application (not python library) is apparently not actionable via call-out to the shell,
so run that first on SimpleCOSMOS images e.g.

% source activate py37
% solve-field --ra 281.271667 --dec -63.963153 --radius 1.0 SDIM0137_02_intensity.fits

for further command line options / optimisations, see http://astrometry.net/doc/readme.html

This code assumes the above has been done prior to running this script, and uses the WCS so produced.

N.B. For image perusal in Starlink-GAIA, launch gaia with
% gaia -unicoderadec 0 &
to avoid X Error / opcode failure crash - see https://starlink.eao.hawaii.edu/starlink/2021ADownload

@author: nch
'''

# set to true to sanity check plots on screen
visual_check = False

# set of images to be analysed: SDIM image number, and plate label identification
#    113 : 'R 4350',  target is blended with another source on this plate!
simplecos_images = {
    137 : 'R10300',
    138 : 'R 4350',
    }

class Formatter(object):
    def __init__(self, im):
        self.im = im
    def __call__(self, x, y):
        z = self.im.get_array()[int(y), int(x)]
        return 'x={:.01f}, y={:.01f}, z={:.01f}'.format(x, y, z)
    
if __name__ == '__main__':

    # imports used within the loop
    from astropy.io import fits
    from photutils.background.background_2d import Background2D
    from photutils.background import BkgZoomInterpolator
    #from astropy.coordinates import SkyCoord
    from astropy import wcs
    import numpy as np
    
    # attempt to run astrometry.net as a subprocess:
    #import os
    #os.environ["PATH"] += os.pathsep + '/Users/nch/opt/anaconda3/envs/py37/bin'
    
    for key in simplecos_images:
        
        simplecos_image = '/Users/nch/RECONS/SCR1845/SDIM0%d_02_intensity.fits'%(key)
        print('Analysing %s'%(simplecos_image))
        
        #print(os.environ["PATH"])
        # use astrometry.net to establish a rough J2000 WCS and apply to the catalogue
        #astronet_output = subprocess.run(['/Users/nch/opt/anaconda3/envs/py37/bin/solve-field -O --no-plots --ra %f, --dec %f, --radius 1.0 %s'
        #astronet_output = subprocess.run(['solve-field -O --no-plots --ra %f, --dec %f, --radius 1.0 %s'
        #    %(tangent_point_coord.ra.degree, tangent_point_coord.dec.degree, simplecos_image)], shell = True, stdout=subprocess.PIPE, executable = '/bin/bash')        
        #print(astronet_output)
        # ... the above does not work because the shell process is not the conda environment in which the astrometry.net app is
        # installed to run. The following assumes it's already been done:
        
        wcs_filename = simplecos_image.replace('fits','wcs')
        # read back the WCS
        hdulist = fits.open(wcs_filename)

        # Parse the WCS keywords in the primary HDU
        w = wcs.WCS(hdulist[0].header)

        # Print out the "name" of the WCS, as defined in the FITS header
        print(w.wcs.name)

        # Print out all of the settings that were parsed from the header
        w.wcs.print_contents()
        
        hdu = fits.open(simplecos_image)
        data = hdu[0].data
        
        # cast to float for all subsequent operations (some won't work with shorts?!)
        data = data.astype('float64')
        
        bck_box = 32 # default SimpleCOS image size (1768, 2652) HCF gives reasonable results? Check.
        bkg_interp = BkgZoomInterpolator(order = 1) 
        # ... override default bicubic spline interpolation since that yields crazy values (e.g. -ve rms) near VBS 
        bck = Background2D(data, box_size = bck_box, interpolator = bkg_interp, filter_size = 5)
        print("Background computed with typical RMS %.2f"%(bck.background_rms_median))
        
        # sanity check by eye
        if visual_check:
            import matplotlib.pyplot as plt
            from astropy.visualization import LogStretch
            from astropy.visualization.mpl_normalize import ImageNormalize
            norm = ImageNormalize(stretch=LogStretch())
            plt.imshow(bck.background_rms, cmap='Greys', origin='lower', interpolation='nearest', norm=norm)
            plt.show()
    
        # work with a background subtracted image
        data -= bck.background
        
        # threshold for segmentation image is background plus kappa x sigma
        threshold = 3.0 * bck.background_rms#2.3 * bck.background_rms 17012 needed 3 sigma for some reason!
        # ... choose SuperCOSMOS equivalent for isophotal detection threshold 
        
        # matched image detection filter: Gaussian FWHM = 3 pix
        from astropy.convolution import Gaussian2DKernel
        from astropy.stats import gaussian_fwhm_to_sigma
        from photutils import detect_sources
        sigma = 2.0 * gaussian_fwhm_to_sigma
        kernel = Gaussian2DKernel(sigma, x_size=3, y_size=3)
        kernel.normalize() 
        segm = detect_sources(data, threshold, npixels=5, kernel=kernel, connectivity = 8)
        print("Segmentation / source detection finished ... ")
        
        # deblend
        from photutils import deblend_sources
        segm_deblend = deblend_sources(data, segm, npixels=5, kernel=kernel)#, mode = 'linear')
        print("Source deblending finished ...")
        
        # sanity check by eye
        if visual_check:
            from astropy.visualization import SqrtStretch
            from photutils.utils import colormaps
            rand_cmap = colormaps.make_random_cmap()#segm.max + 1, random_state=12345)
            norm = ImageNormalize(stretch=SqrtStretch())
            fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(8, 8))
            ax1.imshow(data, origin='lower', cmap='Greys_r', norm=norm)
            im2 = ax2.imshow(segm, origin='lower', cmap=rand_cmap)
            ax2.format_coord = Formatter(im2)
            im3 = ax3.imshow(segm_deblend, origin='lower', cmap=rand_cmap)
            ax3.format_coord = Formatter(im3)
            plt.show()
        
        from photutils.segmentation import SourceCatalog
        # photometer the sources, assuming we're background-limited:
        cat = SourceCatalog(data, segm_deblend, error = bck.background_rms, kernel = kernel)
        print("Source photometry finished:")
        
        columns = ['label', 'xcentroid', 'ycentroid', #'covar_sigx2', 'covar_sigy2',
                   'area', 'semimajor_sigma', 'semiminor_sigma', 'orientation', 'eccentricity',
                   'segment_flux', 'segment_fluxerr']
        tbl = cat.to_table(columns)
        tbl['xcentroid'].info.format = '.3f'  # optional format
        tbl['ycentroid'].info.format = '.3f'
        
        # a few extra table columns for convenience
        flux_min = np.min(tbl['segment_flux'])
        tbl['mag'] = -2.5*np.log10(tbl['segment_flux'] / flux_min)
        # Irwin (1984) rule-of-thumb: centroid error equals the flux relative error multiplied by the scale size of image
        tbl['xsig_estimate'] = 2.5 * tbl['segment_fluxerr'] / tbl['segment_flux']

        print(tbl)
        
        # sanity check by eye
        if visual_check:
            from astropy.visualization import simple_norm
            norm = simple_norm(data, 'sqrt')
            fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 8))
            ax1.imshow(data, origin='lower', cmap='Greys_r', norm=norm)
            ax1.set_title('Data')
            cmap = segm_deblend.make_cmap(seed=123)
            ax2.imshow(segm_deblend, origin='lower', cmap=cmap, interpolation='nearest')
            ax2.set_title('Segmentation Image')
            cat.plot_kron_apertures((2.5, 1.0), axes=ax1, color='white', lw=1.5)
            cat.plot_kron_apertures((2.5, 1.0), axes=ax2, color='white', lw=1.5)
            plt.show()
            
        # The pixel coordinates of interest.
        # Note we've silently assumed an NAXIS=2 image here.
        # The pixel coordinates are pairs of [X, Y].
        # The "origin" argument indicates whether the input coordinates
        # are 0-based (as in Numpy arrays) or
        # 1-based (as in the FITS convention, for example coordinates
        # coming from DS9).
        pixcrd = np.column_stack((tbl['xcentroid'], tbl['ycentroid']))
    
        # Convert pixel coordinates to world coordinates
        # The second argument is "origin" -- in this case we're declaring we
        # have 1-based (SExtractor-like) coordinates.
        # https://docs.astropy.org/en/stable/wcs/loading_from_fits.html
        # http://star-www.dur.ac.uk/~pdraper/extractor/Guide2source_extractor.pdf
        world = w.wcs_pix2world(pixcrd, 1)
        tbl.add_column(world.T[0], name='ra', index=0)
        tbl.add_column(world.T[1], name='dec', index=1)
        
        # spew out for a close look in topcat
        tbl.write(simplecos_image.replace(".fits", "_segextract.csv"), format="csv", overwrite = True)

        # Just do one image for now; COMMENT OUT FOR PRODUCTION RUN:
        # break