#include <NTL/RR.h>
#include <NTL/ZZ.h>
#include <NTL/ZZ_p.h>
#include <NTL/ZZ_pX.h>
#include <sys/time.h>
#include <stdio.h>

NTL_CLIENT

void PrintOutPoly(ZZ_pX& poly);
void PrintOutBinary(int *binary, int binSize, bool bigEndian);

int main(int argc, char **argv)
{
	ZZ N, copyN, NmodR, sqrtN, R_zz;
	RR floatN, logN;
	long R;
	ZZ bound, base, powRes, gcd_tmp;
	long expo = 2;
	int *bin;
	int binSize = 0;
	long cnt = 0, cntSqr = 0, cntCoef = 0;
	bool ok = false;
	struct timeval stime, ftime;

	cout << "Zadaj N: ";
	cin >> floatN;

	TruncToZZ(N, floatN);
	if (!IsOdd(N)) {
		cout << "je to parne cislo\n";
		exit(0);
	}

	copyN = N;
	sqrtN = SqrRoot(N);			// integer square root

	logN = log(floatN)/log(to_RR(2.0));
	bin = (int*)malloc(NumBits(copyN) * sizeof(int));
	
	while (copyN > 0) {
		bin[binSize] = copyN % 2;
		binSize++;
		copyN /= 2;
	}

	cout << "N = " << N << " = (little endian) ";
	PrintOutBinary(bin, binSize, false);
	cout << "logN = " << logN << "\n";

/*	STEP 1	*/
	gettimeofday(&stime, NULL); 

	RR expRR, baseRR, tmpRR;
	
	cnt = 2; baseRR = to_RR(2);
	while(cnt <= binSize) {
		expRR = to_RR(1)/baseRR;
		tmpRR = pow(floatN, expRR);
		powRes = FloorToZZ(tmpRR);
		powRes = power(powRes, cnt);
		if (powRes == N) {
			cout << "je to mocnina " << powRes << "^" << cnt << "\n";
			exit(0);
		}
		cnt++;
		baseRR++;
	}

	gettimeofday(&ftime, NULL); 
	long k1 = (unsigned long long)1000*(ftime.tv_sec-stime.tv_sec)+(ftime.tv_usec-stime.tv_usec)/1000;
	cout << "Krok 1: cas  t = " << k1 << " ms\n";


/* STEP 2 */
	gettimeofday(&stime, NULL); 

	bound = TruncToZZ(4 * sqr(logN));
//	cout << "Dolne hranica radu: " << bound << "\n";

	R = 2; R_zz = 2; ok = false;
	while(!ok) {
		gcd_tmp = GCD(R_zz, N);
		if (gcd_tmp > 1) {
			if (gcd_tmp == N)
				cout << N << " je prvocislo\n";
			else
				cout << N << " je delitelne " << R << "\n";
			exit(0);
		}
		expo = 1;
		ok = true;
		while(ok && expo <= bound) {
			copyN = N % R;
			powRes = PowerMod(copyN, expo, R_zz);
			if (powRes == 1)
				ok = false;
			expo++;
		}
		R++; R_zz++;
	}
	R--;
	cout << "Vyhovujuce R: " << R << "\n";

	gettimeofday(&ftime, NULL); 
	k1 = (unsigned long long)1000*(ftime.tv_sec-stime.tv_sec)+(ftime.tv_usec-stime.tv_usec)/1000;
	cout << "Krok 2: cas:  t = " << k1 << " ms\n";

/*	STEP 3 */
/*	left out, already done while looking for suitable R
	line: gcd_tmp = GCD(R_zz, N); */

/*	STEP 4 */
	if (N <= R)
		cout << "Zlozene cislo\n";

/*	STEP 5 */
	ZZ_p::init(N);
	ZZ_pX tmp, rs;
	ZZ_p A;

	NmodR = N % R;
	A = 3;
	FloorToZZ(bound, to_RR(2)*SqrRoot(to_RR(R - 1))*logN);
//	cout << "N mod R = " << NmodR << "\n";
	
	gettimeofday(&stime, NULL); 
	cout << "Overuje sa " << bound << " rovnosti (X+a)^n = X^n+a (mod X^r-1, n)\n";
	cout << "	chvilu to potrva\n";
		
	for(cnt = 1; cnt <= bound; cnt++) {
		set(rs);
		clear(tmp);
		SetCoeff(tmp, 0, A);
		SetCoeff(tmp, 1, 1);

		for(cntSqr = 0; cntSqr < binSize; cntSqr++) {
			if (bin[cntSqr] == 1) {
				rs *= tmp;
				if (deg(rs) >= R) {
					for(cntCoef = R; cntCoef <= deg(rs); cntCoef++) {
						SetCoeff(rs, cntCoef - R, coeff(rs, cntCoef - R) + coeff(rs, cntCoef));
						SetCoeff(rs, cntCoef, 0);
					}
				}
			}
			tmp = sqr(tmp);
			if (deg(tmp) >= R) {
				for(cntCoef = R; cntCoef <= deg(tmp); cntCoef++) {
					SetCoeff(tmp, cntCoef - R, coeff(tmp, cntCoef - R) + coeff(tmp, cntCoef));
					SetCoeff(tmp, cntCoef, 0);
				}
			}
		}

		if ((deg(rs) != NmodR) || (LeadCoeff(rs) != 1) || (ConstTerm(rs) != A)) {
			cout << "Lava a prava strana sa nezhoduju v jednom z nasledujucich:\n";
			cout << "	rozny stupen\n";
			cout << "	rozny absolutny clen\n";
			cout << "	rozny veduci koeficient\n";
			break;
		}
		for(cntCoef = deg(rs) - 1; cntCoef > 1; cntCoef--)
			if (coeff(rs, cntCoef) != 0) {
				cout << "Lava a prava strana sa nezhoduju v jednom z nasledujucich:\n";
				cout << "	Koeficient " << cntCoef << " nie je nula\n";
				break;
			}

		A++;
	}

	gettimeofday(&ftime, NULL); 
	
	k1 = (unsigned long long)1000*(ftime.tv_sec-stime.tv_sec)+(ftime.tv_usec-stime.tv_usec)/1000;
	long k5 = (long)--cnt;

	if ((cnt + 1) > bound) {
		cout << "Krok 5: cas:  t = " << k1 << " ms\n";
		cout << "Priemerny cas jednej iteracie:  t = " << k1/k5 << " ms\n";
		cout << "\nJe to prvocislo\n";
	} else {
		cout << "Nerovnost v " << cnt << " iteracii\n";
		cout << "Priemerny cas jednej iteracie:  t = " << k1/k5 << " ms\n";
		cout << "\nNie je to prvocislo\n";
	}

	return 0;
}

// vypis pola, obsahujuceho binarny rozvoj
// bigEndian = true - zacne najvyznamnejsim bitom
void PrintOutBinary(int *binary, int binSize, bool bigEndian)
{
	int cnt;
	
	cout << "[";
	
	if (bigEndian) {
		for(cnt = binSize - 1; cnt >= 0; cnt--)
			cout << binary[cnt];
	} else {
		for(cnt = 0; cnt < binSize; cnt++)
			cout << binary[cnt];
	}
	
	cout << "]\n";
}

// na vypis polynomu ZZ_pX v citatelnej forme (na debug)
void PrintOutPoly(ZZ_pX& poly)
{
	long i;
	ZZ_p c;

	for(i = deg(poly); i > 1; i--) {
		c = coeff(poly, i);
		if (c != 0) {
			cout << "+" << c << "x^" << i << " ";
		}
	}
	
	c = coeff(poly, 1);
	if (c != 0) {
		cout << "+" << c << "x ";
	}
	
	c = ConstTerm(poly);
	if (c != 0) {
		cout << "+" << c;
	}
	
	cout << "\n";
}
