// Copyright 2017 <Jozef Brandys>
#ifndef PINTOOL_CLUSTER_CLASIFIER_SEQ_BINARY_H_
#define PINTOOL_CLUSTER_CLASIFIER_SEQ_BINARY_H_

#include <glog/logging.h>
#include <iostream>
#include <vector>

#include "Cluster_clasifier.h"
#include "Typedefs.h"

class ClusterClasifierSeqBinary : public ClusterClasifier {
 protected:
  int size_;
  std::vector<CoMatrix> occ_first_, occ_last_;

 public:
  ClusterClasifierSeqBinary() {}
  explicit ClusterClasifierSeqBinary(std::istream &in) {
    Read(in);
  }

  void Clear() {
    occ_first_.resize(0);
    occ_last_.resize(0);
  }

  void Train(const std::vector<Trace_binary> &traces) {
    Clear();

    int traces_count = traces.size();
    size_ = traces_count == 0 ? 1 : traces[0].size();
    std::vector<std::vector<char16_t>> occ_count_first(size_,
        std::vector<char16_t> (size_, 0));
    std::vector<std::vector<char16_t>> occ_count_last(size_,
        std::vector<char16_t> (size_, 0));

    LOG(INFO) << "Putting traces into maps";
    for (auto &trace : traces) {
      std::vector<int> seq = trace.GetSorted();
      for (uint i = 0; i < seq.size(); i++) {
        for (uint j = 0; j <= i; j++) {
          occ_count_first[seq[i]][seq[j]]++;
        }
      }
      seq = trace.GetSortedReverse();
      for (uint i = 0; i < seq.size(); i++) {
        for (uint j = 0; j <= i; j++) {
          occ_count_last[seq[i]][seq[j]]++;
        }
      }
    }
    LOG(INFO) << "Putting traces into maps finished.";

    occ_first_.resize(size_, CoMatrix(size_));
    occ_last_.resize(size_, CoMatrix(size_));
    for (int i = 0; i < size_; i++) {
      if (occ_count_first[i][i] < 5) continue;
      for (int j = 0; j < size_; j++) {
        if (occ_count_first[i][j] == occ_count_first[i][i]) {
          occ_first_[i][j] = true;
        }
      }
    }
    for (int i = 0; i < size_; i++) {
      if (occ_count_last[i][i] < 5) continue;
      for (int j = 0; j < size_; j++) {
        if (occ_count_last[i][j] == occ_count_last[i][i]) {
          occ_last_[i][j] = true;
        }
      }
    }
  }

  bool Accepts(const std::vector<int> &seq, const std::vector<CoMatrix> &occ) {
    CoMatrix visited(size_);
    for (auto event : seq) {
      visited[event] = true;
      if ((occ[event] & visited) != occ[event]) {
        LOG(INFO) << "Failed pre/postrequisite";
        return false;
      }
    }
    return true;
  }

  bool AcceptsTrace(const Trace_binary &trace) {
    if (size_ == 1) return true;
    return Accepts(trace.GetSorted(), occ_first_) &&
           Accepts(trace.GetSortedReverse(), occ_last_);
  }

  void Print(std::ostream &out) {
    out << size_ << "\n";
    for (auto &r : occ_first_) {
      out << r << "\n";
    }
    for (auto &r : occ_last_) {
      out << r << "\n";
    }
  }
  void Read(std::istream &in) {
    in >> size_;
    occ_first_.resize(size_);
    occ_last_.resize(size_);
    for (auto &r : occ_first_) {
      in >> r;
    }
    for (auto &r : occ_last_) {
      in >> r;
    }
  }
};

#endif  // PINTOOL_CLUSTER_CLASIFIER_SEQ_BINARY_H_
