#include "stdio.h"
#include <algorithm>
#include <ctime>
#include <cstdlib> 

#define gpuErrchk(ans) { gpuAssert((ans), __FILE__, __LINE__); }
#define THREADS 64
#define BLOCKS 256

#define HASH_ROW 131072
#define HASH_COL 320
#define INPUT_SIZE 31894014
#define _dif ((INPUT_SIZE)/(THREADS*BLOCKS)+1)


typedef unsigned long long uint64; 
typedef long long int64; 

inline void gpuAssert(cudaError_t code, char *file, int line, bool abort=true)
{
   if (code != cudaSuccess) 
   {
      //fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
      printf("GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
      if (abort) exit(code);
   }
}

__device__ uint64 dev_base;
__device__ unsigned int dev_index; // index vzhladom na block


time_t curtime;

/*************************************************************/
// Montgomery multiplication 
// http://www.hackersdelight.org/MontgomeryMultiplication.pdf 
__device__ void xbinGCD(uint64 a, uint64 b, uint64 *pu, uint64 *pv)  { 
  uint64 alpha, beta, u, v; 
  u = 1; v = 0; 
  alpha = a; beta = b; // Note that alpha is 
  // even and beta is odd. 
 
  /* The invariant maintained from here on is: 
  a = u*2*alpha - v*beta. */ 

  while (a > 0) { 
    a = a >> 1; 
    if ((u & 1) == 0) { // Delete a common 
      u = u >> 1; v = v >> 1; // factor of 2 in 
    } // u and v. 
    else { 
      /* We want to set u = (u + beta) >> 1, but 
      that can overflow, so we use Dietz's method. */ 
      u = ((u ^ beta) >> 1) + (u & beta); 
      v = (v >> 1) + alpha; 
    }
  }
 
  *pu = u; 
  *pv = v; 
  return;
}


__device__ uint64 modul64(uint64 x, uint64 y, uint64 z) { 
 
  /* Divides (x || y) by z, for 64-bit integers x, y, 
  and z, giving the remainder (modulus) as the result. 
  Must have x < z (to get a 64-bit result). This is 
  checked for. */ 
 
  int64 i, t; 
 
  if (x >= z) { 
    //printf("Bad call to modul64, must have x < z."); 
    //exit(1); 
  } 
  for (i = 1; i <= 64; i++) { // Do 64 times. 
    t = (int64)x >> 63; // All 1's if x(63) = 1. 
    x = (x << 1) | (y >> 63); // Shift x || y left 
    y = y << 1; // one bit. 
    if ((x | t) >= z) { 
      x = x - z; 
      y = y + 1; 
    } 
  } 
  return x;
}


__device__ void mulul64(uint64 u, uint64 v, uint64 *whi, uint64 *wlo)  { 
  uint64 u0, u1, v0, v1, k, t; 
  uint64 w0, w1, w2; 
 
  u1 = u >> 32; u0 = u & 0xFFFFFFFF; 
  v1 = v >> 32; v0 = v & 0xFFFFFFFF; 
 
  t = u0*v0; 
  w0 = t & 0xFFFFFFFF; 
  k = t >> 32; 
 
  t = u1*v0 + k; 
  w1 = t & 0xFFFFFFFF; 
  w2 = t >> 32; 
 
  t = u0*v1 + w1; 
  k = t >> 32; 
 
  *wlo = (t << 32) + w0; 
  *whi = u1*v1 + w2 + k; 
 
  return;
}


__device__ uint64 montmul(uint64 abar, uint64 bbar, uint64 m, uint64 mprime) { 
 
  uint64 thi, tlo, tm, tmmhi, tmmlo, uhi, ulo, ov; 
 
  mulul64(abar, bbar, &thi, &tlo); // t = abar*bbar. 

  /* Now compute u = (t + ((t*mprime) & mask)*m) >> 64. 
  The mask is fixed at 2**64-1. Because it is a 64-bit 
  quantity, it suffices to compute the low-order 64 
  bits of t*mprime, which means we can ignore thi. */ 
 
  tm = tlo*mprime; 
 
  mulul64(tm, m, &tmmhi, &tmmlo); // tmm = tm*m. 
 
  ulo = tlo + tmmlo; // Add t to tmm 
  uhi = thi + tmmhi; // (128-bit add). 
  if (ulo < tlo) uhi = uhi + 1; // Allow for a carry. 
 
  // The above addition can overflow. Detect that here. 
 
  ov = (uhi < thi) | ((uhi == thi) & (ulo < tlo)); 
 
  ulo = uhi; // Shift u right 
  uhi = 0; // 64 bit positions. 
 
  if (ov > 0 || ulo >= m) // If u >= m, 
  ulo = ulo - m; // subtract m from u. 
 
  return ulo; 
}

__device__ uint64 mulmodMont(uint64 baseM,uint64 e,uint64 modul,uint64 pv,uint64 oneM) {
    uint64 ans = oneM;
    while(e>0) {
	if(e&1) {
	    ans = montmul(baseM,ans,modul,pv);
	}
	baseM = montmul(baseM,baseM,modul,pv);
	e>>=1;
    }
    return ans;
}
	    



__device__ int is_SPRP(uint64 base,uint64 modul) {
    if(base>=modul) base = base%modul;
    uint64 pu,pv;
    xbinGCD(1ull<<63,modul,&pu,&pv);
    uint64 baseM = modul64(base,0,modul);
    uint64 oneM = modul64(1,0,modul);
    uint64 moneM = modul - oneM;
    uint64 e = modul-1;
    while(e%2==0) e>>=1;
    uint64 t = mulmodMont(baseM,e,modul,pv,oneM);
    if(t==oneM) return 1;
    while(e<modul-1) {
	if(t==moneM) return 1;
	t = montmul(t,t,modul,pv);
	e<<=1;
    }
    return 0;
}
 
__device__ int hashh(uint64 x) {
    x = ((x >> 32) ^ x) * 0x45d9f3b3335b369;  // 0x3335b369
    x = ((x >> 32) ^ x) * 0x3335b36945d9f3b;
    x = ((x >> 32) ^ x);
    return x&131071;
}
 
 
 
/*************************************************************/


__global__ void find(uint64 *in,uint64 *out,int *c) {
    long long b = (threadIdx.x+blockIdx.x*blockDim.x)*_dif;
    long long e = b+_dif;
    if(e>INPUT_SIZE) e = INPUT_SIZE;
    for(int long long i=b;i<e;++i) {
	int hash_num = hashh(in[i]);
	int pos = atomicAdd(&c[hash_num],1);
	out[hash_num*HASH_COL+pos] = in[i];
    }
}

// find base for which all elements in input are NOT SPRP. base is from {2,..,34} stored in 32bit uint
/*
__global__ void solve(uint64 *input, uint64 *count,unsigned int *ans) {
    unsigned int dif = (count[dev_index])/(blockDim.x*gridDim.x) +1;
    unsigned int b = (threadIdx.x+blockIdx.x*blockDim.x)*dif;
    unsigned int e = b+dif>(count[dev_index])?(count[dev_index]):b+dif;
    // each thread doing its part
    uint64 *input_offset = input+dev_index*HASH_COL;
    for(uint64 j = b; j<e ; ++j) {
	//is some element is sprp base i break
	if((*ans)==0) break;
	if(is_SPRP(dev_base,input_offset[j])!=0) {
	    *ans=0;
	    //atomicExch(ans,0);
	    break;
	}
    }
    
}
*/
//							size=HASH_ROW
__global__ void solve(uint64 *input, int *count,uint64 *bases,unsigned int *ans) {
    int index = threadIdx.x+blockIdx.x*blockDim.x;
    int p = blockDim.x*gridDim.x;
    for(int i=index;i<HASH_ROW;i+=p) {
	//if(*ans==0) return;
	uint64* input_offset = input+i*HASH_COL;
	int global_sol = 0;
	for(uint64 base = 3; base < 1048576 ; ++base) {
	    int sol = 1;
	    for(int j=0;j<count[i];++j) {
		if(is_SPRP(base,input_offset[j])!=0) {
		    sol=0;
		    break;
		}
	    }
	    if(sol) {
		bases[i]=base;
		global_sol = 1;
		break;
	    }
	}
	if(global_sol==0) {
	    *ans=0;
	    bases[i]=0;
	}
    }
}
    

__global__ void incBase() {
    ++dev_base;
}

__global__ void incIndex() { // ITERATION, automatically nulled
    ++dev_index;
}


int gettime(void) {
    time_t b = curtime;
    curtime = time(NULL);
    return curtime-b;
}


int main(void) {
    // ALOKACIE!
    
    // allocate buffers
    uint64 *dev_input,*dev_hash_input,*dev_bases;
    int *dev_count;
    unsigned int ans, *dev_ans;
    gpuErrchk(cudaMalloc((void**)&dev_input,sizeof(uint64)*INPUT_SIZE)); // 125M = 1mld / rounds
    gpuErrchk(cudaMalloc((void**)&dev_count,sizeof(int)*HASH_ROW));
    gpuErrchk(cudaMalloc((void**)&dev_hash_input,sizeof(uint64)*HASH_ROW*HASH_COL));
    gpuErrchk(cudaMalloc((void**)&dev_bases,sizeof(uint64)*HASH_ROW));
    gpuErrchk(cudaMalloc((void**)&dev_ans,4));
    
    size_t a,b;
    cudaMemGetInfo(&a,&b);
    printf("free %d total %d\n",a,b);
    

    
    uint64 *base_adr;
    gpuErrchk(cudaGetSymbolAddress((void**)&base_adr,dev_base));
    uint *index_adr;
    gpuErrchk(cudaGetSymbolAddress((void**)&index_adr,dev_index));
    
    
    
    uint64 bb[HASH_ROW];
    
    printf("all done\n");
    
    uint64 *input = new uint64[INPUT_SIZE];
    FILE *f = fopen("spsp2.bin","rb");
    if(fread(input,8,INPUT_SIZE,f)!=INPUT_SIZE) {
	printf("failed to load input file\n");
	delete[] input;
	fclose(f);
	return 0;
    }
    fclose(f);
    
    gpuErrchk(cudaMemcpy(dev_input,input,INPUT_SIZE*8,cudaMemcpyHostToDevice));
    printf("copy done\n");
    delete[] input;
    
    time_t starttime = time(NULL);	
    tm *timeinfo = localtime(&starttime);
    printf("Started: %02d:%02d:%02d\n",timeinfo->tm_hour,timeinfo->tm_min,timeinfo->tm_sec);
    
    curtime = time(NULL);
    
	
    printf("SIZE: %u   ROW: %u   COL: %u\n",HASH_ROW*HASH_COL,HASH_ROW,HASH_COL);
	
	
	
	
	    
	    
    /* prepare for finding solution */
    gpuErrchk(cudaMemset(index_adr,0,4));
	    
	    /* find solution */
	    
    //int global_solution = 1;
	    
    gpuErrchk(cudaMemset(dev_count,0,sizeof(int)*(HASH_ROW)));
    find<<<BLOCKS,THREADS>>>(dev_input,dev_hash_input,dev_count);
    gpuErrchk( cudaPeekAtLastError() );
    gpuErrchk( cudaDeviceSynchronize() );
    printf("scan done\n");
    gpuErrchk(cudaMemset(dev_ans,0xFF,4));
    solve<<<BLOCKS,THREADS>>>(dev_hash_input,dev_count,dev_bases,dev_ans);
    gpuErrchk( cudaPeekAtLastError() );
    gpuErrchk( cudaDeviceSynchronize());
    
    gpuErrchk(cudaMemcpy(&ans,dev_ans,4,cudaMemcpyDeviceToHost));
    if(ans==0) {
	printf("!!!!!!!!!!!!!!no solution found!!!!!!!!!!!!!!!\n");
    }
    //else {
	gpuErrchk(cudaMemcpy(bb,dev_bases,sizeof(uint64)*HASH_ROW,cudaMemcpyDeviceToHost));
	f = fopen("bases2.txt","w");
	for(int i=0;i<HASH_ROW;++i) fprintf(f,"%llu ",bb[i]);
	fclose(f);
	printf("%%%%%%%%%% FOUND %%%%%%%%%%\n");
    //}
	
    
    //printf("%d seconds\n",gettime());
    
    
    //unsigned int count[256];
    //gpuErrchk(cudaMemcpy(count,dev_count,4*(HASH_ROW),cudaMemcpyDeviceToHost));
    //printf("size: %u\n",count);
    
    /*for(int j=0;j<HASH_ROW;++j) {
	//printf("%d count: %u\n",j,count[j]);
	
	gpuErrchk(cudaMemset(base_adr,0,7));
	gpuErrchk(cudaMemset(base_adr,0x02,1));
	int solution=0;
	for(int base=2;base<262146;++base) { // <1024 
	    gpuErrchk(cudaMemset(dev_ans,0xFF,4));
	    solve<<<512,128>>>(dev_hash_input,dev_count,dev_ans);
	    gpuErrchk( cudaPeekAtLastError() );
	    gpuErrchk( cudaDeviceSynchronize() );
	    gpuErrchk(cudaMemcpy(&ans,dev_ans,4,cudaMemcpyDeviceToHost));
	    //printf("%d\n",j);
	  
	  
	    if(ans!=0) {
	      solution=base;
	      break;
	    }
	    incBase<<<1,1>>>();
	    gpuErrchk( cudaPeekAtLastError() );
	    gpuErrchk( cudaDeviceSynchronize() );
	}
	bb[j]=solution;
      
	printf("%d:   %d      in %d seconds\n",j,solution,gettime());
      
	
	if(solution==0) {
	    //gpuErrchk(cudaFree(dev_ans));
	    //gpuErrchk(cudaFree(dev_input));
	    //gpuErrchk(cudaFree(dev_count));
	    //printf("no solution found\n");
	    //return 0;
	    
	    global_solution=0;
	    break;
	}
	
	incIndex<<<1,1>>>();
	gpuErrchk( cudaPeekAtLastError() );
	gpuErrchk( cudaDeviceSynchronize() );
    }
	   
	    
    // print bases
    if(global_solution==1) {
	FILE *f = fopen("bases.txt","w");
	for(int i=0;i<HASH_ROW;++i) fprintf(f,"%d ",bb[i]);
	fclose(f);
    }
    else printf("Solution not found\n");
    */
    time_t endtime = time(NULL);
    timeinfo = localtime(&endtime);
    printf("Ended: %02d:%02d:%02d   (total %d seconds)\n",timeinfo->tm_hour,timeinfo->tm_min,timeinfo->tm_sec,(int)(endtime-starttime));

    
    /* dealloc */
    gpuErrchk(cudaFree(dev_input));
    gpuErrchk(cudaFree(dev_count));
    gpuErrchk(cudaFree(dev_hash_input));
    gpuErrchk(cudaFree(dev_ans));
    
    return 0;
}
  
  
