from matplotlib import pyplot as plt
from glob import glob
import pandas as pd
import numpy as np
import cv2 as cv


def get_single_metrics(mask, pred):
    mask = np.array(mask > 0, dtype='int8')
    pred = np.array(pred > 0, dtype='int8')
    diff = mask - pred
    fp = np.count_nonzero(diff == -1)
    fn = np.count_nonzero(diff == 1)
    tp = np.count_nonzero(cv.bitwise_and(mask, pred))
    tn = np.count_nonzero(cv.bitwise_and(np.array(mask == 0, dtype='int8'),
                                         np.array(pred == 0, dtype='int8')))
    precision = float(tp) / float(tp + fp)
    recall = float(tp) / float(tp + fn)
    acc = float(tp + tn) / float(tp + tn + fp + fn)
    f1 = 2 * (precision * recall) / (precision + recall)
    return acc, precision, recall, f1


def get_metrics(masks, predictions):
    mstr = "{} - acc: {:.2f} | precision: {:.2f} | recall: {:.2f} | f1: {:.2f}"
    mets = []
    for mask, prediction in zip(masks, predictions):
        met = get_single_metrics(mask[0], prediction[0])
        print(mstr.format(prediction[1], *met))
        mets.append(np.array(met))
    return mets


def get_positions(mask):
    vol = np.zeros((mask.shape[0], mask.shape[1], 2))
    contours, _ = cv.findContours(mask.copy(), cv.RETR_LIST,
                                  cv.CHAIN_APPROX_SIMPLE)
    for cont in contours:
        top_left = np.min(cont, axis=0)[0]
        vol[top_left[1], top_left[0]] = [top_left[1] / 90.0,
                                         top_left[0] / 160.0]
    return vol


def get_single_diff(true_vol, pred):
    pred_vol = get_positions(pred)
    return true_vol - pred_vol


def get_sample_mse(true_vol, predictions):
    diff = []
    for i, pred in enumerate(predictions):
        vol = true_vol[:, :, i * 2:i * 2 + 2]
        diff.extend(get_single_diff(vol, pred[0]))
    return np.mean(np.array(diff) ** 2) * 100


def get_vols(x_size=5, future=5):
    vols = np.load('class_models/vols.npz')['vols']
    volumes = []
    for i in range(len(vols) - (x_size + future)):
        start = i + x_size
        volumes.append(np.array(vols[i:start]))
    return volumes

f, ax = plt.subplots(5, 2)

masks = []
predictions = []
for i, true in enumerate(sorted(glob('true*.png'))):
    print(true)
    masks.append((cv.imread(true, 0), true))
    ax[i, 0].imshow(masks[-1][0], cmap='gray')
    ax[i, 0].set_title(true)
    ax[i, 0].axis('off')

for i, prediction in enumerate(sorted(glob('prediction*.png'))):
    predictions.append((cv.imread(prediction, 0), prediction))
    get_positions(predictions[-1][0])
    ax[i, 1].imshow(predictions[-1][0], cmap='gray')
    ax[i, 1].set_title(prediction)
    ax[i, 1].axis('off')

metrics = get_metrics(masks, predictions)

metrics = pd.DataFrame(metrics, columns=['acc', 'precision', 'recall', 'f1'])
print("")
print(metrics.mean())

test_vol = np.array(get_vols(future=1))[0]
test_vol = np.rollaxis(test_vol, 0, 3)
print(test_vol.shape)
test_vol = np.reshape(test_vol, (90, 160, 10))
print(test_vol.shape)
mse = get_sample_mse(test_vol, predictions)
print("\nMSE: {} %".format(mse))

plt.savefig('out.png')
plt.show()
