package james;

import android.text.TextUtils;
import android.util.Base64;
import android.util.Log;

import net.f5.crypt.F5Random;
import net.f5.crypt.Permutation;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;

import info.guardianproject.f5android.Extract;
import sun.security.provider.SecureRandom;

/**
 * Created by askar on 22.4.2016.
 */
public class DCTStegoCERep implements DCTSteganography {
// complementary embedding, rep times

    // discussed in part 4.4 (for S-family attack resistance)
    final private double alpha = 4.0 / 5.0;
    final private double beta = 1.0 / 3.0;
    final private int maxLenghtBits = 8 * 2;
    final private int maxCoeffs = 50000000;
    final private int maxCoeffRange = 255 / 2;
    final private int rep = 1;

    String randomSeed = "a";

    public DCTStegoCERep() {
        // intentionally left blank
    }

    public void testAll () {
        Log.d(Jpeg.LOG, "testAll() begin");
        boolean t1 = testEncodeBits2Byte();
        boolean t2 = testEncodeInt2Bytes();
        boolean t3 = testEmbedExtractBit();
        boolean t4 = testRandom();
        Log.d(Jpeg.LOG, String.format("TESTS: t1 %b t2 %b t3 %b t4 %b", t1, t2, t3, t4));
        Log.d(Jpeg.LOG, "testAll() end");
    }

    public boolean testEncodeBits2Byte () {
        int[] bits = new int[8];
        boolean res = true;
        for (int i = 0; i < 256; i++) {
            this.encodeByteToBits(i, bits, 0);
            final int j = this.encodeBitsToByte(bits, 0);
            if (i != j && res) {
                res = false;
                Log.wtf(Jpeg.LOG, String.format("incorrect bs2b: i: %d j: %d", i, j));
                return false;
            }
        }
        return res;
    }

    public boolean testEncodeInt2Bytes () {
        boolean res = true;
        int[] bytes = new int[this.maxLenghtBits / 8];
        for(int i = 0; i < (1 << this.maxLenghtBits); i++) {
            this.encodeIntToBytes(i, bytes, 0, this.maxLenghtBits / 8);
            final int j = this.encodeBytesToInt(bytes, 0, this.maxLenghtBits / 8);
            if (i != j && res) {
                res = false;
                Log.wtf(Jpeg.LOG, String.format("incorrect i2bs: i: %d j: %d", i, j));
                return false;
            }
        }

        return res;
    }

    public boolean testEmbedExtractBit () {
        boolean res = true;
        for (int type = 1; type <= 2; type++) {
            for (int c = -10000; c <= 10000; c++) {
                if (c == 0) continue;
                for (int bit = 0; bit <= 1; bit++) {
                    final int d = this.embedBit(type, c, bit);
                    if (d == 0) {
                        Log.wtf(Jpeg.LOG, String.format("0 generated by embedding: type: %d c: %d bit: %d d: %d", type, c, bit, d));
                    }
                    final int e = this.extractBit(type, d);
                    if (bit != e && res) {
                        Log.wtf(Jpeg.LOG, String.format("embedding b: type: %d c: %d bit: %d d: %d e:%d", type, c, bit, d, e));
                        res = false;
                        return false;
                    }
                }
            }
        }
        return res;
    }

    public boolean testRandom () {
        int n = 40;
        int m = 1000;

        int[] a;
        int[] b;
        for (int j = 0; j < m; j++) {
            a = getRandomPermutation(n);
            b = getRandomPermutation(n);

            for (int i = 0; i < n; i++) {
                if (a[i] != b[i]) {
                    Log.d(Jpeg.LOG, "different randoms");
                    return false;
                }
            }
        }
        return true;
    }

    public boolean loadKeyFromString (final String s) {
        if (s != null){
            this.randomSeed = s;
            return true;
        }
        else {
            return false;
        }
    }

    private boolean isOdd (final int x) {
        if (x >= 0) {
            return (x % 2) == 1;
        }
        else {
            return ((-x) % 2) == 1;
        }
    }

    private boolean isEven (final int x) {
        return !isOdd(x);
    }

    public int[] selectNonZeroCoeffs (final int[] coeff) {
        int[] allNonZeroCoeffs = new int[coeff.length];
        int nonZeroCount = 0;
        for (int i = 0; i < coeff.length; i++) {
            if (coeff[i] != 0) {
                allNonZeroCoeffs[nonZeroCount] = i;
                nonZeroCount++;
            }
        }
        int[] res = new int[nonZeroCount];
        System.arraycopy(allNonZeroCoeffs, 0, res, 0, nonZeroCount);
        return res;
    }

    public String generateStegoKeyString () {
        byte[] a = new byte[3 * 20];
        java.security.SecureRandom srand = new java.security.SecureRandom();
        srand.nextBytes(a);
        return Base64.encodeToString(a, Base64.NO_WRAP);
    }

    public int[] getRandomPermutation (final int amount) {
        // @DEBUG
//        int[] res = new int[amount];
//        for (int i = 0; i < amount; i++) {
//            res[i] = i;
//        }
//        return res;

        final F5Random random = new F5Random(this.randomSeed);
        final Permutation permutation = new Permutation(amount, random);

        int[] perm = new int[amount];
        for (int i = 0; i < amount; i++) {
            perm[i] = permutation.getShuffled(i);
        }
        return perm;
    }

    public int[] getPermutedNonZeroCoeffs (final int[] coeff) {
        int[] nonZeroCoeffs = this.selectNonZeroCoeffs(coeff);
        int[] perm = this.getRandomPermutation(nonZeroCoeffs.length);

        int[] qperm = new int[perm.length];
        for (int i = 0; i < qperm.length; i++) {
            qperm[i] = nonZeroCoeffs[perm[i]];
        }
        return qperm;
    }

    public int[] getNon64PermutedNonZeroCoeffs (final int[] coeff) {
        int[] qperm = this.getPermutedNonZeroCoeffs(coeff);
        int non64count = 0;
        for (int i = 0; i < qperm.length; i++) {
            if (qperm[i] % 64 != 0) {
                // good one
                qperm[non64count] = qperm[i];
                non64count++;
            }
        }

        int[] res = new int[non64count];
        System.arraycopy(qperm, 0, res, 0, non64count);
        return res;
    }

    public int[] getSmallCoeffs(final int[] coeff, final int type) {
        int[] qperm = this.getNon64PermutedNonZeroCoeffs(coeff);

        int[] good_coeffs = new int[qperm.length];
        int goodCount = 0;

        int maxCoeff = -100000;
        int minCoeff = 100000;

        for (int i = 0; i < qperm.length; i++) {
            if (coeff[qperm[i]] < -this.maxCoeffRange || coeff[qperm[i]] > this.maxCoeffRange) {
                //continue;
            }

            good_coeffs[goodCount] = qperm[i];
            goodCount++;

            if (coeff[qperm[i]] > maxCoeff) maxCoeff = coeff[qperm[i]];
            if (coeff[qperm[i]] < minCoeff) minCoeff = coeff[qperm[i]];
        }

        Log.d(Jpeg.LOG, String.format("maxCoeff %d minCoeff %d", maxCoeff, minCoeff));

        int[] res = new int[goodCount];
        System.arraycopy(good_coeffs, 0, res, 0, goodCount);
        return res;
    }

    public int[] getNon64NonZeroPermutedCoeffs(final int[] coeff) {
        int[] perm = this.getRandomPermutation(coeff.length);
        int[] nonZeroNon64perm = new int[coeff.length];
        int goodCoeffCount = 0;
        for (int i = 0; i < coeff.length; i++) {
            final int q = perm[i];
            if (coeff[q] == 0) continue;
            if (q % 64 == 0) continue;
            nonZeroNon64perm[goodCoeffCount] = q;
            goodCoeffCount++;
        }

        int[] res = new int[goodCoeffCount];
        System.arraycopy(nonZeroNon64perm, 0, res, 0, goodCoeffCount);
        return res;
    }

    public int[] getGoodCoeffs(final int[] coeff) {
        int[] a = getRandomPermutation(coeff.length);
        int non64 = 0;
        for (int i = 0; i < a.length; i++) {
            if (a[i] % 64 != 0) {
                a[non64] = a[i];
                non64++;
            }
        }
        int[] res = new int[non64];
        System.arraycopy(a, 0, res, 0, non64);
        return res;
    }

    public void encodeByteToBits (final int x, int[] dst, final int pos) {
//        for (int i = 0; i < 8; i++) {
//            dst[pos + i] = (x >> i) & 1;
//        }
        int q = x;
        for (int i = 0; i < 8; i++) {
            final int res = q % 2;
            dst[pos + i] = res;
            q /= 2;
        }
    }

    public int encodeBitsToByte (int[] src, final int pos) {
//        int x = 0;
//        for (int i = 0; i < 8; i++) {
//            x += src[pos + i] << i;
//        }
//        return x;
        int q = 0;
        for (int i = 7; i >= 0; i--) {
            q = (2 * q + src[pos + i]);
        }
        return q;
    }

    public void encodeIntToBytes (final int x, int[] dst, final int pos, final int length) {
        for (int i = 0; i < length; i++) {
            dst[pos + i] = (x >> (8 * i)) & ((1 << 8) - 1);
        }
    }

    public int encodeBytesToInt (final int[] src, final int pos, final int length) {
        int x = 0;
        for (int i = 0; i < length; i++) {
            x += src[pos + i] << (8*i);
        }
        return x;
    }

    public int embedBitUnsafe (final int type, final int c, final int bit) {
        if (c == 0) {
            Log.wtf(Jpeg.LOG, "zero coeff!?!");
            // @TODO add correct handling
            return 0;
        }
        if (type == 1) {
            if (c > 0 && isOdd(c)) {
                if (bit == 0 && c-1 == 0) return c-2;
                else if (bit == 0 && c-1 != 0) return c-1;
                else if (bit == 1) return c;
            }
            else if (c > 0 && isEven(c)) {
                if (bit == 1) return c-1;
                else if (bit == 0) return c;
            }
            else if (c < 0 && isOdd(c)) {
                if (bit == 1) return c-1;
                else if (bit == 0) return c;
            }
            else if (c < 0 && isEven(c)) {
                if (bit == 0) return c-1;
                    // in paper the condition is (c == 1), but it doesn't make much sense
                else if (bit == 1) return c;
            }
        }
        else if (type == 2) {
            if (c > 0 && isOdd(c)) {
                if (bit == 1) return c+1;
                else if (bit == 0) return c;
            }
            else if (c > 0 && isEven(c)) {
                if (bit == 0) return c+1;
                else if (bit == 1) return c;
            }
            else if (c < 0 && isOdd(c)) {
                if (bit == 0 && c + 1 == 0) return c+2;
                else if (bit == 0 && c + 1 != 0) return c+1;
                else if (bit == 1) return c;
            }
            else if (c < 0 && isEven(c)) {
                if (bit == 1) return c+1;
                else if (bit == 0) return c;
            }
        }
        else {
            Log.wtf(Jpeg.LOG, "type != {1, 2}");
            return 0;
        }

        Log.wtf(Jpeg.LOG, "embed bit: this should never occur");
        return 0;
    }

    public int embedBit (final int type, final int c, final int bit) {
        final int res = this.embedBitUnsafe(type, c, bit);
        if (res == 0) {
            Log.wtf(Jpeg.LOG, String.format("we embedded zero?!! type: %d c: %d bit: %d", type, c, bit));
        }
        return res;
    }

    public int extractBit (final int type, final int c) {
        if (c == 0) {
            Log.wtf(Jpeg.LOG, "zero coeff!?!");
            // @TODO add correct handling
            return 0;
        }
        if (type == 1) {
            if (c > 0 && isEven(c)) return 0;
            if (c < 0 && isOdd(c))  return 0;
            if (c > 0 && isOdd(c))  return 1;
            if (c < 0 && isEven(c)) return 1;
        }
        else if (type == 2) {
            if (c > 0 && isOdd(c))  return 0;
            if (c < 0 && isEven(c)) return 0;
            if (c > 0 && isEven(c)) return 1;
            if (c < 0 && isOdd(c))  return 1;
        }
        else {
            Log.wtf(Jpeg.LOG, "type != {1, 2}");
            return 0;
        }

        Log.wtf(Jpeg.LOG, "extract bit: this should never occur");
        return 0;
    }

    public int extractBitRep (final int type, final int[] c, final int pos) {
        if (c.length < rep + pos) {
            Log.wtf(Jpeg.LOG, "incorrect tuple length!!");
            // @TODO correct handling
        }
        int sum = 0;
        for (int z = 0; z < rep; z++) {
            sum += this.extractBit(type, c[pos + z]);
        }

        if (sum > (rep/2)) {
            return 1;
        }
        else {
            return 0;
        }
    }

    public void embedPart (final int type, int[] coeff, final int[] qperm, final int qbegin, final int qend, final int[] M) {
        Log.d(Jpeg.LOG, String.format("embedPart: type: %d qbegin: %d qend: %d Mlength: %d", type, qbegin, qend, M.length));
        if (rep * 8 * M.length > qend - qbegin) {
            Log.e(Jpeg.LOG, "message is too long");
            // @TODO add proper handling
        }

        int[] octet = new int[8];
        int zeroCoeffs = 0;
        for (int j = 0; j < Math.min((qend - qbegin)/(8 *rep), M.length); j++) {
            this.encodeByteToBits(M[j], octet, 0);
            for (int k = 0; k < 8; k++) {
                for (int z = 0; z < rep; z++) {
                    int qi = qbegin + (j * 8 + k) * rep + z + zeroCoeffs;

                    while (coeff[qperm[qi]] == 0) {
                        qi += 1;
                        zeroCoeffs += 1;
                    }
                    final int i = qperm[qi];
                    coeff[i] = this.embedBit(type, coeff[i], octet[k]);
                }
            }
        }

        zeroCoeffs = 0;
        if (type == 1) {
            for (int j = 0; j < (int) (this.beta * M.length * 8 * rep); j++) {
                int qi = qbegin + j + zeroCoeffs;
                while (coeff[qperm[qi]] == 0) {
                    qi += 1;
                    zeroCoeffs += 1;
                }
                final int i = qperm[qi];
                if (coeff[i] == -2) coeff[i] = 1;
            }
        }

//        {
//            ArrayList<String> s = new ArrayList<String>(0);
//            for (int i = 0; i < (8 * M.length) * rep; i++) {
//                s.add(String.format("\t%d (%d)", coeff[qperm[qbegin + i]], this.extractBit(type, coeff[qperm[qbegin + i]])));
//            }
//            Log.d(Jpeg.LOG, "coeff:" + TextUtils.join("", s));
//        }
    }

    public void extractPart (final int type, final int[] coeff, final int[] qperm, final int qbegin, final int qend, ByteArrayOutputStream fos) {
        this.extractPart(type, coeff, qperm, qbegin, qend, fos, null);
    }

    public void extractPart (final int type, final int[] coeff, final int[] qperm, int qbegin, final int qend, ByteArrayOutputStream fos, int[] S) {
        Log.d(Jpeg.LOG, String.format("extractPart: type: %d qbegin: %d qend: %d", type, qbegin, qend));
        int[] lengthCoeffsRep = new int[maxLenghtBits * rep];
        if (qend - qbegin < maxLenghtBits * rep) {
            Log.wtf(Jpeg.LOG, "Coeffs are too small for length encoding?!!");
            return;
        }

        int zeroCoeffs = 0;

        for (int i = 0; i < lengthCoeffsRep.length; i++) {
            int qi = qbegin + i + zeroCoeffs;
            while (coeff[qperm[qi]] == 0) {
                qi += 1;
                zeroCoeffs += 1;
            }
            lengthCoeffsRep[i] = coeff[qperm[qi]];
        }
        int[] lengthBits = new int[maxLenghtBits];
        for (int i = 0; i < maxLenghtBits; i++) {
            lengthBits[i] = this.extractBitRep(type, lengthCoeffsRep, rep * i);
        }
        int[] lengthBytes = new int[maxLenghtBits/8];
        for (int i = 0; i < maxLenghtBits/8; i++) {
            lengthBytes[i] = this.encodeBitsToByte(lengthBits, 8*i);
        }

        final int L = this.encodeBytesToInt(lengthBytes, 0, maxLenghtBits/8);
        Log.d(Jpeg.LOG, String.format("L: %d", L));

        qbegin += maxLenghtBits * rep;

        int[] octet = new int[8];
        int[] tuple = new int[rep];
        for (int j = 0; j < Math.min(L, (qend - qbegin)/(8 * rep)); j++) {
            for (int k = 0; k < 8; k++) {
                for (int z = 0; z < rep; z++) {
                    int qi = qbegin + (j * 8 + k) * rep + z + zeroCoeffs;
                    while (coeff[qperm[qi]] == 0) {
                        qi += 1;
                        zeroCoeffs += 1;
                    }
                    final int i = qperm[qi];
                    tuple[z] = coeff[i];
                }
                octet[k] = this.extractBitRep(type, tuple, 0);
            }
            final int b = this.encodeBitsToByte(octet, 0);

            if (S != null) {
                S[j] = b;
            }
            else {
                fos.write((byte) b);
            }
        }

//        {
//            ArrayList<String> s = new ArrayList<String>(0);
//            for (int i = 0; i < (maxLenghtBits  + 8 * L) * rep; i++) {
//                s.add(String.format("\t%d (%d)", coeff[qperm[qbegin + i]], this.extractBit(type, coeff[qperm[qbegin + i]])));
//            }
//            Log.d(Jpeg.LOG, "coeff:" + TextUtils.join("", s));
//        }
    }

    public int getQSep (final int length) {
        return (int) (this.alpha * length);
    }

    public void embed (int coeffOrig[], final InputStream embeddedData) {
        {
            String c = "";
            for (int i = 0; i < 1000; i++) {
                c += String.valueOf(coeffOrig[i]) + " ";
            }
            Log.d(Jpeg.LOG, "coeffOrig: " + c);
        }
        try{
            final int Clength = Math.min(this.maxCoeffs, coeffOrig.length);
            int[] coeff = new int[Clength];
            System.arraycopy(coeffOrig, 0, coeff, 0, Clength);

            Log.d(Jpeg.LOG, "Embedding started");
            int[] qperm = this.getGoodCoeffs(coeff);
            //int[] qperm = this.getNon64PermutedNonZeroCoeffs(coeff);
            Log.d(Jpeg.LOG, String.format("qperm length: %d", qperm.length));

            {
                String c = "";
                for (int i = 0; i < 1000; i++) {
                    c += String.valueOf(coeff[qperm[i]]) + " ";
                }
                Log.d(Jpeg.LOG, "coeff: " + c);
            }

            {
                String c = "";
                for (int i = 0; i < 1000; i++) {
                    c += String.valueOf(qperm[i]) + " ";
                }
                Log.d(Jpeg.LOG, "qperm: " + c);
            }

            final int Qsep = this.getQSep(qperm.length);

//            int Slength = 0;
//            try {
//                Slength = embeddedData.available();
//            } catch (IOException e) {
//                // @TODO serious error handling
//                Log.wtf(Jpeg.LOG, "error during embeddedData.length extraction");
//                e.printStackTrace();
//            }
//            Log.d(Jpeg.LOG, String.format("Slength: %d", Slength));
//
//            int[] S = new int[Slength];
//            for (int i = 0; i < Slength; i++) {
//                try {
//                    S[i] = embeddedData.read();
//                } catch (IOException e) {
//                    // @TODO add proper error handling
//                    Log.wtf(Jpeg.LOG, "error during embeddedData extraction");
//                    e.printStackTrace();
//                }
//            }

            int Slength = 0;
            int[] S = new int[50000];
            readEmbed: do {
                int k = embeddedData.available();
                if (k == 0) {
                    break readEmbed;
                }
                S[Slength] = embeddedData.read();
                Slength += 1;
            } while (true);


            final int Ssep = this.getQSep(Slength);
            int[] M1 = new int[maxLenghtBits/8 + Ssep];
            System.arraycopy(S, 0, M1, maxLenghtBits/8, Ssep);
            this.encodeIntToBytes(M1.length - maxLenghtBits / 8, M1, 0, maxLenghtBits / 8);

            int[] M2 = new int[maxLenghtBits/8 + Slength - Ssep];
            System.arraycopy(S, Ssep, M2, maxLenghtBits/8, Slength - Ssep);
            this.encodeIntToBytes(M2.length - maxLenghtBits / 8, M2, 0, maxLenghtBits / 8);

            embedPart(1, coeff, qperm, 0, Qsep, M1);
            embedPart(2, coeff, qperm, Qsep, qperm.length, M2);
            System.arraycopy(coeff, 0, coeffOrig, 0, Clength);
            Log.d(Jpeg.LOG, "Embedding ended");

            //testing
            int[] ES1 = new int[Slength + 4];
            this.extractPart(1, coeff, qperm, 0, Qsep, null, ES1);
            boolean s1correct = true;
            for (int i = 0; i < Ssep; i++) {
                if (ES1[i] != M1[maxLenghtBits/8 + i]) {
                    s1correct = false;
                    break;
                }
            }
            if (!s1correct) {
                Log.e(Jpeg.LOG, "first part embedded incorrectly :(((");
            }

            int[] ES2 = new int[Slength + 4];
            this.extractPart(2, coeff, qperm, Qsep, qperm.length, null, ES2);
            boolean s2correct = true;
            for (int i = 0; i < Slength - Ssep; i++) {
                if (ES2[i] != M2[maxLenghtBits/8 + i]) {
                    s2correct = false;
                    break;
                }
            }
            if (!s2correct) {
                Log.e(Jpeg.LOG, "second part embedded incorrectly :(((");
            }

//            // check for nonzero count
//            int[] nz = this.getNon64PermutedNonZeroCoeffs(coeff);
//            if (nz.length != qperm.length) {
//                Log.wtf(Jpeg.LOG, String.format("Length of coeffs is not correct! old: %d new: %d", qperm.length, nz.length));
//            }
//
//            {
//                String c = "";
//                for (int i = 0; i < 1000; i++) {
//                    c += String.valueOf(coeff[qperm[i]]) + " ";
//                }
//                Log.d(Jpeg.LOG, "coeffOrig: " + c);
//            }
        }
        catch (Exception e) {
            Log.wtf(Jpeg.LOG, String.format("Something went brutally wrong during embedding :/ Error: %s", e.toString()));
        }

        // @MEGADEBUG3000
        //for (int i = 0; i < coeffOrig.length; i++) coeffOrig[i] = 0;
    }

    public void extract (int coeffOrig[],  ByteArrayOutputStream fos, Extract.ExtractionListener listener) {
        {
            // @MEGADEBUG3000
            boolean isNonZero = false;
            for (int i = 0; i < coeffOrig.length; i++) {
                if (coeffOrig[i] != 0) {
                    isNonZero = true;
                    break;
                }
            }
            Log.d(Jpeg.LOG, String.format("isNonZero: %b", isNonZero));
        }

        {
            String c = "";
            for (int i = 0; i < 1000; i++) {
                c += String.valueOf(coeffOrig[i]) + " ";
            }
            Log.d(Jpeg.LOG, "coeffOrig: " + c);
        }

        try {
            final int Clength = Math.min(this.maxCoeffs, coeffOrig.length);
            int[] coeff = new int[Clength];
            System.arraycopy(coeffOrig, 0, coeff, 0, Clength);

            Log.d(Jpeg.LOG, "Extraction started");
            int[] qperm = this.getGoodCoeffs(coeff);
            //int[] qperm = this.getNon64PermutedNonZeroCoeffs(coeff);
            Log.d(Jpeg.LOG, String.format("qperm length: %d", qperm.length));
            {
                String c = "";
                for (int i = 0; i < 1000; i++) {
                    c += String.valueOf(coeff[qperm[i]]) + " ";
                }
                Log.d(Jpeg.LOG, "coeffOrig: " + c);
            }


            {
                String c = "";
                for (int i = 0; i < 1000; i++) {
                    c += String.valueOf(qperm[i]) + " ";
                }
                Log.d(Jpeg.LOG, "qperm: " + c);
            }

            final int Qsep = this.getQSep(qperm.length);

            extractPart(1, coeff, qperm, 0, Qsep, fos);
            extractPart(2, coeff, qperm, Qsep, qperm.length, fos);
            Log.d(Jpeg.LOG, "Extraction ended");
            Log.d(Jpeg.LOG, String.valueOf(fos));
            listener.onExtractionResult(fos);
        }
        catch (Exception e) {
            Log.wtf(Jpeg.LOG, String.format("Something went brutally wrong during extracting :/ Error: %s", e.toString()));
        }
    }
}

