package tech.espublico.pades.server.rest.auth;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Base64;
import java.util.Date;
import java.util.Objects;

import javax.servlet.http.HttpServletRequest;

import org.apache.commons.codec.binary.Hex;
import org.apache.commons.codec.digest.DigestUtils;
import org.apache.commons.codec.digest.HmacUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import spark.Request;
import tech.espublico.pades.server.rest.PadesServerHeaders;
import tech.espublico.pades.server.di.Service;
import tech.espublico.pades.server.di.ServiceLocator;
import tech.espublico.pades.server.services.GsonService;
import tech.espublico.pades.server.services.PropertyValidatorException;
import tech.espublico.pades.server.services.ConfigService;

@Service
public class AuthService {

	private static final Logger log = LoggerFactory.getLogger(AuthService.class);

	public static AuthService instance() {
		return ServiceLocator.INSTANCE.getInstance(AuthService.class);
	}

	private final GsonService gsonService;

	private final String superSecret;
	private final long maxLife;

	public AuthService() throws PropertyValidatorException {
		this.gsonService = GsonService.instance();
		this.superSecret = ConfigService.instance().getString("auth.super-secret");
		this.maxLife = ConfigService.instance().getLong("auth.max-life");
	}

	/**
	 * Validates only the sessionId and the uri from the request
	 *
	 * @param request
	 * 		Spark request
	 *
	 * @throws UnauthorizedException
	 * 		in case of unauthorized request
	 */
	public void validateWithoutData(Request request) throws UnauthorizedException {
		String sessionId = request.headers(PadesServerHeaders.SESSION_ID_HEADER);
		HttpServletRequest servletRequest = request.raw();
		StringBuffer requestURL = servletRequest.getRequestURL();
		log.trace("Auth uri [{}]", requestURL.toString());
		log.trace("Auth sessionId [{}]", sessionId);
		String authToken = request.headers(PadesServerHeaders.AUTH_HEADER);

		this.validate(sessionId, requestURL.toString(), authToken);
	}

	/**
	 * Validates the jwt header for posts requests. Validates the uri, sessionid, secret and body content.
	 *
	 * @param request
	 * 		Spark request
	 *
	 * @throws UnauthorizedException
	 * 		in case of unauthorized request
	 */
	public void validatePostData(Request request) throws UnauthorizedException {
		String sessionId = request.headers(PadesServerHeaders.SESSION_ID_HEADER);
		HttpServletRequest servletRequest = request.raw();
		StringBuffer requestURL = servletRequest.getRequestURL();
		String authToken = request.headers(PadesServerHeaders.AUTH_HEADER_POST);
		byte[] body = request.bodyAsBytes();
		log.trace("Auth uri [{}]", requestURL.toString());
		log.trace("Auth sessionId [{}]", sessionId);
		this.validate(sessionId, requestURL.toString(), body, authToken);
	}

	public void validate(Request request, Path path) throws UnauthorizedException {
		String sessionId = request.headers(PadesServerHeaders.SESSION_ID_HEADER);
		HttpServletRequest servletRequest = request.raw();
		StringBuffer stringBuffer = servletRequest.getRequestURL();
		String authToken = request.headers(PadesServerHeaders.AUTH_HEADER_POST);

		this.validate(sessionId, stringBuffer.toString(), path, authToken);
	}

	public void validate(String sessionId, String uri, String token) throws UnauthorizedException {
		AuthValidationModel authValidationModel = new AuthValidationModel(sessionId, uri);
		this.validate(authValidationModel, token);
	}

	public <M> void validate(String sessionId, String uri, M model, String token) throws UnauthorizedException {
		AuthValidationModel<M> authValidationModel = new AuthValidationModel<>(sessionId, uri, model);
		this.validate(authValidationModel, token);
	}

	public void validate(String sessionId, String uri, Path path, String token) throws UnauthorizedException {
		AuthValidationModel authValidationModel = new AuthValidationModel(sessionId, uri, path);
		this.validate(authValidationModel, token);
	}

	public void validate(AuthValidationModel authValidationModel, String token) throws UnauthorizedException {
		Pair<AuthModel, byte[]> authModelPair = decodeAuth(token);
		AuthModel authModel = authModelPair.getLeft();
		byte[] hmac = authModelPair.getRight();
		String hashType = authModel.getHash().hash();
		long time = Long.valueOf(authModel.getBirthTime());

		if (!validate(authValidationModel, hashType, time, hmac)) {
			throw new UnauthorizedException("Unauthorized");
		}
	}

	private Pair<AuthModel, byte[]> decodeAuth(String token) throws UnauthorizedException {
		Objects.requireNonNull(token, "no token");
		log.trace("Token arrive {}", token);

		String[] splitted = token.split("\\.");
		if (splitted.length != 2) {
			log.error("Illegal desklet token {}", token);
			throw new UnauthorizedException("Auth token invalid");
		}

		log.trace("Token data [{}] and token [{}]", splitted[0], splitted[1]);

		String dataToValidate = new String(Base64.getDecoder().decode(splitted[0]), StandardCharsets.UTF_8);
		byte[] signature = Base64.getDecoder().decode(splitted[1]);

		return Pair.of(this.gsonService.fromJson(dataToValidate, AuthModel.class), signature);
	}

	public boolean validate(AuthValidationModel authValidationModel, String hashType, long time, byte[] hmac) throws UnauthorizedException {
		MessageDigest messageDigest = null;
		try {
			messageDigest = MessageDigest.getInstance("SHA1");
		} catch (NoSuchAlgorithmException e) {
			throw new UnauthorizedException("Cannot instantiate MessageDigest");
		}

		digest(messageDigest, authValidationModel.getSessionId());
		digest(messageDigest, authValidationModel.getUri());

		if (authValidationModel.getPath() != null) {
			digest(messageDigest, authValidationModel.getPath());
		}

		if (authValidationModel.getModel() != null) {
			if (authValidationModel.getModel() instanceof byte[]) {
				messageDigest.update((byte[]) authValidationModel.getModel());
			} else if (authValidationModel.getModel() instanceof String) {
				digest(messageDigest, (String) authValidationModel.getModel());
			} else {
				digest(messageDigest, authValidationModel.getModel());
			}
		}

		String realHash = Hex.encodeHexString(messageDigest.digest());

		if (verify(realHash, hashType, time, hmac)) {
			return !isOutOfDate(time);
		}
		return false;
	}

	public void digest(MessageDigest messageDigest, URL uri) {
		digest(messageDigest, uri.toExternalForm());
	}

	public void digest(MessageDigest messageDigest, Path path) throws UnauthorizedException {
		try (InputStream inputStream = Files.newInputStream(path, StandardOpenOption.READ)) {
			String pathHash = DigestUtils.sha1Hex(inputStream);
			log.debug("File hash [{}]", pathHash);
			digest(messageDigest, pathHash);
		} catch (IOException e) {
			throw new UnauthorizedException("Cannot find local file to validate");
		}
	}

	public void digest(MessageDigest messageDigest, String string) {
		log.trace("Digest string {}", string.getBytes(StandardCharsets.UTF_8));
		messageDigest.update(string.getBytes(StandardCharsets.UTF_8));
	}

	public void digest(MessageDigest messageDigest, Object o) {
		digest(messageDigest, GsonService.instance().toJson(o));
	}

	private boolean isOutOfDate(long bithTime) {
		return (bithTime + maxLife) <= new Date().getTime();
	}

	private boolean verify(String realHash, String hashType, long time, byte[] hmac) {

		log.trace("Auth calc hash [{}]", realHash);
		log.trace("Auth time [{}]", time);
		log.trace("Hash type [{}]", hashType);

		try (ByteArrayOutputStream outputStream = new ByteArrayOutputStream()) {
			outputStream.write(String.valueOf(time).getBytes(StandardCharsets.UTF_8));
			outputStream.write(hashType.getBytes(StandardCharsets.UTF_8));
			outputStream.write(realHash.getBytes(StandardCharsets.UTF_8));
			String mac = HmacUtils.hmacSha256Hex(superSecret.getBytes(StandardCharsets.UTF_8), outputStream.toByteArray());

			log.trace("Real mac {} calculated mac {}", new String(hmac, StandardCharsets.UTF_8), mac);

			return mac.equals(new String(hmac, StandardCharsets.UTF_8));
		} catch (IOException e) {
			throw new RuntimeException(e);
		}
	}
}
