#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
A demontration of using 2D FFT on a B-Scan of ultrasonic data in order
to extract a single guided wave mode from a selection of multiple, 
dispersive guided wave moves.
'''
import numpy as np
import pylab as plt
import metaArray as m

# Read in the data into correct format
Data = m.pout_hist('BScan.flxhst')
Time = Data[0].data
Points = len(Data) - 2 
Disp = np.zeros([Points, len(Time)])
x = np.zeros(Points)

for i in xrange(Points):
    Buff = Data[2+i]
    Disp[i] = Buff.data
    x[i] = float(Buff['POUT_hist.xcrd'])

# Plot Results
Extent = (1e6*Time[0], 1e6*Time[-1], 1e3*x[0], 1e3*x[-1])
plt.figure()
plt.imshow(Disp, aspect='auto', origin='lower', extent=Extent)
plt.xlabel('Time ($\mu$s)')
plt.ylabel('Distance (mm)')

# Calculating Fourier Transform
Nx, Nt = Disp.shape

# Calculating next power of 2 for FFT lengths
NextNx = int(np.ceil(np.log2(Nx)))
NextNt = int(np.ceil(np.log2(Nt)))

# Performing 2D FFT on Data to get dispersion curves
Fourier = np.fft.fft2(Disp, s=[2**NextNx, 2**NextNt])

# Fourier spectrum axes - k = wavenumber; freq = frequency
k = np.fft.fftfreq(2**NextNx, x[2] - x[1])
freq = np.fft.fftfreq(2**NextNt, Time[2] - Time[1])

# Modifying shape of dispersion curves so that it is in correct format
Fourier = np.fft.fftshift(Fourier)
k = np.fft.fftshift(k)
freq = np.fft.fftshift(freq)
df = freq[2] - freq[1]
dk = k[2] - k[1]

# Calculate extent of image for plotting
ExtentF = (freq[0]/1.e6, freq[-1]/1.e6, 2*np.pi*k[0]/1.e3, 2*np.pi*k[-1]/1.e3)

# Plotting calculated Dispersion curves
plt.figure()
plt.imshow(np.abs(Fourier), aspect='auto', extent=ExtentF, origin='lower')
plt.xlabel('Frequency (MHz)')
plt.ylabel('Wavenumber (mm$^{-1}$)')
plt.axvline(0, color='k')
plt.axhline(0, color='k')
plt.xlim(-1.5, 1.5)
plt.ylim(-2.5, 2.5)

# Want to extract SH0 mode (non-dispersive), so calculate dispersion relation
vs = 3114.
freqstart = int(2e6 / df)
freq_selection = freq[2**(NextNt-1)-freqstart:2**(NextNt-1)+freqstart]
mode_k = -freq_selection / vs

# Now need to convert to indices values
bin_selection = (freq_selection / df).astype('int') + 2**(NextNt-1)
mode_nk = np.round(mode_k / dk).astype('int')
FourierC = np.zeros(Fourier.shape, dtype=complex)
width = 3       # Bandwidth of filter

for nf, nk in zip(bin_selection, mode_nk):
    # Forward and backward mode k value
    forward = 2**(NextNx-1) + nk
    backward = 2**(NextNx-1) - nk
    # Copy over forward travelling wave
    FourierC[forward-width:forward+width, nf] = Fourier[forward-width:forward+width, nf]
    # Backwards travelling wave
    FourierC[backward-width:backward+width, nf] = Fourier[backward-width:backward+width, nf]

# Save space
del Fourier

# Plotting calculated Dispersion curves
plt.figure()
plt.imshow(np.abs(FourierC), aspect='auto', extent=ExtentF, origin='lower')
plt.xlabel('Frequency (MHz)')
plt.ylabel('Wavenumber (mm$^{-1}$)')
plt.axvline(0, color='k')
plt.axhline(0, color='k')
plt.xlim(-1.5, 1.5)
plt.ylim(-2.5, 2.5)

# Do inverse Fourier to get back to time-spatial domain
Filtered = np.fft.ifft2(np.fft.fftshift(FourierC), s=[2**NextNx, 2**NextNt])
Filtered = np.real(Filtered[:len(x), :len(Time)])

# Plot results
plt.figure()
plt.imshow(Filtered, aspect='auto', origin='lower', extent=Extent)
plt.xlabel('Time ($\mu$s)')
plt.ylabel('Distance (mm)')
plt.show()
