// Copyright 2017 <Jozef Brandys>
#include <sys/stat.h>
#include <unistd.h>
#include <cassert>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <iomanip>
#include <map>
#include <unordered_map>
#include <functional>

#include "pin.H"
#include "Trace.h"
#include "Trace_basic.h"
#include "Trace_seq.h"
#include "Trace_all.h"
#include "Trace_stat.h"

/* ===================================================================== */
/* Commandline Switches */
/* ===================================================================== */

KNOB<string> KnobOutput(KNOB_MODE_WRITEONCE, "pintool", "output",
  "", "output directory");
KNOB<string> KnobFiles(KNOB_MODE_WRITEONCE, "pintool", "files",
  "", "File containing list of files to trace");

KNOB<BOOL> KnobIRET(KNOB_MODE_WRITEONCE, "pintool", "ret", "1",
  "instrument ret instructions");
KNOB<BOOL> KnobICALL(KNOB_MODE_WRITEONCE, "pintool", "call", "1",
  "instrument call instructions");
KNOB<BOOL> KnobIBRANCH(KNOB_MODE_WRITEONCE, "pintool", "branch", "1",
  "instrument branch instructions");

KNOB<UINT32> KnobSize(KNOB_MODE_WRITEONCE, "pintool", "size", "13",
  "Size of tracing table.");
KNOB<BOOL> KnobStats(KNOB_MODE_WRITEONCE, "pintool", "stats", "0",
  "Print stats instead of tracing");

/* ===================================================================== */
/* Print Help Message                          */
/* ===================================================================== */

static INT32 Usage() {
  cerr << KNOB_BASE::StringKnobSummary();
  cerr << endl;
  return -1;
}

/* ===================================================================== */
/* Global Variables */
/* ===================================================================== */

string trace_name, data_name, log_name;
std::ofstream trace_out, data_out, log_out;

void reopen(ofstream &f, const string &name) {
  if (!f) {
    f.close();
    f.clear();
    f.open(name.c_str(), fstream::out | fstream::app);
    if (!f) {
      abort();
    }
  }
}

set<string> instrument_img;
set<string> seen_img;
bool Instrument(INS ins) {
  ADDRINT addr = INS_Address(ins);
  IMG img = IMG_FindByAddress(addr);
  if (!instrument_img.empty()) {
    if (IMG_Valid(img) && instrument_img.count(IMG_Name(img))) {
    } else {
      return false;
    }
  } else {
    if (IMG_Valid(img) && seen_img.insert(IMG_Name(img)).second) {
      reopen(data_out, data_name);
      data_out << "IMG: " << IMG_Name(img) << endl;
    }
  }
  if (INS_IsRet(ins)) {
    return KnobIRET;
  } else if ( INS_IsCall(ins) ) {
    return KnobICALL;
  } else if (INS_IsBranch(ins)) {
    return KnobIBRANCH;
  }
  return false;
}

bool Instrument(const BBL &bbl) {
  ADDRINT addr = BBL_Address(bbl);
  IMG img = IMG_FindByAddress(addr);
  if (!instrument_img.empty()) {
    if (IMG_Valid(img) && instrument_img.count(IMG_Name(img))) {
    } else {
      return false;
    }
  } else {
    if (IMG_Valid(img) && seen_img.insert(IMG_Name(img)).second) {
      reopen(data_out, data_name);
      data_out << "IMG: " << IMG_Name(img) << endl;
    }
  }
  return true;
}

Trace *trace = NULL;

/* ===================================================================== */

VOID trace_ins(INT32 ID) {
  trace->trace(ID);
}

VOID trace_bbl(INT32 ID) {
  trace->trace(ID);
}

hash<string> hash_string;
hash<int> hash_int;
int GenerateID(ADDRINT addr) {  // Generates pseudo random IDs
  PIN_LockClient();
  IMG img = IMG_FindByAddress(addr);
  INT32 ret = 0;
  if (IMG_Valid(img)) {
    ret = hash_string(IMG_Name(img)) ^ hash_int(addr - IMG_LowAddress(img));
    srand(ret);
    ret = rand();
  }
  PIN_UnlockClient();
  return ret;
}

INT32 instructed = 0;
INT32 no_id = 0;
VOID Instruction(INS ins, void *v) {
  if (Instrument(ins)) {
    instructed++;
    INT32 randomID = GenerateID(INS_Address(ins));
    if (randomID != 0) {
      INS_InsertCall(ins, IPOINT_BEFORE, (AFUNPTR) trace_ins, IARG_UINT32,
        randomID, IARG_END);
    } else {
      no_id++;
      // Nevieme vyprodukovat ID, ktore by sa nemenilo medzi behmi.
    }
  }
}

VOID Trace(TRACE trace, VOID *v) {
  instructed++;
  // Visit every basic block in the trace
  for (BBL bbl = TRACE_BblHead(trace); BBL_Valid(bbl); bbl = BBL_Next(bbl)) {
    if (!Instrument(bbl)) continue;
    INT32 randomID = GenerateID(BBL_Address(bbl));
    // Insert a call to CountBbl() before every basic bloc, passing
    // the number of instructions
    BBL_InsertCall(bbl, IPOINT_BEFORE, (AFUNPTR) trace_bbl, IARG_UINT32,
        randomID, IARG_END);
  }
}


/* ===================================================================== */

VOID Fini(int n, void *v) {
  reopen(trace_out, trace_name);
  reopen(data_out, data_name);
  reopen(log_out, log_name);

  trace->setRange(0, 0);
  trace->Print(trace_out);

  trace->PrintDebug(log_out);
  log_out << "Instructed: " << instructed << endl;
  log_out << "Without id: " << no_id << endl;
}



/* ===================================================================== */
/* Main                                  */
/* ===================================================================== */

int main(int argc, char *argv[]) {
  cerr << "Starting diplo tool" << endl;
  if (PIN_Init(argc, argv)) {
    return Usage();
  }
  char cwd[1024];
  if (getcwd(cwd, sizeof(cwd)) == NULL) {
    cerr << "Cannot determine current directory" << endl;
    abort();
  }

  string directory =
    KnobOutput.Value().empty() == true ? string(cwd) : KnobOutput.Value();
  trace_name = directory + "/dipl_trace";
  data_name =  directory + "/dipl_data";
  log_name = directory + "/dipl_log";

  trace_out.open(trace_name.c_str(), fstream::out | fstream::app);
  data_out.open(data_name.c_str(), fstream::out | fstream::app);
  log_out.open(log_name.c_str(), fstream::out | fstream::app);

  for (int i = 0; i < argc; i++) {
    data_out << argv[i] << " ";
  }
  data_out << endl;


  if (!KnobFiles.Value().empty()) {
    ifstream img_names(KnobFiles.Value().c_str());
    string img_name;
    log_out << "Allowing only: " << endl;
    while (img_names >> img_name) {
      instrument_img.insert(img_name);
      log_out << img_name << endl;
    }
  }
  if (KnobStats) {
    trace = new Trace_stat(KnobSize.Value());
  } else {
    trace = new Trace_all(KnobSize.Value());
  }
  TRACE_AddInstrumentFunction(Trace, 0);
  PIN_AddFiniFunction(Fini, 0);

  // Never returns

  PIN_StartProgram();
  return 0;
}
