from load_data import Loader, get_sample_label_pairs
from keras.models import Model, load_model
import keras.backend as K
import tensorflow as tf
from glob import glob
import numpy as np
import cv2 as cv
import click
import h5py
import os


split = 0.8
input_width = 160
input_height = 90
n = 5


def load_dataset(block, pad, f, classification):
    if classification is False:
        name_format = "w_160_h_90_n_5_f_{}_bw_{}_bh_{}_pad_{}.npz"
    else:
        name_format = "class_w_160_h_90_n_5_f_{}_bw_{}_bh_{}_pad_{}.npz"
    name = name_format.format(f, block[1], block[0], pad)
    data = np.load("data/{}".format(name))
    X, y = data['X'], data['y']
    return X, y


def prep_image(X, y, f):
    n, s, wx, wy, ch = X.shape
    X = np.rollaxis(X, 1)
    y = np.rollaxis(y, 1)
    return X, y


def prep_and_split_data(X, y, f, split, classification=False):
    if classification is False:
        l, n, s, wx, wy, ch = X.shape
        X = np.rollaxis(X, 2, 1).reshape(l * s, n, wx, wy, ch)
        y = np.rollaxis(y, 2, 1).reshape(l * s, f, wx, wy, 1)
    else:
        l, s, n, wx, wy, ch = X.shape
        X = X.reshape(l * s, n, wx, wy, ch)
        y = y.reshape(l * s, f)

    train_split = int(np.ceil(len(X) * split))

    X_train = X[:train_split]
    y_train = y[:train_split]
    X_test = X[train_split:]
    y_test = y[train_split:]
    return X_train, y_train, X_test, y_test


def recover(data, block, input_width, input_height, n):
    num_blocks = (int(input_height / block[0]), int(input_width / block[1]))
    if n != 1:
        ret_data = np.rollaxis(data, 1)
    else:
        ret_data = data

    print(ret_data.shape)
    ret_data = ret_data.reshape(-1, n, num_blocks[0], num_blocks[1],
                                block[0], block[1], block[2])
    ret_data = np.rollaxis(ret_data, 3, 5)
    i, _, nbh, bh, nbw, bw, ch = ret_data.shape
    ret_data = ret_data.reshape(i, n, nbh * bh, nbw * bw, ch)
    return ret_data

# def recover_class(data, input_width, input_height):
#     ni, f = data.shape
#     stride = 8
#     im = np.zeros((input_height * input_width, f))
#     for i, vec in enumerate(data):
#         # vol = np.tile(np.tile(vec, stride), stride).reshape(stride, stride, f)
#         vol = np.tile(vec, stride).reshape(stride, f)
#         im[i*stride:i*stride+stride] = vol
#     # im = im.reshape(input_width, input_height, f)
#     im = im.reshape(input_height, input_width, f)
#     im = np.rollaxis(im, 2)
#     print(im.shape)
#     return np.array(im * 255, dtype='uint8')


def recover_class(data, input_width, input_height):
    stride = 4
    n, f = data.shape
    data = np.rollaxis(data, 1)
    num_blocks = (int(np.ceil(input_width / stride)),
                  int(np.ceil(input_height / stride)))
    im = data.reshape(f, num_blocks[1], num_blocks[0]) * 255
    return im.astype('uint8')


def save(data, output):
    out = np.squeeze(data)
    for i, im in enumerate(out):
        if np.max(im) > 1:
            out = np.array(im, dtype='uint8')
        else:
            out = np.array(im * 255, dtype='uint8')

        print("writing {}_{}.png".format(output, i))
        cv.imwrite("{}_{}.png".format(output, i), out)

def recover_and_save(data, block, input_width, input_height, n, output):
    out = recover(data, block, input_width, input_height, n)
    save(out, output)

def save_class(data, filename, input_width, input_height):
    imgs = recover_class(data, input_width, input_height)
    for i, im in enumerate(imgs):
        cv.imwrite("{}_{}.png".format(filename, i), im)   


@click.command()
@click.argument('filename')
def main(filename):
    input_width = 160
    input_height = 90
    n = 5
    future = 3
    num_channels = 3
    thresh = 0.5
    classification = True

    name = os.path.splitext(filename)[0]
    dig = [int(s) for s in name.split("_") if s.isdigit()]
    model_num, bh, bw, f, e, _, pad = dig

    block = (bh, bw, num_channels)
    X, y = load_dataset(block, pad, f, classification)
    print("Loaded dataset {} | {}".format(X.shape, y.shape))

    if classification is False:
        ret_data = recover(X[0], block, input_width, input_height, n)
        print("Recover shape: {}".format(ret_data.shape))
        block_x, block_y = prep_image(X[0], y[0], n)
        print("Blocks: {} | {}".format(block_x.shape, block_y.shape))

        recover_and_save(block_x, block, input_width, input_height,
                         n, 'image')
    else:
        save_class(y[0], 'true', input_width, input_height)

    # recover_and_save(block_y, (block[0], block[1], 1),
    #                  input_width, input_height, f, 'true')

    # save_class(y[0], 'true', input_width, input_height)

    out = []
    raw = []
    if classification is False:
        models = glob('ensamble_[0-9]__window*.h5')
    else:
        models = glob('ensamble_*_classificator*.h5')
    models.sort()

    for model_name in models:
        print(model_name)
        model = load_model(model_name)
        if classification is False:
            y_pred = model.predict(block_x)
        else:
            y_pred = model.predict(X[0])
        # y_pred = np.squeeze(y_pred)

        print("count before: {}".format(np.count_nonzero(y_pred)))
        print("MIN: {} | MAX: {}".format(np.min(y_pred), np.max(y_pred)))

        if classification is False:
            y_pred = recover(y_pred, (block[0], block[1], 1),
                             input_width, input_height, 1)
        else:
            r = recover_class(y_pred, input_width, input_height)
        raw.append(r * 255)

        y_pred = y_pred > thresh
        print("thresh {} | count after: {}".format(thresh,
                                               np.count_nonzero(y_pred)))

        if classification is False:
            y_pred = recover(y_pred, (block[0], block[1], 1),
                             input_width, input_height, 1)
        else:
            print(y_pred.shape)
            y_pred = recover_class(y_pred, input_width, input_height)
        out.append(y_pred)
        
    out = np.array(out)
    raw = np.array(raw)
    save(raw, 'predictions')
    save(out, 'thresholded')
        
    
    
if __name__ == '__main__':
    main()
