# This file contains helper functions for jupyter notebooks.

import os, datetime, json
import numpy as np
import pandas as pd

def get_new_log_dir():
    '''Crate new directory for log files, depending on current timestamp.'''
    current_dir = os.getcwd()
    data_dir = 'data'.join(current_dir.rsplit('code', 1))
    all_logs_dir = os.path.join(current_dir, 'logs')
    log_dir = os.path.join(all_logs_dir, datetime.datetime.now().strftime('%Y%m%d'))
    log_inner_dir = '%s' % datetime.datetime.now().strftime('%H%M%S')
    log = os.path.join(log_dir, log_inner_dir)
    if not os.path.exists(log):
        os.makedirs(log)
    return log

### Functions for styling dataframe in Jupyter. ###

def color_not_false_or_zero(val):
    color = 'gray' if val in [False, 0] or np.isnan(val) else 'black'
    return 'color: %s' % color

def highlight_max(s):
    is_max = s == s.max()
    return ['background-color: rgb(152, 223, 138)' if v else '' for v in is_max]

def highlight_min(s):
    is_min = s == s[s != False].min()
    return ['background-color: rgb(174, 199, 232)' if v else '' for v in is_min]

def bold_max(s):
    is_max = s == s.max()
    return ['font-weight: bold' if v else '' for v in is_max]

def format_floats(x):
    if x == False:
        return x
    try:
        return '%.2f' % x
    except:
        return x

def get_data_frame(infos, data, styled=True):
    table = []
    feature_names = [x.replace('_', ' ') for x in data.features]
    for model in infos:
        row = {
            'name': model['name'],
            'test score': model['test_score'],
            'train score': model['train_score'],
        }
        for coef_type in ['coefficients', 'feature_importances']:
            if coef_type in model:
                row['type'] = coef_type.replace('_',' ')
                support = [True for x in model[coef_type]]
                if 'support' in model:
                    support = model['support']
                cnt_false = 0
                for i, feature in enumerate(support):
                    if not feature:
                        cnt_false += 1
                        row[feature_names[i]] = False
                        continue
                    row[feature_names[i]] = model[coef_type][i-cnt_false]
        table.append(row)
    x = pd.DataFrame(table)
    order_cols = ['name', 'train score', 'test score'] + feature_names + ['type',]
    if not styled:
        return x[order_cols].set_index('name')
    return x[order_cols].set_index('name').\
        style.set_properties(**{'white-space': 'normal'}).\
        applymap(color_not_false_or_zero, subset=feature_names).\
        apply(highlight_min, axis=1, subset=feature_names).\
        apply(highlight_max, axis=1, subset=feature_names).\
        apply(bold_max, axis=0, subset=['test score', 'train score'])

### Helper functions for generating Latex tables automatically. ###

def get_table_header(cols, resize=True):
    if resize:
        return r'''\begin{table}[H]
\centering
\resizebox{\textwidth}{!}{
    \begin{tabular}{%s}
''' % cols
    
    return r'''\begin{table}[H]
\centering
    \begin{tabular}{%s}
''' % cols

def get_table_footer(name, resize=True):
    if resize:
        return r'''
    \end{tabular}
}
\caption{%s} \label{tab:%s}
\end{table}
''' % (name, name.replace(' ', '').lower())

    return r'''
    \end{tabular}
\caption{%s} \label{tab:%s}
\end{table}
''' % (name, name.replace(' ', '').lower())

def get_latex_infos_start(infos, params=True):
    res = get_table_header('l|lcc', resize=False)
    headers = ['Model', 'Parameters', 'Train score', 'Test score']
    if not params:
        headers.remove('Parameters')
    res += '        %s\\\\ \\hline\n' % (' & '.join(headers))

    best_train, best_test = 0.0, 0.0
    for info in infos:
        best_train = max(info['train_score'], best_train)
        best_test = max(info['test_score'], best_test)
    for info in infos:
        s = []
        s.append(r'\pkg{%s}' % info['name'])
        z = []
        for k in info['best_params']:
            v = info['best_params'][k]
            try:
                z.append(r'\pkg{%s}: %.2f' % (k, v))
            except:
                z.append(r'\pkg{%s}: %s' % (k, v))
        if params: s.append(', '.join(z).replace('_', '\_'))
        if info['train_score'] == best_train:
            s.append(r'\textbf{%.2f}' % info['train_score'])
        else:
            s.append('%.2f' % info['train_score'])
        if info['test_score'] == best_test:
            s.append(r'\textbf{%.4f}' % info['test_score'])
        else:
            s.append('%.4f' % info['test_score'])
        res += '        %s\\\\\n' % '   & '.join(s)
    return res

def get_latex_infos_grid(infos, puzzle):
    '''GridSearch -- latex table with the best parameters, train and test scores.'''
    res = get_latex_infos_start(infos)
    res += get_table_footer('%s -- GridSearch parameters and scores' % puzzle, resize=False)
    return res

def get_latex_infos_rfecv(infos, puzzle):
    '''RFECV -- latex table with train and test scores.'''
    res = get_latex_infos_start(infos, False)
    res += get_table_footer('%s -- RFECV scores' % puzzle, resize=False)
    return res

def get_float(val):
    color = 'tabgray' if val in [False, 0] or np.isnan(val) else None
    s = format_floats(val)
    if color is None:
        return s
    return r'{\color{%s} %s}' % (color, s)

def get_coefficients_table_start(features, infos):
    feature_names = [x.replace('_', ' ') for x in features]
    res = get_table_header('l|%s' % ('c'*len(feature_names)), resize=True)
    res = r'''\begin{table}[H]
\resizebox{0.95\textwidth}{!}{
    \begin{tabular}{%s}
''' % ('l|%s' % ('c'*len(feature_names)))
    res += '        %s\\\\ \\hline\n' % ('Model & ' + ' & '.join([ r'\rot{%s}' % f.replace(' ', r' ') for f in feature_names]))


    rows = []
    for model in infos:
        row = {
            'name': model['name'],
        }
        for coef_type in ['coefficients', 'feature_importances']:
            if coef_type in model:
                row['type'] = coef_type.replace('_',' ')
                support = [True for x in model[coef_type]]
                if 'support' in model:
                    support = model['support']
                cnt_false = 0
                for i, feature in enumerate(support):
                    if not feature:
                        cnt_false += 1
                        row[feature_names[i]] = False
                        continue
                    row[feature_names[i]] = model[coef_type][i-cnt_false]
        rows.append(row)

    for i, model in enumerate(infos):
        line = []
        line.append(r'\pkg{%s}' % model['name'])
        maxf = feature_names[0]
        for f in feature_names:
            if rows[i][f] > rows[i][maxf]:
                maxf = f
        for f in feature_names:
            if f == maxf:
                line.append(r'\textbf{%s}' % get_float(rows[i][f]))
            else:
                line.append(r'%s' % get_float(rows[i][f]))
        res += '        %s\\\\ \n' % ' & '.join(line)
    return res

def get_coefficients_table_grid(features, infos, puzzle):
    '''GridSearch - get latex table with coefficients.'''
    res = get_coefficients_table_start(features, infos)
    res += get_table_footer('%s -- GridSearch coefficients' % puzzle, resize=True)
    return res

def get_coefficients_table_rfecv(features, infos, puzzle):
    '''RFECV - get latex table with coefficients.'''
    res = get_coefficients_table_start(features, infos)
    res += get_table_footer('%s -- RFECV coefficients' % puzzle, resize=True)
    return res

def write_infos_to_log(log_dir, infos, data, name):
    '''Save everything into log file.'''
    file = os.path.join(log_dir, '%s.log' % name)
    print(file)
    with open(file, 'w') as out:
        result = {}
        result['predict_type'] = data.predict_col
        result['models'] = infos
        result['dataset_actions'] = data.actions

        result_text = json.dumps(result, sort_keys=True)
        _ = out.write(result_text)
