package tech.espublico.pades.server.signers.sign.helper;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.security.Signature;
import java.security.cert.CertificateException;
import java.security.cert.CertificateExpiredException;
import java.security.cert.CertificateFactory;
import java.security.cert.CertificateNotYetValidException;
import java.security.cert.X509Certificate;
import java.util.Base64;
import java.util.Objects;

import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.NoSuchPaddingException;

import org.bouncycastle.asn1.ASN1InputStream;
import org.bouncycastle.asn1.DEREncodable;
import org.bouncycastle.asn1.DERObject;
import org.bouncycastle.asn1.DERObjectIdentifier;
import org.bouncycastle.asn1.cms.Attribute;
import org.bouncycastle.asn1.cms.ContentInfo;
import org.bouncycastle.asn1.cms.Time;
import org.bouncycastle.asn1.util.ASN1Dump;
import org.bouncycastle.cert.X509CertificateHolder;
import org.bouncycastle.cms.CMSSignedData;
import org.bouncycastle.cms.SignerInformation;
import org.bouncycastle.cms.SignerInformationStore;
import org.bouncycastle.util.Store;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import tech.espublico.pades.server.helper.IOHelper;
import tech.espublico.pades.server.models.PDFSignerModel.ByteMode;
import tech.espublico.pades.server.signers.sign.exceptions.InvalidSignatureException;
import tech.espublico.pades.server.signers.sign.exceptions.InvalidSignatureException.InvalidSignatureExceptionReason;

public class SignatureHelper {

	private static Logger log = LoggerFactory.getLogger(SignatureHelper.class);

	public static X509Certificate validatePKCS7Signature(byte[] bytesToSign, byte[] pkcs7, ByteMode byteMode) throws InvalidSignatureException {
		Objects.requireNonNull(bytesToSign, "bytesToSign");
		Objects.requireNonNull(pkcs7, "pkcs7");

		ContentInfo signatureData;
		Store certStore;

		try {
			log.trace("Reading pkcs7 signature content");
			signatureData = getSignatureContentInfo(pkcs7);
		} catch (IOException e) {
			throw new InvalidSignatureException(InvalidSignatureExceptionReason.COULD_NOT_PARSE_PKCS7, e);
		}

		CMSSignedData signedData = new CMSSignedData(signatureData);
		SignerInformationStore signerInformationStore = signedData.getSignerInfos();
		try {
			log.trace("Reading certificate store from signerInformation");
			certStore = signedData.getCertificates();
		} catch (Exception e) {
			throw new InvalidSignatureException(InvalidSignatureExceptionReason.INVALID_CERTIFICATE, e);
		}

		for (Object signerObj : signerInformationStore.getSigners()) {
			SignerInformation signer = (SignerInformation) signerObj;
			Time time = getSigningTime(signer);

			if (log.isDebugEnabled()) {
				log.debug("Signing time: {}", time.getTime());
			}

			try {
				for (Object signerCertificate : certStore.getMatches(signer.getSID())) {
					X509CertificateHolder certHolder = (X509CertificateHolder) signerCertificate;

					X509Certificate cert = null;
					ByteArrayInputStream stream = null;
					try {
						stream = new ByteArrayInputStream(certHolder.getEncoded());
						cert = (X509Certificate) CertificateFactory.getInstance("X.509").generateCertificate(stream);
					} catch (IOException e) {
						throw new InvalidSignatureException(InvalidSignatureExceptionReason.INVALID_CERTIFICATE, e);
					} finally {
						IOHelper.closeQuietly(stream);
					}

					if (log.isDebugEnabled())
						log.debug("Checking certificate validity\n{}", cert.toString());

					cert.checkValidity(time.getDate());


					switch (byteMode) {
						case HASH:
							validateHashMode(cert, signer.getSignature(), bytesToSign);
							break;
						case RAW:
							validateRawMode(cert, byteMode.getSignatureAlgorithm(), signer.getSignature(), bytesToSign);
							break;
						default:
							throw new IllegalArgumentException("Unexpected byteModel property");
					}

					log.debug("Signature is valid");
					return cert;
				}

			} catch (CertificateExpiredException e) {
				throw new InvalidSignatureException(InvalidSignatureExceptionReason.CERTIFICATE_EXPIRED, e);
			} catch (CertificateNotYetValidException e) {
				throw new InvalidSignatureException(InvalidSignatureExceptionReason.CERTIFICATE_NOT_VALID_YET, e);
			} catch (CertificateException e) {
				throw new InvalidSignatureException(InvalidSignatureExceptionReason.INVALID_CERTIFICATE, e);
			}
		}

		throw new InvalidSignatureException(InvalidSignatureExceptionReason.SIGNER_NOT_FOUND);
	}

	private static void validateHashMode(X509Certificate certificate, byte[] signatureBytes, byte[] bytesToSign) throws InvalidSignatureException {

		byte[] decrypted = decryptSignature(certificate, signatureBytes);
		if (log.isDebugEnabled()) {
			log.debug("Source bytes to sign {}", bytesToSign);
			log.debug("Decrypted bytes      {}", decrypted);
		}

		if (decrypted.length != bytesToSign.length)
			throw new InvalidSignatureException(InvalidSignatureExceptionReason.INVALID_SIGNATURE);

		for (int i = 0; i < decrypted.length; i++) {
			if (decrypted[i] != bytesToSign[i]) {
				throw new InvalidSignatureException(InvalidSignatureExceptionReason.INVALID_SIGNATURE);
			}
		}
	}

	private static void validateRawMode(X509Certificate certificate, String signatureAlgorithm, byte[] signatureBytes, byte[] bytesToSign)
			throws InvalidSignatureException {

		try {
			Signature sig = Signature.getInstance(signatureAlgorithm);
			sig.initVerify(certificate.getPublicKey());
			sig.update(bytesToSign);
			boolean verified = sig.verify(signatureBytes);
			if (!verified)
				throw new InvalidSignatureException(InvalidSignatureExceptionReason.INVALID_SIGNATURE);

		} catch (Exception e) {
			throw new InvalidSignatureException(InvalidSignatureExceptionReason.INVALID_SIGNATURE);
		}
	}

	private static byte[] decryptSignature(X509Certificate cert, byte[] signatureBytes) throws InvalidSignatureException {
		Objects.requireNonNull(cert, "cert");
		Objects.requireNonNull(signatureBytes, "Signature Bytes");

		Cipher rsa_cipher;
		byte[] signature = null;
		try {
			rsa_cipher = Cipher.getInstance("RSA/ECB/PKCS1Padding");
			rsa_cipher.init(Cipher.DECRYPT_MODE, cert);
			signature = rsa_cipher.doFinal(signatureBytes);
		} catch (NoSuchAlgorithmException e) {
			throw new InvalidSignatureException(InvalidSignatureExceptionReason.INVALID_SIGNATURE, e);
		} catch (NoSuchPaddingException e) {
			throw new InvalidSignatureException(InvalidSignatureExceptionReason.INVALID_SIGNATURE, e);
		} catch (InvalidKeyException e) {
			throw new InvalidSignatureException(InvalidSignatureExceptionReason.INVALID_KEY, e);
		} catch (IllegalBlockSizeException e) {
			throw new InvalidSignatureException(InvalidSignatureExceptionReason.INVALID_SIGNATURE, e);
		} catch (BadPaddingException e) {
			throw new InvalidSignatureException(InvalidSignatureExceptionReason.INVALID_SIGNATURE, e);
		}
		return signature;
	}

	private static Time getSigningTime(SignerInformation signer) throws InvalidSignatureException {
		Objects.requireNonNull(signer, "signer information");
		Objects.requireNonNull(signer.getSignedAttributes(), "signer attributes");

		Attribute signinTimeAttribute = signer.getSignedAttributes().get(new DERObjectIdentifier("1.2.840.113549.1.9.5")); // Signin time
		if (signinTimeAttribute == null)
			throw new InvalidSignatureException(InvalidSignatureExceptionReason.NO_SIGNIN_TIME);

		if (signinTimeAttribute.getAttrValues().size() != 1)
			throw new InvalidSignatureException(InvalidSignatureExceptionReason.INVALID_SIGNED_ATTRIBUTES);

		DEREncodable signingTime = signinTimeAttribute.getAttrValues().getObjectAt(0);
		return Time.getInstance(signingTime.getDERObject());
	}

	private static ContentInfo getSignatureContentInfo(byte[] pkcs7) throws IOException {
		Objects.requireNonNull(pkcs7, "Encoded PKCS7 Signature");

		if (log.isTraceEnabled())
			log.trace("Parsing PKCS7 signature: \n{}", Base64.getEncoder().encodeToString(pkcs7));

		ContentInfo contentInfo;
		try {
			DERObject derPKCS7 = new ASN1InputStream(pkcs7).readObject();
			contentInfo = ContentInfo.getInstance(derPKCS7);
		} catch (IOException e) {
			log.error("Could not read PKCS7 signagure byte array", e);
			throw e;
		}

		if (log.isDebugEnabled()) {
			log.debug("Content Info", ASN1Dump.dumpAsString(contentInfo));
		}
		return contentInfo;
	}

}
