package eu.dnetlib.dhp.solr;

import com.google.common.base.Splitter;
import com.google.common.cache.*;

import org.apache.http.client.HttpClient;
import org.apache.http.impl.client.DefaultHttpClient;
import org.apache.http.impl.conn.PoolingClientConnectionManager;
import org.apache.http.params.BasicHttpParams;
import org.apache.http.params.DefaultedHttpParams;
import org.apache.http.params.HttpParams;
import org.apache.solr.client.solrj.impl.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.Serializable;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.stream.StreamSupport;

public class CacheCloudSolrClient implements Serializable {

    private static final Logger log = LoggerFactory.getLogger(CacheCloudSolrClient.class);

    private static final CacheLoader<CloudClientParams, CloudSolrClient> loader = new CacheLoader<>() {
        @Override
        public CloudSolrClient load(CloudClientParams params) throws Exception {
            return getCloudSolrClient(params);
        }
    };

    private static final RemovalListener<CloudClientParams, CloudSolrClient> listener = rn -> Optional
            .ofNullable(rn.getValue())
            .ifPresent(client -> {
                try {
                    client.close();
                } catch (IOException e) {
                    throw new RuntimeException(e);
            }});

    private static final LoadingCache<CloudClientParams, CloudSolrClient> cache = CacheBuilder
            .newBuilder()
            .removalListener(listener)
            .build(loader);

    public static void invalidateCachedClient(CloudClientParams params) {
        cache.invalidate(params);
    }

    public static CloudSolrClient getCachedCloudClient(CloudClientParams params) throws ExecutionException {
        return CacheCloudSolrClient.cache.get(params);
    }

    private static CloudSolrClient getCloudSolrClient(CloudClientParams cloudClientParams) {
        if (cloudClientParams == null) {
            throw new IllegalArgumentException("CloudClientParams cannot be null");
        }

        if (cloudClientParams.getZkHost() == null || cloudClientParams.getZkHost().isEmpty()) {
            throw new IllegalArgumentException("Zookeeper host cannot be null or empty");
        }

        if (cloudClientParams.getCollection() == null || cloudClientParams.getCollection().isEmpty()) {
            throw new IllegalArgumentException("Collection name cannot be null or empty");
        }

        log.debug("Creating a new SolrCloudClient for zkhost {}", cloudClientParams.getZkHost());
        String zkHost = cloudClientParams.getZkHost();
        log.info("Creating a new SolrCloudClient for zkhost {}", zkHost);

        final List<String> zkUrlList = StreamSupport.stream(
                Splitter.on(",").split(zkHost).spliterator(),
                false
        ).toList();

        // Build HTTP params from available CloudClientParams
        HttpParams httpParams = new BasicHttpParams();

        // connection timeout (milliseconds)
        Integer connectionTimeout = cloudClientParams.getZkConnectTimeout();
        if (connectionTimeout != null) {
            httpParams.setIntParameter("http.connection.timeout", connectionTimeout);
        }

        // socket timeout (milliseconds)
        Integer socketTimeout = cloudClientParams.getHttpSocketTimeoutMillis();
        if (socketTimeout != null) {
            httpParams.setIntParameter("http.socket.timeout", socketTimeout);
        }

        PoolingClientConnectionManager connManager = new PoolingClientConnectionManager();
        connManager.setMaxTotal(cloudClientParams.getMaxConnTotal());
        connManager.setDefaultMaxPerRoute(cloudClientParams.getMaxConnPerRoute());
        HttpClient httpClient = new DefaultHttpClient(connManager, new DefaultedHttpParams(httpParams, null));

        CloudSolrClient solrClient = new CloudSolrClient.Builder(zkUrlList, Optional.empty())
                .withParallelUpdates(true)
                .withLBHttpSolrClientBuilder(
                    new LBHttpSolrClient.Builder()
                        .withHttpClient(httpClient)
                        .withConnectionTimeout(cloudClientParams.getHttpConnectTimeoutMillis())
                        .withSocketTimeout(cloudClientParams.getHttpSocketTimeoutMillis()))
                .build();

        solrClient.connect();
        log.debug("Created new SolrCloudClient for zkhost {}", zkHost);

        return solrClient;
    }

}
