package org.gcube.keycloak.protocol.oidc;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import org.apache.http.HttpEntity;
import org.apache.http.HttpHeaders;
import org.apache.http.NameValuePair;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.entity.UrlEncodedFormEntity;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.message.BasicNameValuePair;
import org.apache.http.util.EntityUtils;
import org.gcube.keycloak.protocol.oidc.tip.TIPConfiguration;
import org.gcube.keycloak.protocol.oidc.tip.TIPConfiguration.Config.TipConfig.RemoteIssuer;
import org.gcube.keycloak.protocol.oidc.tip.TIPConfiguration.Config.TipConfig.RemoteIssuer.ClaimMapping;
import org.gcube.keycloak.protocol.oidc.tip.TIPConfigurationException;
import org.jboss.logging.Logger;
import org.keycloak.Config.Scope;
import org.keycloak.events.Details;
import org.keycloak.events.Errors;
import org.keycloak.events.EventBuilder;
import org.keycloak.jose.jws.JWSInput;
import org.keycloak.jose.jws.JWSInputException;
import org.keycloak.models.KeycloakSession;
import org.keycloak.protocol.oidc.AccessTokenIntrospectionProvider;
import org.keycloak.protocol.oidc.TokenIntrospectionProvider;
import org.keycloak.representations.AccessToken;
import org.keycloak.services.Urls;
import org.keycloak.util.BasicAuthHelper;
import org.keycloak.util.JsonSerialization;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.fasterxml.jackson.databind.node.TextNode;

import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;

/**
 * @author <a href="mailto:mauro.mugnaini@nubisware.com">Mauro Mugnaini</a>
 */
public class EOSCNodeAccessTokenIntrospectionProvider<T extends AccessToken> extends AccessTokenIntrospectionProvider<T>
        implements TokenIntrospectionProvider {

    private static final int SOCKET_TIMEOUT = 10000;
    private static final int CONNECTION_TIMEOUT = 5000;

    private static final Logger logger = Logger.getLogger(EOSCNodeAccessTokenIntrospectionProvider.class);

    protected static String serverIssuer;

    protected final Configuration eoscNodeConfiguration;

    public EOSCNodeAccessTokenIntrospectionProvider(KeycloakSession session,
            Configuration eoscNodeConfiguration) {

        super(session);

        // Computing server issuer URL only once
        if (serverIssuer == null) {
            serverIssuer = Urls.realmIssuer(session.getContext().getUri().getBaseUri(),
                    session.getContext().getRealm().getName());

            logger.infof("Keycloak server instance 'issuer' is: %s", serverIssuer);
        }
        this.eoscNodeConfiguration = eoscNodeConfiguration;
    }

    @Override
    public void close() {
        super.close();
    }

    @Override
    public Response introspect(String accessTokenString, EventBuilder eventBuilder) {
        if (!eoscNodeConfiguration.isConfigured()) {
            logger.trace("Provider is not configured, continue with default introspection");
            return super.introspect(accessTokenString, eventBuilder);
        }
        logger.trace("Getting the issuer from the access token");
        String tokenIssuer = getAccessTokenIssuer(accessTokenString); // may be null
        logger.tracef("Access token issued by: %s", tokenIssuer);
        if (serverIssuer.equals(tokenIssuer)) {
            logger.debug("Token is issued by this server, continue introspection with superclass' method");
            return super.introspect(accessTokenString, eventBuilder);
        } else {
            // Is not issued by this server or is null
            logger.debug("Token is NOT issued by this server, continue introspection on remote node/MyAccess");
            return performRemoteNodeIntrospection(tokenIssuer, accessTokenString);
        }
    }

    protected String getAccessTokenIssuer(String accessTokenString) {
        try {
            logger.debug("Deserializing the recevide access token and getting issuer");
            return new JWSInput(accessTokenString).readJsonContent(AccessToken.class).getIssuer();
        } catch (JWSInputException e) {
            logger.debug("Can't deserialize access token from string", e);
            eventBuilder.detail(Details.REASON,
                    "Can't deserialize access token from string. Reason: " + e.getMessage());

            eventBuilder.error(Errors.TOKEN_INTROSPECTION_FAILED);
            throw new RuntimeException("Error parsing access token string to get the issuer", e);
        }
    }

    protected Response performRemoteNodeIntrospection(String issuerURL, String accessTokenString) {
        RemoteIssuer remoteIssuer = eoscNodeConfiguration.findRemoteIssuer(issuerURL);
        if (remoteIssuer != null && remoteIssuer.getIntrospection_endpoint() != null) {
            String introspectionEndpointURL = remoteIssuer.getIntrospection_endpoint(); //eoscNodeConfiguration.getIntrospectionURL(issuerURL);
            logger.debugf("Remote endpoint introspection URL is: %s", introspectionEndpointURL);
            logger.trace("Creating and perfom POST with the HTTP client");
            try (CloseableHttpClient httpClient = HttpClients.createDefault()) {
                logger.trace("Starting HTTP POST creation");

                HttpPost httpPost = new HttpPost(introspectionEndpointURL);

                logger.trace("Setting the Authorization header");
                httpPost.setHeader(HttpHeaders.AUTHORIZATION,
                        BasicAuthHelper.createHeader(remoteIssuer.getClient_id(), remoteIssuer.getClient_secret()));

                logger.tracef("Setting the token parameter: %s", accessTokenString);
                List<NameValuePair> form = new ArrayList<>();
                form.add(new BasicNameValuePair("token", accessTokenString));
                UrlEncodedFormEntity entity = new UrlEncodedFormEntity(form, StandardCharsets.UTF_8);
                httpPost.setEntity(entity);

                logger.tracef("Setting connection timeout to %d millis", CONNECTION_TIMEOUT);
                logger.tracef("Setting socket timeout to %d millis", SOCKET_TIMEOUT);
                RequestConfig requestConfig = RequestConfig.custom()
                        .setConnectTimeout(CONNECTION_TIMEOUT)
                        .setSocketTimeout(SOCKET_TIMEOUT)
                        .build();

                httpPost.setConfig(requestConfig);

                logger.debug("Performing the request...");
                try (CloseableHttpResponse response = httpClient.execute(httpPost)) {
                    int statusCode = response.getStatusLine().getStatusCode();
                    logger.debugf("Resulting status code is %d", statusCode);

                    logger.trace("Getting the response entity");
                    HttpEntity responseEntity = response.getEntity();

                    logger.trace("Reading the response body");
                    String responseBodyString = (responseEntity != null)
                            ? EntityUtils.toString(responseEntity, StandardCharsets.UTF_8)
                            : "";

                    logger.debugf("Resulting body is: %s", responseBodyString);
                    if (statusCode >= 200 && statusCode < 300) {
                        ObjectNode responseObjectNode = JsonSerialization.readValue(responseBodyString,
                                ObjectNode.class);

                        dropClaims(remoteIssuer, responseObjectNode);
                        renameClaims(remoteIssuer, responseObjectNode);
                        mapClaims(remoteIssuer, responseObjectNode);

                        logger.tracef("Serializing response JSON object node as string");
                        byte[] responseBytes = JsonSerialization
                                .writeValueAsBytes(responseObjectNode);

                        if (logger.isTraceEnabled()) {
                            logger.tracef("Resulting response JSON string is: %s", JsonSerialization
                                    .writeValueAsPrettyString(responseObjectNode));
                        }

                        // TODO set info on eventBuilder if it makes sense to trace them for success external introspections
                        return Response.ok(responseBytes).type(MediaType.APPLICATION_JSON_TYPE).build();
                    } else {
                        logger.errorf("MyAccess server response is not OK [%d]: %s", statusCode, responseBodyString);
                        eventBuilder.detail(Details.REASON, responseBodyString);
                        eventBuilder.error(Errors.TOKEN_INTROSPECTION_FAILED);
                        throw new RuntimeException("Error calling MyAccess for token introspection");
                    }
                }
            } catch (IOException e) {
                logger.error("An errord occurred performing the token introspection on MyAccess", e);
                eventBuilder.detail(Details.REASON, e.getMessage());
                eventBuilder.error(Errors.TOKEN_INTROSPECTION_FAILED);
                throw new RuntimeException("Error performing the token introspection on MyAccess", e);
            }
        } else {
            logger.warnf("Cannot perform remote node introspection since the endpoint URL is unknown for issuer: %s",
                    issuerURL);

            return Response.ok(Configuration.DEFAULT_RESPONSE_ON_UNKNOWN_ISSUER_ENDPOINT)
                    .type(MediaType.APPLICATION_JSON_TYPE)
                    .build();
        }
    }

    private void dropClaims(RemoteIssuer remoteIssuer, ObjectNode introspectionResponse) {
        List<String> dropCalims = remoteIssuer.getDrop_claims();

        if (!dropCalims.isEmpty()) {
            logger.debug("Performing the claim dropping...");
            for (String claimName : dropCalims) {
                if (introspectionResponse.has(claimName)) {
                    logger.tracef("Dropping '%s' claim", claimName);
                    introspectionResponse.remove(claimName);
                } else {
                    logger.tracef("Claim not found in token: %s", claimName);
                }
            }
        } else {
            logger.debug("Drop claims is not configured");
        }
    }

    protected void renameClaims(RemoteIssuer remoteIssuer, ObjectNode introspectionResponse) {
        Map<String, String> claimRenamings = remoteIssuer.getClaim_renaming();

        if (!claimRenamings.isEmpty()) {
            logger.debug("Performing the claim renaming...");
            for (String originalClaimName : claimRenamings.keySet()) {
                String newClaimName = claimRenamings.get(originalClaimName);
                if (introspectionResponse.has(originalClaimName)) {
                    logger.tracef("Renaming the '%s' claim to '%s'", originalClaimName, newClaimName);
                    introspectionResponse.set(newClaimName,
                            introspectionResponse.get(originalClaimName));
                    introspectionResponse.remove(originalClaimName);
                } else {
                    logger.tracef("Claim not found in token: %s", originalClaimName);
                }
            }
        } else {
            logger.debug("Claim renaming is not configured");
        }
    }

    protected void mapClaims(RemoteIssuer remoteIssuer, ObjectNode introspectionResponse) {
        ClaimMapping claimMapping = remoteIssuer.getClaim_mapping();
        if (claimMapping != null) {
            Map<String, Map<String, String>> stringMapping = claimMapping.getStrings();
            if (stringMapping != null && !stringMapping.isEmpty()) {
                logger.debug("Performing the string claim mapping...");
                for (String claimName : stringMapping.keySet()) {
                    if (introspectionResponse.has(claimName)) {
                        logger.debugf("Mapping '%s' claim values...", claimName);
                        JsonNode claimNode = introspectionResponse.get(claimName);
                        if (claimNode.isTextual()) {
                            String claimValue = claimNode.asText();
                            if (stringMapping.get(claimName).containsKey(claimValue)) {
                                String newValue = stringMapping.get(claimName).get(claimValue);
                                logger.tracef("Mapping '%s' to '%s'", claimValue, newValue);
                                introspectionResponse.set(claimName,
                                        introspectionResponse.textNode(newValue));
                            }
                        } else {
                            logger.debug("Claim node is not a textual node");
                        }
                    } else {
                        logger.tracef("Claim not found in token: %s", claimName);
                    }
                }
            }

            Map<String, Map<String, List<String>>> stringArraysMapping = claimMapping
                    .getString_arrays();

            if (stringArraysMapping != null && !stringArraysMapping.isEmpty()) {
                logger.debug("Performing the string arrays claim mapping...");
                for (String claimName : stringArraysMapping.keySet()) {
                    if (introspectionResponse.has(claimName)) {
                        logger.debugf("Mapping '%s' claim array values...", claimName);
                        JsonNode claimNode = introspectionResponse.get(claimName);
                        if (claimNode.isArray() && claimNode.size() > 0 && claimNode.get(0).isTextual()) {
                            ArrayNode claimArrayNode = (ArrayNode) claimNode;
                            ArrayNode newClaimArrayNode = introspectionResponse.arrayNode();
                            for (int i = 0; i < claimArrayNode.size(); i++) {
                                TextNode claimValueAtI = (TextNode) claimArrayNode.get(i);
                                String claimValue = claimArrayNode.get(i).asText();
                                if (stringArraysMapping.get(claimName).containsKey(claimValue)) {
                                    List<String> newValues = stringArraysMapping.get(claimName).get(claimValue);
                                    logger.tracef("Mapping '%s' to '%s'", claimValue, newValues);
                                    for (String newValue : newValues) {
                                        newClaimArrayNode.add(introspectionResponse.textNode(newValue));
                                    }
                                } else {
                                    logger.tracef("Calim value not found in array: ", claimValue);
                                    newClaimArrayNode.add(claimValueAtI);
                                }
                            }
                        } else {
                            logger.debug("Claim node is not an array of strings or is empty");
                        }
                    } else {
                        logger.tracef("Claim not found in token: %s", claimName);
                    }
                }
            }
        } else {
            logger.debug("Claim mapping is not set");
        }
    }

    public static class Configuration {

        public static final byte[] DEFAULT_RESPONSE_ON_UNKNOWN_ISSUER_ENDPOINT = "{\"active\": false}".getBytes();

        private TIPConfiguration tipConfiguration;

        public Configuration(Scope config) {
            this.tipConfiguration = null;
            String yamlConfigFile = config.get("yaml-config-file");
            if (yamlConfigFile != null) {
                try {
                    this.tipConfiguration = TIPConfiguration.loadFromYAML(yamlConfigFile);
                } catch (TIPConfigurationException e) {
                    logger.warn("Cannot load TIP config file from: " + yamlConfigFile, e);
                }
            } else {
                logger.info("YAML file configuration not provided");
            }
        }

        public boolean isConfigured() {
            return this.tipConfiguration != null;
        }

        private RemoteIssuer findRemoteIssuer(String issuerURL) {
            if (tipConfiguration != null) {
                if (issuerURL != null) {
                    // Returning the found issuer, the fallback issuer if is set, otherwise null
                    return tipConfiguration.getTip().getRemote_issuers().stream()
                            .filter(remoteIssuer -> issuerURL.equals(remoteIssuer.getIssuer_url())).findFirst()
                            .orElse(tipConfiguration.getTip().getFallback_issuer_unknown_token_issuer());
                } else {
                    return tipConfiguration.getTip().getFallback_issuer_unsupported_token_issuer();
                }
            } else {
                return null;
            }
        }

    }

}
