from glob import glob
import numpy as np
from os.path import basename, splitext
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
from collections import defaultdict



def autolabel(ax, rects):
    """
    Taken and modified from:
        https://matplotlib.org/examples/api/barchart_demo.html

    Attach a text label above each bar displaying its height
    """
    for rect in rects:
        height = rect.get_height()
        ax.text(rect.get_x() + rect.get_width()/2., 1.05*height,
                '%d' % int(height),
                ha='center', va='bottom')


def get_sequencies(arr):
    sequencies = [[arr[0]]]
    
    for i in range(1, len(arr)):
        if arr[i] == arr[i - 1] + 20:
            sequencies[-1].append(arr[i])
        else:
            sequencies.append([arr[i]])
    return sequencies


def draw_bar(name, data):
    fig, ax = plt.subplots()
    ind = np.arange(len(data))
    bar_plot = ax.bar(ind, data)
    
    ax.set_xticklabels(ind + 1)
    ax.set_xticks(ind)
    autolabel(ax, bar_plot)
    plt.savefig(name)


def get_badly_annotated_frames(frame2p, thresh=10):
    bad_frame = []
    avg_annon = []
    for k, v in frame2p.items():
        avg_annon.append(v)
        if v < thresh:
            # print("Bad frame: {}.jpg".format(k))
            bad_frame.append(k)
    return bad_frame, int(np.mean(avg_annon)), np.max(avg_annon)


available_frames = []
annotated_frames = []
missing_annotations = []
one_annotation = []
frame2p = defaultdict(int)


for i, filename in enumerate(glob('resized/annotations/*.txt')):
    data = np.loadtxt(filename, delimiter=',')
    try:
        frames = map(int, data[:, 2])
    except IndexError:
        one = int(data[2])
        frames = [one]

    for frame in frames:
        annotated_frames.append(frame)
        frame2p[frame] += 1


for i, frame_name in enumerate(glob('resized/frames/*.jpg')):
    frame_id = int(splitext(basename(frame_name))[0])
    if frame_id in annotated_frames:
        available_frames.append(frame_id)
    else:
        # print('missing annotation for frame: {}.jpg'.format(frame_id))
        missing_annotations.append(frame_id)

missing_annotations.sort()
missing_annotations = np.unique(missing_annotations)

bad_frame, avg_annon, max_annon = get_badly_annotated_frames(frame2p,
                                                             thresh=20)

miss_seq = get_sequencies(missing_annotations)
seqs_len = map(len, miss_seq)
largest_missing = miss_seq[np.argmax(seqs_len)]

bad_seq = get_sequencies(bad_frame)
bad_seqs_len = map(len, bad_seq)
largest_bad = bad_seq[np.argmax(bad_seqs_len)]

print("")
print('Total annotated frames: {}'.format(len(available_frames)))
print("Average number of pedestrians on image: {}".format(avg_annon))
print("With maximum number of pedestrians on image: {}".format(max_annon))
print('----------------------------------------------------------------')
print('Total missing annotations: {}'.format(len(missing_annotations)))
print('In {} sequencies'.format(len(miss_seq)))
print('With lengths: {}'.format(seqs_len))
print('With largest being from {} to {} ({} sec)'.format(largest_missing[0],
                                                         largest_missing[-1],
                                                         np.max(seqs_len) * 0.8))
print("Sample from missing frames: [{}.jpg {}.jpg {}.jpg {}.jpg {}.jpg]".
      format(largest_missing[0], largest_missing[1],
             largest_missing[2], largest_missing[3],
             largest_missing[4]))
print('----------------------------------------------------------------')
print("Total bad annotations: {}".format(len(bad_frame)))
print('In {} sequencies'.format(len(bad_seq)))
print('With lengths: {}'.format(bad_seqs_len))
print('With largest being from {} to {} ({} sec)'.format(largest_bad[0],
                                                         largest_bad[-1],
                                                         np.max(bad_seqs_len) * 0.8))
print("Sample from bad frames: [{}.jpg {}.jpg {}.jpg {}.jpg {}.jpg]".
      format(largest_bad[0], largest_bad[1],
             largest_bad[2], largest_bad[3],
             largest_bad[4]))
print("")

draw_bar('missing_seq_len.png', seqs_len)
draw_bar('bad_seq_len.png', bad_seqs_len)

# bad frame stastics
lengths = []
ones = []
with open('group_stats.csv', 'w+') as f:
    for i in range(2, 100):
        bad_frames, _, _ = get_badly_annotated_frames(frame2p, thresh=i)
        lengths.append(len(bad_frames))
        
        bad_seq = get_sequencies(bad_frames)
        bad_seqs_len = np.array(map(len, bad_seq))
        one_group = np.sum(bad_seqs_len == 1)
        ones.append(one_group)
        f.write('{},{},{},{}\n'.format(i, lengths[-1], len(bad_seq), one_group))

# plt.figure()
# plt.plot(np.arange(100), lengths)
# plt.savefig('bad_frame_growth.png')
