import time
import numpy as np
import sys


def progbar(i, n, size=16):
    done = (i * size) // n
    bar = ''
    for i in range(size):
        bar += '█' if i <= done else '░'
    return bar


def stream(message) :
    try:
        sys.stdout.write("\r{%s}" % message)
    except:
        #Remove non-ASCII characters from message
        message = ''.join(i for i in message if ord(i)<128)
        sys.stdout.write("\r{%s}" % message)


def simple_table(item_tuples) :

    border_pattern = '+---------------------------------------'
    whitespace = '                                            '

    headings, cells, = [], []

    for item in item_tuples :

        heading, cell = str(item[0]), str(item[1])

        pad_head = True if len(heading) < len(cell) else False

        pad = abs(len(heading) - len(cell))
        pad = whitespace[:pad]

        pad_left = pad[:len(pad)//2]
        pad_right = pad[len(pad)//2:]

        if pad_head :
            heading = pad_left + heading + pad_right
        else :
            cell = pad_left + cell + pad_right

        headings += [heading]
        cells += [cell]

    border, head, body = '', '', ''

    for i in range(len(item_tuples)) :

        temp_head = f'| {headings[i]} '
        temp_body = f'| {cells[i]} '

        border += border_pattern[:len(temp_head)]
        head += temp_head
        body += temp_body

        if i == len(item_tuples) - 1 :
            head += '|'
            body += '|'
            border += '+'

    print(border)
    print(head)
    print(border)
    print(body)
    print(border)
    print(' ')


def time_since(started) :
    elapsed = time.time() - started
    m = int(elapsed // 60)
    s = int(elapsed % 60)
    if m >= 60 :
        h = int(m // 60)
        m = m % 60
        return f'{h}h {m}m {s}s'
    else :
        return f'{m}m {s}s'

def save_attention(attn, path):
    import matplotlib.pyplot as plt

    fig = plt.figure(figsize=(12, 6))
    plt.imshow(attn.T, interpolation='nearest', aspect='auto')
    fig.savefig(f'{path}.png', bbox_inches='tight')
    plt.close(fig)

def save_attention_multiple(attn, path):
    import matplotlib.pyplot as plt

    num_plots = len(attn)
    fig = plt.figure(figsize=(12, 6 * num_plots))
    for i, a in enumerate(attn):
        plt.subplot(num_plots, 1, i+1)
        plt.imshow(a.T, interpolation='nearest', aspect='auto')
        plt.xlabel("Decoder Step")
        plt.ylabel("Encoder Step")
        plt.title(f"Encoder-Decoder Alignment of No.{i} Sequence")
    fig.savefig(f'{path}.png', bbox_inches='tight')
    plt.close(fig)

def save_stop_tokens(stop, path):
    import matplotlib.pyplot as plt

    num_plots = len(stop)
    fig = plt.figure(figsize=(12, 6 * num_plots))
    for i, s in enumerate(stop):
        plt.subplot(num_plots, 1, i+1)
        plt.plot(s)
        plt.xlabel("Timestep")
        plt.ylabel("Stop Value")
        plt.title(f"Stop Tokens of No.{i} Sequence")
    fig.savefig(f'{path}.png', bbox_inches='tight')
    plt.close(fig)


def save_spectrogram(M, path, length=None):
    import matplotlib.pyplot as plt

    M = np.flip(M, axis=0)
    if length : M = M[:, :length]
    fig = plt.figure(figsize=(12, 6))
    plt.imshow(M, interpolation='nearest', aspect='auto')
    plt.xlabel("Time")
    plt.ylabel("Frequency")
    plt.title("Generated Mel Spectrogram")
    fig.savefig(f'{path}.png', bbox_inches='tight')
    plt.close(fig)


def plot(array):
    import matplotlib.pyplot as plt

    fig = plt.figure(figsize=(30, 5))
    ax = fig.add_subplot(111)
    ax.xaxis.label.set_color('grey')
    ax.yaxis.label.set_color('grey')
    ax.xaxis.label.set_fontsize(23)
    ax.yaxis.label.set_fontsize(23)
    ax.tick_params(axis='x', colors='grey', labelsize=23)
    ax.tick_params(axis='y', colors='grey', labelsize=23)
    plt.plot(array)


def plot_spec(M):
    import matplotlib.pyplot as plt

    M = np.flip(M, axis=0)
    plt.figure(figsize=(18,4))
    plt.imshow(M, interpolation='nearest', aspect='auto')
    plt.show()