'''
Takes a catalogue of (approximate, non-) ICRF3 star positions from  a SuperCOSMOS Sky Survey IAM 
detection list with Hipparcos-Tycho2
astrometry, and a Gaia E/DR3 catalogue generated for the field, as inputs. Pairs the
two up the fits a local plane astrometric model to put the plate catalogue solidly on the Gaia
CRF3. The Gaia positions are proper-motion corrected to the plate epoch before use so that the same
catalogue can be used at different epochs.

Input file format should be plain csv with a single header line of column identifications at the
start. All celestial coordinates should be decimal degrees at Equinox J2000.0

Created on 26 Feb 2024 for Kepler 444 (HIP 94931) checks on magnitude terms in global plate astrometry

@author: nch
'''

# global imports
from astropy.time import Time
from astropy.table import Table
from gdr3cat import gdr3_epoch
from localplanecoords import *
from CatExtract import rad2as
from math import radians

# OE 0340: B1950 field centre 19h03m +42 09
celestial_tangent_point = '19h04m35.561 +42d13m35.41'
gdr3_catalogue_file_name = '/Users/nch/BDs/HIP94931/gdr3cat-E0340-Glt18p0.csv'
mag_order = 0

from astropy.coordinates import SkyCoord
raz,dcz = celestial_tangent_point.split()
tangent_point_coord = SkyCoord(raz, dcz)

datafiles = {
    '/Users/nch/BDs/HIP94931/PAE0340/iam_rdsrt.csv' : ' E 0340',
    #'/Users/nch/BDs/HIP94931/POSSE0281.fits' : ' E 0281',
    }

# ... timestamps to be parsed from platerecords.dat:
f = open('/Users/nch/BDs/HIP94931/platerecords.dat')
plate_dict = {}
line = f.readline()
while line:
    plate_dict[line[0:7]] = line
    line = f.readline()
f.close

# Maximum reference star distance - whole plate solution 
max_ref_star_distance_deg = 5.0
fit_order = 5 # 3 seems good for SuperCOS; 5 for SimpleCOS....?

# dictionary of columns and value lists to be output for SCR1845:
od = {
    'ra' : [],
    'dec' : [],
    'sigmaRaDeg' : [],
    'sigmaDecDeg' : [],
    'xiMas' : [],
    'etaMas' : [],
    'sigmaXiMas' : [],
    'sigmaEtaMas' : [],
    'mjd' : [],
    'numrefstars' : [],
    'plateLabel' : [],
    'plateRecord' : [],
    }
# exclude atm:     'epoch' : [],

# loop imports
import astropy.units as u

for detection_cat_file in datafiles:
    
    # derive the detection epoch (good to ~12hr or so):
    label = datafiles[detection_cat_file]
    plate_record = plate_dict[label]
    # extract UT date into isot string
    yy = int(plate_record[30:32].replace(' ','0'))
    mm = int(plate_record[32:34].replace(' ','0'))
    dd = int(plate_record[34:36].replace(' ','0'))
    utcstr = '19%02d-%02d-%02dT12:00:00.0'%(yy,mm,dd)
    detection_epoch = Time(utcstr, format='isot', scale='utc')
    
    # for SSA SuperCOS measures:
    mag_field_name = 'sMag'
    scrmag = 9.0
        
    # read both catalogues into table structures
    ra_colname = 'ra'
    dec_colname = 'dec'
    plate_catalogue = Table.read(detection_cat_file, format = 'csv')
    gdr3_catalogue = Table.read(gdr3_catalogue_file_name, format = 'csv')
    
    # get Gaia positions at the plate epoch: angle units are degrees
    gdr3_ras_j2016 = gdr3_catalogue['ra'] * u.deg
    gdr3_dcs_j2016 = gdr3_catalogue['dec'] * u.deg
    gdr3_pmras = gdr3_catalogue['pmra'] * u.mas/u.year
    gdr3_pmdcs = gdr3_catalogue['pmdec'] * u.mas/u.year
    gaia_coords = SkyCoord(ra = gdr3_ras_j2016, dec = gdr3_dcs_j2016, \
                           pm_ra_cosdec = gdr3_pmras, pm_dec = gdr3_pmdcs,\
                           obstime = gdr3_epoch, frame = 'icrs')
    # assume if we don't tell it full 6d kinematic info (distances and RVs) it will simply
    # apply a proper motion correction
    gaia_coords_at_plate_epoch = gaia_coords.apply_space_motion(new_obstime = detection_epoch)
    
    # sanity check: we're getting an ErfaWarning ... SCR1845 provided it's included in the Gaia selection:
    #print("SCR1845 at 2016.0:")
    #print(gaia_coords[gdr3_catalogue['source_id'] == 6439125097427143808].to_string("hmsdms"))
    #print("SCR1845 at " + detection_epoch.to_value(format='iso') + ":")
    #print(gaia_coords_at_plate_epoch[gdr3_catalogue['source_id'] == 6439125097427143808].to_string("hmsdms"))
    # ... OK for epoch 2000.0 position cf. Deacon et al. TSN11
    
    # extract vectors of local plane angular coordinates
    rough_ras = np.radians(plate_catalogue[ra_colname])
    rough_dcs = np.radians(plate_catalogue[dec_colname])
    ref_ras = gaia_coords_at_plate_epoch.ra.radian
    ref_dcs = gaia_coords_at_plate_epoch.dec.radian
    ref_gs = gdr3_catalogue["phot_g_mean_mag"]
    ref_bmr = gdr3_catalogue["bp_rp"]
    raz = tangent_point_coord.ra.radian
    dcz = tangent_point_coord.dec.radian
    rough_xis, rough_etas, j  = spherical_to_tangent_plane(rough_ras, rough_dcs, raz, dcz)
    all_xis, all_etas, j = spherical_to_tangent_plane(ref_ras, ref_dcs, raz, dcz) 
    
    # limit the reference stars
    max_dist_rad = max_ref_star_distance_deg * math.pi / 180.0
    refs_wanted = (np.abs(all_xis) < max_dist_rad) & (np.abs(all_etas) < max_dist_rad)
    all_xis = all_xis[refs_wanted]
    all_etas = all_etas[refs_wanted]
    
    # pair them up using units of arcsec for convenience
    pidx = pair_up(rough_xis*rad2as, rough_etas*rad2as, all_xis*rad2as, all_etas*rad2as, 2.0)
    
    # grab the relevant pairs-only data for the full-blown, iterative plate model working in pixels and arcsec
    pairs = pidx[pidx >= 0]
    ref_xis = all_xis[pairs]*rad2as/3600.0
    ref_etas = all_etas[pairs]*rad2as/3600.0
    x = rough_xis*rad2as/3600.0
    y = rough_etas*rad2as/3600.0
    ids = plate_catalogue['recno']
    radial_distances = plate_catalogue['rdeg']
    
    pointers = np.arange(len(pidx))
    mags = None
    refplatemags = None
    # switch on magnitude term (regularise to keep numbers of order unity):
    if not scrmag == None: 
        mags = plate_catalogue[mag_field_name] - scrmag
        refplatemags = mags[pointers[pidx >= 0]]
    
    # fit a plate model
    coeffs, rms, rank, s, inliers, ximads, etamads = local_plate_model(
            ref_xis.tolist(), ref_etas.tolist(), x[pointers[pidx >= 0]].tolist(), y[pointers[pidx >= 0]].tolist(), 
            iterate = True, order = fit_order, mags = refplatemags, residmin = 2.0e-5, ids = ids[pointers[pidx > 0]],
            mag_order = mag_order)
    # ... minimum residual: = 62 mas ~ 1.0micron on a POSS plate and/or VBS: don't believe precision can ever be better than that!
    
    print(coeffs)
    print(s)
    print("Transformation has RMS residual %.3f [arcsec]"%(rms*3600.0))
    if not isinstance(mags, type(None)) and mag_order > 0 and mag_order < 4:
        magidx = int(len(coeffs)/2) - 1
        print("Linear magnitude term has Xi gradient %.1f [mas/mag]"%(coeffs[magidx]*3.6e6))
        print("                     and Eta gradient %.1f [mas/mag]"%(coeffs[-1]*3.6e6))
    
    # TODO output a csv of the fit residuals to analyse offline in topcat
    xp, yp = apply_transformation(x[pointers[pidx >= 0]], y[pointers[pidx >= 0]], coeffs, fit_order, 
                                  mags = refplatemags, mag_order = mag_order)
    xresids = xp - ref_xis
    yresids = yp - ref_etas
    gmags = ref_gs[refs_wanted][pairs]
    colours = ref_bmr[refs_wanted][pairs]
    rresids = np.sqrt(xresids * xresids + yresids * yresids) * 3.6e6
    # normalise residuals via MAD as a function of mag for Gaussian N(0,1)
    xn01 = xresids / (1.48 * ximads)
    yn01 = yresids / (1.48 * etamads)
    from astropy.table import QTable
    if not isinstance(mags, type(None)):
        output_mags = refplatemags + scrmag
    else:
        output_mags = np.empty(len(xresids))
        output_mags.fill(99.99)
    resids_table = QTable(
        [ids[pointers[pidx >= 0]], xresids*3.6e6, yresids*3.6e6, gmags, colours, inliers, output_mags, xn01, yn01, xp*3600.0, yp*3600.0, radial_distances[pointers[pidx >=0]], rresids],                   
        names = ("id", "xresidmas", "yresidmas", "Gmag", "bmr", "inFit", "plateMag", "deltXiN01", "deltaEtaN01", "xiArcsec", "etaArcsec", "rdeg", "rDeltaMas"))
    if 'csv' in detection_cat_file:
        output_info_filename = detection_cat_file.replace(".csv", "_gaiacrf3fit_resids.csv")
    elif 'fits' in detection_cat_file:
        output_info_filename = detection_cat_file.replace(".fits", "_gaiacrf3fit_resids.csv")        
    if isinstance(mags, type(None)): output_info_filename = output_info_filename.replace("resids", "resids_nomags")
    resids_table.write(output_info_filename, format="csv", overwrite = True)
    
    # apply transformation and create new coordinate columns in plate catalogue
    xp, yp = apply_transformation(x, y, coeffs, fit_order, mags, mag_order = mag_order)
    ra_icrf3, dc_icrf3 = tangent_plane_to_spherical(np.radians(xp), np.radians(yp), raz, dcz)
    plate_catalogue.add_column(np.degrees(ra_icrf3), name = "ra_icrf3", index=1)
    plate_catalogue.add_column(np.degrees(dc_icrf3), name = "dec_icrf3", index=2)
    if 'csv' in detection_cat_file:
        output_ast_details = detection_cat_file.replace(".csv", "_icrf3.csv")
    elif 'fits' in detection_cat_file:
        output_ast_details = detection_cat_file.replace(".fits", "_icrf3.csv")        
    plate_catalogue.write(output_ast_details, format='csv', overwrite=True)
    
    # output the coordinates of HIP 94931 according to this new astrometry
    pmracosdec = 94.639
    pmdec = -632.269
    scr2000coo = SkyCoord(raz * u.rad, dcz * u.rad, pm_ra_cosdec = pmracosdec * u.mas / u.year, pm_dec = pmdec * u.mas / u.year,
                          obstime = Time('2000-01-01T12:00:00.0', format='isot', scale='utc'), frame = 'icrs')
    scr_obs_epoch = scr2000coo.apply_space_motion(new_obstime = detection_epoch)
    xi_scr, eta_scr, j = spherical_to_tangent_plane(scr_obs_epoch.ra.radian, scr_obs_epoch.dec.radian, raz, dcz)
    scridx = pair_up([xi_scr*180.0/math.pi], [eta_scr*180.0/math.pi], xp, yp, 3.0/3600.0)
    
    if scridx[0] > -1:
        
        ra_scr_icrf3, dc_scr_icrf3 = tangent_plane_to_spherical(xp[scridx[0]]*math.pi/180.0, yp[scridx[0]]*math.pi/180.0, raz, dcz)
        scr_icrf3_coo = SkyCoord(ra_scr_icrf3 * u.rad, dc_scr_icrf3 * u.rad)
        print("HIP 94931 coordinates on ICRF3: %s"%(scr_icrf3_coo.to_string("hmsdms")))
        
        # estimate astrometric precision from all inlying reference stars within 1 mag of SCR 1845:
        if not scrmag == None:
            midmag = mags[scridx[0]]
            #magindex = (refplatemags >= midmag - 0.5) & (refplatemags <= midmag + 0.5) & inliers
            magindex = (refplatemags <= 14.0) & inliers
            # ... assume errors plateau out brighter than this.
            xistdev = np.std(xresids[magindex] * 3.6e6)
            etastdev = np.std(yresids[magindex] * 3.6e6)
        else:
            xistdev = np.std(xresids[inliers]*3.6e6)
            etastdev = np.std(yresids[inliers]*3.6e6)
            
        print("Estimated astrometric precision in RA / Dec is %.3f / %.3f [mas]."%(xistdev, etastdev))
        
        # append plate data to outputs
        od['ra'].append(ra_scr_icrf3 * 180.0 / math.pi)
        od['dec'].append(dc_scr_icrf3 * 180.0 / math.pi)
        od['sigmaRaDeg'].append(xistdev / (3.6e6 * math.cos(dc_scr_icrf3)))
        od['sigmaDecDeg'].append(etastdev / 3.6e6)
        od['mjd'].append(detection_epoch.mjd)
        od['numrefstars'].append(np.count_nonzero(inliers))
        od['plateLabel'].append(label.replace(" ","_"))
        od['plateRecord'].append(plate_record.replace(" ","_"))
        od['xiMas'].append(xp[scridx[0]] * 3.6e6)
        od['etaMas'].append(yp[scridx[0]] * 3.6e6)
        od['sigmaXiMas'].append(xistdev)
        od['sigmaEtaMas'].append(etastdev)        
        
    else:
        print("HIP 94931 not found! Alter code to search for object not tangent point! No out put for %s"%(label))

    # TODO remove for production run, and uncomment file output below !!!
    #break

# finally spew out the SCR 1845 astrometry
#scrtable = QTable([od[key] for key in od.keys()], names = (key for key in od.keys()))
#scrtable.write('/Users/nch/BDs/HIP94931/hip94931-plate-astrometry-icrs.csv', format="csv", overwrite = True)


