#!/usr/bin/env python

import os, sys, numpy, pyfits, math
import astLib.astWCS as astWCS
import photutils

if __name__ == "__main__":

    catfn = sys.argv[1]
    fitsfn = sys.argv[2]
    size = float(sys.argv[3])

    r_sky1 = float(sys.argv[4])
    r_sky2 = float(sys.argv[5])

    output_fn = sys.argv[6]

    with open(catfn, "r") as cat:
        lines = cat.readlines()

    hdulist = pyfits.open(fitsfn)
    imghdu = hdulist[0]
    wcs = astWCS.WCS(imghdu.header, mode='pyfits')
    pixelscale = wcs.getPixelSizeDeg() * 3600.
    if (size < 0):
        npixels = int(numpy.fabs(size))
    else:
        npixels = int(size / pixelscale)

    # check if a weight-map exists
    weight_fn =  fitsfn[:-5]+".weight.fits"
    if (os.path.isfile(weight_fn)):
        print "Loading and accounting for weight map"
        weight_hdu = pyfits.open(weight_fn)
        weight = weight_hdu[0].data
        imghdu.data[weight <= 0] = numpy.NaN


    np_sky1 = r_sky1 / pixelscale
    np_sky2 = r_sky2 / pixelscale

    results = []
    result_data = []
    result_names = []
    for line in lines[2:]:
        if (line.startswith("#")):
            continue
        items = line.split(",")
        
        try:
            id = int(items[0])
            objname = items[1].replace(" ", "_")
            ra = float(items[2])
            dec = float(items[3])
            vel = float(items[5])

            dust = items[17].startswith("x")
            interacting = items[18].startswith("x")
            debris = items[19].startswith("x")
            shells = items[25].startswith("x")

        except:
            continue

        xy = wcs.wcs2pix(ra, dec)
        positions = (xy[0], xy[1])

        obj_apertures = photutils.CircularAperture(positions, r=npixels)
        #sky_apertures = photutils.CircularAperture(positions, r=numpy.sqrt(2)*npixels)
        sky_annulus = photutils.CircularAnnulus(positions, r_in=np_sky1, r_out=np_sky2)

        apertures = [obj_apertures, sky_annulus]
        
        photometry = photutils.aperture_photometry(imghdu.data, apertures)

        obj_flux = photometry['aperture_sum_0'][0]
        obj_ap_size = obj_apertures.area()
        sky_ap_size = sky_annulus.area()
        background_mean = photometry['aperture_sum_1'][0] / sky_ap_size
        background_sum = background_mean * obj_ap_size

        full_flux = photometry['aperture_sum_0'][0] - background_sum

        # print photometry
        results.append("% 5d %-30s %.5f %.5f %d %10.4f %10.4f %10.4f" % (
            id, objname, ra, dec, vel, obj_flux, background_mean, full_flux)
        )
        result_names.append(objname)
        result_data.append([id, ra, dec, vel, obj_flux, background_mean, full_flux])


        if (interacting or debris or shells):
            print "% 5d %-30s %.5f %.5f %d" % (id, objname, ra, dec, vel), \
                " %5s"*4 % (interacting, debris, shells, (interacting| debris| shells)), \
                "###"
        else:
            continue

    result_data = numpy.array(result_data)

    #
    # And finally, create a list of random aperture positions to get a noise 
    # estimate
    #
    print "Calculating aperture noise"
    loops = 1
    random_fluxes = None
    for loop in range(loops):
        random_xy = numpy.random.rand(10000,2) * \
                    [imghdu.data.shape[1], imghdu.data.shape[0]]

        random_apertures = [
            photutils.CircularAperture(random_xy, r=npixels),
            photutils.CircularAnnulus(random_xy, r_in=np_sky1, r_out=np_sky2),
            ]
        random_phot = photutils.aperture_photometry(imghdu.data, random_apertures)
        # print random_phot
        # print random_phot['aperture_sum_0'][:]
        random_flux = random_phot['aperture_sum_0'] - random_phot['aperture_sum_1']/sky_ap_size*obj_ap_size
        # print random_flux

        # do some outlier rejection and calculate median and 1-sigma uncertainties
        good = numpy.isfinite(random_flux)
        for iter in range(5):
            _stats = numpy.percentile(random_flux[good], [16,50,84])
            _med = _stats[1]
            _sigma = 0.5*(_stats[2]-_stats[0])
            bad = (random_flux < _med-3*_sigma) | (random_flux > _med+3*_sigma)
            good[bad] = False
            print _med, _sigma

        print "uncertainties:", _med, _sigma, numpy.var(random_flux[good])

        # with open("output.phot", "w") as o:
        #     o.write("\n".join(results))
        numpy.savetxt("output.random.phot.%d" % (loop+1),
                      numpy.array([random_xy[:,0], random_xy[:,1], random_flux]).T)
        print(random_xy.shape, random_flux.shape, good.shape)
        numpy.savetxt("output.random.phot.%d.clean" % (loop+1),
                      numpy.array([
                          random_xy[:,0][good], 
                          random_xy[:,1][good], 
                          random_flux[good],
                               ]).T)

        combined = numpy.array([random_xy[:,0], random_xy[:,1], random_flux]).T
        numpy.savetxt("output.random.phot.all.%d" % (loop+1), combined)

        random_fluxes = combined if random_fluxes is None else numpy.append(random_fluxes, combined, axis=0)

    print "Writing results"
    results = numpy.array(results)
    numpy.savetxt("output.phot", results, fmt="%s")

    # do some outlier rejection and calculate median and 1-sigma uncertainties
    print "Doing master-combine for random fluxes"
    good = numpy.isfinite(random_fluxes[:,2])
    for iter in range(3):
        _stats = numpy.percentile(random_fluxes[good,2], [16,50,84])
        _med = _stats[1]
        _sigma = 0.5*(_stats[2]-_stats[0])
        bad = (random_fluxes[:,2] < _med-3*_sigma) | (random_fluxes[:,2] > _med+3*_sigma)
        good[bad] = False
        print _med, _sigma
    

    # Now that we have uncertainties, add them to the output data
    # also correct fluxes for the mean flux deviation calculated above 
    # (this should fix problems with background under/over-subtraction)
    print result_data.shape
    corrected_flux = result_data[:,6] - _med
    errors = numpy.ones_like(corrected_flux) * _sigma
    fix = numpy.array([corrected_flux, errors]).T
    print errors.shape, fix.shape, result_data.shape

    final = numpy.append(result_data, fix, axis=1)
    print results.shape, final.shape
    numpy.savetxt("output.phot.final", final, fmt="%s")

    columns = [
        'ID', 'RA', "DEC", "velocity", "raw_flux", "mean_background", "bgsub_flux",
        "corrected_flux", "uncertainty",
    ]
    header = ["# Column %2d: %s" % (i+1,c) for i,c in enumerate(columns)]
    header_str = "\n".join(header)

    numpy.savetxt(output_fn, final, fmt="%s",
                  header=header_str, comments="")



