# -*- coding: utf-8 -*-
"""
Created on Nov 2 2014
@author: Erlend Ronnekleiv

This code is written for Pythonxy v2.7. 
    The Spyder GUI can be used to run the program: 
       https://code.google.com/p/spyderlib/  (download version for python 2.7)
    
    If you googe python <any keyword> you will find a lot of tutorials and Q&As 
    about how to use this language. 
    
    You probably have to download some missing libraries to get the code working. 
    For this you can use pip. In your system command line interface 
    (cmd.exe under windows) run pipto install missing packages. 
    Example:  pip install pyfits
    
Suggested workflow:
 1) Set file directory and name: dir0 and fname_in 
 2) Set parameters for starFilter to match the recorded star shape. 
 3) Adjust parameters for nebFilter to have a radius > detected star radius 
    for proper interpolation.
 4) Adjust starThr to a few times the image st.dev. after applying starFilter 
    Adjust starThrBgrFactor to represent the increase in st.dev with increased
    nebula bacground. 
 5) Set initial unsharp mask parameters usmr and usmw. Unsharpening allows the
    stars to be shapened to take up a smaller diameter before thay are removed. 
 6) Run the program, review results, adjust parameters, and run again until 
    results are acceptable. 
 7) Open <filename>_StarsRemoved.fit and <filename>_usm.fit in an image editor 
    (depending on the editor you may first have to convert to 16 bit tiff).
    Review the difference. Sharp nebula details may have been removed in the 
    process. This can be repaired by "cloning" back the original usm image and
    manually retouching stars in those regions.

Note: I wrote this program to use it myself. It is not very user friendly, 
allthough I have added a number of comments throughout the code. I would be 
pleased to help if someone wants to make the program more user friendly. 
Please contact erlend.web@eronn.net

"""


from __future__ import division
import numpy as np  
import pyfits 
from numpy.fft import rfft2, irfft2, ifftshift
import time
import scipy.ndimage as ndimg
#import matplotlib.pyplot as pl
#import PIL
#import scipy.special
#import matplotlib.cm as cm # colormaps
#import scipy.misc.imsave as imsave
nx=np.newaxis

def cosap2d(r1,r2):
    """ 
    Cosine apodised circular filter response. Flat inside ratius r1. 
    Apodized down to zero at radius r2.
    """
    a = int(np.ceil(r2))-1
    x = np.arange(-a,a+1);
    r = np.sqrt(x**2 + x[:,nx]**2)
    ap = np.zeros(np.shape(r))
    ap[r<=r1]=1;
    ii = (r>r1)&(r<r2); ap[ii] = (1 + np.cos(np.pi*(r[ii]-r1)/(r2-r1)))/2;
    return ap/np.sum(ap)
    
def starFilt(r2=1, r1=1.2, a1=.037, r3=4):
    """ 
    Sum of Gaussian and exponential filter response (mimics the telescope PSF)
    """
    xx = int(np.ceil(r3))-1
    x  = np.arange(-xx,xx+1);
    r = np.sqrt(x**2 + x[:,nx]**2)
    filt       = np.exp(-(r/r2)**2) + a1*np.exp(-r/r1)
    filt[r>r3] = 0
    return filt/np.sum(filt)
    
def star(r1,r2):
    a = int(np.ceil(r2))-1
    x = np.arange(-a,a+1);
    ap = np.zeros(np.shape(x))
    ap[x<=r1]=1;
    ii = (x>r1)&(x<r2); ap[ii] = (1 + np.cos(np.pi*(x[ii]-r1)/(r2-r1)))/2;
    return ap/np.sum(ap)
    
def fftConv2d(Fimg, imgshape, kernel):
    """ 
    2D convolution with FFT
    """
    kern = ifftshift(kernel)
    filt = np.zeros(imgshape)
    nn   = np.shape(kernel)[0]//2
    filt[:nn+1,:nn+1] = kern[:nn+1,:nn+1]
    filt[:nn+1,-nn: ] = kern[:nn+1,-nn: ]
    filt[-nn: ,:nn+1] = kern[-nn: ,:nn+1]
    filt[-nn: ,-nn: ] = kern[-nn: ,-nn: ]
    return irfft2(Fimg*rfft2(filt))

    
def usm(img, r, weight):
    """ 
    Unsharp mask
    """
    return (1+weight)*img - weight*ndimg.gaussian_filter(img,r)


tstart = time.time()    

dir0 = 'E:/astrofoto/2014/14-10-17/' # Directory of the FITS file
fname_in  = 'ic1805ic1848-Tak365-H.fit'    # Name of the FITS file

# Define star "replica filter". Should be similar to star shapes:
starFilter  = starFilt(r2=1.0, r1=1.1, a1=.010, r3=5) # used for ...Tak365-H.fit
#starFilt2  = starFilt(r2=0.95, r1=1.1, a1=.010, r3=5) # used for ...Tak365-S.fit
#starFilt2  = starFilt(r2=0.85, r1=1.1, a1=.010, r3=5) # used for ...Tak365-O.fit

# Define filter with (star radius) < radius < (nebula details):
nebFilter  = cosap2d(4,6);                             

# Threshold for star detection. Constant + background dependent term:
starThr = 6.5; starThrBgrFactor = 0.04; 

zeroPercentile = 25      # Percentile used to determine black point
usmr = 3; usmw = 0.025   # Initial unsharp mask radius and weight 
maskFilter=cosap2d(0,2)  # Filter used to interpolate the masked star. Max filter radius must exceed the detected star radius!
iterations = 4    
border  = 10;

fname_out = fname_in.rsplit('.',1)[0] + '_StarsRemoved.fit' # Append to output file name
fname_usm = fname_in.rsplit('.',1)[0] + '_usm.fit'

print 'loading', fname_in
h = pyfits.open(dir0+fname_in); 
N1 =h[0].header['NAXIS1']
N2 =h[0].header['NAXIS2']
pds=h[0].header['PEDESTAL']
imgdtype = h[0].data.dtype

# Add border to image to handle filtering edge artefacts
img = np.empty((N2+2*border,N1+2*border),dtype=imgdtype)
img[border:border+N2,border:border+N1] = h[0].data
img[:border ,:] = img[ 2*border:     border:-1,:]
img[-border:,:] = img[-1-border:-1-2*border:-1,:]
img[:,:border ] = img[:, 2*border:     border:-1]
img[:,-border:] = img[:,-1-border:-1-2*border:-1]

zeroPt = np.percentile(img,zeroPercentile)  # Black point estimate
img = usm(img-zeroPt, usmr, usmw)           # Subtract black point and and apply unsharp mask
starFiltImg = ndimg.convolve(img,starFilter) # Apply star "replica filter" to optimize SNR for star detection.

print 'Time used1:', time.time()-tstart      

interpImg = img
for n in xrange(iterations): # Iterate star detection and removal
    print 'Iteration', n
    
    # Apply nebula filter to get a bacground estimate:
    nebFiltImg = ndimg.convolve(interpImg, nebFilter)
    print 'Time used2:', time.time()-tstart
    
    # Create star mask: =1 where starFiltImg exceeds nebFiltImg by threshold:
    if n==0:
        mask = (starFiltImg-nebFiltImg) < 3*starThr # increased trhreshold in first iteration
    else:
        mask = (starFiltImg-nebFiltImg) < starThr + starThrBgrFactor*nebFiltImg
    # Smothen mask with the maskFilter:
    mask =  ndimg.convolve(mask.astype(np.double), maskFilter) 
    
    # Filter the masked image with nebFilter
    filtMaskImg = ndimg.convolve(mask*img, nebFilter)
    # Also filter the mask with nebFilter
    filtMask    = ndimg.convolve(mask, nebFilter)
    
    # Add interpolated image filtMaskImg/filtMask to the masked area
    interpImg   = img*mask + (1-mask)*filtMaskImg/filtMask
    print 'Time used3:', time.time()-tstart
    
    # Save mask image from current iteration
    h[0].data = np.array(mask[border:border+N2,border:border+N1]-pds, dtype=imgdtype)
    h.writeto('%smask%d.fit' % (dir0, n), clobber=True)
    
    # Save output image without stars from current iteration
    h[0].data = np.array(interpImg[border:border+N2,border:border+N1]+zeroPt, dtype=imgdtype)
    h.writeto('%sinterpImg%d.fit' % (dir0, n), clobber=True)
    
    print 'Time used4:', time.time()-tstart


# Save input image with USM applied
h[0].data = np.array(img[border:border+N2,border:border+N1]+zeroPt, dtype=imgdtype)
h.writeto(dir0+fname_usm, clobber=True)
        
# Save final output image without stars
h[0].data = np.array(interpImg[border:border+N2,border:border+N1]+zeroPt, dtype=imgdtype)
h.writeto(dir0+fname_out, clobber=True)