from load_data import Loader, get_sample_label_pairs, load_image_location
from keras.models import Model, load_model
import keras.backend as K
import tensorflow as tf
from glob import glob
import numpy as np
import cv2 as cv
import click
import h5py
import os


split = 0.8
input_width = 160
input_height = 90
n = 5


def get_single_metrics(mask, pred):
    mask = np.array(mask > 0, dtype='int8')
    pred = np.array(pred > 0, dtype='int8')
    diff = mask - pred
    fp = np.count_nonzero(diff == -1)
    fn = np.count_nonzero(diff == 1)
    tp = np.count_nonzero(cv.bitwise_and(mask, pred))
    tn = np.count_nonzero(cv.bitwise_and(np.array(mask == 0, dtype='int8'),
                                         np.array(pred == 0, dtype='int8')))
    precision = float(tp) / float(tp + fp)
    recall = float(tp) / float(tp + fn)
    acc = float(tp + tn) / float(tp + tn + fp + fn)
    f1 = 2 * (precision * recall) / (precision + recall)
    return acc, precision, recall, f1


def get_metrics(masks, predictions):
    mstr = "{} - acc: {:.2f} | precision: {:.2f} | recall: {:.2f} | f1: {:.2f}"
    mets = []
    for mask, prediction in zip(masks, predictions):
        met = get_single_metrics(mask[0], prediction[0])
        print(mstr.format(prediction[1], *met))
        mets.append(np.array(met))
    return mets


def get_positions(mask):
    vol = np.zeros((mask.shape[0], mask.shape[1], 2))
    _, contours, _ = cv.findContours(mask.copy(), cv.RETR_LIST,
                                     cv.CHAIN_APPROX_SIMPLE)
    for cont in contours:
        top_left = np.min(cont, axis=0)[0]
        vol[top_left[1], top_left[0]] = [top_left[1] / 90.0,
                                         top_left[0] / 160.0]
    return vol


def get_single_diff(true_vol, pred):
    pred_vol = get_positions(pred)
    return true_vol - pred_vol


def get_sample_mse(true_vol, predictions):
    diff = []
    for i, pred in enumerate(predictions):
        vol = true_vol[i]
        diff.extend(get_single_diff(vol, pred))
    return np.mean(np.array(diff) ** 2) * 100


def load_dataset(block, pad, f, classification):
   if classification is False:
       name_format = "w_160_h_90_n_5_f_{}_bw_{}_bh_{}_pad_{}.npz"
   else:
       name_format = "class_w_160_h_90_n_5_f_{}_bw_{}_bh_{}_pad_{}.npz"
   name = name_format.format(f, block[1], block[0], pad)
   data = np.load("data/{}".format(name))
   X, y = data['X'], data['y']
   return X, y

# def load_dataset(block, input_width, input_height, n, future, padding,
#                  classification=False):
#     images = glob('resized/frames/*.jpg')
#     block_height, block_width, channels = block
#     print("Loading dataset...")
#     if classification is False:
#         data_format_str = 'w_{}_h_{}_n_{}_f_{}_bw_{}_bh_{}_pad_{}'
#     else:
#         data_format_str = 'class_w_{}_h_{}_n_{}_f_{}_bw_{}_bh_{}_pad_{}'
#     data_name = data_format_str.format(input_width, input_height, n, future,
#                                        block_width, block_height, padding)
#     filename = "data/{}.npz".format(data_name)
#     if os.path.exists(filename):
#         print("{} found. Using cache version".format(filename))
#         data = np.load(filename)
#         X, y = data['X'], data['y']
#     else:
#         X, y, _ = load_image_location(images, x_size=n, future=future,
#                                          h=input_height, w=input_width,
#                                          block_shape=block,
#                                          representation=padding,
#                                          classification=classification)
#         X = np.array(X, dtype='float16')
#         y = np.array(y, dtype='float16')
#         np.savez("data/{}".format(data_name), X=X, y=y)
#     return X, y


def prep_image(X, y, f):
    n, s, wx, wy, ch = X.shape
    X = np.rollaxis(X, 1)
    y = np.rollaxis(y, 1)
    return X, y


def prep_and_split_data(X, y, f, split, classification=False):
    if classification is False:
        l, n, s, wx, wy, ch = X.shape
        X = np.rollaxis(X, 2, 1).reshape(l * s, n, wx, wy, ch)
        y = np.rollaxis(y, 2, 1).reshape(l * s, f, wx, wy, 1)
    else:
        l, s, n, wx, wy, ch = X.shape
        X = X.reshape(l * s, n, wx, wy, ch)
        y = y.reshape(l * s, f)

    train_split = int(np.ceil(len(X) * split))

    X_train = X[:train_split]
    y_train = y[:train_split]
    X_test = X[train_split:]
    y_test = y[train_split:]
    return X_train, y_train, X_test, y_test


def recover(data, block, input_width, input_height, n):
    num_blocks = (int(input_height / block[0]), int(input_width / block[1]))
    ret_data = np.rollaxis(data, 1)
    print(ret_data.shape)
    ret_data = ret_data.reshape(-1, n, num_blocks[0], num_blocks[1],
                                block[0], block[1], block[2])
    ret_data = np.rollaxis(ret_data, 3, 5)
    i, _, nbh, bh, nbw, bw, ch = ret_data.shape
    ret_data = ret_data.reshape(i, n, nbh * bh, nbw * bw, ch)
    return ret_data

# def recover_class(data, input_width, input_height):
#     ni, f = data.shape
#     stride = 8
#     im = np.zeros((input_height * input_width, f))
#     for i, vec in enumerate(data):
#         # vol = np.tile(np.tile(vec, stride), stride).reshape(stride, stride, f)
#         vol = np.tile(vec, stride).reshape(stride, f)
#         im[i*stride:i*stride+stride] = vol
#     # im = im.reshape(input_width, input_height, f)
#     im = im.reshape(input_height, input_width, f)
#     im = np.rollaxis(im, 2)
#     print(im.shape)
#     return np.array(im * 255, dtype='uint8')

def recover_class(data, input_width, input_height):
    stride = 4
    n, f = data.shape
    data = np.rollaxis(data, 1)
    num_blocks = (int(np.ceil(input_width / stride)),
                  int(np.ceil(input_height / stride)))
    im = data.reshape(f, num_blocks[1], num_blocks[0]) * 255
    return im.astype('uint8')


def recover_and_save(data, block, input_width, input_height, n, output):
    out = recover(data, block, input_width, input_height, n)
    out = np.squeeze(out)
    for i, im in enumerate(out):
        out = np.array(im * 255, dtype='uint8')
        print("writing {}_{}.png".format(output, i))
        cv.imwrite("{}_{}.png".format(output, i), out)

def save_class(data, filename, input_width, input_height):
    imgs = recover_class(data, input_width, input_height)
    for i, im in enumerate(imgs):
        cv.imwrite("{}_{}.png".format(filename, i), im)   


def get_true_vols(x_size=5, future=5):
    vols = np.load('data/vols.npz')['vols']
    volumes = []
    for i in range(len(vols) - (x_size + future)):
        start = i + x_size
        volumes.append(np.array(vols[i:start]))
    return volumes


def resize_mask(mask):
    x = np.zeros((90, 160))
    for i in range(23):
        for j in range(40):
            x[i * 4:i * 4 + 4, j * 4:j * 4 + 4] = mask[i, j]
    return x

def eval_data(X, y, model, thresh, input_width, input_height, f):
    vols = np.array(get_true_vols(future=5))
    vols = vols[:, 3]
    print("vols", np.array(vols).shape)
    mses = []
    for i in range(len(X)):
        true = recover_class(y[i], input_width, input_height)
        y_pred = model.predict(X[i])
        y_pred = y_pred > thresh
        pred = recover_class(y_pred, input_width, input_height)
        # mets = get_metrics(true, pred)
        # print(mets)
        resize_pred = np.zeros((f, 90, 160), dtype='uint8')
        for i in range(f):
            resize_pred[i] = resize_mask(pred[i])
            # print('Write prediction {}'.format(i))
            # cv.imwrite("{}_{}.png".format('prediction', i), resize_pred[i])   
        mse = get_sample_mse(vols[i], resize_pred)
        mses.append(mse)
    print("MSE: {}".format(np.mean(mses)))


@click.command()
@click.argument('filename')
def main(filename):
    input_width = 160
    input_height = 90
    n = 5
    future = 3
    num_channels = 3
    thresh = 0.5
    classification = False

    name = os.path.splitext(filename)[0]
    dig = [int(s) for s in name.split("_") if s.isdigit()]
    # _, bh, bw, _, e, _, pad = dig
    # f = 1
    bh, bw, f, e, _, pad = dig

    block = (bh, bw, num_channels)
    X, y = load_dataset(block, pad, f, classification)
    # X, y = load_dataset(block, input_width, input_height, n, f, pad,
    #              classification=classification)

    print("Loaded dataset {} | {}".format(X.shape, y.shape))

    # ret_data = recover(X[0], block, input_width, input_height, n)
    # print("Recover shape: {}".format(ret_data.shape))
    block_x, block_y = prep_image(X[0], y[0], n)
    print("Blocks: {} | {}".format(block_x.shape, block_y.shape))

    recover_and_save(block_x, block, input_width, input_height,
                     n, 'image') 

    recover_and_save(block_y, (block[0], block[1], 1),
                     input_width, input_height, f, 'true')

    # save_class(y[0], 'true', input_width, input_height)

    model = load_model(filename)
    y_pred = model.predict(block_x)
    # y_pred = model.predict(X[0])
    y_pred = np.squeeze(y_pred)

    # save_class(y_pred, 'prediction', input_width, input_height)

    print("count before: {}".format(np.count_nonzero(y_pred)))
    print("MIN: {} | MAX: {}".format(np.min(y_pred), np.max(y_pred)))
    y_pred = y_pred > thresh
    print("thresh {} | count after: {}".format(thresh,
                                               np.count_nonzero(y_pred)))

    # save_class(y_pred, 'thresholded', input_width, input_height)

    recover_and_save(y_pred, (block[0], block[1], 1),
                     input_width, input_height, f, 'prediction')
    # eval_data(X[-1200:], y[-1200:], model, thresh, input_width, input_height, f)
    
if __name__ == '__main__':
    main()
