/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.memoryoptsearch.faiss;

import java.io.IOException;
import java.util.concurrent.ThreadLocalRandom;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOSupplier;
import org.apache.lucene.util.Version;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.KNNVectorSimilarityFunction;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.memoryoptsearch.VectorSearcher;
import org.opensearch.knn.memoryoptsearch.faiss.FaissHNSW;
import org.opensearch.knn.memoryoptsearch.faiss.FaissHnswGraph;
import org.opensearch.knn.memoryoptsearch.faiss.FaissIdMapIndex;
import org.opensearch.knn.memoryoptsearch.faiss.FaissIndex;
import org.opensearch.knn.memoryoptsearch.faiss.FlatVectorsScorerProvider;
import org.opensearch.knn.memoryoptsearch.faiss.cagra.FaissCagraHNSW;

public class FaissMemoryOptimizedSearcher
implements VectorSearcher {
    private final IndexInput indexInput;
    private final FaissIndex faissIndex;
    private final FlatVectorsScorer flatVectorsScorer;
    private final FaissHNSW hnsw;
    private final VectorSimilarityFunction vectorSimilarityFunction;
    private final long fileSize;
    private boolean isAdc;

    public FaissMemoryOptimizedSearcher(IndexInput indexInput, FieldInfo fieldInfo) throws IOException {
        this.indexInput = indexInput;
        this.fileSize = indexInput.length();
        this.faissIndex = FaissIndex.load(indexInput);
        KNNVectorSimilarityFunction knnVectorSimilarityFunction = this.faissIndex.getVectorSimilarityFunction();
        this.vectorSimilarityFunction = knnVectorSimilarityFunction != KNNVectorSimilarityFunction.HAMMING ? knnVectorSimilarityFunction.getVectorSimilarityFunction() : null;
        this.isAdc = false;
        SpaceType spaceType = null;
        if (fieldInfo != null) {
            QuantizationConfig quantizationConfig = FieldInfoExtractor.extractQuantizationConfig(fieldInfo, Version.LATEST);
            this.isAdc = quantizationConfig.isEnableADC();
            spaceType = this.isAdc ? SpaceType.getSpace(fieldInfo.getAttribute("spaceType")) : null;
        }
        this.flatVectorsScorer = FlatVectorsScorerProvider.getFlatVectorsScorer(knnVectorSimilarityFunction, this.isAdc, spaceType);
        this.hnsw = FaissMemoryOptimizedSearcher.extractFaissHnsw(this.faissIndex);
    }

    private static FaissHNSW extractFaissHnsw(FaissIndex faissIndex) {
        if (faissIndex instanceof FaissIdMapIndex) {
            FaissIdMapIndex idMapIndex = (FaissIdMapIndex)faissIndex;
            return idMapIndex.getFaissHnsw();
        }
        throw new IllegalArgumentException("Faiss index [" + faissIndex.getIndexType() + "] does not have HNSW as an index.");
    }

    @Override
    public void search(float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
        this.search(VectorEncoding.FLOAT32, (IOSupplier<RandomVectorScorer>)((IOSupplier)() -> this.flatVectorsScorer.getRandomVectorScorer(this.vectorSimilarityFunction, (KnnVectorValues)(this.isAdc ? this.faissIndex.getByteValues(this.getSlicedIndexInput()) : this.faissIndex.getFloatValues(this.getSlicedIndexInput())), target)), knnCollector, acceptDocs);
    }

    @Override
    public void search(byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
        this.search(VectorEncoding.BYTE, (IOSupplier<RandomVectorScorer>)((IOSupplier)() -> this.flatVectorsScorer.getRandomVectorScorer(this.vectorSimilarityFunction, (KnnVectorValues)this.faissIndex.getByteValues(this.getSlicedIndexInput()), target)), knnCollector, acceptDocs);
    }

    @Override
    public void close() throws IOException {
        this.indexInput.close();
    }

    private void search(VectorEncoding vectorEncoding, IOSupplier<RandomVectorScorer> scorerSupplier, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
        if (this.faissIndex.getTotalNumberOfVectors() == 0 || knnCollector.k() == 0) {
            return;
        }
        if (!this.isAdc && this.faissIndex.getVectorEncoding() != vectorEncoding) {
            throw new IllegalArgumentException("Search for vector encoding [" + String.valueOf(vectorEncoding) + "] is not supported in an index vector whose encoding is [" + String.valueOf(this.faissIndex.getVectorEncoding()) + "]");
        }
        RandomVectorScorer scorer = (RandomVectorScorer)scorerSupplier.get();
        KnnCollector collector = this.createKnnCollector(knnCollector, scorer);
        Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs);
        if (knnCollector.k() < scorer.maxOrd()) {
            HnswGraphSearcher.search((RandomVectorScorer)scorer, (KnnCollector)collector, (HnswGraph)new FaissHnswGraph(this.hnsw, this.getSlicedIndexInput()), (Bits)acceptedOrds);
        } else {
            for (int i = 0; i < scorer.maxOrd(); ++i) {
                if (acceptedOrds != null && !acceptedOrds.get(i)) continue;
                if (knnCollector.earlyTerminated()) break;
                knnCollector.incVisitedCount(1);
                knnCollector.collect(scorer.ordToDoc(i), scorer.score(i));
            }
        }
    }

    private IndexInput getSlicedIndexInput() throws IOException {
        return this.indexInput.slice("FaissMemoryOptimizedSearcher", 0L, this.fileSize);
    }

    private KnnCollector createKnnCollector(final KnnCollector knnCollector, RandomVectorScorer scorer) {
        OrdinalTranslatedKnnCollector ordinalTranslatedKnnCollector = new OrdinalTranslatedKnnCollector(knnCollector, arg_0 -> ((RandomVectorScorer)scorer).ordToDoc(arg_0));
        FaissHNSW faissHNSW = this.hnsw;
        if (faissHNSW instanceof FaissCagraHNSW) {
            final FaissCagraHNSW cagraHNSW = (FaissCagraHNSW)faissHNSW;
            return new KnnCollector.Decorator(this, (KnnCollector)ordinalTranslatedKnnCollector){

                public KnnSearchStrategy getSearchStrategy() {
                    return new RandomEntryPointsKnnSearchStrategy(cagraHNSW.getNumBaseLevelSearchEntryPoints(), cagraHNSW.getTotalNumberOfVectors(), knnCollector.getSearchStrategy());
                }
            };
        }
        return ordinalTranslatedKnnCollector;
    }

    static class RandomEntryPointsKnnSearchStrategy
    extends KnnSearchStrategy.Seeded {
        public RandomEntryPointsKnnSearchStrategy(int numberOfEntryPoints, long totalNumberOfVectors, KnnSearchStrategy originalStrategy) {
            super(RandomEntryPointsKnnSearchStrategy.generateRandomEntryPoints(numberOfEntryPoints, Math.toIntExact(totalNumberOfVectors)), numberOfEntryPoints, originalStrategy);
        }

        private static DocIdSetIterator generateRandomEntryPoints(final int numberOfEntryPoints, final int totalNumberOfVectors) {
            return new DocIdSetIterator(){
                int numPopulatedVectors = 0;

                public int docID() {
                    throw new UnsupportedOperationException("DISI in RandomEntryPointsKnnSearchStrategy does not support docID()");
                }

                public int nextDoc() {
                    if (this.numPopulatedVectors < numberOfEntryPoints) {
                        ++this.numPopulatedVectors;
                        return ThreadLocalRandom.current().nextInt(totalNumberOfVectors);
                    }
                    return Integer.MAX_VALUE;
                }

                public int advance(int targetDoc) {
                    throw new UnsupportedOperationException("DISI in RandomEntryPointsKnnSearchStrategy does not support advance(int)");
                }

                public long cost() {
                    throw new UnsupportedOperationException("DISI in RandomEntryPointsKnnSearchStrategy does not support cost()");
                }
            };
        }
    }
}

