package com.finconsgroup.itserr.marketplace.usercommunication.dm.config;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.finconsgroup.itserr.marketplace.core.web.exception.WP2AuthenticationException;
import com.finconsgroup.itserr.marketplace.core.web.exception.WP2AuthorizationException;
import com.finconsgroup.itserr.marketplace.usercommunication.dm.exception.MessagingErrorResponseDto;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.springframework.lang.NonNull;
import org.springframework.messaging.Message;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.web.socket.messaging.StompSubProtocolErrorHandler;

import java.nio.charset.StandardCharsets;
import java.util.List;

import static com.finconsgroup.itserr.marketplace.usercommunication.dm.constant.ErrorConstants.ACCESS_DENIED_CODE;
import static com.finconsgroup.itserr.marketplace.usercommunication.dm.constant.ErrorConstants.ACCESS_DENIED_MESSAGE;
import static com.finconsgroup.itserr.marketplace.usercommunication.dm.constant.ErrorConstants.INTERNAL_SERVER_ERROR;
import static com.finconsgroup.itserr.marketplace.usercommunication.dm.constant.ErrorConstants.SERVER_ERROR_CODE;
import static com.finconsgroup.itserr.marketplace.usercommunication.dm.constant.ErrorConstants.UNAUTHENTICATED_CODE;
import static com.finconsgroup.itserr.marketplace.usercommunication.dm.constant.ErrorConstants.UNAUTHENTICATED_MESSAGE;

/**
 * Error handler for Stomp errors.
 * It returns a suitable message for authorization related exceptions,
 * which otherwise tend to return a generic Failed to send to ExecutorSubscribableChannel message
 */
@Slf4j
public class WebSocketStompErrorHandler extends StompSubProtocolErrorHandler {

    private final ObjectMapper objectMapper;

    WebSocketStompErrorHandler(ObjectMapper objectMapper) {
        super();
        this.objectMapper = objectMapper;
    }

    @NonNull
    @Override
    protected Message<byte[]> handleInternal(@NonNull StompHeaderAccessor errorHeaderAccessor,
                                             @NonNull byte[] errorPayload, Throwable cause,
                                             StompHeaderAccessor clientHeaderAccessor) {
        byte[] payload;
        if (cause != null) {
            Throwable rootCause = ExceptionUtils.getRootCause(cause);
            if (rootCause instanceof WP2AuthenticationException) {
                payload = mapToErrorResponse(UNAUTHENTICATED_CODE, UNAUTHENTICATED_MESSAGE, cause, clientHeaderAccessor);
            } else if (rootCause instanceof AccessDeniedException || rootCause instanceof WP2AuthorizationException) {
                payload = mapToErrorResponse(ACCESS_DENIED_CODE, ACCESS_DENIED_MESSAGE, cause, clientHeaderAccessor);
            } else {
                payload = mapToErrorResponse(SERVER_ERROR_CODE, INTERNAL_SERVER_ERROR, cause, clientHeaderAccessor);
            }
        } else {
            payload = mapToErrorResponse(SERVER_ERROR_CODE, INTERNAL_SERVER_ERROR, null, clientHeaderAccessor);
        }

        if (errorPayload.length > 0) {
            log.error("Override error payload {} with {}", errorPayload, payload);
        }

        return MessageBuilder.createMessage(payload, errorHeaderAccessor.getMessageHeaders());
    }

    private byte[] mapToErrorResponse(int code, String message, Throwable cause, StompHeaderAccessor clientHeaderAccessor) {
        try {
            String destination = clientHeaderAccessor != null ? clientHeaderAccessor.getDestination() : null;
            String requestMessageId = clientHeaderAccessor  != null ? clientHeaderAccessor.getFirstNativeHeader("requestMessageId") : null;
            // also return stack trace if debug level logs are enabled for this class
            List<String> stackTrace = log.isDebugEnabled() ? ExceptionUtils.getRootCauseStackTraceList(cause) : null;
            var errorDto = new MessagingErrorResponseDto(code, List.of(message), destination, requestMessageId, stackTrace);
            return objectMapper.writeValueAsBytes(errorDto);
        } catch (Exception e) {
            log.error("Error converting message to error response", e);
            // fallback to construct the json string
            return ("{\"code\": " + code +
                    ", \"messages\": [\"" + message + "\"] }" +
                    ", \"destination\": \"" + clientHeaderAccessor + "\"}")
                    .getBytes(StandardCharsets.UTF_8);
        }
    }
}
