'''
Created Dec 2020; vectorised and enhanced over original version 7 Sep 2020.

@author: nch
'''
import math
import numpy as np
TINY = 1.0e-6

def ranrm(theta):
    '''
    Normalise an angle expressed in radians into the range 0 to 2pi.
    '''
    w = np.fmod(theta, 2.0 * math.pi)
    
    if not isinstance(w, float):
        
        # vectorised version of block below
        w[w < 0.0] += 2.0 * math.pi
        
    else:
        if w < 0.0: w = w + 2.0 * math.pi
        
    return w
    
def tangent_plane_to_spherical(xi, eta, raz, dcz):
    '''
    Takes the local plane coordinates (xi, eta) and de-projects back onto the celestial sphere using the given
    tangent point origin (raz, dcz). Algorithm cribbed from Pat Wallaces' dtp2s.c in SLALIB. Arguments xi and 
    eta can be Numpy array-like as well as scalars.
    
    Returns spherical coordinates RA, Dec.
    
    All arguments and results in RADIANS.
    '''
    sdecz = math.sin (dcz)
    cdecz = math.cos (dcz)
    denom = cdecz - eta * sdecz
    ra = ranrm ( np.arctan2 ( xi, denom ) + raz )
    dc = np.arctan2 ( sdecz + (eta * cdecz), np.sqrt( (xi * xi) + (denom * denom) ))
    return ra, dc

def spherical_to_tangent_plane(ra, dec, raz, decz):
    '''
    Takes arbitrary point in spherical polars and projects about the given (local) projection origin
    in the same frame yielding local plane angular coordinates and a flag indicating any issues with 
    method applied.
    
    Returns xi, eta: local plane (rectangular) coordinates [radians]
           status j:   0 = OK, star on tangent plane
                       1 = error, star too far from axis
                       2 = error, antistar on tangent plane
                       3 = error, antistar too far from axis

    Cribbed from Patrick Wallace's SLALIB C code dst2p.    
    '''
    
    # Trig functions for convenience
    sdecz = math.sin(decz)
    sdec = np.sin(dec)
    cdecz = math.cos(decz)
    cdec = np.cos(dec)
    radif = ra - raz
    sradif = np.sin(radif)
    cradif = np.cos(radif)

    # Reciprocal of star vector length to tangent plane 
    denom = sdec * sdecz + cdec * cdecz * cradif

    # Handle vectors too far from axis 
    if not isinstance(ra, float):

        # vectorised version of the conditional block below
        j = np.empty(len(ra), dtype=int)
        
        j0 = denom > TINY
        j1 = (denom >= 0.0) & (denom <= TINY)
        j2 = (denom < 0.0) & (denom > -TINY)
        j3 = denom <= -TINY 
        
        # set flags accordingly
        j[j0] = 0
        j[j1] = 1
        j[j2] = 2
        j[j3] = 3
        
        # overwrite dodgy denominators
        denom[j1] = TINY
        denom[j2] = -TINY
        
    else:    
        if denom > TINY:
            j = 0
        elif denom >= 0.0:
            j = 1
            denom = TINY
        elif denom > -TINY:
            j = 2
            denom = -TINY
        else:
            j = 3

    # Compute tangent plane coordinates (even in dubious cases) 
    xi = cdec * sradif / denom
    eta = ( sdec * cdecz - cdec * sdecz * cradif ) / denom
    
    return xi, eta, j

def linear_coefficients_array(x, y, order, mags = None, mag_order = 2):
    '''
    Create the linear coefficients array for the given fit order, e.g. for 1st order
    
    1, x, y
    
    and for third order
    
    1, x, y, xy, xx, yy, xyy, xxy, xxx, yyy
    
    etc. Optionally include a magnitude term on the end if given mags.
    '''

    ax = np.array(x)
    ay = np.array(y)
    
    # first order coefficients
    c = np.column_stack((np.ones(len(x)), ax, ay))
    
    if order > 1: 
        # append second order coefficients if required
        c = np.column_stack((c, ax*ay, ax*ax, ay*ay))
        if order > 2: 
            # append third order coefficients if required
            c = np.column_stack((c, ax*ay*ay, ax*ax*ay, ax*ax*ax, ay*ay*ay))
            if order > 3:
                # append fourth order coefficients if required
                c = np.column_stack((c, ax*ay*ay*ay, ax*ax*ay*ay, ax*ax*ax*ay, ax*ax*ax*ax, ay*ay*ay*ay))
                # ... higher terms not implemented at present.
                
    # add up to a cubic magnitude term if mags are defined and dependency order is in 1,2,3:
    if not isinstance(mags, type(None)) and mag_order > 0 and mag_order < 4: 
        c = np.column_stack((c, np.array(mags)))
        if order > 1:
            c = np.column_stack((c, np.power(mags, 2.0)))
            if order > 2:
                c = np.column_stack((c, np.power(mags, 3.0)))
                
    # ... cubic destabilising for some plates - go with quadratic as default.
                
    return c

import scipy.stats as spstats
#import matplotlib.pyplot as plt

def mads_versus_mags(y, x, inliers, yminimum = 0.0):
    '''
    Computes robust predictions of y(x) given the inlier flags. Hard-coded for the specific
    case of a log-linear scattered set of points as a function of x, e.g. the scatter of
    median deviations of astrometric reference stars in a local plane model as a function of
    magnitude.
    
    Specify the minimum y value as a safety valve to prevent unphysically small predictions.
    
    Returns a NumPy array the same length as y and x containing the robustly estimated value
    of y(x).
    '''
    
    # work in log space on the y axis:
    lny = np.log(y)
    
    # robust linear regression
    medslope, medintercept = spstats.siegelslopes(lny[inliers], x[inliers])
    
    # sanity check visualisation
    #plt.scatter(x[inliers], lny[inliers], c = 'tab:orange')
    #plt.scatter(x[~inliers], lny[~inliers], c = 'tab:blue')
    #xs = np.linspace(np.max(x), np.min(x), 100)
    #ys = np.polyval(np.array([medslope, medintercept]), xs)
    #plt.plot(xs, ys)
    #plt.show()
    
    # predictions based on the above regression, transforming back to given linear y-space
    ypred = np.exp(medintercept + medslope * x)
    
    # give back the result with the appropriate safety valve applied
    return ypred.clip(min = yminimum)

import numpy.linalg as la

    
def pair_up(x1, y1, x2, y2, tol):
    '''
    Given two lists of 2d locations in cartesian coordinates, pairs up all sources that are
    coincident within the given radial tolerance. Coordinates and tolerance should all be in 
    the same (arbitrary) units. Nearest possible pairing within the tolerance is given: no
    unique hand-shake pairing or any other more sophisticated association (over and above simple
    proximity pairing) is done. 
    
    Returns a Numpy array pointers, p[i], between the first and second sets such that (x1[i], y1[i])
    corresponds to (x2[p[i]], y2[p[i]]).
    '''
    
    # squared maximum proximity
    tol2 = tol*tol
    
    # sort the second set on the first coordinate
    s = np.argsort(x2)    
    
    # uninitialised pointer array
    ps = np.empty(len(x1), dtype=int)
    
    # for each entry in the first find nearest pairing in the second by examining all possible
    for i in range(len(x1)):
        
        # lowest value to search for (no point checking below this value)
        x = x1[i] - tol
        
        # initialise no pairing
        p = -1
        d2max = tol2
        
        # find entry point in the sorted list
        j = np.searchsorted(x2, x, sorter = s)
        
        # work through from this point finding the nearest, stopping when there's no point in
        # going any further
        while j < len(s) and s[j] < len(x2) and x2[s[j]] <= x1[i] + tol:
            
            # test this one
            d2 = (x2[s[j]]-x1[i])*(x2[s[j]]-x1[i]) + (y2[s[j]]-y1[i])*(y2[s[j]]-y1[i])
            if d2 < d2max:
                p = s[j]
                d2max = d2
                
            # move on to the next one
            j = j + 1
        
        # record this pointer and move on
        ps[i] = p
    
    # return the list of pointers of locations in 2 corresponding to those in 1
    return ps
    
def apply_transformation(x, y, coeffs, order = 1, mags = None, mag_order = 2):
    '''
    Apply linear transformation to (x,y) coordinates, and optionally magnitudes for a first-order
    magnitude correction term, given the vector of coefficients.
    
    Returns two NumPy vectors of the transformed coordinates.
    '''
    
    # linear coefficients array
    c = linear_coefficients_array(x, y, order, mags, mag_order = mag_order)
    n = int(len(coeffs) / 2)
    
    # transform x
    v = np.array(coeffs[:n])
    xd = np.matmul(c, v)
    
    # transform y
    v = np.array(coeffs[n:])
    yd = np.matmul(c, v)
    
    return xd, yd

def local_plate_model(xi, eta, x, y, order = 1, iterate = False, mags = None, residmin = 0.0, ids = None, mag_order = 2):
    '''
    Performs a least-squares fit given the input lists of local plane angular coordinates
    (xi,eta) of standard points and a corresponding set of coordinates, e.g. ccd (x,y) or
    approximate (xi,eta). By default solves the set of linear equations 
    
    xi = a + bx + cy 
    eta = d + ex + fy
    
    and returns the 6-vector of model coefficients (a,b, ... f) plus other stuff. For
    higher order fits specify order = 2 or higher. Optionally include a first-order 
    magnitude term if mags is defined, along with a minimum residual below which any
    iterative rejection, robustly estimated as a function of magnitude, will not take
    place (same units as xi and eta).
    '''
    
    # initialise 
    iteration = 0
    inliers = np.ones(len(x), dtype=bool)
    # Gaussian equivalent 3sigma tolerance for absolute deviations
    normtol = 3.0 * 1.48
    axi = np.array(xi)
    aeta = np.array(eta)

    # get the array of coefficients
    c = linear_coefficients_array(x, y, order, mags, mag_order = mag_order) 
    
    # always do at least one iteration
    again = True
    
    while again:
        
        c1 = c[inliers]
        b = np.hstack([axi[inliers], aeta[inliers]])
        zeros = np.zeros((c1.shape[0], c1.shape[1]))
        A = np.block([[c1, zeros], [zeros, c1]])
        coeffs, sumsqresid, rank, s = la.lstsq(A, b, rcond = None)
        n = int(len(coeffs) / 2)

        # check for outlying points: apply the transformation to everything
        v = np.array(coeffs[:n])
        xp = np.matmul(c, v)
        v = np.array(coeffs[n:])
        yp = np.matmul(c, v)
        
        # absolute deviations using only the current inliers
        xallads = np.abs(xp - axi) 
        xad = xallads[inliers]
        yallads = np.abs(yp - aeta)
        yad = yallads[inliers]
        
        # Gaussian-equivalent 3sigma tolerance based on the typical over-all values
        xtoltyp = 3.0 * 1.48 * np.median(xad)
        ytoltyp = 3.0 * 1.48 * np.median(yad)
        
        # individual median absolute deviations as vectors (mag-dependent if required):
        if isinstance(mags, type(None)):
            xmads = np.empty(len(xp))
            xmads.fill(xtoltyp / normtol)
            ymads = np.empty(len(yp))
            ymads.fill(ytoltyp / normtol)
        else:
            # create the vectors as a function of magnitude:
            xmads = mads_versus_mags(xallads, mags, inliers, yminimum = residmin)
            ymads = mads_versus_mags(yallads, mags, inliers, yminimum = residmin)
            
        # redetermine the inliers 
        oldinliers = inliers.copy()
        oldpt = np.count_nonzero(oldinliers)

        # finish here if exact solution, i.e. points <= numcoeffs /2 and we can't go any further
        if oldpt <= n:
            rms = 0.0 
            break
                    
        currentinliers = (xallads / xmads < normtol) & (yallads / ymads < normtol)
        # N.B. strictly speaking this is lax in that the comparison here is between individual MAD against
        # typical MAD whereas normtol is 1.48 x 3 ! i.e. this condition is equivalent to a 4.44sigma cut
        # in an N(0,1) distribution.
        
        # note that outliers cannot come back in once rejected
        inliers = oldinliers & currentinliers
        newpt = np.count_nonzero(inliers)
        
        # sigmas are quoted in arcsec IFF fitting in angular units of degrees!    
        print("Iteration %d: previous Nfit %d; current Nfit %d; xsig, ysig: %.3f, %.3f [arcsec]"%
                (iteration, oldpt, newpt, xtoltyp * normtol * 1200.0, ytoltyp * normtol * 1200.0))

        # RMS residual
        rms = math.sqrt(sumsqresid / float(np.count_nonzero(inliers)))
        
        # iteration condition:
        again = newpt < oldpt
    
    
    return coeffs, rms, rank, s, inliers, xmads, ymads
