package com.finconsgroup.itserr.marketplace.usercommunication.dm.service.impl;

import com.finconsgroup.itserr.marketplace.usercommunication.dm.entity.User;
import com.finconsgroup.itserr.marketplace.usercommunication.dm.repository.UserRepository;
import com.finconsgroup.itserr.marketplace.usercommunication.dm.security.WebSocketUser;
import com.finconsgroup.itserr.marketplace.usercommunication.dm.service.SessionManagementService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import org.springframework.messaging.simp.user.SimpUser;
import org.springframework.messaging.simp.user.SimpUserRegistry;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Propagation;
import org.springframework.transaction.annotation.Transactional;

import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;

/**
 * The external broker implementation for {@link SessionManagementService}
 */
@ConditionalOnProperty(prefix = "user-communication.dm.web-socket", value = "enable-broker-relay", havingValue = "true")
@Service
@Slf4j
public class ExternalBrokerSessionManagementService extends AbstractSessionManagementService {
    private final SimpUserRegistry userRegistry;

    public ExternalBrokerSessionManagementService(final SimpUserRegistry userRegistry, final UserRepository userRepository) {
        super(userRepository);
        this.userRegistry = userRegistry;
    }

    // Track session to user mapping
    private final Map<String, WebSocketUser> userBySessionId = new ConcurrentHashMap<>();

    @Transactional(propagation = Propagation.REQUIRED, rollbackFor = Exception.class)
    @Override
    public User onSessionCreated(@NonNull String sessionId, @NonNull WebSocketUser user) {
        UUID userId = user.getUserId();

        // Track this session for the user
        userBySessionId.put(sessionId, user);

        int otherSessionCount = getSessionCount(userId, sessionId);
        boolean isFirstSession = otherSessionCount == 0;
        // If this is the user's first session, mark them as online
        if (isFirstSession) {
            log.info("User {} came online with session {}", userId, sessionId);
            return updateUserOnlineStatus(user, true);
        } else {
            log.info("User {} (session: {}, total sessions: {})", userId, sessionId, otherSessionCount + 1);
            return null;
        }
    }

    @Transactional(propagation = Propagation.REQUIRED, rollbackFor = Exception.class)
    @Override
    public User onSessionDestroyed(@NonNull String sessionId) {
        WebSocketUser user = userBySessionId.remove(sessionId);

        if (user != null) {
            UUID userId = user.getUserId();
            int otherSessionCount = getSessionCount(userId, sessionId);
            boolean isLastSession = otherSessionCount == 0;
            if (isLastSession) {
                log.info("User {} went offline (session: {})", userId, sessionId);
                return updateUserOnlineStatus(user, false);
            } else {
                log.info("User {} disconnected session {} (remaining sessions: {})", userId, sessionId,
                        otherSessionCount);
            }
        }
        return null;
    }

    @Override
    public boolean isUserOnline(@NonNull UUID userId) {
        return userRegistry.getUser(userId.toString()) != null;
    }

    @Override
    public int getUserSessionCount(@NonNull UUID userId) {
        return getSessionCount(userId, null);
    }

    private int getSessionCount(@NonNull UUID userId, @Nullable String ignoreSessionId) {
        String sessionUsername = userId.toString();
        int sessionCounter = 0;
        for (SimpUser userRegistryUser : userRegistry.getUsers()) {
            if (sessionUsername.equals(userRegistryUser.getName())) {
                int sessionCount;
                if (ignoreSessionId != null) {
                    sessionCount = (int) userRegistryUser.getSessions().stream().filter(s -> !ignoreSessionId.equals(s.getId())).count();
                } else {
                    sessionCount = userRegistryUser.getSessions().size();
                }
                sessionCounter += sessionCount;
            }
        }
        return sessionCounter;
    }
}
