#!/usr/bin/python3
import readline
import sys
import re
import typing
import colorama
from colorama import Fore, Back
from collections import deque
from functools import lru_cache
import cmd
from glob import glob
import pickle
from subprocess import call

block_begin_regex = re.compile('^([^"]*""[^"]*)*[^"]*({)')  # {
block_end_regex = re.compile('^([^"]*""[^"]*)*[^"]*(})')  # }

shape_style = {
    "rectangle": ("[", "]"),
    "ellipse": ("(", ")"),
    "diamond": ("<", ">"),
    "note": ("\u25f0", "\u25f3")
}
SHAPE_STYLE_ERROR = ("?", "?")

shape_regex = re.compile("shape *= *(\w+)")
node_regex = re.compile("(\w+) *\[ *label *= *\"(.*)\" *\]")
edge_regex = re.compile("(\w+) *-> *(\w+( *, *\w+)*)")


class Node:
    children = None
    is_root = True
    included = False
    visited = False

    def __init__(self, nid, label, shape):
        self.nid = nid
        self.label = label
        self.shape = shape
        self.children = list()

    def __str__(self):
        return shape_style.get(self.shape, SHAPE_STYLE_ERROR)[0] + str(self.nid) + ":" + str(self.label) + \
               shape_style.get(self.shape, SHAPE_STYLE_ERROR)[1]

    def get_children_str(self) -> str:
        return ";".join((x.nid for x in self.children))

    def get_parents_str(self, nodes: typing.Iterable) -> str:
        return ";".join((x.nid for x in nodes if self in x.children))

    def reset_node(self):
        self.included = False
        self.visited = False
        for n in self.children:
            n.reset_node()




def parse_data(data):
    shape = None
    shape_stack = list()
    lines = data.split("\n")
    nodes = dict()
    for l in lines:
        if block_begin_regex.match(l) is not None:
            shape_stack.append(shape)
        elif block_end_regex.match(l) is not None:
            shape = shape_stack.pop() if len(shape_stack) > 0 else None
        else:
            shapes = shape_regex.findall(l)
            if len(shapes) > 0:
                shape = shapes[-1]

            for node in node_regex.findall(l):
                nodes[node[0]] = Node(node[0], node[1], shape)

            for edge in edge_regex.findall(l):
                for e1 in (x.strip() for x in edge[1].split(",")):
                    nodes[e1].is_root = False  # now, edge 1 has parent
                    nodes[edge[0]].children.append(nodes[e1])

    return nodes


def get_roots(d: typing.Dict[str, Node]):
    return [x for x in d.values() if x.is_root]


def c_num(text):
    return colorama.Fore.LIGHTBLUE_EX + str(text) + colorama.Fore.RESET


def c_node(node: typing.Union[Node, str]):
    if isinstance(node, str):
        return Fore.WHITE + node + Fore.RESET

    return Fore.LIGHTYELLOW_EX + shape_style.get(node.shape, SHAPE_STYLE_ERROR)[0] + Fore.LIGHTBLUE_EX + str(
        node.nid) + Fore.LIGHTYELLOW_EX + ":" + Fore.WHITE + re.compile("\\\\[nl]").sub(" ", node.label) + \
           Fore.LIGHTYELLOW_EX + shape_style.get(node.shape, SHAPE_STYLE_ERROR)[1] + Fore.RESET


def include(node: Node, out: typing.Set, diamond):
    if node.included or node.visited:
        return
    if node.shape == "note":
        out.add(node)
    if node.shape == "diamond":
        diamond.append(node)
        return
    node.included = True
    for ch in node.children:
        include(ch, out, diamond)


def find_nodes(n: typing.Union[str, typing.Iterable[str]], nodes: typing.Dict[str, Node]) -> typing.List[Node]:
    r = list()
    if isinstance(n, str):
        n = [n]
    for n in n:
        r.append(nodes[n])
    return r


def error(msg: str):
    print(colorama.Fore.RED + msg + colorama.Fore.RESET)

def warning(msg: str):
    print(colorama.Fore.YELLOW + msg + colorama.Fore.RESET)


class Console(cmd.Cmd):
    def __init__(self):
        super().__init__()
        colorama.init()
        with open(sys.argv[1]) as f:
            self.data_s = f.read()
        self.graph = parse_data(self.data_s)
        self.to_search = deque()
        self.included = set()
        self.roots = get_roots(self.graph)
        self.current = "ROOT"
        self.manually_included = list()
        readline.set_completer_delims(" ")

    @staticmethod
    def file_glob(line):
        """
        Search paths beginning with chosen string.

        :param line: beginning of path
        :return: list of found paths
        """

        path = line[line.find(" ") + 1:]
        return [" ".join(x.split(" ")[(line.count(" ") - 1):]) for x in glob(path + "*")]

    def mod_color(self, x: Node) -> str:
        """
        Get background color if node is missing from graph, or text is different in corresponding node in graph.

        :param x: Node to check
        :return: string representing background color
        """
        if x.nid not in self.graph:
            return Back.RED
        if self.graph[x.nid].label != x.label:
            return Back.LIGHTBLACK_EX
        return ""

    def do_keep(self, args):
        """
        Recursively search from selected nodes and keep A-nodes in profile.

        :param args: node ids
        """

        if len(args) == 0:
            error("Command needs list of nodes!")
            return
        keep = filter(None, args.split(" "))
        try:
            keep = find_nodes(keep, self.graph)
        except KeyError:
            error("Invalid argument!")
            return
        for n in keep:
            self.manually_included.append(n)
            include(n, self.included, self.to_search)

    def complete_keep(self, text, line, begidx, endidx):
        return [n for n in self.graph if n.startswith(text)]

    def do_search(self, args):
        """
        Append selected nodes to queue to make decision about them.

        :param args: node ids
        """

        if len(args) == 0:
            error("Command needs list of nodes!")
            return
        search = filter(None, args.split(" "))
        try:
            search = find_nodes(search, self.graph)
        except KeyError:
            error("Invalid argument!")
            return
        self.to_search.extend(search)

    def complete_search(self, text, line, begidx, endidx):
        return [n for n in self.graph if n.startswith(text)]

    def do_current(self, args):
        """
        Shows currently processed node and nodes connected by outgoing edges.
        """

        if len(args) > 0:
            error("Command requires no arguments!")
            return
        print(c_num(len(self.to_search)) + (
            " nodes" if len(self.to_search) != 1 else " node") + " to check after this one")
        print("Current node: " + c_node(self.current))
        print("               \u2503")
        if self.current == "ROOT":
            children = self.roots
        else:
            assert isinstance(self.current, Node)
            children = self.current.children
        for i, ch in enumerate(children):
            back = ""
            if ch.included:
                back = Back.GREEN
            elif ch.visited:
                back = Back.YELLOW
            ech = "\u2517" if i == len(children) - 1 else "\u2523"
            print("            " + back + " " + Back.RESET + "  " + ech + "\u2501 " + c_node(ch))

    def do_next(self, args):
        """
        Gets first node from queue and consider it "current".
        Then, command `current` is executed automatically to show this node.
        """

        if len(args) > 0:
            error("Command requires no arguments!")
            return
        if len(self.to_search) == 0:
            if self.current == "ROOT":
                error("Queue is empty. Use keep or search at least once.")
            else:
                error("Nothing left in the queue.")
            return
        self.current = self.to_search.popleft()
        self.do_current(args=args)
        self.current.visited = True

    def emptyline(self):
        """
        Same as `next` command.
        """

        self.do_next(args="")

    def do_exit(self, args):
        """
        Exits program without saving anything.
        """

        return True

    def do_print(self, args):
        """
        If needed, it asks user for profile name, then shows genetated profile.

        """

        if len(args) > 0:
            error("Printing does not require any raguments!")
            return
        print(gen_profile(sys.argv[2], self.included))

    def do_export(self, args):
        """
        Same as `print` but exports profile to file.

        :param args: file path
        """

        if len(args) == 0:
            error("Enter file location for export!")
            return
        try:
            with open(args, "w") as f:
                f.write(gen_profile(sys.argv[2], self.included))
        except FileNotFoundError:
            error("Enter correct filename!")

    def complete_export(self, text: str, line: str, begidx, endidx):
        return self.file_glob(line)

    def do_save(self, args):
        """
        Saves current state to file.

        :param args: file path
        """

        if len(args) == 0:
            error("Enter file location for pickled file!")
        try:
            with open(args, "wb") as f:
                pickle.dump((self.graph, self.current, self.included, self.to_search, self.manually_included), f)
        except OSError:
            error("Save failed!")

    def complete_save(self, text: str, line: str, begidx, endidx):
        return self.file_glob(line)

    def do_load(self, args):
        """
        Loads saved state. Current node, nodes included for keeping and search queue is recovered.
        Information about visited and included nodes is loaded to current graph.
        List of manually included nodes is extended by loaded nodes.

        :param args: file path
        """

        if len(args) == 0:
            error("You must provide file to load!")
            return
        g = dict()
        m = []
        try:
            with open(args, "rb") as f:
                g, self.current, self.included, self.to_search, m = pickle.load(f)
        except OSError:
            error("Load failed!")

        self.manually_included.extend(m)
        for nid, node in g.items():
            if nid in self.graph:
                if node.visited:
                    self.graph[nid].visited = True
                if node.included:
                    self.graph[nid].included = True

    def complete_load(self, text: str, line: str, begidx, endidx):
        return self.file_glob(line)

    def do_list(self, args):
        """
        Lists nodes manually included by `keep` command (by user).
        """

        if len(args) > 0:
            error("This command lists every decision. No argument needed!")
            return
        print(", ".join(
            (self.mod_color(x) + c_node(x) + Back.RESET
             for x in self.manually_included)))

    def do_remove(self, args):
        """
        Removes nodes from list of nodes manually included by `keep` command (by user).
        """

        if len(args) == 0:
            error("Provide node id to remove!")
            return
        remove = set(filter(None, args.split(" ")))
        self.manually_included = list(filter(lambda x: x.nid not in remove, self.manually_included))

    def complete_remove(self, text, line, begidx, endidx):
        return [n.nid for n in self.manually_included if n.nid.startswith(text)]

    def do_apply(self, args):
        """
        Includes all nodes from list of manually included nodes. Use `list` command to see this list.
        (Similar to running `keep` command on every node from list.)
        """

        if len(args) > 0:
            error("This command applies `keep` to everything listed by `list`. No argument needed!")
            return
        if len(self.included) > 0:
            warning("Applying to non-empty profile. You can use `reset` to clean it before using `apply`.")
        for node in self.manually_included:
            include(node, self.included, [])

    def do_reset(self, args):
        """
        Resets everything except list of manually included nodes by user.
        """

        if len(args) > 0:
            error("No argument needed!")
            return
        self.included.clear()
        for n in self.graph.values():
            n.included = False
            n.visited = False
        self.to_search.clear()
        self.current = "ROOT"
        for node in self.manually_included:
            node.reset_node()

    def do_man(self, args):
        """
        Opens documentation about specific node using external command.

        :param args: node id
        """

        search = list(filter(None, args.split(" ")))
        if len(search) != 1:
            error("Expecting exactly one node id.")
            return

        call(["./documentation.sh", search[0]])

    def complete_man(self, text, line, begidx, endidx):
        return [n for n in self.graph if n.startswith(text)]


@lru_cache(1)
def user_get_app_name() -> str:
    print("Enter possible names of app separated by space:")
    return input()


def gen_profile(path: str, nodes: typing.Iterable[Node], abstractions=True):
    result = ["include <tunables/global>\n", "@{APPNAME} = " + user_get_app_name() + "\n", path + " {\n"]
    for node in nodes:
        prefix = ""
        if node.shape == "note":
            text = re.compile("\\\\[nl]+").sub("\n    ", node.label) + "\n"
            if abstractions and len(node.label) >= 2 and node.label[0] == "<" and node.label[-1] == ">":
                prefix = "#include "
            result.append("    " + prefix + text)
    result.append("}\n")
    return "".join(result)


if __name__ == '__main__':
    if len(sys.argv) != 3:
        print("Usage: " + sys.argv[0] + " <dot file> <path to binary file>")
        exit(1)
    console = Console()
    console.cmdloop()
