import matplotlib.pyplot as plt
import numpy as np
import math
import click
import pywt


def add_rectangle(a, t1, t2, f1,f2,v):
    for row in range(t1,t2):
        for col in range(f1,f2):
            a[col][row] = v

def is_power_of_two(n):
    nlog = round( math.log(n, 2) )
    return math.pow(2, nlog) == n

def plot_scalogram(ax, v):
    """
    Expects a vector of Haar DWT coefficients
    (dc ; then low frequencies first)
    and plots the scalogram
    :param v:  vector of length n = 2**L
    :return: nothing -plots the scalogram
    """
    n = len(v)
    L = round( math.log(n,2) ) # number of levels

    assert is_power_of_two(n), print(f'{n=} is not a power of 2 - exiting')
    # print(f'{v=}')
    scalogram_image = np.zeros( (n,n) )

    # plot the dc coefficient (father wavelet)
    t_offset = 0
    d_t = n
    f_offset = 0
    d_f = 1
    i = 0
    add_rectangle(scalogram_image, t_offset, t_offset + d_t, f_offset, f_offset + d_f, v[i])

    # now plot the diff coefficients (mother wavelet)
    f_offset = f_offset + d_f
    for level in range(1, L+1):
        # print(f'{level=} ------')
        t_offset = 0 # always start from t=0
        while t_offset < n:
            i = i+1 # next element of the vector 'v'
            add_rectangle(scalogram_image, t_offset, t_offset + d_t, f_offset, f_offset + d_f, v[i])
            # print(v[i], end=" ")
            t_offset = t_offset + d_t
        # print("")
        d_t = round(d_t / 2)
        f_offset = f_offset + d_f
        d_f = 2* d_f

    # plt.imshow(scalogram_image, cmap='jet')
    # plt.imshow(scalogram_image, cmap='Greys')
    ax.imshow(scalogram_image, cmap='Blues', aspect='auto')
    ax.set_xlabel('time')
    ax.set_ylabel('freq')
    # plt.title('DWT Haar scalogram')
    # plt.show()

def haar(s):
    """ expects a time sequence s of length n (power of 2)
    and returns the DWT coefficients
    """
    n = len(s)
    assert is_power_of_two(n)
    coeffs = pywt.wavedec(s, 'haar')
    return coeffs

def haar_and_plot(s, wave_name="blank"):
    fig, axs = plt.subplots(2, sharex=True)
    fig.suptitle(wave_name)
    axs[0].plot(s)
    axs[0].set_xlabel("time")
    # plt.show()
    coeffs = haar(s)
    dwt_haar = np.concatenate(coeffs).ravel()
    plot_scalogram(axs[1], dwt_haar)
    plt.tight_layout()
    plt.show()

@click.command()
@click.option('--duration', default=16, help='length of generated sequence')
def main(duration ):
    N = duration
    v = list(range(N))
    # plot_scalogram( v)
    # plot_scalogram( v[::-1])

    if False:
        coeffs = haar(v)
        # print(f'{coeffs=}')
        dwt_haar = np.concatenate(coeffs).ravel()
        plot_scalogram(dwt_haar)

    # haar_and_plot(v)

    half = round(N/2)

    single_spike = [0 for i in v]
    single_spike[half] = 1
    haar_and_plot(single_spike, "single_spike")

    prolonged_spike = [0 for i in v]
    for i in range(half, half+10):
        prolonged_spike [i] = 100
    haar_and_plot(prolonged_spike, "prolonged_spike")

    large_square_wave = [0 for i in v]
    for i in range(half):
        large_square_wave[i] = 1
    haar_and_plot(large_square_wave, "square_wave")

    flip_flop_square_wave = [0 for i in v]
    for i in range(N):
        if i % 2 == 0:
            flip_flop_square_wave[i] = 1
    haar_and_plot(flip_flop_square_wave, "flip_flop")

    zero_mean_flip_flop_square_wave = [-1 for i in v]
    for i in range(N):
        if i % 2 == 0:
            zero_mean_flip_flop_square_wave[i] = 1
        else:
            zero_mean_flip_flop_square_wave[i] = -1

    haar_and_plot(zero_mean_flip_flop_square_wave, "zero_mean_flip_flop")

    large_sine = [ math.sin(2* math.pi * i / N ) for i in v]
    haar_and_plot(large_sine, "large_sine")



if __name__ == "__main__":
    main()
