import numpy as np
import pyfits

NUM_PHOTCODES = 9  # Number of photcodes in this DVO catalog
FILTERS = "grizy"  # Iterable of filters

# Mapping from numpy type to FITS TFORM code
TYPEMAP = {
    np.float64: "D",
    np.float32: "E",
    np.int64: "K",
    np.bool: "L",
}


def readAstrom(filename):
    cptName = filename.replace(".cps", ".cpt")
    fits = pyfits.open(cptName)
    data = fits[1].data
    num = len(data)
    print "Read %d rows from %s" % (num, filename)

    schema = [pyfits.Column(name=name, format=TYPEMAP[format], unit=unit) for name, format, unit in (
        ("id", np.int64, None),
        ("ra", np.float64, "degrees"),
        ("dec", np.float64, "degrees"),
        ("ra_err", np.float32, "arcsec"),
        ("dec_err", np.float32, "arcsec"),
        ("epoch", np.float32, "sec"),
        ("pm_ra", np.float32, "arcsec/year"),
        ("pm_dec", np.float32, "arcsec/year"),
        ("pm_ra_err", np.float32, "arcsec/year"),
        ("pm_dec_err", np.float32, "arcsec/year"),
        ("good_visit", np.bool, None),
        ("good_stack", np.bool, None),
        ("extended", np.bool, None),
    )]

    out = pyfits.BinTableHDU.from_columns(schema, nrows=num)

    for colOut, colIn in (("id", "EXT_ID"),
                          ("ra", "RA"),
                          ("dec", "DEC"),
                          ("epoch", "MEAN_EPOCH"),
                          ("pm_ra", "U_RA"),
                          ("pm_dec", "U_DEC"),
                          ("pm_ra_err", "V_RA_ERR"),
                          ("pm_dec_err", "V_DEC_ERR"),
                          ):
        out.data.field(colOut)[:] = data.field(colIn)

    flags = (data.field("FLAGS") >> 24) & 0xFF
    out.data.field("good_visit")[:] = flags & 0x0004 > 0
    out.data.field("good_stack")[:] = flags & 0x0010 > 0
    out.data.field("extended")[:] = flags & 0x0001 > 0

    return out


def readPhotom(filename, astrom):
    cpsName = filename.replace(".cpt", ".cps")
    fits = pyfits.open(cpsName)
    data = fits[1].data
    num = len(data)//NUM_PHOTCODES
    print "Read %d entries from %s" % (num, filename)

    schema = [pyfits.Column(name=ff, format=TYPEMAP[np.float32], unit="mag") for ff in FILTERS]
    schema += [pyfits.Column(name=ff + "_err", format=TYPEMAP[np.float32], unit="mag") for ff in FILTERS]

    out = pyfits.BinTableHDU.from_columns(schema, nrows=num)

    visit = data.field("MAG")
    visitErr = data.field("MAG_ERR")
    warp = data.field("MAG_PSF_WRP")
    warpErr = data.field("FLUX_PSF_WRP_ERR")/data.field("FLUX_PSF_WRP")
    stack = data.field("MAG_PSF_STK")
    stackErr = data.field("FLUX_PSF_STK_ERR")/data.field("FLUX_PSF_STK")

    for ii, ff in enumerate(FILTERS):
        select = slice(ii, NUM_PHOTCODES*num + ii, NUM_PHOTCODES)
        useVisit = astrom.data.field("good_visit") & np.isfinite(visit[select])
        useWarp = np.isfinite(warp[select])
        useStack = astrom.data.field("good_stack") & np.isfinite(stack[select])

        mag = np.where(useVisit, visit[select],
                       np.where(useWarp, warp[select],
                                np.where(useStack, stack[select], np.nan)))
        err = np.where(useVisit, visitErr[select],
                       np.where(useWarp, warpErr[select],
                                np.where(useStack, stackErr[select], np.nan)))
        out.data.field(ff)[:] = mag
        out.data.field(ff + "_err")[:] = err

    return out


def select(astrom, photom):
    assert len(astrom.data) == len(photom.data)
    numGood = np.zeros(len(astrom.data), dtype=int)
    for ff in FILTERS:
        numGood += np.isfinite(photom.data.field(ff)) & np.isfinite(photom.data.field(ff + "_err"))
    return ~astrom.data.field("extended") & (numGood > 3)
