package cn.wisenergy.chnmuseum.party.common.video;

import javax.crypto.Cipher;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.ShortBufferException;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import java.nio.ByteBuffer;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;

public final class AesFlushingCipher {

    private final Cipher cipher;
    private final int blockSize;
    private final byte[] zerosBlock;
    private final byte[] flushedBlock;

    private int pendingXorBytes;

    public AesFlushingCipher(int mode, byte[] secretKey, long nonce, long offset) {
        try {
            //cipher = Cipher.getInstance("DES/CBC/PKCS5Padding");
            cipher = Cipher.getInstance("AES/CTR/NoPadding");
            blockSize = cipher.getBlockSize();
            zerosBlock = new byte[blockSize];
            flushedBlock = new byte[blockSize];
            long counter = offset / blockSize;
            int startPadding = (int) (offset % blockSize);
            cipher.init(
                    mode,
                    new SecretKeySpec(secretKey, cipher.getAlgorithm().split("/", 2)[0]),
                    new IvParameterSpec(getInitializationVector(nonce, counter)));
            if (startPadding != 0) {
                updateInPlace(new byte[startPadding], 0, startPadding);
            }
        } catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException
                | InvalidAlgorithmParameterException e) {
            throw new RuntimeException(e);
        }
    }

    public void updateInPlace(byte[] data, int offset, int length) {
        update(data, offset, length, data, offset);
    }

    public void update(byte[] in, int inOffset, int length, byte[] out, int outOffset) {
        while (pendingXorBytes > 0) {
            out[outOffset] = (byte) (in[inOffset] ^ flushedBlock[blockSize - pendingXorBytes]);
            outOffset++;
            inOffset++;
            pendingXorBytes--;
            length--;
            if (length == 0) {
                return;
            }
        }

        int written = nonFlushingUpdate(in, inOffset, length, out, outOffset);
        if (length == written) {
            return;
        }

        int bytesToFlush = length - written;
        if (!(bytesToFlush < blockSize)) {
            throw new IllegalStateException();
        }
        outOffset += written;
        pendingXorBytes = blockSize - bytesToFlush;
        written = nonFlushingUpdate(zerosBlock, 0, pendingXorBytes, flushedBlock, 0);
        if (!(written == blockSize)) {
            throw new IllegalStateException();
        }
        for (int i = 0; i < bytesToFlush; i++) {
            out[outOffset++] = flushedBlock[i];
        }
    }

    private int nonFlushingUpdate(byte[] in, int inOffset, int length, byte[] out, int outOffset) {
        try {
            return cipher.update(in, inOffset, length, out, outOffset);
        } catch (ShortBufferException e) {
            throw new RuntimeException(e);
        }
    }

    private byte[] getInitializationVector(long nonce, long counter) {
        return ByteBuffer.allocate(16).putLong(nonce).putLong(counter).array();
    }
}