# Class for creating SegmentGraph and computing features on it.

class Node:
    '''Class representing node. It is either a starting point (with ball), or green segment.'''

    def __init__(self, id):
        self.id = id
        self.kids = []
        self.parents = []

    def add_kid(self, kid):
        if kid is None: return
        self.kids.append(kid)
        kid.add_parent(self)

    def add_parent(self, parent):
        self.parents.append(parent)

class SegmentNode(Node):
    '''Green Segment - node in G'.'''

    def __init__(self, id, end1, end2):
        Node.__init__(self,id)
        self.end1 = end1
        self.end2 = end2

    def is_horiz(self):
        return self.end1[0] == self.end2[0]

class Checkpoint:
    '''Square in grid containing checkpoint.
    
        Remembers position, but also horizontal and vertical green segment, which it belongs to.
    '''

    def __init__(self, row, col, row_node, col_node):
        self.row = row
        self.col = col
        self.row_node = row_node
        self.col_node = col_node


class SegmentGraph:
    '''SegmentGraph - contains functions to get features.'''

    def __init__(self, hasWall, isCheckpoint, ball_row, ball_col):
        ''' Creating all segment nodes. For starting point, for all rows and columns.'''
        self.scc_created = False
        rows = len(hasWall[0])
        cols = len(hasWall[0][0])
        count = 0
        # Special node, because it doesn't represent any segment.
        start_node = Node(count)
        count += 1

        # Every square can be intersected by at most one vertical and at most one horizontal segment.
        row_node = [ [None for j in range(cols)] for i in range(rows)]
        col_node = [ [None for j in range(cols)] for i in range(rows)]
        nodes = set()
        # Creating nodes from horizontal segments.
        for r in range(rows):
            c = 0
            while c<cols:
                c1 = c
                while not hasWall[1][r][c1]: c1+=1
                if c1 == c: 
                    c += 1
                    continue
                n = SegmentNode(count, (r,c), (r,c1))
                nodes.add(n)
                count += 1
                while c <= c1:
                    row_node[r][c] = n
                    c += 1
        # Creating nodes from vertical segments.
        for c in range(cols):
            r = 0
            while r<rows:
                r1 = r
                while not hasWall[2][r1][c]: r1+=1
                if r1 == r: 
                    r += 1
                    continue
                n = SegmentNode(count, (r,c), (r1,c))
                nodes.add(n)
                count += 1
                while r <= r1:
                    col_node[r][c] = n
                    r += 1

        # Creating edges (in G'), between nodes.
        start_node.add_kid(row_node[ball_row][ball_col])
        start_node.add_kid(col_node[ball_row][ball_col])
        for node in nodes:
            end1, end2 = node.end1, node.end2
            if node.is_horiz():
                node.add_kid(col_node[end1[0]][end1[1]])
                node.add_kid(col_node[end2[0]][end2[1]])
            else:
                node.add_kid(row_node[end1[0]][end1[1]])
                node.add_kid(row_node[end2[0]][end2[1]])
        self.start_node = start_node
        nodes.add(start_node)
        self.nodes = nodes

        # Add segment nodes to every checkpoint.
        self.checkpoints = set()
        for r in range(rows):
            for c in range(cols):
                if isCheckpoint[r][c]:
                    c = Checkpoint(r,c,row_node[r][c],col_node[r][c])
                    self.checkpoints.add(c)

    def __print(self):
        '''For debugging.'''
        print(self.start_node.id, list(map(lambda n: n.id, self.start_node.kids)))
        for node in self.nodes:
            print(node.id, list(map(lambda n: n.id, node.kids)))

    def get_scc_count(self):
        '''Return number of strongly connected components in segment graph.'''
        if not self.scc_created:
            self.__create_scc()
        # Without the scc component for a starting position.
        return self.count_scc-1

    def get_scc_checkpoint_count(self):
        '''Return number of strongly connected components in segment graph, which contains at least one checkpoint.'''
        if not self.scc_created:
            self.__create_scc()
        have_check = [False for i in range(self.count_scc)]
        for check in self.checkpoints:
            row = check.row_node
            if not row is None and row.id in self.scc:
                have_check[self.scc[row.id]] = True
            col = check.col_node
            if not col is None and col.id in self.scc:
                have_check[self.scc[col.id]] = True
        return sum(have_check)

    def __scc_visit(self, node, L):
        '''Helping function for SCC -- the first dfs.'''
        if node.reached: return
        node.reached = True
        for kid in node.kids:
            self.__scc_visit(kid, L)
        L.append(node)

    def __scc_assign(self, node, scc, index):
        ''' Helping function for SCC -- the second dfs.'''
        # We will not go to nodes, which are not reachable form starting node.
        if not node.reached: return

        if node.green: return
        node.green = True
        scc[node.id] = index
        for par in node.parents:
            self.__scc_assign(par, scc, index)

    def __create_scc(self):
        '''Run algorithm for finding strongly connected components. Remember results.
        
            Standard algorithm from wikipedia (https://en.wikipedia.org/wiki/Kosaraju's_algorithm).
        '''
        L = []
        for node in self.nodes:
            node.reached = False

        # I dont't want to color every node, only reachable ones.
        # for node in self.nodes:
        #     if node.grey:
        #         self.__scc_visit(node, L)
        self.__scc_visit(self.start_node, L)

        for node in self.nodes:
            node.green = False

        scc = {}
        index = 0
        for node in L[::-1]:
            # Not calling assign for not reachable nodes.
            if node.green: continue
            if not node.reached: continue
            self.__scc_assign(node, scc, index)
            index += 1
        # Remember results.
        self.scc = scc
        self.count_scc = index
        self.scc_created = True
