from keras.layers import Conv3D, MaxPooling3D, UpSampling3D, UpSampling2D
from load_data import Loader, load_image_location
from keras.layers import Dense, Dropout, Flatten
from keras.models import Sequential, load_model
from keras.callbacks import EarlyStopping
from keras.optimizers import Adam
import tensorflow as tf
import keras.backend as K
from model import get_model
from glob import glob
import numpy as np
import matplotlib as mpl
mpl.use('Agg')
from matplotlib import pyplot as plt
import os
import click


def save_model_stats(history, model_save_name, score):
    model_loss = history.history['loss']
    model_val_loss = history.history['val_loss']
    stats = np.array(np.column_stack((model_loss,
                                      model_val_loss,
                                      [score] * len(model_loss))))

    with open('{}.txt'.format(model_save_name), 'w') as f:
        f.write("loss,loss_val,score")
        for i, loss in enumerate(model_loss):
            f.write('{},{},{}'.format(loss, model_val_loss[i], score))

    plt.figure()
    plt.plot(model_loss, '-bo', label='loss')
    plt.plot(model_val_loss, '-go', label='validation loss')
    plt.legend(loc='best')
    plt.savefig('{}.png'.format(model_save_name))


def load_dataset(block, input_width, input_height, n, future, padding,
                 classification=False):
    images = glob('resized/frames/*.jpg')
    block_height, block_width, channels = block
    print("Loading dataset...")
    if classification is False:
        data_format_str = 'w_{}_h_{}_n_{}_f_{}_bw_{}_bh_{}_pad_{}'
    else:
        data_format_str = 'class_w_{}_h_{}_n_{}_f_{}_bw_{}_bh_{}_pad_{}'
    data_name = data_format_str.format(input_width, input_height, n, future,
                                       block_width, block_height, padding)
    filename = "data/{}.npz".format(data_name)
    if os.path.exists(filename):
        print("{} found. Using cache version".format(filename))
        data = np.load(filename)
        X, y = data['X'], data['y']
    else:
        X, y, vols = load_image_location(images, x_size=n, future=future,
                                         h=input_height, w=input_width,
                                         block_shape=block,
                                         representation=padding,
                                         classification=classification)
        X = np.array(X, dtype='float16')
        y = np.array(y, dtype='float16')
        np.savez("data/{}".format(data_name), X=X, y=y)
    return X, y


def save_vols(block, input_width, input_height, n, future, padding,
              classification=False):
    images = glob('resized/frames/*.jpg')
    block_height, block_width, channels = block
    _, _, vols = load_image_location(images, x_size=n, future=future,
                                     h=input_height, w=input_width,
                                     block_shape=block,
                                     representation=padding,
                                     classification=classification)
    vols = np.array(vols, dtype='float16')
    np.savez("data/vols", vols=vols)



def prep_and_split_data(X, y, f, split, classification=False):
    if classification is False:
        l, n, s, wx, wy, ch = X.shape
        X = np.rollaxis(X, 2, 1).reshape(l * s, n, wx, wy, ch)
        y = np.rollaxis(y, 2, 1).reshape(l * s, f, wx, wy, 1)
    else:
        l, s, n, wx, wy, ch = X.shape
        X = X.reshape(l * s, n, wx, wy, ch)
        y = y.reshape(l * s, f)

    train_split = int(np.ceil(len(X) * split))

    X_train = X[:train_split]
    y_train = y[:train_split]
    X_test = X[train_split:]
    y_test = y[train_split:]
    return X_train, y_train, X_test, y_test


def get_model_name(model_type, block_height, block_width, future,
                   loss, num_epoches, batch, padding, add=None,
                   ensamble=False):
    model_name = ''
    if ensamble:
        model_name = 'ensamble_{}_'.format(add)

    if model_type == 'class': 
        model_name = '{}_classificator'.format(model_name, model_type)
    else:
        model_name = '{}{}_encoder'.format(model_name, model_type)

    model_save_format = '{}_window_{}_{}_f_{}_loss_{}_epoches_{}_batch_{}_p_{}'
    model_save_name = model_save_format.format(model_name,
                                               block_height,
                                               block_width,
                                               future,
                                               loss,
                                               num_epoches,
                                               batch,
                                               padding)
    return model_name, model_save_name


def train(block, future, batch, loss, model_type, input_width,
          input_height, n, padding, ensamble):
    split = 0.8
    thresh = 0.5
    if model_type == 'class':
        classification = True
    else:
        classification = False
 
    block_height, block_width, channels = block
    
    X, y = load_dataset(block, input_width, input_height, n, future, padding,
                        classification)

    input_depth = X[0].shape[-1]

    X_train, y_train, X_test, y_test = prep_and_split_data(X, y, future, split, 
                                                           classification)

    print("\nTraining set size: {} with shape: {}".
            format(X_train.shape[0], X_train.shape[1:]))

    print("Training labels size: {} with shape: {}".
            format(y_train.shape[0], y_train.shape[1:]))

    print("Test set size: {} with shape: {}".
            format(X_train.shape[0], X_train.shape[1:]))

    print("Test labels size: {} with shape: {}".
            format(y_train.shape[0], y_train.shape[1:]))

    print("------------------------------------------------------------------")
    num_epoches = 30

    # if ensamble:
    #     models = get_model('ensamble', block_width, block_height,
    #                        input_depth, n, future, loss, model_type)
    # else:
    #     model = get_model(model_type, block_width, block_height,
    #                       input_depth, n, future, loss)

    weights_idx = np.unique(np.nonzero(y_train)[0])
    non_zero = np.count_nonzero(y_train)
    print("")
    print("padding: {}".format(padding))
    print("One samples: {} | All samples: {} | ratio: {}".
            format(weights_idx.shape, X_train.shape[0],
                   weights_idx.shape[0] / X_train.shape[0]))
    all_pixels = np.prod(X_train.shape)
    print("all pixels: {} | nonzero: {} | ratio: {}".
            format(all_pixels, non_zero, non_zero / all_pixels))
    print("")
    return 0, None, ""

    sample_weights = np.zeros(X_train.shape[0]) + 0.3
    sample_weights[weights_idx] += 0.7
    early_stopper = EarlyStopping(patience=3)

    if ensamble:
        for i, model in enumerate(models):
            if model_type == 'class':
                lu, fu = y_train.shape
                y_train_ens = y_train[:, i].reshape(lu, 1)
                lu, _ = y_test.shape
                y_test_ens = y_test[:, i].reshape(lu, 1)
            else:
                lu, fu, wx, wy, ch = y_train.shape
                y_train_ens = y_train[:, i].reshape(lu, 1, wx, wy, ch)
                lu, _, _, _, _ = y_test.shape
                y_test_ens = y_test[:, i].reshape(lu, 1, wx, wy, ch)

            print("=========================================================")
            print("Ensamble {} train shape: {}".format(i, y_train_ens.shape))
            print("=========================================================")

            history = model.fit(X_train, y_train_ens, batch_size=batch,
                                epochs=num_epoches, validation_split=0.1,
                                verbose=True, sample_weight=sample_weights,
                                callbacks=[early_stopper])

            model.history = history
            model_name, model_save_name = get_model_name(model_type,
                                                         block_height,
                                                         block_width,
                                                         future,
                                                         loss,
                                                         num_epoches,
                                                         batch,
                                                         padding, i, ensamble)
            print("=========================================================")
            print("\tSaving model: {}".format(model_save_name))
            print("=========================================================")
            model.save('{}.h5'.format(model_save_name))
            score = model.evaluate(X_test, y_test_ens, batch_size=batch)
            save_model_stats(history, model_save_name, score)
    else:
        history = model.fit(X_train, y_train, batch_size=batch,
                            epochs=num_epoches, validation_split=0.1,
                            verbose=True, sample_weight=sample_weights,
                            callbacks=[early_stopper])

        model.history = history
        model_name, model_save_name = get_model_name(model_type,
                                                     block_height,
                                                     block_width,
                                                     future,
                                                     loss,
                                                     num_epoches,
                                                     batch,
                                                     padding,
                                                     None,
                                                     ensamble)

        print("=========================================================")
        print("\tSaving model: {}".format(model_save_name))
        print("=========================================================")
        model.save('{}.h5'.format(model_save_name))
        score = model.evaluate(X_test, y_test, batch_size=batch)

        save_model_stats(history, model_save_name, score)
        print("Final loss: {}\nFinal val_loss: {}\nScore: {}".
                 format(history.history['loss'],
                        history.history['val_loss'],
                        score))
    return score, model, model_save_name


@click.command()
@click.option('--batch', default=256, type=int)
@click.option('--loss', default='mse', type=str)
@click.option('--type', default='column', type=str)
@click.option('--ensamble', default=False, is_flag=True)
def main(batch, loss, type, ensamble):
    input_width = 160
    input_height = 90
    # n represents past frame nn is exposed to
    n = 5
    # future represents number of predicted
    # frames to the future
    futures = [5]
    # pedestrain representation
    pads = [6]
    num_channels = 3
    block = (9, 8, 3)
   
    best_score = -1
    best_model = None
    best_model_name = None
    for pad in pads:
        for future in futures:
            print("==================================================")
            print("\t\t   padding: {}".format(pad))
            print("==================================================")
            score, model, model_name = train(block, future, batch, loss, type,
                                             input_width, input_height, n, pad,
                                             ensamble)

if __name__ == '__main__':
    main()
