from __future__ import print_function
try:
    import cv2
except ModuleNotFoundError:
    print("Please install opencv-python module using following command:\npip3 install opencv-python")
import stmpy
import numpy as np
import scipy as sp
import matplotlib as mpl
import matplotlib.pyplot as plt
import scipy.optimize as opt
import scipy.ndimage as snd
from scipy.interpolate import interp1d, interp2d
from skimage import transform as tf
from skimage.feature import peak_local_max
#import stmpy.driftcorr as dfc

'''
Local drift correction of square/triangular lattice. (Please carefully rewrite if a rectangular lattice version is needed.)

Usage:
0. please import stmpy.driftcorr (Package 'skimage' required, run 'pip install -U scikit-image' in terminal);

1. findBraggs: FT the topo image (please interpolate to be 2D square array if not), then find all Bragg peaks by peak_local_max, plot result with abs(FT) to check validity (sometimes points in the center should be removed);

2. gshearcorr: global shear correction, outputs corrected image and positions of the corrected Bragg peaks;

3. phasemap: use the new Bragg peaks to generate local phase shift (theta) maps. Use chkphasemap get phase (phi) maps and check if pi lines in those matches with raw image. If not, please check your Bragg peaks or adjust sigma in the FT DC filter;

4. fixphaseslip: fix phase slips of 2d phase maps in a 1d spiral way.(this is not well developed yet. Phase slips could still be present
near edges after processing, please crop images and DOS maps accordingly and manually AFTER applying driftmap.) Output is 2d array in the same shape as the input array;

5. driftmap: calculate drift fields u(x,y) in both x and y directions;

6. driftcorr: local drift correction using u(x,y) on 2D (topo) or 3D(DOS map) input.

7. (OPTIONAL) chkperflat: check perfect lattice generated by Q1, Q2 (to be improved)

REFERENCES:
[1] MH Hamidian, et al. "Picometer registration of zinc impurity states in Bi2Sr2CaCu2O8+d for phase determination in intra-unit-cell Fourier transform STM", New J. Phys. 14, 053017 (2012).
[2] JA Slezak, PhD thesis (Ch. 3), http://davisgroup.lassp.cornell.edu/theses/Thesis_JamesSlezak.pdf

History:
    2017-04-28      CREATED BY JIANFENG GE
    04/29/2019      RL : Add documents for all functions. Add another method to calculate phasemap.
                            Add inverse FFT method to apply the drift field.
'''

#1. - getAttrs
def getAttrs(obj, a0, size=None, pixels=None):
    '''
    Create attributes of lattice constant, map size, number of pixels, and qscale for Spy object.

    Input:
        obj         - Required : Spy object of topo (2D) or map (3D).
        a0             - Required : Lattice constant in the unit of nm.
        size         - Optional : Size of the map in the unit of nm. If not offered, it'll be created
                                    automatically from header file.
        pixels         - Optional : Number of pixels of the topo/map. If not offered, it'll be created
                                    automatically from header file.

    Returns:
        N/A

    Usage:
        import stmpy.driftcorr as dfc
        dfc.getAttrs(topo, a0=a0)

    '''
    if size is None:
        try:
            size = obj.header['scan_range'][-1]
        except KeyError:
            try:
                size = float(obj.header['Grid settings'].split(";")[-2])
            except:
                print("Error: Cannot find map size from header. Please input it manually.")
    if pixels is None:
        try:
            pixels = int(obj.header['scan_pixels'][-1])
        except KeyError:
            try:
                pixels = int(obj.header['Grid dim'].split()[-1][:-1])
            except:
                print("Error: Cannot find number of pixels from header. Please input it manually.")
    obj.a0 = a0
    obj.size = size * 1e9
    obj.pixels = pixels
    obj.qmag = obj.size / obj.a0
    obj.qscale = obj.pixels / (2*obj.qmag)

#2. - findBraggs
def findBraggs(A, obj=None, rspace=True, min_dist=5, thres=0.25, r=0.25, \
    w=None, maskon=True, show=False, angle=0, even_out=True, update_obj=True):
    '''
    Find Bragg peaks in the unit of pixels of topo or FT pattern A using peak_local_max. If obj is offered,
    an attribute of bp will be created for obj.

    Input:
        A           - Required : 2D array of topo in real space, or FFT in q space.
        obj         - Optional : Object associated with A
        min_dist    - Optional : Minimum distance (in pixels) between peaks. Default: 5
        thres       - Optional : Minimum intensity of Bragg peaks relative to max value. Default: 0.25
        rspace      - Optional : Boolean indicating if A is real or Fourier space image. Default: True
        r           - Optional : width of the gaussian mask to remove low-q noise, =r*width
        w           - Optional : width of the mask that filters out noise along qx=0 and qy=0 lines.
                                    Set w=None will disable this mask.
        angle       - Optional : Angle of line masks in degrees.
        maskon      - Optional : Boolean, if False then no mask will be applied.
        show        - Optional : Boolean, if True then A and Bragg peaks will be plotted out.
        update_obj  - Optional : Boolean, if True then all the attributes of the object will be updated.

    Returns:
        coords      -  (4x2) array contains Bragg peaks in the format of [[x1,y1],[x2,y2],...,[x4,y4]]

    Usage:
        import stmpy.driftcorr as dfc
        bp = dfc.findBraggs(A, obj=topo, min_dist=10, thres=0.2, rspace=True, show=True)

    History:
        04/28/2017      JG : Initial commit.
        04/29/2019      RL : Add maskon option, add outAll option, and add documents.

    '''
    if rspace is True:
        F = stmpy.tools.fft(A, zeroDC=True)
    else:
        F = np.copy(A)
    # Remove low-q high intensity data with Gaussian mask
    if maskon is True:
        *_, Y, X = np.shape(A)
        Lx = X * r
        Ly = Y * r
        x = np.arange(X)
        y = np.arange(Y)
        p0 = [int(X/2), int(Y/2), Lx, Ly, 1, np.pi/2]
        G = 1-stmpy.tools.gauss2d(x, y, p=p0)
        if w is not None:
            mask3 = np.ones([Y, X])
            mask3[Y//2-int(Y*w):Y//2+int(Y*w),:] = 0
            mask3[:,X//2-int(X*w):X//2+int(X*w)] = 0
        else:
            mask3 = 1
        F *= G * mask3
        if obj is not None:
            L = np.shape(A)[-1]
            x = np.arange(L)
            y = np.arange(L)
            mask2 = stmpy.tools.gauss_ring(x, y, major=obj.qmag, minor=obj.qmag,
                                   sigma=10, x0=L/2, y0=L/2)
            F *= mask2*mask3
    coords = peak_local_max(F, min_distance=min_dist, threshold_rel=thres)
    coords = np.fliplr(coords)

    if show is True:
        plt.figure(figsize=[4,4])
        c = np.mean(F)
        s = np.std(F)
        plt.imshow(F, cmap=plt.cm.gray_r, interpolation='None', origin='lower left', clim=[0,c+5*s], aspect=1)
        plt.plot(coords[:, 0], coords[:, 1], 'r.')
        plt.gca().set_aspect(1)
        plt.axis('tight')
        print('#:\t[x y]')
        for ix, iy in enumerate(coords):
            print(ix, end='\t')
            print(iy)
    if even_out is True:
        center = np.array(np.shape(A)[::-1]) // 2
        for i, ix in enumerate(coords):
            coords[i] = center + (ix - center) // 2 * 2
    if obj is not None:
        if update_obj is True:
            obj.bp = coords
    return coords

#3. - gshearcorr
def gshearcorr(A, bp=None, obj=None, rspace=True, pts1=None, pts2=None, angle=np.pi/4, matrix=None, update_obj=True):
    '''
    Global shear correction based on position of Bragg peaks in FT of 2D or 3D array A

    Inputs:
        A           - Required : 2D or 3D array to be shear corrected.
        bp          - Required : (Nx2) array contains Bragg peaks in the unit of pixels.
        obj         - Optional : Spy object of topo (2D) or map (3D).
        rspace      - Optional : Boolean indicating if A is real or Fourier space image. Default: True
        pts1        - Optional : 3x2 array containing coordinates of three points in the raw FT (center,
                                    bg_x, bg_y).
        pts2        - Optional : 3x2 array containing coordinates of three corresponding points in the corrected
                                    FT (i.e., model center and bg_x and bg_y coordinates).
        angle       - Optional : Specify angle between scan direction and lattice unit vector direction (x and ux direction)
                                    in the unit of radian. Default is pi/4 -- 45 degrees rotated.
        matrix      - Optional : If provided, matrix will be used to transform the dataset directly
        update_obj  - Optional : Boolean, if True then all the attributes of the object will be updated.

    Returns:
        A_corr      - 2D or 3D array after global shear correction.
        M           - Transformation matrix to shear correct the topo/map

    Usage:
        import stmpy.driftcorr as dfc
        M, A_gcorr = dfc.gshearcorr(A, bp, obj=topo, rspace=True)
    '''
    *_, s2, s1 = np.shape(A)
    bp_temp = bp
    if matrix is None:
        if pts1 is None:
            bp = sortBraggs(bp, s=np.shape(A))
            s = np.array(np.shape(A))
            bp_temp = bp * s
            center = [int(s[0]*s[1]/2), int(s[0]*s[1]/2)]
            Q1, Q2, Q3, Q4, *_ = bp_temp
            if obj is None:
                Qx_mag = compute_dist(Q1, center)
                Qy_mag = compute_dist(Q2, center)
                Q_corr = np.mean([Qx_mag, Qy_mag])
            else:
                Q_corr = obj.qmag
            Qc1 = Q_corr*np.array([-np.cos(angle), -np.sin(angle)]) + center
            Qc2 = Q_corr*np.array([np.sin(angle), -np.cos(angle)]) + center
            Q1, Q2, Q3, Q4, *_ = bp
            Qc2 = Qc2 / s
            Qc1 = Qc1 / s
            center = [int(s2/2),int(s1/2)]
            print(Q1,Q2,Qc1,Qc2,center)
            pts1 = np.float32([center,Q1,Q2])
        else:
            pts1 = pts1.astype(np.float32)
        if pts2 is None:
            pts2 = np.float32([center,Qc1,Qc2])
        else:
            pts2 = pts2.astype(np.float32)
        M = cv2.getAffineTransform(pts1,pts2)
    else:
        M = matrix

    if rspace is not True:
        A_corr = cv2.warpAffine(A, M, (s2,s1),
                        flags=(cv2.INTER_CUBIC + cv2.BORDER_CONSTANT))
    else:
        M[:,-1] = np.array([0,0])
        offset = np.min(A)
        A = A - offset
        A_corr = cv2.warpAffine(np.flipud(A.T), M, (s2,s1),
                        flags=(cv2.INTER_CUBIC + cv2.BORDER_CONSTANT))
        A_corr = np.flipud(A_corr).T + offset
    return M, A_corr


#4. phasemap
def phasemap(A, bp, obj=None, sigma=10, method="lockin", update_obj=True):
    '''
    Calculate local phase and phase shift maps. Two methods are available now: spatial lockin or Gaussian mask convolution

    Input:
        A           - Required : 2D arrays after global shear correction with bad pixels cropped on the edge
        bp          - Required : Coords of Bragg peaks of FT(A), can be computed by findBraggs(A)
        obj         - Optional : Spy object of topo (2D) or map (3D).
        sigma       - Optional : width of DC filter in lockin method or len(A)/s
        method      - Optional : Specify which method to use to calculate phase map.
                                "lockin": Spatial lock-in method to find phase map
                                "convolution": Gaussian mask convolution method to find phase map
        update_obj  - Optional : Boolean, if True then all the attributes of the object will be updated.

    Returns:
        thetax      -       2D array, Phase shift map in x direction, relative to perfectly generated cos lattice
        thetay      -       2D array, Phase shift map in y direction, relative to perfectly generated cos lattice
        Q1          -       Coordinates of 1st Bragg peak
        Q2          -       Coordinates of 2nd Bragg peak

    Usage:
        import stmpy.driftcorr as dfc
        thetax, thetay, Q1, Q2 = dfc.phasemap(A, bp, sigma=10, method='lockin')

    History:
        04/28/2017      JG : Initial commit.
        04/29/2019      RL : Add "convolution" method, and add documents.
        11/30/2019      RL : Add support for non-square dataset
    '''

    *_, s2, s1 = A.shape
    s = np.minimum(s1, s2)
    bp = sortBraggs(bp, s=np.shape(A))
    t1 = np.arange(s1, dtype='float')
    t2 = np.arange(s2, dtype='float')
    x, y = np.meshgrid(t1, t2)
    Q1 = 2*np.pi*np.array([(bp[0][0]-int(s1/2))/s1, (bp[0][1]-int(s2/2))/s2])
    Q2 = 2*np.pi*np.array([(bp[1][0]-int(s1/2))/s1, (bp[1][1]-int(s2/2))/s2])
    if method is "lockin":
        Axx = A * np.sin(Q1[0]*x+Q1[1]*y)
        Axy = A * np.cos(Q1[0]*x+Q1[1]*y)
        Ayx = A * np.sin(Q2[0]*x+Q2[1]*y)
        Ayy = A * np.cos(Q2[0]*x+Q2[1]*y)
        Axxf = FTDCfilter(Axx, sigma)
        Axyf = FTDCfilter(Axy, sigma)
        Ayxf = FTDCfilter(Ayx, sigma)
        Ayyf = FTDCfilter(Ayy, sigma)
        thetax = np.arctan2(Axxf, Axyf)
        thetay = np.arctan2(Ayxf, Ayyf)
        if obj is not None:
            if update_obj is True:
                obj.phix = thetax
                obj.phiy = thetay
                obj.Q1 = Q1
                obj.Q2 = Q2
        return thetax, thetay, Q1, Q2
    elif method is "convolution":
        t_x = np.arange(s1)
        t_y = np.arange(s2)
        xcoords, ycoords = np.meshgrid(t_x, t_y)
        exponent_x = (Q1[0] * xcoords + Q1[1] * ycoords)#(2.* np.pi/s)*(Q1[0] * xcoords + Q1[1] * ycoords)
        exponent_y = (Q2[0] * xcoords + Q2[1] * ycoords)#(2.* np.pi/s)*(Q2[0] * xcoords + Q2[1] * ycoords)
        A_x = A * np.exp(np.complex(0,-1)*exponent_x)
        A_y = A * np.exp(np.complex(0,-1)*exponent_y)
        sx = sigma
        sy = sigma * s1 / s2
        Amp = 1/(4*np.pi*sx*sy)
        p0 = [int(s/2), int(s/2), sx, sy, Amp, np.pi/2]
        G = stmpy.tools.gauss2d(t_x, t_y, p=p0, symmetric=True)
        T_x = sp.signal.fftconvolve(A_x, G, mode='same',)
        T_y = sp.signal.fftconvolve(A_y, G, mode='same',)
        R_x = np.abs(T_x)
        R_y = np.abs(T_y)
        phi_y = np.angle(T_y)
        phi_x = np.angle(T_x)
        if obj is not None:
            if update_obj is True:
                obj.phix = phi_x
                obj.phiy = phi_y
                obj.Q1 = Q1
                obj.Q2 = Q2
        return phi_x, phi_y, Q1, Q2
    else:
        print('Only two methods are available now:\n1. lockin\n2. convolution')

#5. fixphaseslip
def fixphaseslip(A, thres=None, maxval=None, method='unwrap', orient=0):
    '''
    Fix phase slip by adding 2*pi at phase jump lines.

    Inputs:
        A       - Required : 2D arrays of phase shift map, potentially containing phase slips
        thres   - Optional : Float number, specifying threshold for finding phase jumps in diff(A). Default: None
        method  - Optional : Specifying which method to fix phase slips.
                                "unwrap": fix phase jumps line by line in x direction and y direction, respectively
                                "spiral": fix phase slip in phase shift maps by flattening A into a 1D array in a spiral way
        orient  - Optional : Used in "spiral" phase fixing method. 0 for clockwise and 1 for counter-clockwise

    Returns:

        phase_corr      -       2D arrays of phase shift map with phase slips corrected

    Usage:
        import stmpy.driftcorr as dfc
        thetaxf = dfc.fixphaseslip(thetax, method='unwrap')

    History:
        04/28/2017      JG : Initial commit.
        04/29/2019      RL : Add "unwrap" method, and add documents.
    '''
    output = np.copy(A[::-1,::-1])
    if len(np.shape(A)) == 2:
        *_, s2, s1 = np.shape(A)
        for i in range(s2):
            output[i,:] = unwrap_phase(output[i,:], tolerance=thres, maxval=maxval)
        for i in range(s1):
            output[:,i] = unwrap_phase(output[:,i], tolerance=thres, maxval=maxval)
        return output[::-1,::-1]

def unwrap_phase(ph, tolerance=None, maxval=None):
    maxval = 2 * np.pi if maxval is None else maxval0
    tol = 0.25*maxval if tolerance is None else tolerance*maxval
    if len(ph) < 2:
        return ph

    dph = np.diff(ph)
    dph[np.where(np.abs(dph) < tol)] = 0
    dph[np.where(dph < -tol)] = 1
    dph[np.where(dph > tol)] = -1
    ph[1:] += maxval * np.cumsum(dph)
    return ph

#6. driftmap
def driftmap(phix=None, phiy=None, Q1=None, Q2=None, obj=None, method="lockin", update_obj=True):
    '''
    Calculate drift fields based on phase shift maps, with Q1 and Q2 generated by phasemap.

    Inputs:
        obj         - Optional : Spy object of topo (2D) or map (3D).
        phix        - Optional : 2D arrays of phase shift map in x direction with phase slips corrected
        phiy        - Optional : 2D arrays of phase shift map in y direction with phase slips corrected
        Q1          - Optional : Coordinates of 1st Bragg peak, generated by phasemap
        Q2          - Optional : Coordinates of 2nd Bragg peak, generated by phasemap
        method      - Optional : Specifying which method to use.
                                    "lockin": Used for phase shift map generated by lockin method
                                    "convolution": Used for phase shift map generated by lockin method
        update_obj  - Optional : Boolean, if True then all the attributes of the object will be updated.

    Returns:
        ux          - 2D array of drift field in x direction
        uy          - 2D array of drift field in y direction

    Usage:
        import stmpy.driftcorr as dfc
        ux, uy = dfc.driftmap(thetaxf, thetayf, Q1, Q2, method='lockin')

    History:
        04/28/2017      JG : Initial commit.
        04/29/2019      RL : Add "lockin" method, and add documents.
        11/30/2019      RL : Add support for non-square dataset
    '''
    if method is "lockin":
        tx = np.copy(phix)
        ty = np.copy(phiy)
        ux = -(Q2[1]*tx - Q1[1]*ty) / (Q1[0]*Q2[1]-Q1[1]*Q2[0])
        uy = -(Q2[0]*tx - Q1[0]*ty) / (Q1[1]*Q2[0]-Q1[0]*Q2[1])
        if obj is not None:
            if update_obj is True:
                obj.ux = ux
                obj.uy = uy
        return ux, uy
    elif method is "convolution":
        #s = np.shape(thetax)[-1]
        Qx_mag = np.sqrt((Q1[0])**2 + (Q1[1])**2)
        Qy_mag = np.sqrt((Q2[0])**2 + (Q2[1])**2)
        Qx_ang = np.arctan2(Q1[1],Q1[0]) # in radians
        Qy_ang = np.arctan2(Q2[1],Q2[0]) # in radians
        Qxdrift = 1/(Qx_mag) * phix#s/(2*np.pi*Qx_mag) * thetax
        Qydrift = 1/(Qy_mag) * phiy#s/(2*np.pi*Qy_mag) * thetay
        ux = Qxdrift * np.cos(Qx_ang) - Qydrift * np.sin(Qy_ang-np.pi/2)
        uy = Qxdrift * np.sin(Qx_ang) + Qydrift * np.cos(Qy_ang-np.pi/2)
        if obj is not None:
            if update_obj is True:
                obj.ux = -ux
                obj.uy = -uy
        return -ux, -uy
    else:
        print("Only two methods are available now:\n1. lockin\n2. convolution")

#7. - driftcorr
def driftcorr(A, ux=None, uy=None, obj=None, method="lockin", interpolation='cubic'):
    '''
    Correct the drift in the topo according to drift fields

    Inputs:
        A           - Required : 2D or 3D arrays of topo to be drift corrected
        obj         - Optional : Spy object of topo (2D) or map (3D).
        ux          - Optional : 2D arrays of drift field in x direction, generated by driftmap()
        uy          - Optional : 2D arrays of drift field in y direction, generated by driftmap()
        method      - Optional : Specifying which method to use.
                                    "interpolate": Interpolate A and then apply it to a new set of coordinates,
                                                    (x-ux, y-uy)
                                    "convolution": Used inversion fft to apply the drift fields
        interpolation - Optional : Specifying which method to use for interpolating

    Returns:
        A_corr      - 2D or 3D array of topo with drift corrected

    Usage:
        import stmpy.driftcorr as dfc
        A_corr = dfc.driftcorr(ux, uy, method='interpolate', interpolation='cubic')

    History:
        04/28/2017      JG : Initial commit.
        04/29/2019      RL : Add "invfft" method, and add documents.
        11/30/2019      RL : Add support for non-square dataset
    '''
    if method is "lockin":
        A_corr = np.zeros_like(A)
        *_, s2, s1 = np.shape(A)
        t1 = np.arange(s1, dtype='float')
        t2 = np.arange(s2, dtype='float')
        x, y = np.meshgrid(t1, t2)
        xnew = (x - ux).ravel()
        ynew = (y - uy).ravel()
        tmp = np.zeros(s1*s2)
        if len(A.shape) is 2:
            tmp_f = interp2d(t1, t2, A, kind=interpolation)
            for ix in range(tmp.size):
                tmp[ix] = tmp_f(xnew[ix], ynew[ix])
            A_corr = tmp.reshape(s2, s1)
            return A_corr
        elif len(A.shape) is 3:
            for iz, layer in enumerate(A):
                tmp_f = interp2d(t1, t2, layer, kind=interpolation)
                for ix in range(tmp.size):
                    tmp[ix] = tmp_f(xnew[ix], ynew[ix])
                A_corr[iz] = tmp.reshape(s2, s1)
                print('Processing slice %d/%d...'%(iz+1, A.shape[0]), end='\r')
            return A_corr
        else:
            print('ERR: Input must be 2D or 3D numpy array!')
    elif method is "convolution":
        A_corr = np.zeros_like(A)
        if len(A.shape) is 2:
            return _apply_drift_field(A, ux=ux, uy=uy, zeroOut=True)
        elif len(A.shape) is 3:
            for iz, layer in enumerate(A):
                A_corr[iz] = _apply_drift_field(layer, ux=ux, uy=uy, zeroOut=True)
                print('Processing slice %d/%d...'%(iz+1, A.shape[0]), end='\r')
            return A_corr
        else:
            print('ERR: Input must be 2D or 3D numpy array!')

def _apply_drift_field(A, ux, uy, zeroOut=True):
    A_corr = np.copy(A)
    *_, s2, s1 = np.shape(A)
    t1 = np.arange(s1, dtype='float')
    t2 = np.arange(s2, dtype='float')
    x, y = np.meshgrid(t1, t2)
    xshifted = x - ux
    yshifted = y - uy
    if zeroOut is True:
        A_corr[np.where(xshifted < 0)] = 0
        A_corr[np.where(yshifted < 0)] = 0
        A_corr[np.where(xshifted > s1)] = 0
        A_corr[np.where(yshifted > s2)] = 0
    qcoordx = (2*np.pi/s1)*(np.arange(s1)-int(s1/2))
    qcoordy = (2*np.pi/s2)*(np.arange(s2)-int(s2/2))
    #qcoord = (2*np.pi/s)*(np.arange(s)-(s/2))
    xshifted = np.reshape(xshifted, [1, s1*s2])
    yshifted = np.reshape(yshifted, [1, s1*s2])
    qcoordx = np.reshape(qcoordx, [s1, 1])
    qcoordy = np.reshape(qcoordy, [s2, 1])
    xphase = np.exp(-1j*(np.matmul(xshifted.T, qcoordx.T).T))
    yphase = np.exp(-1j*(np.matmul(yshifted.T, qcoordy.T).T))
    avgData = np.mean(A_corr)
    A_corr -= avgData
    A_corr = np.reshape(A_corr, s1*s2)
    data_temp = np.zeros([s2, s1*s2])
    for i in range(s2):
        data_temp[i] = A_corr
    FT = np.matmul(data_temp * xphase, yphase.T).T
    invFT = np.fft.ifft2(np.fft.fftshift(FT)) + avgData
    return np.real(invFT)

##################################################################################
####################### Useful functions in the processing #######################
##################################################################################

#8. - sortBraggs
def sortBraggs(br, s):
    ''' Sort the Bragg peaks in the order of "lower left, lower right, upper right, and upper left" '''
    *_, s2, s1 = s
    Br_s = np.zeros_like(br)
    index_corr = [[-1,-1],[1,-1],[1,1],[-1,1]]
    center = np.array([int(s1/2),int(s2/2)])
    for i,ix in enumerate(index_corr):
        for j,jy in enumerate(np.sign(br-center)):
            if np.all(jy==ix):
                Br_s[i] = br[j]
    return Br_s

#9. - cropedge
def cropedge(A, n, obj=None, bp=None, c1=2,c2=2, a1=None, a2=None, force_commen=False, update_obj=True):
    """
    Crop out bad pixels or highly drifted regions from topo/dos map.

    Inputs:
        A           - Required : 2D or 3D array of image to be cropped.
        n           - Required : List of integers specifying how many bad pixels to crop on each side.
                                    Order: [left, right, down, up].
        obj         - Optional : Spy object of topo (2D) or map (3D).
        force_commen- Optional : Boolean determining if the atomic lattice is commensurate with
                                    the output image.

        update_obj  - Optional : Boolean, if True then all the attributes of the object will be updated.

    Returns:
        A_crop  - 2D or 3D array of image after cropping.

    Usage:
        import stmpy.driftcorr as dfc
        A_crop = dfc.cropedge(A, n=5, obj=obj, corner='1')

    History:
        06/04/2019      RL : Initial commit.
        11/30/2019      RL : Add support for non-square dataset
    """
    if not isinstance(n, list):
        n = [n]
    if force_commen is not True:
        B = _rough_cut(A, n=n)
        print('Shape before crop:', end=' ')
        print(A.shape)
        print('Shape after crop:', end=' ')
        print(B.shape)
        return B
    else:
        if n != 0:
            B = _rough_cut(A, n)
        else:
            B = np.copy(A)
        *_, L2, L1 = np.shape(A)
        if bp is None:
            bp = findBraggs(A, show=False)
        bp = sortBraggs(bp, s=np.array([L2,L1]))
        bp_new = bp - np.array([int(L1/2), int(L2/2)])
        N1 = np.absolute(bp_new[0,0] - bp_new[1,0])
        N2 = np.absolute(bp_new[0,1] - bp_new[-1,1])
        offset = 0

        if a1 is None:
            a1 = c1 * L1 / N1
        if a2 is None:
            a2 = a1
            #a2 = c2 * L2 / N2
        *_, L2, L1 = np.shape(B)
        L_new1 = a1 * ((L1-offset)//(a1))
        L_new2 = a2 * ((L2-offset)//(a2))
        delta1 = (L1 - offset - L_new1) / 2
        delta2 = (L2 - offset - L_new2) / 2
        t1 = np.arange(L1)
        t2 = np.arange(L2)
        if len(np.shape(A)) == 2:
            f = interp2d(t1, t2, B, kind='cubic')
            t_new1 = np.linspace(delta1, L_new1+delta1, num=L1-offset+1)
            t_new2 = np.linspace(delta2, L_new2+delta2, num=L2-offset+1)
            z_new = f(t_new1[:-1], t_new2[:-1])
        elif len(np.shape(A)) == 3:
            z_new = np.zeros([np.shape(A)[0], L2-offset, L1-offset])
            for i in range(len(A)):
                f = interp2d(t1, t2, B[i], kind='cubic')
                t_new1 = np.linspace(0, L_new1, num=L1-offset+1)
                t_new2 = np.linspace(0, L_new2, num=L2-offset+1)
                z_new[i] = f(t_new1[:-1], t_new2[:-1])
        else:
            print('ERR: Input must be 2D or 3D numpy array!')
        if obj is not None:
            if update_obj is True:
                obj.size = obj.size * (obj.pixels - offset) / obj.pixels
                obj.pixels = obj.pixels - offset
                obj.qmag = obj.size / obj.a0
                obj.qscale = obj.pixels / (2*obj.qmag)
                obj.a1=a1
                obj.a2=a2
        return z_new

def _rough_cut(A, n):
    B = np.copy(A)
    if len(n) == 1:
        n1 = n2 = n3 = n4 = n[0]
    else:
        n1, n2, n3, n4, *_ = n
    if len(B.shape) is 2:
        if n2 == 0:
            n2 = -B.shape[1]
        if n4 == 0:
            n4 = -B.shape[0]
        return B[n3:-n4, n1:-n2]
    elif len(B.shape) is 3:
        if n2 == 0:
            n2 = -B.shape[2]
        if n4 == 0:
            n4 = -B.shape[1]
        return B[:,n3:-n4, n1:-n2]

def Gaussian2d(x, y, sigma_x, sigma_y, theta, x0, y0, Amp):
    '''
    x, y: ascending 1D array
    x0, y0: center
    '''
    a = np.cos(theta)**2/2/sigma_x**2 + np.sin(theta)**2/2/sigma_y**2
    b = -np.sin(2*theta)**2/4/sigma_x**2 + np.sin(2*theta)**2/4/sigma_y**2
    c = np.sin(theta)**2/2/sigma_x**2 + np.cos(theta)**2/2/sigma_y**2
    z = np.zeros((len(x), len(y)))
    X, Y = np.meshgrid(x, y)
    z = Amp * np.exp(-(a*(X-x0)**2 + 2*b*(X-x0)*(Y-y0) + c*(Y-y0)**2))
    return z

def FTDCfilter(A, sigma):
    '''
    Filtering DC component of Fourier transform and inverse FT, using a gaussian with one parameter sigma
    A is a 2D array, sigma is in unit of px
    '''
    *_, s2, s1 = A.shape
    m1, m2 = np.arange(s1, dtype='float'), np.arange(s2, dtype='float')
    c1, c2 = np.float((s1-1)/2), np.float((s2-1)/2)
    sigma1 = sigma
    sigma2 = sigma * s1 / s2
    g = Gaussian2d(m1, m2, sigma1, sigma2, 0, c1, c2, 1)
    ft_A = np.fft.fftshift(np.fft.fft2(A))
    ft_Af = ft_A * g
    Af = np.fft.ifft2(np.fft.ifftshift(ft_Af))
    return np.real(Af)

def unwrap_phase_2d(A, thres=None):
    output = np.copy(A[::-1,::-1])
    if len(np.shape(A)) == 2:
        n = np.shape(A)[-1]
        for i in range(n):
            output[i,:] = unwrap_phase(output[i,:], tolerance=thres)
        for i in range(n):
            output[:,i] = unwrap_phase(output[:,i], tolerance=thres)
        return output[::-1,::-1]

#10. - compute_dist
def compute_dist(x1, x2, p=None):

    if p is None:
        p1, p2 = 1, 1
    else:
        p1, p2 = p
    return np.sqrt(((x1[0]-x2[0])*p1)**2+((x1[1]-x2[1])*p2)**2)

#11. - global_corr
def global_corr(A, obj=None, bp=None, show=False, angle=np.pi/4, update_obj=True, **kwargs):
    """
    Global shear correct the 2D topo automatically.

    Inputs:
        A           - Required : 2D array of topo to be shear corrected.
        obj         - Optional : Spy object of topo (2D) or map (3D).
        bp          - Optional : Bragg points. If not offered, it will calculated from findBraggs(A)
        show        - Optional : Boolean specifying if the results are plotted or not
        update_obj  - Optional : Boolean, if True then all the attributes of the object will be updated.
        angle       - Optional : orientation of the Bragg peaks, default as pi/4. Will be passed to gshearcorr
        **kwargs    - Optional : key word arguments for findBraggs function
    Returns:
        bp_1    - Bragg peaks returned by gshearcorr. To be used in local_corr()
        data_1  - 2D array of topo after global shear correction

    Usage:
        import stmpy.driftcorr as dfc
        bp1, data1 = dfc.global_corr(A, show=True)

    History:
        04/29/2019      RL : Initial commit.
    """
    *_, s2, s1 = np.shape(A)
    if bp is None:
        bp_1 = findBraggs(A, obj=obj, thres=0.2, show=show, update_obj=update_obj, **kwargs)
    else:
        bp_1 = bp
    m, data_1 = gshearcorr(A, bp_1, obj=obj, rspace=True, angle=angle, update_obj=update_obj)
    if show is True:
        fig,ax=plt.subplots(1,2,figsize=[8,4])
        ax[0].imshow(data_1, cmap=stmpy.cm.blue2,origin='lower')
        ax[0].set_xlim(0,s1)
        ax[0].set_ylim(0,s2)
        ax[1].imshow(stmpy.tools.fft(data_1, zeroDC=True), cmap=stmpy.cm.gray_r,origin='lower')
        fig.suptitle('After global shear correction', fontsize=14)
        fig,ax=plt.subplots(2,2,figsize=[8,8])
        ax[0,0].imshow(data_1, cmap=stmpy.cm.blue2,origin='lower')
        ax[0,1].imshow(data_1, cmap=stmpy.cm.blue2,origin='lower')
        ax[1,0].imshow(data_1, cmap=stmpy.cm.blue2,origin='lower')
        ax[1,1].imshow(data_1, cmap=stmpy.cm.blue2,origin='lower')
        ax[0,0].set_xlim(0, s1/10)
        ax[0,0].set_ylim(s2-s2/10, s2)
        ax[0,1].set_xlim(s1-s1/10, s1)
        ax[0,1].set_ylim(s2-s2/10, s2)
        ax[1,0].set_xlim(0, s1/10)
        ax[1,0].set_ylim(0, s2/10)
        ax[1,1].set_xlim(s1-s1/10, s1)
        ax[1,1].set_ylim(0, s2/10)
        fig.suptitle('Bad pixels in 4 corners', fontsize=14)
    return m, data_1

#12. - local_corr
def local_corr(data, obj=None, bp=None, sigma=10, method="lockin", fixMethod='unwrap', show=False, update_obj=True, **kwargs):
    """
    Locally drift correct 2D topo automatically.

    Inputs:
        data        - Required : 2D array of topo after global shear correction, with bad pixels removed on the edge
        sigma       - Optional : Floating number specifying the size of mask to be used in phasemap()
        obj         - Optional : Spy object of topo (2D) or map (3D).
        bp          - Optional : Bragg points. If not offered, it will calculated from findBraggs(A)
        method      - Optional : Specifying which method to in phasemap()
                                "lockin": Spatial lock-in method to find phase map
                                "convolution": Gaussian mask convolution method to find phase map
        fixMethod   - Optional : Specifying which method to use in fixphaseslip()
                                "unwrap": fix phase jumps line by line in x direction and y direction, respectively
                                "spiral": fix phase slip in phase shift maps by flattening A into a 1D array in a spiral way
        show        - Optional : Boolean specifying if the results are plotted or not
        update_obj  - Optional : Boolean, if True then all the attributes of the object will be updated.
        **kwargs    - Optional : key word arguments for findBraggs function

    Returns:
        ux          - 2D array of drift field in x direction
        uy          - 2D array of drift field in y direction
        data_corr   - 2D array of topo after local drift corrected

    Usage:
        import stmpy.driftcorr as dfc
        ux, uy, data_corr = dfc.local_corr(data, sigma=5, method='lockin', fixMethod='unwrap', show=True)

    History:
        04/29/2019      RL : Initial commit.
    """
    *_, s2, s1 = np.shape(data)
    if bp is None:
        bp_2 = findBraggs(data, obj=obj, thres=0.2, show=show, update_obj=update_obj, **kwargs)
    else:
        bp_2 = bp
    thetax, thetay, Q1, Q2= phasemap(data, bp=bp_2, obj=obj, method=method, sigma=sigma, update_obj=update_obj)
    if show is True:
        fig,ax=plt.subplots(1,2,figsize=[8,4])
        ax[0].imshow(thetax, origin='lower')
        ax[1].imshow(thetay, origin='lower')
        fig.suptitle('Raw phase maps')
    thetaxf = fixphaseslip(thetax, method=fixMethod)
    thetayf = fixphaseslip(thetay, method=fixMethod)
    if show is True:
        fig,ax=plt.subplots(1,2,figsize=[8,4])
        ax[0].imshow(thetaxf, origin='lower')
        ax[1].imshow(thetayf, origin='lower')
        fig.suptitle('After fixing phase slips')
    ux, uy = driftmap(thetaxf, thetayf, Q1, Q2, obj=obj, method=method, update_obj=update_obj)
    if method=='lockin':
        data_corr = driftcorr(data, ux, uy, method='lockin', interpolation='cubic')
    elif method=='convolution':
        data_corr = driftcorr(data, ux, uy, method='convolution',)
    else:
        print("Error: Only two methods are available, lockin or convolution.")
    if show is True:
        fig,ax=plt.subplots(2,2,figsize=[8, 8])
        ax[1, 0].imshow(data_corr, cmap=stmpy.cm.blue1,origin='lower')
        ax[1, 1].imshow(stmpy.tools.fft(data_corr, zeroDC=True), cmap=stmpy.cm.gray_r,origin='lower')
        ax[0, 0].imshow(data, cmap=stmpy.cm.blue1,origin='lower')
        ax[0, 1].imshow(stmpy.tools.fft(data, zeroDC=True), cmap=stmpy.cm.gray_r,origin='lower')
        fig.suptitle('Before and after local drift correction')
    return ux, uy, data_corr

#14. - apply_dfc_3d
def apply_dfc_3d(data, ux, uy, matrix, bp=None, obj=None, n1=None, n2=None, method='convolution',update_obj=False):
    """
    Apply drift field (both global and local) found in 2D to corresponding 3D map.

    Inputs:
        data        - Required : 3D array of map to be drift corrected
        bp         - Required : Coordinates of Bragg peaks returned by local_corr()
        ux          - Required : 2D array of drift field in x direction. Usually generated by local_corr()
        uy          - Required : 2D array of drift field in y direction. Usually generated by local_corr()
        crop1       - Optional : List of length 1 or length 4, specifying after global shear correction how much to crop on the edge
        crop2       - Optional : List of length 1 or length 4, specifying after local drift correction how much to crop on the edge
        method      - Optional : Specifying which method to apply the drift correction
                                    "interpolate": Interpolate A and then apply it to a new set of coordinates,
                                                    (x-ux, y-uy)
                                    "invfft": Used inversion fft to apply the drift fields
    Returns:
        data_corr   - 2D array of topo after local drift corrected

    Usage:
        import stmpy.driftcorr as dfc
        data_corr = dfc.apply_dfc_3d(data, bp=bp, ux=ux, uy=uy, crop1=[5], crop2=[5], method='interpolate')

    History:
        04/29/2019      RL : Initial commit.
    """
    data_c = np.zeros_like(data)
    for i in range(len(data)):
        _, data_c[i]  = gshearcorr(data[i], matrix=matrix, obj=obj, rspace=True, update_obj=update_obj)
    if n1 is None:
        data_c = data_c
    else:
        data_c = cropedge(data_c, n=n1)
    data_corr = driftcorr(data_c, ux, uy, method=method, interpolation='cubic')
    if n2 is None:
        data_out = data_corr
    else:
        data_out = cropedge(data_corr, obj=obj, bp=bp, n=n2, force_commen=True, update_obj=update_obj)
    return data_out

#15. - display
def display(A, B=None, sigma=3, clim_same=True):
    '''
    Display or compare images in both real space and q-space.

    Inputs:
        A           - Required : Real space image to display.
        B           - Optional : Another real space image to be compared with A.
        sigma       - Optional : sigma for the color limit.
        clim_same   - Optional : If True, then both FT of A and B will be displayed under the
                                    same color limit (determined by A).

    Returns:
        N/A

    Usage:
        import stmpy.driftcorr as dfc
        dfc.display(topo.z)

    '''
    if B is None:
        A_fft = stmpy.tools.fft(A, zeroDC=True)
        c = np.mean(A_fft)
        s = np.std(A_fft)
        fig,ax=plt.subplots(1,2,figsize=[8, 4])
        ax[0].imshow(A, cmap=stmpy.cm.blue2, origin='lower')
        ax[1].imshow(A_fft, cmap=stmpy.cm.gray_r, origin='lower', clim=[0,c+sigma*s])
    else:
        A_fft = stmpy.tools.fft(A, zeroDC=True)
        B_fft = stmpy.tools.fft(B, zeroDC=True)
        c1 = np.mean(A_fft)
        s1 = np.std(A_fft)
        if clim_same is True:
            c2 = c1
            s2 = s1
        else:
            c2 = np.mean(B_fft)
            s2 = np.std(B_fft)
        fig,ax=plt.subplots(2,2,figsize=[8, 8])
        ax[0,0].imshow(A, cmap=stmpy.cm.blue2, origin='lower')
        ax[0,1].imshow(A_fft, cmap=stmpy.cm.gray_r, origin='lower', clim=[0,c1+sigma*s1])
        ax[1,0].imshow(B, cmap=stmpy.cm.blue2, origin='lower')
        ax[1,1].imshow(B_fft, cmap=stmpy.cm.gray_r, origin='lower', clim=[0,c2+sigma*s2])
        for axis in ax.flatten():
            axis.set_aspect(1)
        plt.tight_layout()


def quick_linecut(A, width=2, n=4, bp=None, ax=None, thres=3):
    """
    Take four linecuts automatically, horizontal, vertical, and two diagonal.
    Inputs:
        A           - Required : FT space image to take linecuts.
        width       - Optional : Number of pixels for averaging.
        bp          - Optional : Bragg peaks
        thres       - Optional : threshold for displaying FT

    Returns:
        N/A

    Usage:
        import stmpy.driftcorr as dfc
        r, cut = dfc.quick_linecut(A)

    """
    Y = np.shape(A)[-2] / 2
    X = np.shape(A)[-1] / 2
    r = []
    cut = []
    start = [[0,Y],[X,0],[0,0],[0,Y*2]]
    end = [[X*2, Y],[X, Y*2],[X*2, Y*2], [X*2, 0]]
    color = ['r','g','b','k']

    plt.figure(figsize=[4,4])
    if len(np.shape(A)) == 3:
        if bp is None:
            bp_x = np.min(findBraggs(np.mean(A, axis=0), rspace=False))
        else:
            bp_x = bp
        cm = np.mean(np.mean(A, axis=0))
        cs = np.std(np.mean(A, axis=0))
        plt.imshow(np.mean(A, axis=0), clim=[0,cm+thres*cs])
    elif len(np.shape(A)) == 2:
        if bp is None:
            bp_x = np.min(findBraggs(A, rspace=False))
        else:
            bp_x = bp
        cm = np.mean(A)
        cs = np.std(A)
        plt.imshow(A, clim=[0, cm+thres*cs])

    qscale = X*2 / (X*2 - bp_x * 2)

    for i in range(n):
        r1, cut1 = stmpy.tools.linecut(A, start[i], end[i],
            width = width, show=True, ax=plt.gca(), color=color[i])
        r.append(r1)
        cut.append(cut1)
    plt.gca().set_xlim(-1, X*2+1)
    plt.gca().set_ylim(-1, Y*2+1)
    return qscale, cut

def quick_show(A, en, thres=5, rspace=True, saveon=False, qlimit=1.2, imgName='', extension='png'):
    layers = len(A)
    if rspace is False:
        imgsize = np.shape(A)[-1]
        bp_x = np.min(findBraggs(np.mean(A, axis=0), min_dist=int(imgsize/10), rspace=rspace))
        ext = imgsize / (imgsize - 2*bp_x)
    if layers > 12:
        skip = layers // 12
    else:
        skip = 1
    fig,ax=plt.subplots(3,4,figsize=[16,12])
    try:
        for i in range(12):
            c = np.mean(A[i*skip])
            s = np.std(A[i*skip])
            if rspace is True:
                ax[i//4,i%4].imshow(A[i*skip], clim=[c-thres*s,c+thres*s],cmap=stmpy.cm.jackyPSD)
            else:
                ax[i//4,i%4].imshow(A[i*skip],extent=[-ext,ext,-ext,ext,],clim=[0,c+thres*s],cmap=stmpy.cm.gray_r)
                ax[i//4,i%4].set_xlim(-qlimit,qlimit)
                ax[i//4,i%4].set_ylim(-qlimit,qlimit)
            ax[i//4,i%4].set_aspect(1)
            stmpy.image.add_label("${}$ mV".format(int(en[i*skip])), ax=ax[i//4,i%4])
        plt.tight_layout()
    except IndexError:
        pass
    if saveon is True:
        plt.savefig("{}.{}".format(imgName, extension), bbox_inches='tight')

def quick_show_cut(A, en, qscale, thres=5, thres2=None, saveon=False, qlimit=1.2, imgName='', extension="png"):
    fname = ["M-0", "M-90", "X-45","X-135"]
    X1, Y1 = np.shape(A[0])
    X2, Y2 = np.shape(A[-1])
    q1 = np.linspace(-qscale, qscale, num=Y1)
    q2 = np.linspace(-qscale*np.sqrt(2), qscale*np.sqrt(2), num=Y2)
    if thres2 is None:
        thres2 = thres
    for i,ix in enumerate(A):
        plt.figure(figsize=[6,3])
        c = np.mean(ix)
        s = np.std(ix)
        if i in [0,1]:
            plt.pcolormesh(q1, en, ix,cmap=stmpy.cm.gray_r, vmin=0, vmax=c+thres*s)
        else:
            plt.pcolormesh(q2, en, ix,cmap=stmpy.cm.gray_r, vmin=0, vmax=c+thres2*s)
        plt.gca().set_xlim(-qlimit,qlimit)
        plt.axvline(-1, linestyle='--')
        plt.axvline(1, linestyle='--')
        if saveon is True:
            plt.savefig(imgName + " along {}.{}".format(fname[i], extension), facecolor='w')

# Quick show single images
def quick_show_single(A, en, thres=5, qscale=None, rspace=False, saveon=False, qlimit=1.2, imgName='', extension='png'):
    layers = len(A)
    if rspace is False:
        if qscale is None:
            imgsize = np.shape(A)[-1]
            if len(np.shape(A)) == 3:
                A_topo = np.mean(A, axis=0)
            else:
                A_topo = A
            bp_x = np.min(findBraggs(A_topo, min_dist=int(imgsize/10), rspace=rspace))
            ext = imgsize / (imgsize - 2*bp_x)
        else:
            ext = qscale
    if len(np.shape(A)) == 3:
        for i in range(layers):
            plt.figure(figsize=[5,5])
            c = np.mean(A[i])
            s = np.std(A[i])
            if rspace is True:
                plt.imshow(A[i], clim=[c-thres*s,c+thres*s],cmap=stmpy.cm.jackyPSD)
            else:
                plt.imshow(A[i],extent=[-ext,ext,-ext,ext,],clim=[0,c+thres*s],cmap=stmpy.cm.gray_r)
                plt.xlim(-qlimit,qlimit)
                plt.ylim(-qlimit,qlimit)
            stmpy.image.add_label("${}$ mV".format(int(en[i])), ax=plt.gca())
            plt.gca().axes.get_xaxis().set_visible(False)
            plt.gca().axes.get_yaxis().set_visible(False)
            plt.gca().set_frame_on(False)
            plt.gca().set_aspect(1)
            if saveon is True:
                plt.savefig("{} at {}mV.{}".format(imgName, int(en[i]), extension), dpi=400, bbox_inches='tight',pad_inches=0)
    elif len(np.shape(A)) == 2:
        plt.figure(figsize=[5,5])
        c = np.mean(A)
        s = np.std(A)
        if rspace is True:
            plt.imshow(A, clim=[c-thres*s,c+thres*s],cmap=stmpy.cm.jackyPSD)
        else:
            plt.imshow(A,extent=[-ext,ext,-ext,ext,],clim=[0,c+thres*s],cmap=stmpy.cm.gray_r)
            plt.xlim(-qlimit,qlimit)
            plt.ylim(-qlimit,qlimit)
        stmpy.image.add_label("${}$ mV".format(int(en)), ax=plt.gca())
        plt.gca().axes.get_xaxis().set_visible(False)
        plt.gca().axes.get_yaxis().set_visible(False)
        plt.gca().set_frame_on(False)
        plt.gca().set_aspect(1)
        if saveon is True:
            plt.savefig("{} at {}mV.{}".format(imgName, int(en), extension), dpi=400, bbox_inches='tight',pad_inches=0)
