from queue import Queue
from heapq import heappush as hpush, heappop as hpop

class Graph:
    '''State graph for level.'''
    def __init__(self, hasWall, isCheckpoint, ball_row, ball_col):
        self.bfs_computed = False
        self.viable_computed = False
        self.tiles_computed = False

        rows = len(isCheckpoint) 
        cols = len(isCheckpoint[0]) 

        checks = [ [-1 for c in range(cols)] for r in range(rows)]

        # Every checkpoint will get number from 0 to count-1.
        count = 0
        for r in range(rows):
            for c in range(cols):
                if isCheckpoint[r][c]:
                    checks[r][c] = count
                    count += 1

        powers = [1]
        for c in range(count):
            powers.append(powers[-1]*2)

        # Creating edges. Every node will get at most four (to different cardinal dirrections).
        # Every edge is labeled by a bitmask, representing, which checkpoints can be taken on it.
        G = {}
        dr = [-1, 0, 1, 0]
        dc = [0, 1, 0, -1]
        for r in range(rows):
            for c in range(cols):
                neighs = {}
                for dir in range(4):
                    i,j = r,c
                    chs = 0
                    while not hasWall[dir][i][j]:
                        i += dr[dir]
                        j += dc[dir]
                        if checks[i][j] >= 0:
                            chs += powers[checks[i][j]]
                    if r != i or c != j:
                        neighs[(i,j)] = chs
                G[(r,c)] = neighs
        self.G = G
        self.ball = (ball_row, ball_col)
        self.winning = powers[-1]-1

    def __bfs(self):
        '''Counting reachable states, the length of the shortest solution and number of reachable tiles.'''

        if self.bfs_computed: return
        self.bfs_computed = True

        q = Queue()
        start_state = self.ball + (0,)
        shortest = -1

        q.put(start_state)
        distance = {}
        distance[start_state] = 0

        reached = set()
        reached.add(self.ball)

        previous = {}
        previous[start_state] = []
        final_states = []

        while not q.empty():
            r,c,ch = q.get()
            dis = distance[(r,c,ch)]
            for neigh, val in self.G[(r,c)].items():
                next_state = (neigh + (ch|val,))
                
                if not next_state in previous:
                    previous[next_state] = []
                previous[next_state].append((r,c,ch))

                if not next_state in distance:
                    distance[next_state] = dis+1
                    reached.add(neigh)
                    if (ch|val) != self.winning:
                        q.put(next_state)
                    elif shortest == -1:
                        shortest = dis+1
                        final_states.append(next_state)

        self.shortest = shortest
        self.reachable_tile_count = len(reached)
        self.reachable_states_count = len(distance)

        self.reached = reached
        self.distance = distance
        self.previous = previous
        self.final_states = final_states

    def __viable_bfs(self):
        '''Traverse transposed graph from winning states to get the number of viable states.'''

        if self.viable_computed: return
        self.viable_computed = True

        self.__bfs()

        q = Queue()
        reached = set()
        for final in self.final_states:
            q.put(final)
            reached.add(final)

        while not q.empty():
            state = q.get()
            for next_state in self.previous[state]:
                if not next_state in reached:
                    q.put(next_state)
                    reached.add(next_state)

        self.viable_states_count = len(reached)

    def __shortest_tiles_bfs(self):
        '''Find minimum number of tiles on any shortest path.

            * do BFS
            * create set 'on_shortest' of nodes lying on any shortest path
                * initialize set with all winning states with shortest distance
                * traverse transposed graph to get all nodes closer and closer to the start
            * use this set to run Dijkstra's algorithm on the graph of nodes on some shortest path
                * weight of edges is number of tiles on it
            
        '''

        if self.tiles_computed: return
        self.tiles_computed = True

        self.__bfs()

        # Create set of nodes on anu shortest path.
        q = Queue()
        on_shortest = set()
        for final in self.final_states:
            if self.distance[final] == self.shortest:
                q.put(final)
                on_shortest.add(final)

        while not q.empty():
            state = q.get()
            state_dis = self.distance[state]
            for prev_state in self.previous[state]:
                prev_dis = self.distance[prev_state]
                if prev_dis +1 != state_dis: continue
                if prev_state not in on_shortest: continue
                q.put(prev_state)
                on_shortest.add(prev_state)

        # Do Dijkstra's algorithm.
        pq = []
        d = {}

        start_state = self.ball + (0,)
        hpush(pq, (0, start_state))
        d[start_state] = 0

        shortest = -1

        while len(pq) > 0:
            state_dis, state = hpop(pq)
            if d[state] < state_dis: continue

            r,c,ch = state
            for neigh, val in self.G[(r,c)].items():
                next_state = (neigh + (ch|val,))
                nr, nc = neigh
                alt_dis = state_dis + abs(r-nr) + abs(c-nc)
                if next_state not in d or alt_dis < d[next_state]:
                    d[next_state] = alt_dis
                    if (ch|val) != self.winning:
                        hpush(pq, (alt_dis, next_state))
                    elif shortest == -1:
                        shortest = alt_dis

        self.shortest_path_tiles = shortest

    def get_shortest_path(self):
        '''Get the length of the shortest solution.'''
        self.__bfs()
        return self.shortest

    def get_reachable_states_count(self):
        '''Get the number of reachable states.'''
        self.__bfs()
        return self.reachable_states_count

    def get_reachable_tile_count(self):
        '''Get the number of reachable tiles, counting only tiles where the ball can stop.'''
        self.__bfs()
        return self.reachable_tile_count

    def get_viable_states_count(self):
        '''Get the number of viable states.'''
        self.__viable_bfs()
        return self.viable_states_count

    def get_shortest_path_tiles(self):
        '''Get the minimum number of tiles on any shortest path.'''
        self.__shortest_tiles_bfs()
        return self.shortest_path_tiles

