package eu.dnetlib.data.mdstore.modular.mongodb;

import java.util.ArrayList;
import java.util.List;
import java.util.regex.Pattern;

import com.google.common.base.Function;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.mongodb.BasicDBObject;
import com.mongodb.DBObject;
import com.mongodb.QueryBuilder;
import com.mongodb.client.FindIterable;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.model.Filters;
import com.mongodb.client.model.Sorts;
import eu.dnetlib.enabling.resultset.ResultSet;
import eu.dnetlib.enabling.resultset.ResultSetAware;
import eu.dnetlib.enabling.resultset.ResultSetListener;
import eu.dnetlib.miscutils.maps.ConcurrentSizedMap;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.bson.conversions.Bson;

import static com.mongodb.client.model.Filters.*;

public class MongoResultSetListener implements ResultSetListener, ResultSetAware {

	private static final Log log = LogFactory.getLog(MongoResultSetListener.class);

	private ConcurrentSizedMap<Integer, String> lastKeys = new ConcurrentSizedMap<Integer, String>();
	private Bson sortByIdAsc = Sorts.orderBy(Sorts.ascending("id"));

	private Function<DBObject, String> serializer;
	private Pattern filter;
	private MongoCollection<DBObject> collection;

	public MongoResultSetListener(final MongoCollection<DBObject> collection, final Pattern filter, final Function<DBObject, String> serializer) {
		this.collection = collection;
		this.filter = filter;
		this.serializer = serializer;
	}

	@Override
	public List<String> getResult(final int fromPosition, final int toPosition) {

		ArrayList<DBObject> page = null;

		String lastKey = lastKeys.get(fromPosition);
		if (lastKey != null) {
			page = continueFrom(lastKey, (toPosition - fromPosition) + 1);
		} else {
			page = fetchNew(fromPosition - 1, (toPosition - fromPosition) + 1);
		}

		if (!page.isEmpty()) {
			DBObject last = page.get(page.size() - 1);
			lastKeys.put(toPosition + 1, (String) last.get("id"));
		}

		if (log.isDebugEnabled()) {
			log.info(String.format("got %s records from %s to %s", page.size(), fromPosition, toPosition));
		}

		return Lists.newArrayList(Iterables.transform(page, serializer));
	}

	private ArrayList<DBObject> fetchNew(final int from, final int size) {
		FindIterable<DBObject> it = null;
		if (filter != null) {
			Bson query = Filters.regex("body", filter);
			it = collection.find(query);
		} else
			it = collection.find();

		return Lists.newArrayList(it.sort(sortByIdAsc).skip(from).limit(size));
	}

	private ArrayList<DBObject> continueFrom(final String lastKey, final int size) {
		if (log.isDebugEnabled()) {
			log.debug("trying to continue from previous key: " + lastKey);
		}
		Bson filterQuery = gt("id", lastKey);
		if (filter != null) {
			filterQuery = and(filterQuery, regex("body", filter));
		}
		final FindIterable<DBObject> it = collection.find(filterQuery).sort(sortByIdAsc).limit(size);
		return Lists.newArrayList(it);
	}

	@Override
	public int getSize() {
		if (filter != null) {
			BasicDBObject query = (BasicDBObject) QueryBuilder.start("body").regex(filter).get();
			return (int) collection.count(query);
		}
		return (int) collection.count();
	}

	@Override
	public void setResultSet(final ResultSet resultSet) {
		resultSet.close();
	}

}
