#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
A quick programme demonstrating filtering and cross-correlation. 

First, create a synthetic signal (Gaussian windowed sine function) 
and then add random noise to the signal to make it look more like an
experimental signal.

Then, perform frequency-domain filtering to clean up some of the 
signal to remove a lot of the random noise.

Finally, then do the cross-correlation with a reference signal to 
measure the time shifts.
'''
import numpy as np
import pylab as plt
import utlib as ut
import scipy.signal as signal

# Creating a function to create a synthetic function
def Gaussian(A, freq, sigma, t0, t):
    '''
    Creates Gaussian modulated sine wave to simulate experimental signal

    Arguments:
    ----------------
    ** A **:                        Amplitude of signal
    ** freq **:                     Frequency of signal
    ** sigma **:            Width of Gaussian envelope
    ** t0 **:                       Time shift of envelope
    ** t **:                        Time vector

    Returns:
    ----------------
    ** Gaussian **:         Returns Gaussian modulated sine wave (same length as t)
    '''
    
    Func = A * np.exp(-(t-t0)**2/(2*sigma**2))*np.sin(2*np.pi*freq*(t-t0))
    return Func

# Creating time vector and Reference signal
t = np.linspace(0, 50, 2**13)
pow2 = int(np.ceil(np.log2(len(t)))) + 2
Signal = Gaussian(1., 2, 1.75, 20, t) 

# Add Noise to signal
Signal += 0.2*(np.random.rand(len(t)) - 0.5) + 0.05*np.cos(2*np.pi*3.5*t) + 0.07*np.cos(2*np.pi*.5*t)
plt.plot(t, Signal)
plt.ylim(-1.2, 1.2)
plt.xlabel('Time (s)')
plt.ylabel('Amplitude (Arb.)')

# Fourier Transforming signal into frequency domain
Fourier = np.fft.fft(Signal, 2**pow2)[:2**(pow2-1)]
Freq = np.fft.fftfreq(2**pow2, t[2] - t[1])[:2**(pow2-1)]

# Create windowing function
FreqMask = ut.hannfilter(1, 1.5, 2.5, 3., Freq[-1], 1/Freq[1])[:-1]

# Plotting Fourier magnitude and windowing function
plt.figure()
plt.plot(Freq, np.abs(Fourier)/np.abs(Fourier).max())
plt.plot(Freq, FreqMask)
plt.xlim(0, 5)
plt.ylim(0, 1.05)
plt.xlabel('Frequency (Hz)')
plt.ylabel('Fourier Amplitude (Arb.)')

# Plotting real and imaginary components of FT
# Two subplots, unpack the axes array immediately
fig1, (ax1, ax2) = plt.subplots(2, 1, sharex=True, sharey=True)
ax1.plot(Freq, np.real(Fourier)/np.real(Fourier).max())
ax1.plot(Freq, FreqMask)
ax2.plot(Freq, np.imag(Fourier)/np.imag(Fourier).max())
ax2.plot(Freq, FreqMask)
ax1.set_xlim(0, 5)
ax1.set_ylim(-1.2, 1.2)
ax1.set_ylabel('Real Component')
ax2.set_ylabel('Imaginary Component')
ax2.set_xlabel('Frequency (Hz)')
fig1.subplots_adjust(hspace=0.02)
plt.setp([a.get_xticklabels() for a in fig1.axes[:-1]], visible=False)

# Plotting the real and imaginary components after masking
Fourier *= FreqMask
fig2, (ax3, ax4) = plt.subplots(2, 1, sharex=True, sharey=True)
ax3.plot(Freq, np.real(Fourier)/np.real(Fourier).max())
ax4.plot(Freq, np.imag(Fourier)/np.imag(Fourier).max())
ax3.set_xlim(0, 5)
ax3.set_ylim(-1.2, 1.2)
ax3.set_ylabel('Real Component')
ax4.set_ylabel('Imaginary Component')
ax4.set_xlabel('Frequency (Hz)')
fig2.subplots_adjust(hspace=0.02)
plt.setp([a.get_xticklabels() for a in fig2.axes[:-1]], visible=False)

# Now do the inverse Fourier transform back into time domain
Filtered = 2 * np.real(np.fft.ifft(Fourier, 2**pow2)[:len(t)])

# Plot results
plt.figure()
plt.plot(t, Signal)
plt.plot(t, Filtered, 'r')
plt.ylim(-1.1, 1.1)
plt.xlabel('Time (s)')
plt.ylabel('Amplitude (Arb.)')

plt.figure()
plt.plot(t, Filtered)
plt.ylim(-1.1, 1.1)
plt.xlabel('Time (s)')
plt.ylabel('Amplitude (Arb.)')
plt.show()
