#!/usr/bin/env python

import re
import os
import glob
import multiprocessing
from argparse import ArgumentParser

import numpy
import pyfits


JanskysPerABFlux = 3631.0;

# Mapping from numpy type to FITS TFORM code
TYPEMAP = {
    numpy.float64: "D",
    numpy.float32: "E",
    numpy.int64: "K",
    numpy.bool: "L",
}


# Support pickling of instance methods
import copy_reg, types
def unpickleInstanceMethod(obj, name):
    """Unpickle an instance method

    This has to be a named function rather than a lambda because
    pickle needs to find it.
    """
    return getattr(obj, name)
def pickleInstanceMethod(method):
    """Pickle an instance method

    The instance method is divided into the object and the
    method name.
    """
    obj = method.__self__
    name = method.__name__
    return unpickleInstanceMethod, (obj, name)
copy_reg.pickle(types.MethodType, pickleInstanceMethod)


class Function(object):
    """Wrapper for a function that operates on args,kwargs"""
    def __init__(self, func):
        self._func = func

    @staticmethod
    def args(*args, **kwargs):
        return (args, kwargs)

    def __call__(self, data):
        return self._func(*data[0], **data[1])


class MultiprocessingMapper(object):
    def __init__(self, threads, maxtasksperchild=1):
        self.pool = multiprocessing.Pool(threads, maxtasksperchild=maxtasksperchild)

    def __call__(self, func, dataList):
        return self.pool.map(func, dataList)


class CtrlPoolMapper(object):
    """Use the LSST ctrl_pool package for parallelisation.

    This allows parallelisation over nodes, but requires running
    python under 'mpiexec'.

    Here's an example Slurm submission script:

    #!/bin/bash
    #SBATCH --ntasks=96
    #SBATCH --time=12:00:00
    #SBATCH --job-name=andIndex
    #SBATCH --dependency=singleton
    #SBATCH --output=/tigress/HSC/users/price/ps1_pv3/andIndex.o%j
    #SBATCH --error=/tigress/HSC/users/price/ps1_pv3/andIndex.o%j
    unset EUPS_DIR EUPS_SHELL EUPS_PATH EUPS_PKGROOT SETUP_EUPS
    . /tigress/HSC/LSST/stack_20160915/eups/bin/setups.sh
    setup lsst_distrib
    setup miniconda2
    export PYTHONPATH=/tigress/HSC/users/price/ps1_pv3/python/:$PYTHONPATH
    export PATH=/tigress/HSC/users/price/ps1_pv3/bin/:/tigress/HSC/users/price/ps1_pv3/3pi.pv3.20160422/DATA/ref_cats/ps1_pv3_3pi_20170110/bin:$PATH
    cd /tigress/HSC/users/price/ps1_pv3/3pi.pv3.20160422/DATA/ref_cats/ps1_pv3_3pi_20170110
    date
    mpiexec -bind-to socket python /tigress/HSC/users/price/ps1_pv3/3pi.pv3.20160422/DATA/ref_cats/ps1_pv3_3pi_20170110/bin/buildAndCatalog.py --use-ctrl-pool -o /tigress/HSC/users/price/ps1_pv3/ps1_pv3_3pi_20170110-and/ps1_pv3_3pi_20170110 '*.fits'
    date
    """    
    def __init__(self):
        from lsst.ctrl.pool.pool import Pool, startPool
        startPool()  # Worker nodes peel off here
        self.pool = Pool(None)

    def __del__(self):
        if hasattr(self, "pool"):
            self.pool.exit()

    def __call__(self, func, dataList):
        return self.pool.map(func, dataList)


def system(command):
    print command
    return os.system(command)


class BuildAndCatalog(object):
    def __init__(self, inputList, outputRoot, nside=32, concat=1000):
        """Constructor

        The schema needs to be set appropriately for the output.  It is a
        dict with keys being the column names and values being the FITS
        column type.

        The build arguments needs to be set appropriately for the output.
        These are arguments to build-astrometry-index.
        """
        self.inputList = inputList
        self.outputRoot = outputRoot
        self.nside = nside
        self.concat = concat
        self.filters = "grizy"
        self.schema = dict([("id", "K"), ("ra", "D"), ("dec", "D")] + [(f, "E") for f in self.filters] +
                           [(f + "_err", "E") for f in self.filters])
        self.buildArgs = "-S i -L 20 -E -M -j 0.2 -n 100 -r 1"

    def filter(self, data):
        """Filter the input data, returning the appropriate columns"""
        out = {}
        out["id"] = data.field("id")
        out["ra"] = numpy.degrees(data.field("coord_ra"))
        out["dec"] = numpy.degrees(data.field("coord_dec"))
        for ff in self.filters:
            flux = data.field(ff + "_flux")
            fluxErr = data.field(ff + "_fluxSigma")
            out[ff] = -2.5*numpy.log10(flux/JanskysPerABFlux)
            out[ff + "_err"] = fluxErr/flux/0.4/numpy.log(10.0)
        return out

    def convert(self, inNames, outName):
        """Convert input data to the format to be processed by astrometry.net

        Concatenates multiple FITS files together at the same time.
        The concatenation is important because we can't dump thousands of filenames
        on the 'hpsplit' utility.
        """
        if os.path.exists(outName):
            print "Output file %s exists; not clobbering" % outName
            return
        sizes = [pyfits.open(fn)[1].header["NAXIS2"] for fn in inNames]
        num = sum(sizes)

        schema = [pyfits.Column(name=name, format=TYPEMAP[format], unit=unit) for name, format, unit in (
            ("id", numpy.int64, None),
            ("ra", numpy.float64, "degrees"),
            ("dec", numpy.float64, "degrees"),
        )]
        schema += [pyfits.Column(name=ff, format=TYPEMAP[numpy.float32], unit="mag") for ff in self.filters]
        schema += [pyfits.Column(name=ff + "_err", format=TYPEMAP[numpy.float32], unit="mag") for
                   ff in self.filters]

        out = pyfits.BinTableHDU.from_columns(schema, nrows=num)
        outData = out.data
        start = 0
        for fn, sz in zip(inNames, sizes):
            fits = pyfits.open(fn)
            inData = fits[1].data
            print "Read %d rows from %s" % (sz, fn)
            assert len(inData) == sz
            select = slice(start, start + sz)

            outData.field("id")[select] = inData.field("id")
            outData.field("ra")[select] = numpy.degrees(inData.field("coord_ra"))
            outData.field("dec")[select] = numpy.degrees(inData.field("coord_dec"))
            for ff in self.filters:
                flux = inData.field(ff + "_flux")
                fluxErr = inData.field(ff + "_fluxSigma")
                outData.field(ff)[select] = -2.5*numpy.log10(flux/JanskysPerABFlux)
                outData.field(ff + "_err")[select] = fluxErr/flux/0.4/numpy.log(10.0)
            start += sz
            fits.close()

        out.writeto(outName, clobber=True)

    def hpsplit(self, inputList):
        """Split the files into healpixes

        Only a single instance of this should be run on a catalog at a time.
        """
        out = "%s_hp_%%i.fits" % self.outputRoot
        args = "-r ra -d dec -n %d" % self.nside
        system("hpsplit -o " + out + " " + args + " " + " ".join(inputList))

    def generateIndexes(self, inName, index, healpix=None):
        """Generate astrometry.net indices

        Only a single instance of this should be run per input;
        inputs are usually divided into healpixes.
        """
        outName = "%s_and_%d" % (self.outputRoot, index)
        args = self.buildArgs[:] # Copy, so we're not overwriting when we append
        if healpix is not None:
            args += " -H %d" % healpix
        if self.nside is not None:
            args += " -s %d" % self.nside
        if not os.path.exists(outName + "_0.fits"):
            system("build-astrometry-index -i " + inName + " -o " + outName + "_0.fits -I " + str(index) + "0 -P 0 " + args)
        if not os.path.exists(outName + "_1.fits"):
            system("build-astrometry-index -1 " + outName + "_0.fits -o " + outName + "_1.fits -I " + str(index) + "1 -P 1 " + args)
        if not os.path.exists(outName + "_2.fits"):
            system("build-astrometry-index -1 " + outName + "_0.fits -o " + outName + "_2.fits -I " + str(index) + "2 -P 2 " + args)
        if False:
            # Don't need these: "-P 2  should work for images about 12 arcmin across" says build-astrometry-index
            system("build-astrometry-index -1 " + outName + "_0.fits -o " + outName + "_3.fits -I " + str(index) + "3 -P 3 " + args)
            system("build-astrometry-index -1 " + outName + "_0.fits -o " + outName + "_4.fits -I " + str(index) + "4 -P 4 " + args)

    def writeConfig(self, filenames):
        """Write the necessary configuration files"""
        filenames = [os.path.basename(fn) for fn in filenames]
        dirname = os.path.dirname(self.outputRoot)
        configName = os.path.join(dirname, "andConfig.py")

        config = {
            "magColumnMap": {ff: ff for ff in self.filters},
            "magErrorColumnMap": {ff: ff + "_err" for ff in self.filters},
            "indexFiles": filenames,
        }

        with open(configName, "w") as ff:
            for field in config:
                ff.write("root.%s = %s\n" % (field, config[field]))

        # This code for writing the cache is copied from
        # lsst.meas.astrom.multiindex.AstrometryNetCatalog.writeCache,
        # with optimisations added because we know the healpix and nside
        # values.
        # The full catalog is too big to load all in memory, so the usual
        # cache writing fails.
        outName = os.path.join(dirname, "andCache.fits")
        maxLength = max(len(fn) for fn in filenames) + 1
        healpix = [int(re.search(r"%s_and_(\d+)_\d+.fits" % os.path.basename(self.outputRoot), fn).group(1))
                   for fn in filenames]

        # First table
        first = pyfits.BinTableHDU.from_columns(
            [pyfits.Column(name="id", format="K"),
             pyfits.Column(name="healpix", format="K"),
             pyfits.Column(name="nside", format="K"),
             ], nrows=len(filenames))
        ident = numpy.arange(len(filenames), dtype=int)
        first.data.field("id")[:] = ident
        first.data.field("healpix")[:] = numpy.array(healpix)
        first.data.field("nside")[:] = numpy.ones(len(filenames), dtype=int)*self.nside

        # Second table
        second = pyfits.BinTableHDU.from_columns(
            [pyfits.Column(name="id", format="K"),
             pyfits.Column(name="filename", format="%dA" % maxLength),
             ], nrows=len(filenames)*2)
        # The filenames are duplicated...
        for indices in (2*ident, 2*ident + 1):
            second.data.field("id")[indices] = ident
            second.data.field("filename")[indices] = filenames

        pyfits.HDUList([pyfits.PrimaryHDU(), first, second]).writeto(outName, clobber=True)

        upsName = os.path.join(os.path.dirname(self.outputRoot), "ups", "astrometry_net_data.table")
        if not os.path.exists(upsName):
            os.makedirs(os.path.join(os.path.dirname(self.outputRoot), "ups"))
            open(upsName, 'a').close()  # Empty file

    def run(self, mapFunc=map):
        """Create astrometry.net indices"""
        # Chunk the input files, because hpsplit can't handle a really long list
        filenames = sum([glob.glob(filename) for filename in self.inputList], [])
        inputList = []
        while filenames:
            chunk = filenames[:self.concat]
            filenames = filenames[self.concat:]
            inputList.append(chunk)

        # Convert the inputs to the desired format
        catList = ["%s_in_%d.fits" % (self.outputRoot, i) for i, _ in enumerate(inputList)]
        convertData = [Function.args(inNames, catName) for inNames, catName in zip(inputList, catList) if
                       not os.path.exists(catName)]
        if convertData:
            mapFunc(Function(self.convert), convertData)

        # Split the inputs by HEALPix
        if glob.glob("%s_hp_*.fits" % self.outputRoot):
            print "Found healpix files: NOT clobbering"
        else:
            self.hpsplit(catList)

        # Generate indices for each HEALPix
        hpFiles = glob.glob("%s_hp_*.fits" % self.outputRoot)
        healpixes = [int(re.search(r"%s_hp_(\d+)\.fits" % self.outputRoot, inName).group(1)) for
                     inName in hpFiles]
        indexData = [Function.args(inName, healpix, healpix=healpix) for
                     inName, healpix in zip(hpFiles, healpixes)]
        mapFunc(Function(self.generateIndexes), indexData)

        # Write configuration files
        self.writeConfig(glob.glob("%s_and_*_[012].fits" % self.outputRoot))

    @classmethod
    def parseAndRun(cls):
        """Parse command-line arguments and run"""
        parser = ArgumentParser()
        parser.add_argument("input", nargs="*", help="Input files")
        parser.add_argument("-j", dest="threads", type=int, default=0, help="Number of threads")
        parser.add_argument("-o", "--output", required=True, help="Output root name")
        parser.add_argument("-s", "--nside", type=int, default=32, help="HEALPix nside (power of 2)")
        parser.add_argument("-c", "--concat", type=int, default=1000, help="Number of files to concatenate")
        parser.add_argument("--use-ctrl-pool", dest="useCtrlPool", action="store_true", default=False,
                            help="Use the LSST ctrl_pool package for multithreading? (Will ignore -j)")
        args = parser.parse_args()
        mapFunc = map
        if args.useCtrlPool:
            mapFunc = CtrlPoolMapper()
        elif args.threads > 0:
            mapFunc = MultiprocessingMapper(args.threads)
        self = cls(args.input, args.output, nside=args.nside, concat=args.concat)
        return self.run(mapFunc)


if __name__ == "__main__":
    BuildAndCatalog.parseAndRun()
