package com.finconsgroup.itserr.marketplace.institutionalpage.dm.util;

import jakarta.persistence.Column;
import jakarta.persistence.JoinColumn;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.lang.NonNull;

import java.lang.reflect.Field;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;

/**
 * Utility class for Spring Data related functions
 */
public class SpringDataUtils {

    private static final Map<String, String> COLUMN_NAME_BY_ENTITY_FIELD = new ConcurrentHashMap<>();

    private SpringDataUtils() {
        throw new UnsupportedOperationException("JpaUtils cannot be instantiated");
    }

    /**
     * Maps the sort using column name which is needed for repository methods that use native queries.
     *
     * @param pageable       the pageable to be mapped
     * @param entityClass    the class representing the entity to map field name to column name
     * @param failNonColumns the flag to indicate if it should fail when field name cannot be mapped to column name
     * @return the {@link Pageable} with sort mapped to column name instead of field name
     */
    public static Pageable mapSortToColumnName(Pageable pageable, Class<?> entityClass, boolean failNonColumns) {
        if (pageable.getSort().isUnsorted()) {
            return pageable;
        }

        Sort sort = mapSortToColumnName(entityClass, pageable.getSort(), failNonColumns);
        Pageable mappedPageable;
        if (pageable.isUnpaged()) {
            mappedPageable = Pageable.unpaged(sort);
        } else {
            mappedPageable = PageRequest.of(pageable.getPageNumber(), pageable.getPageSize(), sort);
        }
        return mappedPageable;
    }

    /*
     * Get the database column name for a given entity field.
     * If @Column is not present, defaults to the field name.
     */
    private static String getColumnName(Class<?> entityClass, String fieldName, boolean failNonColumns) {
        String key = "%s_%s".formatted(entityClass.getName(), fieldName);
        String columnName = COLUMN_NAME_BY_ENTITY_FIELD.get(key);
        if (columnName != null) {
            return columnName;
        }

        try {
            Field field = getDeclaredField(entityClass, fieldName);
            Column columnAnnotation = field.getAnnotation(Column.class);
            JoinColumn joinColumnAnnotation = field.getAnnotation(JoinColumn.class);

            if (columnAnnotation != null && !columnAnnotation.name().isBlank()) {
                columnName = columnAnnotation.name();
            } else if (joinColumnAnnotation != null && !joinColumnAnnotation.name().isBlank()) {
                columnName = joinColumnAnnotation.name();
            } else if (failNonColumns) {
                throw new IllegalArgumentException(
                        "Field '" + fieldName + "' does not have supported annotation in " + entityClass.getSimpleName()
                );
            } else {
                // Default JPA behavior: use field name as column name
                columnName = field.getName();
            }
            // add to cache only if valid column name was found
            if (failNonColumns) {
                COLUMN_NAME_BY_ENTITY_FIELD.put(key, columnName);
            }
            return columnName;
        } catch (NoSuchFieldException e) {
            throw new IllegalArgumentException(
                    "Field '" + fieldName + "' not found in " + entityClass.getSimpleName(), e
            );
        }
    }

    /*
     * Maps the sort property name to the column name that can be used to sort the entities using native queries and
     * returns the new sort.
     *
     * @param entityClass    the type of entity
     * @param sort           the original sort
     * @param failNonColumns the flag to indicate if it should fail when field name cannot be mapped to column name
     * @return Sort with property name mapped to column name
     */
    @NonNull
    private static Sort mapSortToColumnName(@NonNull Class<?> entityClass, @NonNull Sort sort, boolean failNonColumns) {
        final AtomicReference<Sort> sortResult = new AtomicReference<>(Sort.unsorted());

        for (Sort.Order order : sort) {
            Sort newSort = Sort.by(order.getDirection(), getColumnName(entityClass, order.getProperty(), failNonColumns));
            sortResult.getAndAccumulate(newSort, Sort::and);
        }
        return sortResult.get();
    }

    /*
     * Finds the field with given name in the provided class or its parent classes.
     *
     * @param clazz the class to find field for
     * @param fieldName the field to find
     * @return the field found
     * @throws NoSuchFieldException in case if field could not be found
     */
    @NonNull
    private static Field getDeclaredField(@NonNull Class<?> clazz, @NonNull String fieldName) throws NoSuchFieldException {
        Class<?> current = clazz;
        while (current != null) {
            try {
                return current.getDeclaredField(fieldName);
            } catch (NoSuchFieldException e) {
                // Move to the parent class
                current = current.getSuperclass();
            }
        }
        throw new NoSuchFieldException("Field '" + fieldName + "' not found in class hierarchy for " + clazz.getSimpleName());
    }

}
