import os
from glob import glob
from collections import defaultdict
from lsst.pipe.base import ButlerInitializedTaskRunner, InputOnlyArgumentParser
from lsst.meas.algorithms import IngestIndexedReferenceTask
from lsst.ctrl.pool.parallel import BatchPoolTask, Pool, abortOnError
from lsst.pex.config import Config, ConfigurableField, Field
from lsst.afw.table import BaseCatalog, SourceCatalog

import pyfits
import numpy as np
import cPickle as pickle
import operator

JanskysPerABFlux = 3631.0;


class IngestDriverConfig(Config):
    ingest = ConfigurableField(target=IngestIndexedReferenceTask, doc="Ingestion")
    chunkSize = Field(dtype=int, default=1000, doc="Size of chunks for building outputs")

    def setDefaults(self):
        Config.setDefaults(self)
        self.ingest.dataset_config.ref_dataset_name = "ps1_pv3_3pi_20170110"
        self.ingest.dataset_config.indexer.name = "HTM"
        self.ingest.dataset_config.indexer["HTM"].depth = 7  # Colin says that 8 gives far too many files
        self.ingest.ra_name = "ra"
        self.ingest.dec_name = "dec"
        self.ingest.mag_column_list = list("grizy")
        self.ingest.mag_err_column_map = {ff: ff + "_err" for ff in self.ingest.mag_column_list}
        self.ingest.id_name = "id"
        self.ingest.extra_col_names = ["coord_ra_err", "coord_dec_err", "epoch",
                                       "pm_ra", "pm_dec", "pm_ra_err", "pm_dec_err"]


class IngestDriverRunner(ButlerInitializedTaskRunner):
    @staticmethod
    def getTargetList(parsedCmd, **kwargs):
        """Return a single element, with the list of filenames"""
        kwargs["clobber"] = bool(parsedCmd.clobberConfig)
        kwargs["doBackup"] = not bool(parsedCmd.noBackupConfig)
        return [(sum((glob(filename) for filename in parsedCmd.filenames), []), kwargs)]

    def run(self, parsedCmd):
        butler = parsedCmd.butler
        task = self.TaskClass(config=self.config, log=self.log, butler=butler)
        filenames = sum((glob(filename) for filename in parsedCmd.filenames), [])
        return task.run(filenames)


def unpickle(factory, args, kwargs):
    """Unpickle something by calling a factory"""
    return factory(*args, **kwargs)


class IndicesToFilenames(object):
    def __init__(self, *args):
        self._data = defaultdict(set, *args)

    @classmethod
    def forFilename(cls, filename, indices):
        return cls({ii: set([filename]) for ii in indices})

    def update(self, other):
        for ii, filenames in other._data.iteritems():
            self._data[ii].update(filenames)

    def __iadd__(self, other):
        self.update(other)
        return self

    def items(self):
        return self._data.items()

    def keys(self):
        return self._data.keys()

    def __getitem__(self, index):
        return self._data[index]

    def __len__(self):
        return len(self._data)


class IngestDriverTask(BatchPoolTask):
    RunnerClass = IngestDriverRunner
    ConfigClass = IngestDriverConfig
    _DefaultName = "ingester"

    @classmethod
    def _makeArgumentParser(cls, **kwargs):
        """Create an argument parser

        This overrides the original because we need the file arguments
        """
        parser = InputOnlyArgumentParser(name=cls._DefaultName)
        parser.add_argument("filenames", nargs="+", help="Names of (or globs for) files to index")
        return parser

    def __init__(self, *args, **kwargs):
        """!Constructor for the HTM indexing engine

        @param[in] butler  dafPersistence.Butler object for reading and writing catalogs
        """
        self.butler = kwargs.pop('butler')
        BatchPoolTask.__init__(self, *args, **kwargs)
        self.makeSubtask("ingest", butler=self.butler)
        self.refCatName = self.config.ingest.dataset_config.ref_dataset_name  # Fragile!

    def __reduce__(self):
        """Pickler

        Required because we changed the signature of __init__ to include 'butler'.
        """
        return unpickle, (self.__class__, [], dict(config=self.config, name=self._name,
                                                   parentTask=self._parentTask, log=self.log,
                                                   butler=self.butler))

    @abortOnError
    def run(self, filenames):
        if len(filenames) == 0:
            self.log.fatal("No files found")
            return

        self.writeDatasetConfig()
        schema = self.makeMasterSchema()
        pool = Pool(None)  # No context cache needed, and simpler function signature

        if os.path.exists("indexMapping.pickle"):
            indexMapping = pickle.load(open("indexMapping.pickle"))
        else:
            indexMapping = pool.reduce(operator.iadd, self.runIndex, filenames)
            pickle.dump(indexMapping, open("indexMapping.pickle", "w"))

        self.log.info("Indexed %d files" % (len(filenames),))

        pool = Pool(None)

        schemaCarrier = SourceCatalog(schema)  # Schema won't pickle, so need to wrap it

        # Break up into chunks because 'indexMapping.items()' is HUGE.
        # This wouldn't be necessary if our Pool.map supported iterators instead of lists.
        pending = indexMapping.keys()
        num = 0
        while pending:
            chunk = pending[:self.config.chunkSize]
            pending = pending[self.config.chunkSize:]
            num += pool.reduce(operator.add, self.runExtract, [(ii, indexMapping[ii]) for ii in chunk],
                               schemaCarrier)
        self.log.info("Extracted a total of %d sources", num)

    def writeDatasetConfig(self):
        dataId = self.ingest.indexer.make_data_id(None, self.refCatName)
        self.butler.put(self.config.ingest.dataset_config, 'ref_cat_config', dataId=dataId)

    def makeMasterSchema(self):
        """Persist an empty catalog to hold the master schema"""
        dtype = np.dtype([
            ("id", np.int64),
            ("ra", np.float64),
            ("dec", np.float64),
            ("coord_ra_err", np.float32),
            ("coord_dec_err", np.float32),
            ("epoch", np.int64),  # Sec since the UNIX epoch
            ("pm_ra", np.float32),
            ("pm_dec", np.float32),
            ("pm_ra_err", np.float32),
            ("pm_dec_err", np.float32),
            ] +
            [(ff + "_flux", np.float32) for ff in self.config.ingest.mag_column_list] +
            [(ff + "_err", np.float32) for ff in self.config.ingest.mag_column_list])
        schema, keyMap = self.ingest.make_schema(dtype)
        dataId = self.ingest.indexer.make_data_id('master_schema', self.refCatName)
        self.butler.put(self.ingest.get_catalog(dataId, schema), "ref_cat", dataId=dataId)
        return schema

    @abortOnError
    def runIndex(self, filename):
        outName = filename + ".indices"
        if os.path.exists(outName):
            indices = pickle.load(open(outName))
        else:
            catalog = self.read(filename)
            indices = self.ingest.indexer.index_points(catalog[self.config.ingest.ra_name],
                                                       catalog[self.config.ingest.dec_name])
            pickle.dump(indices, open(outName, "w"), -1)
        out = IndicesToFilenames.forFilename(filename, set(indices))
        self.log.info("Found %d indices in %s" % (len(out), filename))
        return out

    @abortOnError
    def runExtract(self, data, schemaCarrier):
        index, filenames = data
        dataId = self.ingest.indexer.make_data_id(index, self.refCatName)
        if self.butler.datasetExists("ref_cat", dataId=dataId):
            num = self.butler.get("ref_cat_len", dataId=dataId, immediate=True)
            self.log.info("See %d existing sources for index %d", num, index)
            return num
        # The input files are already spatially indexed, which means there should be a small number
        # contributing to this spatial index, so it doesn't hurt to load the indices for all of them.
        indicesList = [pickle.load(open(fn + ".indices")) for fn in filenames]
        maxSize = sum((indices == index).sum() for indices in indicesList)
        catalog = SourceCatalog(schemaCarrier.schema)
        self.log.info("Catalog size: %d rows" % maxSize)
        try:
            catalog.reserve(maxSize)
        except:
            self.log.warn("Failed reserving %d rows" % maxSize)
            raise

        for ii in range(maxSize):
            catalog.addNew()
        assert catalog.isContiguous()
        start = 0
        for fn, indices in zip(filenames, indicesList):
            data = self.read(fn)
            select = (indices == index)
            num = select.sum()
            if num == 0:
                continue
            assert catalog.isContiguous()
            subCat = catalog[start:start + num]

            for col in ("id", "epoch", "pm_ra", "pm_dec", "pm_ra_err", "pm_dec_err"):
                subCat[col][:] = data[col][select]
            for fromCol, toCol in (("ra", "coord_ra"),
                                   ("dec", "coord_dec"),
                                   ):
                subCat[toCol][:] = np.radians(data[fromCol][select])
            for fromCol, toCol in (("ra_err", "coord_ra_err"),
                                   ("dec_err", "coord_dec_err"),
                                   ):
                subCat[toCol][:] = data[fromCol][select]
            for ff in self.config.ingest.mag_column_list:
                flux = 10.0**(-0.4*data[ff][select])*JanskysPerABFlux
                fluxErr = np.abs(-0.4*data[ff + "_err"][select]*flux*np.log(10.0))
                subCat[ff + "_flux"][:] = flux
                subCat[ff + "_fluxSigma"][:] = fluxErr
            start += num

        self.butler.put(catalog, "ref_cat", dataId=dataId)
        self.log.info("Wrote %d sources for index %d", len(catalog), index)
        return len(catalog)

    def read(self, filename):
        return BaseCatalog.readFits(filename)
        fits = pyfits.open(filename)
        return fits[1].data

    def _getConfigName(self): return None
    def _getMetadataName(self): return None

