from skimage.util.shape import view_as_blocks, view_as_windows
from skimage.measure import block_reduce
from glob import glob
import numpy as np
import cv2 as cv
import sys
import re


class Loader():

    def __init__(self, annotation_folder='resized/annotations/*.txt',
                 size=(1080, 1920, 3), representation=5, batch=6001):
        self.h, self.w, self.c = size
        self.batch = batch
        
        self.frame2p = {}
        self.p2loc = {}
        self.displacements = []
        self.annotation_folder = annotation_folder
        self.representation = representation

        self.load_data(annotation_folder)
    
    def get_last_n_frames(self, locs, frame, n, w, h):
        last_locs = []
        for loc in locs:
            if loc[0] >= frame-n and loc[0] <= frame:
                cur_loc = (loc[1], loc[2])
                last_locs.append([loc[1], loc[2]])
        return last_locs, cur_loc
    
    def get_frame(self, locs, frame):
        for loc in locs:
            if loc[0] == frame:
                return loc[1:]
    
    def get_pedestrian_vector(self, p_name, frame, w=1080, h=1920, n=5):
        # get all pedestrian locations
        all_locs = self.p2loc[p_name]
        # take locations from last n frames
        locs, vol_idx = self.get_last_n_frames(all_locs, frame, n, w, h)
        if len(locs) <= n:
            return None, None
        # get current location
        cur_loc = locs[-1]
        # construct displacement vector
        diff = np.array(locs) - np.array(cur_loc)
        return diff, (int(vol_idx[0]), int(vol_idx[1]))

    def get_pedestrian_location(self, p_name, frame, h=1920, w=1080, n=5):
        # get all pedestrian locations
        all_locs = self.p2loc[p_name]
        # take locations from last n frames
        locs, vol_idx = self.get_last_n_frames(all_locs, frame, n, w, h)
        if len(locs) <= n:
            return None, None
        # get current location
        cur_loc = locs[-1]
        return np.array(locs), (int(vol_idx[1]), int(vol_idx[0]))
    
    def get_frame_idx(self, path):
        return int(re.findall('\\d+', path)[0]) / 20
    
    def get_displacement_for_frame(self, frame, h=1920, w=1080, n=5):
        frame_idx = self.get_frame_idx(frame)
        if frame_idx in self.frame2p:
            pedestrians = self.frame2p[frame_idx]
        else:
            return None
        disp_vol = np.zeros((h, w, 2 * (n + 1)), dtype='float16')
        for p in pedestrians:
            p_vec, vol_idx = self.get_pedestrian_vector(p, frame_idx, w, h)
            if p_vec is None:
                continue
            disp_vol[vol_idx[0], vol_idx[1], :] = p_vec.flatten()
    
        return disp_vol

    def get_locations_for_frame(self, frame, h=1920, w=1080, n=5):
        frame_idx = self.get_frame_idx(frame)
        if frame_idx in self.frame2p:
            pedestrians = self.frame2p[frame_idx]
        else:
            return None
        disp_vol = np.zeros((h, w, 2 * (n + 1)), dtype='float16')
        for p in pedestrians:
            p_vec, vol_idx = self.get_pedestrian_location(p, frame_idx, w, h)
            if p_vec is None:
                continue
            disp_vol[vol_idx[0], vol_idx[1], :] = p_vec.flatten()
        return disp_vol

    def mark_locations_for_frame(self, frame, h=1920, w=1080, n=5,
                                 block_size=(9, 20, 3), classification=False):
        frame_idx = self.get_frame_idx(frame)
        if frame_idx in self.frame2p:
            pedestrians = self.frame2p[frame_idx]
        else:
            return None
        disp_vol = np.zeros((h, w), dtype='float16')
        true_vol = np.zeros((h, w, 2), dtype='float16')
        for p in pedestrians:
            p_vec, vol_idx = self.get_pedestrian_location(p, frame_idx, w, h, n)
            if p_vec is None:
                continue

            patch_padding = self.representation
            patch = (np.clip(vol_idx[0]+patch_padding, 0, 90) - vol_idx[0],
                     np.clip(vol_idx[1]+patch_padding, 0, 160) - vol_idx[1])

            disp_vol[vol_idx[0]:vol_idx[0]+patch_padding,
                     vol_idx[1]:vol_idx[1]+patch_padding] = np.ones(patch)

            true_vol[vol_idx[0], vol_idx[1]] = [vol_idx[0] / h, vol_idx[1] / w]

        if classification is False:
            num_blocks = int((h / block_size[0]) * (w / block_size[1]))
            disp_vol = view_as_blocks(disp_vol, (block_size[0], block_size[1]))
            disp_vol = disp_vol.reshape(num_blocks, block_size[0], block_size[1])

        return disp_vol, true_vol
    
    def load_data(self, annotation_folder):
        for annon in glob(annotation_folder):
            data = np.loadtxt(annon, delimiter=',')
            if len(data.shape) == 1:
                data = data.reshape((1, -1))
            for line in data:
                frame_idx = int(line[2] / 20)
                if frame_idx in self.frame2p:
                    self.frame2p[frame_idx].append(annon)
                else:
                    self.frame2p[frame_idx] = [annon]
        
                if annon in self.p2loc:
                    self.p2loc[annon].append([frame_idx,
                                         line[0],
                                         line[1]])
                else:
                    self.p2loc[annon] = [[frame_idx, line[0], line[1]]]

    def get_displacements(self):
        files = []        
        for i, img_id in enumerate(glob('resized/frames/*.jpg')):
            disp = self.get_displacement_for_frame(img_id)
            if disp is None:
                continue
            self.displacements.append(disp)
            files.append(img_id)
        return self.displacements, files

    def get_displacements_from_files(self, files, h=1920, w=1080,):
        displacements = []
        for i, img_id in enumerate(files):
            disp = self.get_displacement_for_frame(img_id, w, h)
            if disp is None:
                continue
            displacements.append(disp)
        return displacements

    def get_locations(self, stop=6001):
        files = []
        locations = []
        for i, img_id in enumerate(glob('resized/frames/*.jpg')):
            locs = self.get_locations_for_frame(img_id)
            if i >= stop:
                break
            if locs is None:
                continue
            locations.append(locs)
            files.append(img_id)
        return locations, files

    def get_images(self, files, stop=6001, h=1920, w=1080, n=1,
                   block_size=(9, 20, 3), classification=False,
                   rotate=False):
        images = []
        locations = []
        vols = []
        for i, img_id in enumerate(files):
            img = cv.imread(img_id)
            marks = self.mark_locations_for_frame(img_id, h, w,
                                                 n=n, block_size=block_size,
                                                 classification=classification)
            # locs = self.get_locations_for_frame(img_id, h, w, 5)
            if i >= stop:
                break
            if marks is None:
                continue
            locs, vol = marks
            if classification is False:
                img_vol = np.squeeze(view_as_blocks(img / 255.0, block_size))
                num_blocks = int((h / block_size[0]) * (w / block_size[1]))
                img_vol = img_vol.reshape(num_blocks,
                                          block_size[0],
                                          block_size[1],
                                          block_size[2])
            else:
                pad_w = int(block_size[0] / 2)
                pad_h = int(block_size[1] / 2)
                stride = 4
                img_pad = np.pad(img,
                                 ((pad_w, pad_w), (pad_h, pad_h), (0, 0)),
                                 mode='constant', constant_values=(0))
                img_vol = np.squeeze(view_as_windows(img_pad / 255.0,
                                                     block_size,
                                                     step=stride))
                img_vol = img_vol[:, :-1].reshape(-1, block_size[0],
                                                  block_size[1],
                                                  block_size[2]) 
                img_vol = img_vol.astype('float16')
                locs = block_reduce(locs, (stride, stride), np.max)

            images.append(img_vol)
            locations.append(locs)
            vols.append(vol)
            if rotate is True:
                images.append(np.fliplr(img_vol))
                locations.append(np.fliplr(locs))
                vols.append(np.fliplr(vol))
        return images, locations, vols


def load_image_location(files, x_size=5, future=1, block_shape=(9, 20, 3),
                        stop=6001, h=1920, w=1080, representation=5,
                        classification=False):
    loader = Loader(batch=stop, representation=representation)
    files.sort()
    images, locations, vols = loader.get_images(files, stop, h, w,
                                          block_size=block_shape,
                                          classification=classification)
    print("images: {} with shape {}".format(len(images), images[0].shape))
    print("locations: {} with shape {}".format(len(locations),
                                               locations[0].shape))
    x = []
    y = []
    for i in range(len(images) - (x_size + future)):
        start = i + x_size
        if classification is True:
            image = np.rollaxis(np.array(images[i:start]), 1)
            location = np.array(locations[start:start+future])
            location = location.reshape(future, -1).T
            x.append(image)
            y.append(location)
        else:
            x.append(np.array(images[i:start]))
            y.append(locations[start:start+future])
    return x, y, vols
 

def save_data(filename, batch_size=6001, save_files=False):
    loader = Loader(batch=batch_size)
    files = []
    for i, img_id in enumerate(glob('resized/frames/*.jpg')):
        files.append(img_id)
        if (i + 1) % batch_size == 0:
            displacements = loader.get_displacements_from_files(files)
            dis = np.array(displacements, dtype='float16')
            np.savez('{}_{}.npz'.format(filename, i), dis)
            files = []
            if save_files is True:
                with open('{}.txt'.format(filename), 'w+') as f:
                    for sf in files:
                        f.write('{}\n'.format(sf))


def save_locations(filename, stop=6001, save_files=False):
    loader = Loader()
    locs, files = loader.get_locations(stop)
    np.savez('{}.npz'.format(filename), locs)
    if save_files is True:
        with open('{}.txt'.format(filename), 'w+') as f:
            for sf in files:
                f.write('{}\n'.format(sf))


def load_data(filename):
    displacements = np.load(filename)
    return displacements


def get_sample_label_pairs(volumes):
    X = []
    y = []
    for volume in volumes:
        X.append(volume[:, :, :-2])
        y.append(volume[:, :, 2:])
    return X, y


def get_train_test(files):
    loader = Loader()
    displacements = loader.get_displacements_from_files(files)
    return get_sample_label_pairs(displacements)
