package com.finconsgroup.itserr.marketplace.usercommunication.dm.security;

import com.finconsgroup.itserr.marketplace.core.web.exception.WP2AuthenticationException;
import com.finconsgroup.itserr.marketplace.core.web.security.jwt.JwtTokenHolder;
import com.finconsgroup.itserr.marketplace.core.web.security.jwt.JwtTokenVerifier;
import io.micrometer.common.util.StringUtils;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.lang.NonNull;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;

import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;

/**
 * Interceptor to retrieve and verify the Jwt Token for authentication of the user.
 */
@RequiredArgsConstructor
@Slf4j
public final class AuthenticationChannelInterceptor implements ChannelInterceptor {

    private final String securityAudience;
    private final JwtTokenVerifier jwtTokenVerifier;

    @Override
    public Message<?> preSend(@NonNull Message<?> message, @NonNull MessageChannel channel) {
        StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
        // only the connect message should contain the authentication header
        if (accessor != null && StompCommand.CONNECT.equals(accessor.getCommand())) {
            String token = accessor.getFirstNativeHeader("token");
            WebSocketAuthentication authentication = mapTokenToAuthentication(token);
            // remove token from header to avoid logging
            accessor.removeNativeHeader("token");
            if (authentication != null) {
                accessor.setUser(authentication);
                return message;
            } else {
                throw new WP2AuthenticationException();
            }
        }

        return message;
    }

    private WebSocketAuthentication mapTokenToAuthentication(String token) {
        try {
            if (StringUtils.isNotBlank(token)) {
                JwtTokenHolder.setToken(token, jwtTokenVerifier);
                Optional<UUID> userId = JwtTokenHolder.getUserId();
                if (userId.isPresent()) {
                    log.info("user id - {}", userId.get());
                    WebSocketUser user = WebSocketUser.builder()
                            .userId(userId.get())
                            .displayName(JwtTokenHolder.getNameOrThrow())
                            .firstName(JwtTokenHolder.getGivenNameOrThrow())
                            .lastName(JwtTokenHolder.getFamilyNameOrThrow())
                            .build();
                    Set<GrantedAuthority> authorities = JwtTokenHolder.getRoles(securityAudience).stream()
                            .map(SimpleGrantedAuthority::new).collect(Collectors.toSet());
                    return new WebSocketAuthentication(user, authorities);
                }
            }
            return null;
        } finally {
            // unset the token after details extraction
            JwtTokenHolder.setToken(null);
        }
    }
}
