import py2048
import math
from keras.models import Model, Sequential, Input
from keras.layers import Dense, merge
from keras.optimizers import Adam
from keras import regularizers
import keras.backend as K
import numpy as np
import matplotlib.pyplot as plt
import pickle
import random
from collections import deque
from PER.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
from PER.schedules import LinearSchedule

debug = False
# ak restore == True, nahra sa ulozeny model, v opacnom pripade sa ulozene data prepisu
restore = False

# define some parameters
n_games = 1000
batch_size = 32
display_step = 50
copy_step = 50
epsilon = 1.0
eps_decay = 0.9999
eps_decay_steps = 58
n_input = 16

class DDDQNAgent:

    def __init__(self, n_input, learning_rate):
        self.n_input = n_input
        self.learning_rate = learning_rate
        self.gamma = 0.995
        self.max_timesteps=20000000
        self.eval_model = self._build_model()
        self.target_model = self._build_model()
        self.update_target()
        self.prioritized_replay_alpha=0.6
        self.prioritized_replay_beta0=0.4
        self.prioritized_replay_beta_iters=None
        self.prioritized_replay_eps=1e-6

        # Create the replay buffer
        self.replay_buffer = PrioritizedReplayBuffer(10000, alpha=self.prioritized_replay_alpha)
        if self.prioritized_replay_beta_iters is None:
            self.prioritized_replay_beta_iters = self.max_timesteps

        self.beta_schedule = LinearSchedule(self.prioritized_replay_beta_iters,
                                    initial_p=self.prioritized_replay_beta0,
                                    final_p=1.0)

    def _build_model(self):
        # define the model
        regRate = 0.00000001
        input_layer = Input(shape=(n_input,))
        fc1 = Dense(300, activation='relu', kernel_regularizer = regularizers.l1(regRate))(input_layer)
        fc2 = Dense(300, activation='relu',kernel_regularizer = regularizers.l1(regRate))(fc1)
        fc3 = Dense(200, activation='relu',kernel_regularizer = regularizers.l1(regRate))(fc2)
        fc4 = Dense(200, activation='relu',kernel_regularizer = regularizers.l1(regRate))(fc3)
        fc5 = Dense(100, activation='relu',kernel_regularizer = regularizers.l1(regRate))(fc4)
        fc6 = Dense(100)(fc5)
        advantage = Dense(4)(fc6)
        fc7 = Dense(60)(fc5)
        value = Dense(1)(fc7)
        policy = merge([advantage, value], mode = lambda x: x[0]-K.mean(x[0])+x[1], output_shape=(4,))

        opt = Adam(lr=self.learning_rate)
        model = Model(inputs=[input_layer], outputs=[policy])
        model.compile(loss='mse', optimizer=opt)
        return model

    def choose_action(self, state):
        # epsilon greedy strategy of exploration
        if np.random.rand(1) < epsilon:
            return random.randint(0, 3)

        values = self.eval_model.predict(state)
        return np.argmax(values[0])

    def store_experience(self, state, action, reward, new_state, done):
        self.replay_buffer.add(state, action, reward, new_state, done)

    def train(self, batch_size, t):
        experience = self.replay_buffer.sample(batch_size, beta=self.beta_schedule.value(t))
        (states, actions, rewards, next_states, dones, _, batch_idxes) = experience

        # train on stored experiences
        X = []
        Y = []
        errors = []
        for state, action, reward, next_state, terminal in zip(states, actions, rewards, next_states, dones):
            target = self.eval_model.predict(state)
            Qval = np.max(target)

            target[0][action] = reward
            if not terminal:
                next_act = np.argmax(self.eval_model.predict(next_state)[0])
                Qnext = self.target_model.predict(next_state)[0]
                target[0][action] += self.gamma * Qnext[next_act]
            errors.append(target[0][action] - Qval)

            X.append(state)
            Y.append(target)
        X = np.array(X).reshape((batch_size, n_input))
        Y = np.array(Y).reshape((batch_size, 4))
        self.eval_model.fit(X, Y, epochs=1, verbose=0, batch_size=batch_size)

        new_priorities = np.abs(errors) + self.prioritized_replay_eps
        self.replay_buffer.update_priorities(batch_idxes, new_priorities)

    def update_target(self):
        self.target_model.set_weights(self.eval_model.get_weights())

    def save_model(self):
        self.eval_model.save_weights("./saved_models/ddqn_agent.save")

    def load_model(self):
        self.eval_model.load_weights("./saved_models/ddqn_agent.save")
        self.update_target()

       
def encode(state):
    return (state / 11).reshape(1, n_input)


def main():
    global epsilon
    global eps_decay
    global eps_decay_steps

    board_sums = []
    highest = []
    last_scores = deque(maxlen = 500)
    step = 0
    game = 0
    gameEnv = py2048.GameBoard(4, 4)
    agent = DDDQNAgent(n_input, learning_rate=0.0001)
    if restore:
        agent.load_model()
        "the files associated with the state of the learning are saved in dump.pkl"
        with open('dump.pkl', 'rb') as f:
            step, game, epsilon, eps_decay, eps_decay_steps = pickle.load(f)
        "the scores are saved in scores.pkl"
        with open('scores.pkl', 'rb') as f:
            highest, board_sums = pickle.load(f)

    try:
        while True:
            gameEnv.reset()
            score = 0
            time = 0

            while True:
                observation = gameEnv.board
                # selection network weights update
                if step % copy_step == 0:
                    agent.update_target()

                # encode the state
                state = encode(observation.flatten())

                # choose and perform an action
                action = agent.choose_action(state)
                observation, reward, done, = gameEnv.step(action)

                # take another step and store the transition
                new_state = encode(observation.flatten())
                agent.store_experience(state, action, reward, new_state, done)

                if time > 0 and step > 0 and step % batch_size == 0:
                    agent.train(batch_size, step)
                    # reduce the chance for a random action
                    if epsilon > 0.1:
                        epsilon *= eps_decay

                time += 1
                step += 1

                if done:
                    board_exp = gameEnv.exponentiate()
                    maxi = np.max(board_exp)
                    board_sum = np.sum(board_exp)
                    last_scores.append(board_sum)
                    if len(last_scores) == 500:
                        board_sums.append(board_sum)
                        highest.append(maxi)
                    if game % display_step == 0:
                        print("game: {} | score: {} | max: {} | time: {} | last_scores_avg: {} | eps: {}".format(
                            game, gameEnv.score, maxi, time, np.mean(list(last_scores)), epsilon))
                    break

                if debug:
                    action = -1
                    if action == 0:
                        print("dolava")
                    elif action == 1:
                        print("hore")
                    elif action == 2:
                        print("doprava")
                    else:
                        print("dole")
                    input()

            game += 1

            if game % 1000 == 0:
                agent.save_model()
                "save the files associated with the state of the learning into dump.pkl"
                with open('dump.pkl', 'wb') as f:
                    pickle.dump([step, game, epsilon, eps_decay, eps_decay_steps], f)
                "save the scores separately"
                with open('scores.pkl', 'wb') as f:
                    pickle.dump([highest, board_sums], f)

    except KeyboardInterrupt:
        agent.save_model()
        with open('dump.pkl', 'wb') as f:
            pickle.dump([step, game, epsilon, eps_decay, eps_decay_steps], f)
        with open('scores.pkl', 'wb') as f:
            pickle.dump([highest, board_sums], f)


if __name__ == '__main__':
    main()
