#!/usr/bin/env python3

import os
import sys
import astropy.io.fits as pyfits
import numpy
import pandas
import io

from astropy.wcs import WCS

saturated_mask = (2 | 256)

if __name__ == "__main__":

    cat_fn = sys.argv[1]
    cat = numpy.loadtxt(cat_fn)

    # with open(cat_fn, "r") as f:
    #     lines = f.readlines()
    #     lines_out = [",".join(l.split(" ")) for l in lines]
    #     cat = pandas.read_csv(io.StringIO("\n".join(lines_out)))
    #cat = pandas.read_csv(cat_fn, sep=" ", comment="#") #, header=None)
    # cat.info()

    # sys.exit(0)

    ra_dec = cat[:, 3:5]
    print(ra_dec.shape)
    print(ra_dec[:5])

    output_fn = sys.argv[2]

    boxsize = int(numpy.round(3./0.168,0)) #14
    print("Using boxsize of %d pixels" % (boxsize))

    mask_files = sys.argv[3:]
    n_channels = len(mask_files)
    saturated_count = numpy.zeros((cat.shape[0], n_channels*2))

    for i_channel, img_fn in enumerate(mask_files):
        print("Processing %s ..." % (img_fn))

        hdulist = pyfits.open(img_fn)
        wcs = WCS(hdulist[0].header)
        img = hdulist[0].data.astype(numpy.int)

        xy = numpy.round(wcs.all_world2pix(ra_dec, 0), 0).astype(numpy.int)
        print(xy)


        for i,_xy in enumerate(xy):
            if (
                    (_xy[0] < boxsize) or
                    (_xy[1] < boxsize) or
                    (_xy[0] >= img.shape[1]-boxsize) or
                    (_xy[1] >= img.shape[0]-boxsize)
            ):
                # outside range
                continue


            # now we have a valid center coordinate
            source_box = img[ _xy[1]-boxsize:_xy[1]+boxsize,
                              _xy[0]-boxsize:_xy[0]+boxsize ]

            saturated_pixels = numpy.bitwise_and(source_box, 2)
            saturated_count[i,2*i_channel] = numpy.sum(saturated_pixels > 0)

            problematic_pixels = numpy.bitwise_and(source_box, 256)
            saturated_count[i,2*i_channel+1] = numpy.sum(problematic_pixels > 0)

            # source_mask = numpy.bitwise_and(source_box, saturated_mask)
            # saturated_pixels = numpy.sum(source_mask > 0)
            # if (saturated_pixels > 0):
            #     # there are at least some saturated pixels
            #     pass


        print(cat.shape, saturated_count.shape)

        cat_out = numpy.concatenate((cat, saturated_count), axis=1)
        print(cat.shape, cat_out.shape)

    numpy.savetxt(output_fn, cat_out)
