from keras.layers import Input, Conv3D, MaxPooling3D, UpSampling3D, Concatenate, Add, Dropout, Flatten, Dense
from keras.models import Model


def get_normal_encoder(input_width, input_height, input_depth,
                       vol_depth, future, loss_type):
    input_volume = Input(shape=(vol_depth, input_height,
                                input_width, input_depth))
    x = Conv3D(64, (3, 3, input_depth),
               activation='relu',
               padding='same')(input_volume)
    x = Conv3D(64, (3, 3, 64),
               activation='relu',
               padding='same')(x)
    x = MaxPooling3D(pool_size=(2, 2, 1), strides=(2, 2, 1))(x)

    x = Conv3D(64, (3, 3, 64),
               activation='relu',
               padding='same')(x)
    encoded = MaxPooling3D(pool_size=(2, 3, 1), strides=(2, 3, 1))(x)

 
    x = Conv3D(64, (3, 3, 64),
               activation='relu',
               padding='same')(encoded)
    x = UpSampling3D((1, 2, 1))(x)
    x = Conv3D(64, (3, 3, 64),
               activation='relu',
               padding='same')(x)
    x = UpSampling3D((future, 3, 1))(x)
    decoded = Conv3D(1, (3, 3, 64), activation='relu', padding='same')(x)

    autoencoder = Model(input_volume, decoded)
    autoencoder.compile(loss=loss_type, optimizer='adam')
    return autoencoder


def get_column_encoder(input_width, input_height, input_depth,
                       vol_depth, future, loss_type):
    input_volume = Input(shape=(vol_depth, input_height,
                                input_width, input_depth))
    x1 = Conv3D(16, (vol_depth, 7, 7),
               activation='relu',
               padding='same')(input_volume)

    x2 = Conv3D(16, (vol_depth, 5, 5),
               activation='relu',
               padding='same')(input_volume)

    x3 = Conv3D(16, (vol_depth, 3, 3),
               activation='relu',
               padding='same')(input_volume)

    concat = Concatenate()([x1, x2, x3])

    x = Conv3D(32, (5, 3, 3),
               activation='relu',
               padding='same')(concat)
    x = Dropout(0.2)(x)
    x = Conv3D(32, (5, 3, 3),
               activation='relu',
               padding='same')(x)
    x = MaxPooling3D(pool_size=(2, 3, 2), strides=(2, 3, 2))(x)
    
    x = Conv3D(64, (3, 3, 3),
               activation='relu',
               padding='same')(x)
    x = Dropout(0.2)(x)
    x = Conv3D(128, (3, 3, 3),
               activation='relu',
               padding='same')(x)
    encoded = MaxPooling3D(pool_size=(2, 3, 2), strides=(2, 3, 2))(x)
 
    x = Conv3D(64, (3, 3, 3),
               activation='relu',
               padding='same')(encoded)
    x = UpSampling3D((1, 3, 2))(x)
    x = Conv3D(64, (5, 3, 3),
               activation='relu',
               padding='same')(x)
    x = UpSampling3D((future, 3, 2))(x)
    decoded = Conv3D(1, (5, 3, 3), activation='sigmoid', padding='same')(x)

    autoencoder = Model(input_volume, decoded)
    autoencoder.compile(loss=loss_type, optimizer='adam')
    return autoencoder


def get_classificator(input_width, input_height, input_depth,
                      vol_depth, future, loss_type):
    input_volume = Input(shape=(vol_depth, input_height,
                                input_width, input_depth))

    x1 = Conv3D(16, (vol_depth, 7, 7),
               activation='relu',
               padding='same')(input_volume)

    x2 = Conv3D(16, (vol_depth, 5, 5),
               activation='relu',
               padding='same')(input_volume)

    x3 = Conv3D(16, (vol_depth, 3, 3),
               activation='relu',
               padding='same')(input_volume)

    concat = Concatenate()([x1, x2, x3])

    x = Conv3D(32, (5, 3, 3),
               activation='relu',
               padding='same')(concat)
    x = Dropout(0.2)(x)
    x = Conv3D(32, (5, 3, 3),
               activation='relu',
               padding='same')(x)
    x = MaxPooling3D(pool_size=(2, 3, 2), strides=(2, 3, 2))(x)
    
    x = Conv3D(64, (3, 3, 3),
               activation='relu',
               padding='same')(x)
    x = Dropout(0.2)(x)
    x = Conv3D(128, (3, 3, 3),
               activation='relu',
               padding='same')(x)
    encoded = MaxPooling3D(pool_size=(2, 3, 2), strides=(2, 3, 2))(x)
    flat = Flatten()(encoded)

    out = Dense(128, activation='relu')(flat)
    out = Dense(future, activation='sigmoid')(out)


    classificator = Model(input_volume, out)
    classificator.compile(loss=loss_type, optimizer='adam')
    return classificator

def get_single_column_encoder(input_width, input_height, input_depth,
                              vol_depth, future, loss_type):
    input_volume = Input(shape=(vol_depth, input_height,
                                input_width, input_depth))
    x1 = Conv3D(16, (vol_depth, 7, 7),
               activation='relu',
               padding='same')(input_volume)

    x2 = Conv3D(16, (vol_depth, 5, 5),
               activation='relu',
               padding='same')(input_volume)

    x3 = Conv3D(16, (vol_depth, 3, 3),
               activation='relu',
               padding='same')(input_volume)

    concat = Concatenate()([x1, x2, x3])

    x = Conv3D(32, (5, 3, 3),
               activation='relu',
               padding='same')(concat)
    x = MaxPooling3D(pool_size=(2, 3, 2), strides=(2, 3, 2))(x)
    
    x = Conv3D(64, (3, 3, 3),
               activation='relu',
               padding='same')(x)
    encoded = MaxPooling3D(pool_size=(2, 3, 2), strides=(2, 3, 2))(x)

 
    x = Conv3D(64, (3, 3, 3),
               activation='relu',
               padding='same')(encoded)
    x = UpSampling3D((1, 3, 2))(x)
    x = Conv3D(64, (5, 3, 3),
               activation='relu',
               padding='same')(x)
    x = UpSampling3D((1, 3, 2))(x)
    decoded = Conv3D(1, (5, 3, 3), activation='sigmoid', padding='same')(x)

    autoencoder = Model(input_volume, decoded)
    autoencoder.compile(loss='binary_crossentropy', optimizer='adam')
    return autoencoder

def get_ensamble(input_width, input_height, input_depth,
                 vol_depth, future, loss_type, ensamble_type):
    models = []
    for f in range(future):
        if ensamble_type == 'class':
            models.append(get_classificator(input_width, input_height,
                                            input_depth, vol_depth,
                                            1, loss_type))
        else:
            models.append(get_single_column_encoder(input_width, input_height,
                                                    input_depth, vol_depth,
                                                    future, loss_type))
    return models
    


def get_model(model, input_width, input_height, input_depth,
              vol_depth, future, loss_type, ensamble_type=None):
    if model == 'normal':
        nn = get_normal_encoder(input_width, input_height, input_depth,
                                vol_depth, future, loss_type)
    elif model == 'column': 
        nn = get_column_encoder(input_width, input_height, input_depth,
                                vol_depth, future, loss_type)
    elif model == 'class':
        nn = get_classificator(input_width, input_height, input_depth,
                                vol_depth, future, loss_type)
    elif model == 'ensamble':
        nn = get_ensamble(input_width, input_height, input_depth,
                          vol_depth, future, loss_type, ensamble_type)
    return nn
