package org.gcube.common.keycloak.model;

import java.util.Arrays;
import java.util.Base64;
import java.util.List;

import org.gcube.com.fasterxml.jackson.annotation.JsonInclude.Include;
import org.gcube.com.fasterxml.jackson.core.JsonProcessingException;
import org.gcube.com.fasterxml.jackson.databind.ObjectMapper;
import org.gcube.com.fasterxml.jackson.databind.ObjectWriter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ModelUtils {

    protected static final Logger logger = LoggerFactory.getLogger(ModelUtils.class);

    private static final String ACCOUNT_AUDIENCE_RESOURCE = "account";

    private static final ObjectMapper mapper = new ObjectMapper();

    static {
        mapper.setSerializationInclusion(Include.NON_NULL);
    }

    public static String toJSONString(Object object) {
        return toJSONString(object, false);
    }

    public static String toJSONString(Object object, boolean prettyPrint) {
        ObjectWriter writer = prettyPrint ? mapper.writerWithDefaultPrettyPrinter() : mapper.writer();
        try {
            return writer.writeValueAsString(object);
        } catch (JsonProcessingException e) {
            logger.error("Cannot pretty print object", e);
            return null;
        }
    }

    public static String getAccessTokenPayloadJSONStringFrom(TokenResponse tokenResponse) throws Exception {
        return getAccessTokenPayloadJSONStringFrom(tokenResponse, true);
    }

    public static String getAccessTokenPayloadJSONStringFrom(TokenResponse tokenResponse, boolean prettyPrint)
            throws Exception {
        return toJSONString(getAccessTokenFrom(tokenResponse, Object.class), prettyPrint);
    }

    public static AccessToken getAccessTokenFrom(TokenResponse tokenResponse) throws Exception {
        return getAccessTokenFrom(tokenResponse, AccessToken.class);
    }

    public static AccessToken getAccessTokenFrom(String authorizationHeaderOrBase64EncodedJWT) throws Exception {
        return getAccessTokenFrom(authorizationHeaderOrBase64EncodedJWT.matches("[b|B]earer ")
                ? authorizationHeaderOrBase64EncodedJWT.substring("bearer ".length())
                : authorizationHeaderOrBase64EncodedJWT, AccessToken.class);
    }

    private static <T> T getAccessTokenFrom(TokenResponse tokenResponse, Class<T> clazz) throws Exception {
        return getAccessTokenFrom(tokenResponse.getAccessToken(), clazz);
    }

    private static <T> T getAccessTokenFrom(String accessToken, Class<T> clazz) throws Exception {
        return mapper.readValue(getDecodedPayload(accessToken), clazz);
    }

    public static String getRefreshTokenPayloadStringFrom(TokenResponse tokenResponse) throws Exception {
        return getRefreshTokenPayloadStringFrom(tokenResponse, true);
    }

    public static String getRefreshTokenPayloadStringFrom(TokenResponse tokenResponse, boolean prettyPrint)
            throws Exception {

        return toJSONString(getRefreshTokenFrom(tokenResponse, Object.class), prettyPrint);
    }

    public static RefreshToken getRefreshTokenFrom(TokenResponse tokenResponse) throws Exception {
        return getRefreshTokenFrom(tokenResponse.getRefreshToken());
    }

    public static RefreshToken getRefreshTokenFrom(String base64EncodedJWT) throws Exception {
        return mapper.readValue(getDecodedPayload(base64EncodedJWT), RefreshToken.class);
    }

    private static <T> T getRefreshTokenFrom(TokenResponse tokenResponse, Class<T> clazz) throws Exception {
        return mapper.readValue(getDecodedPayload(tokenResponse.getRefreshToken()), clazz);
    }

    protected static byte[] getBase64Decoded(String string) {
        return Base64.getDecoder().decode(string);
    }

    protected static String splitAndGet(String encodedJWT, int index) {
        String[] split = encodedJWT.split("\\.");
        if (split.length == 3) {
            return split[index];
        } else {
            return null;
        }
    }

    public static byte[] getDecodedHeader(String value) {
        return getBase64Decoded(getEncodedHeader(value));
    }

    public static String getEncodedHeader(String encodedJWT) {
        return splitAndGet(encodedJWT, 0);
    }

    public static byte[] getDecodedPayload(String value) {
        return getBase64Decoded(getEncodedPayload(value));
    }

    public static String getEncodedPayload(String encodedJWT) {
        return splitAndGet(encodedJWT, 1);
    }

    public static byte[] getDecodedSignature(String value) {
        return getBase64Decoded(getEncodedSignature(value));
    }

    public static String getEncodedSignature(String encodedJWT) {
        return splitAndGet(encodedJWT, 2);
    }

    public static String getClientIdFromToken(AccessToken accessToken) {
        String clientId;
        logger.debug("Client id not provided, using authorized party field (azp)");
        clientId = accessToken.getIssuedFor();
        if (clientId == null) {
            logger.warn("Issued for field (azp) not present, getting first of the audience field (aud)");
            clientId = getFirstAudienceNoAccount(accessToken);
        }
        return clientId;
    }

    private static String getFirstAudienceNoAccount(AccessToken accessToken) {
        // Trying to get it from the token's audience ('aud' field), getting the first except the 'account'
        List<String> tokenAud = Arrays.asList(accessToken.getAudience());
        tokenAud.remove(ACCOUNT_AUDIENCE_RESOURCE);
        if (tokenAud.size() > 0) {
            return tokenAud.iterator().next();
        } else {
            // Setting it to empty string to avoid NPE in encoding
            return "";
        }
    }
}
