#pragma once
#include <list>
#include <algorithm>
#include <iostream>

class point {
public:
	point(int a = 0, int b = 0, int c = 0) { x = a; y = b; z = c; }
	bool operator ==(const point& o) { return o.x == x && o.y == y && o.z == z; }
	point operator +(const point& o) { return point(o.x + x, o.y + y, o.z + z); }
	int x, y, z;
};

class node {
public:
	bool operator == (const node& o) { return pos == o.pos; }
	bool operator == (const point& o) { return pos == o; }
	bool operator < (const node& o) { return dist + cost < o.dist + o.cost; }
	point pos, parent;
	double dist, cost;
};

class aStar {
public:
	aStar() {
		int i = 0;
		for (int dx = -1; dx <= 1; dx++)
		{
			for (int dy = -1; dy <= 1; dy++)
			{
				for (int dz = -1; dz <= 1; dz++)
				{
					if (dx == 0 && dy == 0 && dz == 0) continue;
					neighbours[i] = point(dx, dy, dz);
					i++;
				}
			}
		}
	}

	double calcDist(point& p) {
		// need a better heuristic
		double x = end.x - p.x, y = end.y - p.y, z = end.z - p.z;
		return (x * x + y * y + z * z);
	}

	bool isValid(point& p) {
		return (p.x > -1 && p.y > -1 && p.z > -1 && p.x < xSize && p.y < ySize && p.z < zSize);
	}

	bool existPoint(point& p, double cost) {
		std::list<node>::iterator i;
		i = std::find(closed.begin(), closed.end(), p);
		if (i != closed.end()) {
			return true;
			/*if ((*i).cost + (*i).dist < cost) return true;
			else { closed.erase(i); return false; }*/
		}
		i = std::find(open.begin(), open.end(), p);
		if (i != open.end()) {
			if ((*i).cost + (*i).dist < cost) return true;
			else { open.erase(i); return false; }
		}
		return false;
	}

	bool fillOpen(node& n) {
		double stepCost, nc, dist;
		point neighbour;

		for (int x = 0; x < 26; x++) {

			neighbour = n.pos + neighbours[x];
			if (neighbour == end) return true;

			int diag = abs(neighbour.x) + abs(neighbour.y) + abs(neighbour.z);
			if (diag == 1) stepCost = 1;
			else if (diag == 2) stepCost = sqrt(2);
			else stepCost = sqrt(3);


			if (isValid(neighbour) && m[neighbour.x][neighbour.y][neighbour.z]) {
				nc = stepCost + n.cost;
				dist = calcDist(neighbour);
				if (!existPoint(neighbour, nc + dist)) {
					node m;
					m.cost = nc; m.dist = dist;
					m.pos = neighbour;
					m.parent = n.pos;
					open.push_back(m);
				}
			}
		}
		return false;
	}

	bool search(point& s, point& e, bool ***mp, double _xSize, double _ySize, double _zSize) {
		node n; end = e; start = s; m = mp;
		xSize = _xSize;
		ySize = _ySize;
		zSize = _zSize;
		n.cost = 0; n.pos = s; n.parent = 0; n.dist = calcDist(s);
		open.push_back(n);
		while (!open.empty()) {
			open.sort();
			node n = open.front();
			open.pop_front();
			closed.push_back(n);
			if (fillOpen(n)) return true;
		}
		return false;
	}

	double path(std::list<point>& path) {
		path.push_front(end);

		double cost = 0;
		path.push_front(closed.back().pos);
		point parent = closed.back().parent;


		for (std::list<node>::reverse_iterator i = closed.rbegin(); i != closed.rend(); i++) {
			if ((*i).pos == parent && !((*i).pos == start)) {
				path.push_front((*i).pos);
				parent = (*i).parent;
			}
		}
		path.push_front(start);

		point p;
		bool first = true;
		double stepCost;

		for (std::list<point>::iterator i = path.begin(); i != path.end(); i++) {
			if (first) {
				p = *i;
				first = false;
			}
			else {
				int diag = abs(p.x - (*i).x) + abs(p.y - (*i).y) + abs(p.z - (*i).z);
				if (diag == 1) stepCost = 1;
				else if (diag == 2) stepCost = sqrt(2);
				else stepCost = sqrt(3);

				cost += stepCost;
				p = *i;
			}
		}


		return cost;
	}

	bool ***m;
	double xSize, ySize, zSize;
	point end,start;
	point neighbours[26];
	std::list<node> open;
	std::list<node> closed;
};
